diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index e54e059..1b9b35d 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -591,7 +591,11 @@ def evaluate_call( custom_tools: Dict[str, Callable], authorized_imports: List[str], ) -> Any: - if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): + if not ( + isinstance(call.func, ast.Attribute) + or isinstance(call.func, ast.Name) + or isinstance(call.func, ast.Subscript) + ): raise InterpreterError(f"This is not a correct function: {call.func}).") if isinstance(call.func, ast.Attribute): obj = evaluate_ast( @@ -617,6 +621,23 @@ def evaluate_call( f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})." ) + elif isinstance(call.func, ast.Subscript): + value = evaluate_ast( + call.func.value, state, static_tools, custom_tools, authorized_imports + ) + index = evaluate_ast( + call.func.slice, state, static_tools, custom_tools, authorized_imports + ) + if isinstance(value, (list, tuple)): + func = value[index] + else: + raise InterpreterError( + f"Cannot subscript object of type {type(value).__name__}" + ) + + if not callable(func): + raise InterpreterError(f"This is not a correct function: {call.func}).") + func_name = None args = [] for arg in call.args: if isinstance(arg, ast.Starred): @@ -726,6 +747,8 @@ def evaluate_name( return state[name.id] elif name.id in static_tools: return static_tools[name.id] + elif name.id in custom_tools: + return custom_tools[name.id] elif name.id in ERRORS: return ERRORS[name.id] close_matches = difflib.get_close_matches(name.id, list(state.keys())) diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 58f250c..6e5b3c2 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -60,6 +60,14 @@ class PythonInterpreterTester(unittest.TestCase): in str(e) ) + def test_subscript_call(self): + code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)""" + state = {} + result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) + assert result == 64 + assert state["result_foo"] == 8 + assert state["result_boo"] == 64 + def test_evaluate_call(self): code = "y = add_two(x)" state = {"x": 3}