From d1b8a7878319cca019a0a8a862dc62f706d55a27 Mon Sep 17 00:00:00 2001 From: Ilya Gusev Date: Fri, 17 Jan 2025 19:16:47 +0100 Subject: [PATCH] Max length of "print" outputs as a parameter of an agent (#209) * Print outputs max length as a parameter --- src/smolagents/agents.py | 2 ++ src/smolagents/local_python_executor.py | 29 +++++++++++++------------ src/smolagents/utils.py | 4 ++-- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 20c67b4..da791aa 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -891,6 +891,7 @@ class CodeAgent(MultiStepAgent): additional_authorized_imports: Optional[List[str]] = None, planning_interval: Optional[int] = None, use_e2b_executor: bool = False, + max_print_outputs_length: Optional[int] = None, **kwargs, ): if system_prompt is None: @@ -937,6 +938,7 @@ class CodeAgent(MultiStepAgent): self.python_executor = LocalPythonInterpreter( self.additional_authorized_imports, all_tools, + max_print_outputs_length=max_print_outputs_length, ) def initialize_system_prompt(self): diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index c06bacf..0c7b5bc 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -46,7 +46,7 @@ ERRORS = { and issubclass(getattr(builtins, name), BaseException) } -PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000 +PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000 OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000 @@ -1435,15 +1435,6 @@ def evaluate_ast( raise InterpreterError(f"{expression.__class__.__name__} is not supported.") -def truncate_print_outputs( - print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT -) -> str: - if len(print_outputs) < max_len_outputs: - return print_outputs - else: - return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n" - - class FinalAnswerException(Exception): def __init__(self, value): self.value = value @@ -1455,6 +1446,7 @@ def evaluate_python_code( custom_tools: Optional[Dict[str, Callable]] = None, state: Optional[Dict[str, Any]] = None, authorized_imports: List[str] = BASE_BUILTIN_MODULES, + max_print_outputs_length: int = DEFAULT_MAX_LEN_OUTPUT, ): """ Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set @@ -1500,26 +1492,34 @@ def evaluate_python_code( node, state, static_tools, custom_tools, authorized_imports ) state["print_outputs"] = truncate_content( - PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT + PRINT_OUTPUTS, max_length=max_print_outputs_length ) is_final_answer = False return result, is_final_answer except FinalAnswerException as e: state["print_outputs"] = truncate_content( - PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT + PRINT_OUTPUTS, max_length=max_print_outputs_length ) is_final_answer = True return e.value, is_final_answer except InterpreterError as e: - msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) + msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length) msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" raise InterpreterError(msg) class LocalPythonInterpreter: - def __init__(self, additional_authorized_imports: List[str], tools: Dict): + def __init__( + self, + additional_authorized_imports: List[str], + tools: Dict, + max_print_outputs_length: Optional[int] = None, + ): self.custom_tools = {} self.state = {} + self.max_print_outputs_length = max_print_outputs_length + if max_print_outputs_length is None: + self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT self.additional_authorized_imports = additional_authorized_imports self.authorized_imports = list( set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) @@ -1541,6 +1541,7 @@ class LocalPythonInterpreter: custom_tools=self.custom_tools, state=self.state, authorized_imports=self.authorized_imports, + max_print_outputs_length=self.max_print_outputs_length, ) logs = self.state["print_outputs"] return output, logs, is_final_answer diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index 9aeba8a..ac4565f 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -169,9 +169,9 @@ def truncate_content( return content else: return ( - content[: MAX_LENGTH_TRUNCATE_CONTENT // 2] + content[: max_length // 2] + f"\n..._This content has been truncated to stay below {max_length} characters_...\n" - + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] + + content[-max_length // 2 :] )