Max length of "print" outputs as a parameter of an agent (#209)

* Print outputs max length as a parameter
This commit is contained in:
Ilya Gusev 2025-01-17 19:16:47 +01:00 committed by GitHub
parent 116b12e93a
commit d1b8a78783
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 16 deletions

View File

@ -891,6 +891,7 @@ class CodeAgent(MultiStepAgent):
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None, planning_interval: Optional[int] = None,
use_e2b_executor: bool = False, use_e2b_executor: bool = False,
max_print_outputs_length: Optional[int] = None,
**kwargs, **kwargs,
): ):
if system_prompt is None: if system_prompt is None:
@ -937,6 +938,7 @@ class CodeAgent(MultiStepAgent):
self.python_executor = LocalPythonInterpreter( self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports, self.additional_authorized_imports,
all_tools, all_tools,
max_print_outputs_length=max_print_outputs_length,
) )
def initialize_system_prompt(self): def initialize_system_prompt(self):

View File

@ -46,7 +46,7 @@ ERRORS = {
and issubclass(getattr(builtins, name), BaseException) and issubclass(getattr(builtins, name), BaseException)
} }
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000 PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000 OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
@ -1435,15 +1435,6 @@ def evaluate_ast(
raise InterpreterError(f"{expression.__class__.__name__} is not supported.") raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
def truncate_print_outputs(
print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT
) -> str:
if len(print_outputs) < max_len_outputs:
return print_outputs
else:
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
class FinalAnswerException(Exception): class FinalAnswerException(Exception):
def __init__(self, value): def __init__(self, value):
self.value = value self.value = value
@ -1455,6 +1446,7 @@ def evaluate_python_code(
custom_tools: Optional[Dict[str, Callable]] = None, custom_tools: Optional[Dict[str, Callable]] = None,
state: Optional[Dict[str, Any]] = None, state: Optional[Dict[str, Any]] = None,
authorized_imports: List[str] = BASE_BUILTIN_MODULES, authorized_imports: List[str] = BASE_BUILTIN_MODULES,
max_print_outputs_length: int = DEFAULT_MAX_LEN_OUTPUT,
): ):
""" """
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
@ -1500,26 +1492,34 @@ def evaluate_python_code(
node, state, static_tools, custom_tools, authorized_imports node, state, static_tools, custom_tools, authorized_imports
) )
state["print_outputs"] = truncate_content( state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT PRINT_OUTPUTS, max_length=max_print_outputs_length
) )
is_final_answer = False is_final_answer = False
return result, is_final_answer return result, is_final_answer
except FinalAnswerException as e: except FinalAnswerException as e:
state["print_outputs"] = truncate_content( state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT PRINT_OUTPUTS, max_length=max_print_outputs_length
) )
is_final_answer = True is_final_answer = True
return e.value, is_final_answer return e.value, is_final_answer
except InterpreterError as e: except InterpreterError as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT) msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg) raise InterpreterError(msg)
class LocalPythonInterpreter: class LocalPythonInterpreter:
def __init__(self, additional_authorized_imports: List[str], tools: Dict): def __init__(
self,
additional_authorized_imports: List[str],
tools: Dict,
max_print_outputs_length: Optional[int] = None,
):
self.custom_tools = {} self.custom_tools = {}
self.state = {} self.state = {}
self.max_print_outputs_length = max_print_outputs_length
if max_print_outputs_length is None:
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
self.additional_authorized_imports = additional_authorized_imports self.additional_authorized_imports = additional_authorized_imports
self.authorized_imports = list( self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
@ -1541,6 +1541,7 @@ class LocalPythonInterpreter:
custom_tools=self.custom_tools, custom_tools=self.custom_tools,
state=self.state, state=self.state,
authorized_imports=self.authorized_imports, authorized_imports=self.authorized_imports,
max_print_outputs_length=self.max_print_outputs_length,
) )
logs = self.state["print_outputs"] logs = self.state["print_outputs"]
return output, logs, is_final_answer return output, logs, is_final_answer

View File

@ -169,9 +169,9 @@ def truncate_content(
return content return content
else: else:
return ( return (
content[: MAX_LENGTH_TRUNCATE_CONTENT // 2] content[: max_length // 2]
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n" + f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] + content[-max_length // 2 :]
) )