Max length of "print" outputs as a parameter of an agent (#209)
* Print outputs max length as a parameter
This commit is contained in:
parent
116b12e93a
commit
d1b8a78783
|
@ -891,6 +891,7 @@ class CodeAgent(MultiStepAgent):
|
|||
additional_authorized_imports: Optional[List[str]] = None,
|
||||
planning_interval: Optional[int] = None,
|
||||
use_e2b_executor: bool = False,
|
||||
max_print_outputs_length: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if system_prompt is None:
|
||||
|
@ -937,6 +938,7 @@ class CodeAgent(MultiStepAgent):
|
|||
self.python_executor = LocalPythonInterpreter(
|
||||
self.additional_authorized_imports,
|
||||
all_tools,
|
||||
max_print_outputs_length=max_print_outputs_length,
|
||||
)
|
||||
|
||||
def initialize_system_prompt(self):
|
||||
|
|
|
@ -46,7 +46,7 @@ ERRORS = {
|
|||
and issubclass(getattr(builtins, name), BaseException)
|
||||
}
|
||||
|
||||
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
||||
PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000
|
||||
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
||||
|
||||
|
||||
|
@ -1435,15 +1435,6 @@ def evaluate_ast(
|
|||
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def truncate_print_outputs(
|
||||
print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT
|
||||
) -> str:
|
||||
if len(print_outputs) < max_len_outputs:
|
||||
return print_outputs
|
||||
else:
|
||||
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
|
||||
|
||||
|
||||
class FinalAnswerException(Exception):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
@ -1455,6 +1446,7 @@ def evaluate_python_code(
|
|||
custom_tools: Optional[Dict[str, Callable]] = None,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
|
||||
max_print_outputs_length: int = DEFAULT_MAX_LEN_OUTPUT,
|
||||
):
|
||||
"""
|
||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||
|
@ -1500,26 +1492,34 @@ def evaluate_python_code(
|
|||
node, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
state["print_outputs"] = truncate_content(
|
||||
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
|
||||
PRINT_OUTPUTS, max_length=max_print_outputs_length
|
||||
)
|
||||
is_final_answer = False
|
||||
return result, is_final_answer
|
||||
except FinalAnswerException as e:
|
||||
state["print_outputs"] = truncate_content(
|
||||
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
|
||||
PRINT_OUTPUTS, max_length=max_print_outputs_length
|
||||
)
|
||||
is_final_answer = True
|
||||
return e.value, is_final_answer
|
||||
except InterpreterError as e:
|
||||
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
|
||||
msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
|
||||
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||
raise InterpreterError(msg)
|
||||
|
||||
|
||||
class LocalPythonInterpreter:
|
||||
def __init__(self, additional_authorized_imports: List[str], tools: Dict):
|
||||
def __init__(
|
||||
self,
|
||||
additional_authorized_imports: List[str],
|
||||
tools: Dict,
|
||||
max_print_outputs_length: Optional[int] = None,
|
||||
):
|
||||
self.custom_tools = {}
|
||||
self.state = {}
|
||||
self.max_print_outputs_length = max_print_outputs_length
|
||||
if max_print_outputs_length is None:
|
||||
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
|
||||
self.additional_authorized_imports = additional_authorized_imports
|
||||
self.authorized_imports = list(
|
||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||
|
@ -1541,6 +1541,7 @@ class LocalPythonInterpreter:
|
|||
custom_tools=self.custom_tools,
|
||||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
max_print_outputs_length=self.max_print_outputs_length,
|
||||
)
|
||||
logs = self.state["print_outputs"]
|
||||
return output, logs, is_final_answer
|
||||
|
|
|
@ -169,9 +169,9 @@ def truncate_content(
|
|||
return content
|
||||
else:
|
||||
return (
|
||||
content[: MAX_LENGTH_TRUNCATE_CONTENT // 2]
|
||||
content[: max_length // 2]
|
||||
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
|
||||
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
|
||||
+ content[-max_length // 2 :]
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue