call.func parameter (#194)
This commit is contained in:
parent
450934ce79
commit
a22c221fa7
|
@ -591,7 +591,11 @@ def evaluate_call(
|
|||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str],
|
||||
) -> Any:
|
||||
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
||||
if not (
|
||||
isinstance(call.func, ast.Attribute)
|
||||
or isinstance(call.func, ast.Name)
|
||||
or isinstance(call.func, ast.Subscript)
|
||||
):
|
||||
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||
if isinstance(call.func, ast.Attribute):
|
||||
obj = evaluate_ast(
|
||||
|
@ -617,6 +621,23 @@ def evaluate_call(
|
|||
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
|
||||
)
|
||||
|
||||
elif isinstance(call.func, ast.Subscript):
|
||||
value = evaluate_ast(
|
||||
call.func.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
index = evaluate_ast(
|
||||
call.func.slice, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
if isinstance(value, (list, tuple)):
|
||||
func = value[index]
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"Cannot subscript object of type {type(value).__name__}"
|
||||
)
|
||||
|
||||
if not callable(func):
|
||||
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||
func_name = None
|
||||
args = []
|
||||
for arg in call.args:
|
||||
if isinstance(arg, ast.Starred):
|
||||
|
@ -726,6 +747,8 @@ def evaluate_name(
|
|||
return state[name.id]
|
||||
elif name.id in static_tools:
|
||||
return static_tools[name.id]
|
||||
elif name.id in custom_tools:
|
||||
return custom_tools[name.id]
|
||||
elif name.id in ERRORS:
|
||||
return ERRORS[name.id]
|
||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||
|
|
|
@ -60,6 +60,14 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||
in str(e)
|
||||
)
|
||||
|
||||
def test_subscript_call(self):
|
||||
code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
|
||||
state = {}
|
||||
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
assert result == 64
|
||||
assert state["result_foo"] == 8
|
||||
assert state["result_boo"] == 64
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
|
|
Loading…
Reference in New Issue