call.func parameter (#194)

This commit is contained in:
NeverLucky 2025-01-15 15:58:52 +03:00 committed by GitHub
parent 450934ce79
commit a22c221fa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 1 deletions

View File

@ -591,7 +591,11 @@ def evaluate_call(
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str], authorized_imports: List[str],
) -> Any: ) -> Any:
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): if not (
isinstance(call.func, ast.Attribute)
or isinstance(call.func, ast.Name)
or isinstance(call.func, ast.Subscript)
):
raise InterpreterError(f"This is not a correct function: {call.func}).") raise InterpreterError(f"This is not a correct function: {call.func}).")
if isinstance(call.func, ast.Attribute): if isinstance(call.func, ast.Attribute):
obj = evaluate_ast( obj = evaluate_ast(
@ -617,6 +621,23 @@ def evaluate_call(
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})." f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
) )
elif isinstance(call.func, ast.Subscript):
value = evaluate_ast(
call.func.value, state, static_tools, custom_tools, authorized_imports
)
index = evaluate_ast(
call.func.slice, state, static_tools, custom_tools, authorized_imports
)
if isinstance(value, (list, tuple)):
func = value[index]
else:
raise InterpreterError(
f"Cannot subscript object of type {type(value).__name__}"
)
if not callable(func):
raise InterpreterError(f"This is not a correct function: {call.func}).")
func_name = None
args = [] args = []
for arg in call.args: for arg in call.args:
if isinstance(arg, ast.Starred): if isinstance(arg, ast.Starred):
@ -726,6 +747,8 @@ def evaluate_name(
return state[name.id] return state[name.id]
elif name.id in static_tools: elif name.id in static_tools:
return static_tools[name.id] return static_tools[name.id]
elif name.id in custom_tools:
return custom_tools[name.id]
elif name.id in ERRORS: elif name.id in ERRORS:
return ERRORS[name.id] return ERRORS[name.id]
close_matches = difflib.get_close_matches(name.id, list(state.keys())) close_matches = difflib.get_close_matches(name.id, list(state.keys()))

View File

@ -60,6 +60,14 @@ class PythonInterpreterTester(unittest.TestCase):
in str(e) in str(e)
) )
def test_subscript_call(self):
code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
state = {}
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
assert result == 64
assert state["result_foo"] == 8
assert state["result_boo"] == 64
def test_evaluate_call(self): def test_evaluate_call(self):
code = "y = add_two(x)" code = "y = add_two(x)"
state = {"x": 3} state = {"x": 3}