call.func parameter (#194)
This commit is contained in:
parent
450934ce79
commit
a22c221fa7
|
@ -591,7 +591,11 @@ def evaluate_call(
|
||||||
custom_tools: Dict[str, Callable],
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str],
|
authorized_imports: List[str],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
if not (
|
||||||
|
isinstance(call.func, ast.Attribute)
|
||||||
|
or isinstance(call.func, ast.Name)
|
||||||
|
or isinstance(call.func, ast.Subscript)
|
||||||
|
):
|
||||||
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||||
if isinstance(call.func, ast.Attribute):
|
if isinstance(call.func, ast.Attribute):
|
||||||
obj = evaluate_ast(
|
obj = evaluate_ast(
|
||||||
|
@ -617,6 +621,23 @@ def evaluate_call(
|
||||||
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
|
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif isinstance(call.func, ast.Subscript):
|
||||||
|
value = evaluate_ast(
|
||||||
|
call.func.value, state, static_tools, custom_tools, authorized_imports
|
||||||
|
)
|
||||||
|
index = evaluate_ast(
|
||||||
|
call.func.slice, state, static_tools, custom_tools, authorized_imports
|
||||||
|
)
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
func = value[index]
|
||||||
|
else:
|
||||||
|
raise InterpreterError(
|
||||||
|
f"Cannot subscript object of type {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not callable(func):
|
||||||
|
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||||
|
func_name = None
|
||||||
args = []
|
args = []
|
||||||
for arg in call.args:
|
for arg in call.args:
|
||||||
if isinstance(arg, ast.Starred):
|
if isinstance(arg, ast.Starred):
|
||||||
|
@ -726,6 +747,8 @@ def evaluate_name(
|
||||||
return state[name.id]
|
return state[name.id]
|
||||||
elif name.id in static_tools:
|
elif name.id in static_tools:
|
||||||
return static_tools[name.id]
|
return static_tools[name.id]
|
||||||
|
elif name.id in custom_tools:
|
||||||
|
return custom_tools[name.id]
|
||||||
elif name.id in ERRORS:
|
elif name.id in ERRORS:
|
||||||
return ERRORS[name.id]
|
return ERRORS[name.id]
|
||||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||||
|
|
|
@ -60,6 +60,14 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||||
in str(e)
|
in str(e)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_subscript_call(self):
|
||||||
|
code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
|
||||||
|
state = {}
|
||||||
|
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||||
|
assert result == 64
|
||||||
|
assert state["result_foo"] == 8
|
||||||
|
assert state["result_boo"] == 64
|
||||||
|
|
||||||
def test_evaluate_call(self):
|
def test_evaluate_call(self):
|
||||||
code = "y = add_two(x)"
|
code = "y = add_two(x)"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
|
|
Loading…
Reference in New Issue