Max length of "print" outputs as a parameter of an agent (#209)
* Print outputs max length as a parameter
This commit is contained in:
parent
116b12e93a
commit
d1b8a78783
|
@ -891,6 +891,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
use_e2b_executor: bool = False,
|
use_e2b_executor: bool = False,
|
||||||
|
max_print_outputs_length: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if system_prompt is None:
|
if system_prompt is None:
|
||||||
|
@ -937,6 +938,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
self.python_executor = LocalPythonInterpreter(
|
self.python_executor = LocalPythonInterpreter(
|
||||||
self.additional_authorized_imports,
|
self.additional_authorized_imports,
|
||||||
all_tools,
|
all_tools,
|
||||||
|
max_print_outputs_length=max_print_outputs_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
def initialize_system_prompt(self):
|
def initialize_system_prompt(self):
|
||||||
|
|
|
@ -46,7 +46,7 @@ ERRORS = {
|
||||||
and issubclass(getattr(builtins, name), BaseException)
|
and issubclass(getattr(builtins, name), BaseException)
|
||||||
}
|
}
|
||||||
|
|
||||||
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000
|
||||||
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
||||||
|
|
||||||
|
|
||||||
|
@ -1435,15 +1435,6 @@ def evaluate_ast(
|
||||||
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
def truncate_print_outputs(
|
|
||||||
print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT
|
|
||||||
) -> str:
|
|
||||||
if len(print_outputs) < max_len_outputs:
|
|
||||||
return print_outputs
|
|
||||||
else:
|
|
||||||
return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
|
|
||||||
|
|
||||||
|
|
||||||
class FinalAnswerException(Exception):
|
class FinalAnswerException(Exception):
|
||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
@ -1455,6 +1446,7 @@ def evaluate_python_code(
|
||||||
custom_tools: Optional[Dict[str, Callable]] = None,
|
custom_tools: Optional[Dict[str, Callable]] = None,
|
||||||
state: Optional[Dict[str, Any]] = None,
|
state: Optional[Dict[str, Any]] = None,
|
||||||
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
|
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
|
||||||
|
max_print_outputs_length: int = DEFAULT_MAX_LEN_OUTPUT,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||||
|
@ -1500,26 +1492,34 @@ def evaluate_python_code(
|
||||||
node, state, static_tools, custom_tools, authorized_imports
|
node, state, static_tools, custom_tools, authorized_imports
|
||||||
)
|
)
|
||||||
state["print_outputs"] = truncate_content(
|
state["print_outputs"] = truncate_content(
|
||||||
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
|
PRINT_OUTPUTS, max_length=max_print_outputs_length
|
||||||
)
|
)
|
||||||
is_final_answer = False
|
is_final_answer = False
|
||||||
return result, is_final_answer
|
return result, is_final_answer
|
||||||
except FinalAnswerException as e:
|
except FinalAnswerException as e:
|
||||||
state["print_outputs"] = truncate_content(
|
state["print_outputs"] = truncate_content(
|
||||||
PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
|
PRINT_OUTPUTS, max_length=max_print_outputs_length
|
||||||
)
|
)
|
||||||
is_final_answer = True
|
is_final_answer = True
|
||||||
return e.value, is_final_answer
|
return e.value, is_final_answer
|
||||||
except InterpreterError as e:
|
except InterpreterError as e:
|
||||||
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
|
msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
|
||||||
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||||
raise InterpreterError(msg)
|
raise InterpreterError(msg)
|
||||||
|
|
||||||
|
|
||||||
class LocalPythonInterpreter:
|
class LocalPythonInterpreter:
|
||||||
def __init__(self, additional_authorized_imports: List[str], tools: Dict):
|
def __init__(
|
||||||
|
self,
|
||||||
|
additional_authorized_imports: List[str],
|
||||||
|
tools: Dict,
|
||||||
|
max_print_outputs_length: Optional[int] = None,
|
||||||
|
):
|
||||||
self.custom_tools = {}
|
self.custom_tools = {}
|
||||||
self.state = {}
|
self.state = {}
|
||||||
|
self.max_print_outputs_length = max_print_outputs_length
|
||||||
|
if max_print_outputs_length is None:
|
||||||
|
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
|
||||||
self.additional_authorized_imports = additional_authorized_imports
|
self.additional_authorized_imports = additional_authorized_imports
|
||||||
self.authorized_imports = list(
|
self.authorized_imports = list(
|
||||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||||
|
@ -1541,6 +1541,7 @@ class LocalPythonInterpreter:
|
||||||
custom_tools=self.custom_tools,
|
custom_tools=self.custom_tools,
|
||||||
state=self.state,
|
state=self.state,
|
||||||
authorized_imports=self.authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
|
max_print_outputs_length=self.max_print_outputs_length,
|
||||||
)
|
)
|
||||||
logs = self.state["print_outputs"]
|
logs = self.state["print_outputs"]
|
||||||
return output, logs, is_final_answer
|
return output, logs, is_final_answer
|
||||||
|
|
|
@ -169,9 +169,9 @@ def truncate_content(
|
||||||
return content
|
return content
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
content[: MAX_LENGTH_TRUNCATE_CONTENT // 2]
|
content[: max_length // 2]
|
||||||
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
|
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
|
||||||
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
|
+ content[-max_length // 2 :]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue