From 33b38e6cb7a619d229ec5ef73aaa3caa896e021d Mon Sep 17 00:00:00 2001 From: Parteek Date: Tue, 28 Jan 2025 15:11:32 +0530 Subject: [PATCH] fix final_answer issue in e2b_executor (#319) --- src/smolagents/e2b_executor.py | 9 +++++++-- src/smolagents/tool_validation.py | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/smolagents/e2b_executor.py b/src/smolagents/e2b_executor.py index 404a8e2..5e007b5 100644 --- a/src/smolagents/e2b_executor.py +++ b/src/smolagents/e2b_executor.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re import base64 import pickle import textwrap @@ -45,6 +46,8 @@ class E2BExecutor: ) self.custom_tools = {} + self.final_answer = False + self.final_answer_pattern = re.compile(r'^final_answer\((.*)\)$') self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j") # TODO: validate installing agents package or not # print("Installing agents package on remote executor...") @@ -85,6 +88,8 @@ class E2BExecutor: self.logger.log(tool_definition_execution.logs) def run_code_raise_errors(self, code: str): + if self.final_answer_pattern.match(code): + self.final_answer = True execution = self.sbx.run_code( code, ) @@ -122,7 +127,7 @@ locals().update({key: value for key, value in pickle_dict.items()}) execution = self.run_code_raise_errors(code_action) execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) if not execution.results: - return None, execution_logs + return None, execution_logs, self.final_answer else: for result in execution.results: if result.is_main_result: @@ -144,7 +149,7 @@ locals().update({key: value for key, value in pickle_dict.items()}) "text", ]: if getattr(result, attribute_name) is not None: - return getattr(result, attribute_name), execution_logs + return getattr(result, attribute_name), execution_logs, self.final_answer raise ValueError("No main result returned by executor!") diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index 9ac157c..d8e6daa 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -25,6 +25,9 @@ class MethodChecker(ast.NodeVisitor): self.class_attributes = class_attributes self.errors = [] self.check_imports = check_imports + self.typing_names = { + 'Any' + } def visit_arguments(self, node): """Collect function arguments""" @@ -97,6 +100,7 @@ class MethodChecker(ast.NodeVisitor): or node.id in self.imports or node.id in self.from_imports or node.id in self.assigned_names + or node.id in self.typing_names ): self.errors.append(f"Name '{node.id}' is undefined.")