Improve python executor's error logging (#275)

* Improve python executor's error logging
This commit is contained in:
Aymeric Roucher 2025-01-20 15:57:16 +01:00 committed by GitHub
parent 3c18d4d588
commit 7a91123729
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 86 additions and 37 deletions

View File

@ -972,16 +972,8 @@ class CodeAgent(MultiStepAgent):
]
observation += "Execution logs:\n" + execution_logs
except Exception as e:
if isinstance(e, SyntaxError):
error_msg = (
f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n"
f"{e.text}"
f"{' ' * (e.offset or 0)}^\n"
f"Error: {str(e)}"
)
else:
error_msg = str(e)
if "Import of " in str(e) and " is not allowed" in str(e):
error_msg = str(e)
if "Import of " in error_msg and " is not allowed" in error_msg:
self.logger.log(
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
level=LogLevel.INFO,

View File

@ -554,7 +554,7 @@ def evaluate_call(
func = ERRORS[func_name]
else:
raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
f"It is not permitted to evaluate other functions than the provided tools or functions defined/imported in previous code (tried to execute {call.func.id})."
)
elif isinstance(call.func, ast.Subscript):
@ -1245,7 +1245,16 @@ def evaluate_python_code(
updated by this function to contain all variables as they are evaluated.
The print outputs will be stored in the state under the key 'print_outputs'.
"""
expression = ast.parse(code)
try:
expression = ast.parse(code)
except SyntaxError as e:
raise InterpreterError(
f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n"
f"{e.text}"
f"{' ' * (e.offset or 0)}^\n"
f"Error: {str(e)}"
)
if state is None:
state = {}
if static_tools is None:
@ -1273,10 +1282,13 @@ def evaluate_python_code(
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
is_final_answer = True
return e.value, is_final_answer
except InterpreterError as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg)
except Exception as e:
exception_type = type(e).__name__
error_msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
error_msg = (
f"Code execution failed at line '{ast.get_source_segment(code, node)}' due to: {exception_type}:{str(e)}"
)
raise InterpreterError(error_msg)
class LocalPythonInterpreter:

View File

@ -168,7 +168,9 @@ class Tool:
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)
json_schema = _convert_type_hints_to_json_schema(self.forward)
json_schema = _convert_type_hints_to_json_schema(
self.forward
) # This function will raise an error on missing docstrings, contrary to get_json_schema
for key, value in self.inputs.items():
if "nullable" in value:
assert key in json_schema and "nullable" in json_schema[key], (
@ -885,6 +887,16 @@ class ToolCollection:
yield cls(tools)
def get_tool_json_schema(tool_function):
tool_json_schema = get_json_schema(tool_function)["function"]
tool_parameters = tool_json_schema["parameters"]
inputs_schema = tool_parameters["properties"]
for input_name in inputs_schema:
if "required" not in tool_parameters or input_name not in tool_parameters["required"]:
inputs_schema[input_name]["nullable"] = True
return tool_json_schema
def tool(tool_function: Callable) -> Tool:
"""
Converts a function into an instance of a Tool subclass.
@ -893,12 +905,19 @@ def tool(tool_function: Callable) -> Tool:
tool_function: Your function. Should have type hints for each input and a type hint for the output.
Should also have a docstring description including an 'Args:' part where each argument is described.
"""
parameters = get_json_schema(tool_function)["function"]
if "return" not in parameters:
tool_json_schema = get_tool_json_schema(tool_function)
if "return" not in tool_json_schema:
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
class SimpleTool(Tool):
def __init__(self, name, description, inputs, output_type, function):
def __init__(
self,
name: str,
description: str,
inputs: Dict[str, Dict[str, str]],
output_type: str,
function: Callable,
):
self.name = name
self.description = description
self.inputs = inputs
@ -907,10 +926,10 @@ def tool(tool_function: Callable) -> Tool:
self.is_initialized = True
simple_tool = SimpleTool(
parameters["name"],
parameters["description"],
parameters["parameters"]["properties"],
parameters["return"]["type"],
name=tool_json_schema["name"],
description=tool_json_schema["description"],
inputs=tool_json_schema["parameters"]["properties"],
output_type=tool_json_schema["return"]["type"],
function=tool_function,
)
original_signature = inspect.signature(tool_function)

View File

@ -332,7 +332,7 @@ class AgentTests(unittest.TestCase):
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "got an error"
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
assert "Code execution failed at line 'print = 2' due to: InterpreterError" in str(agent.logs)
def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
@ -426,7 +426,7 @@ class AgentTests(unittest.TestCase):
with console.capture() as capture:
agent.run("Count to 3")
str_output = capture.get()
assert "import under `additional_authorized_imports`" in str_output
assert "Consider passing said import under" in str_output.replace("\n", "")
def test_multiagents(self):
class FakeModelMultiagentsManagerAgent:

View File

@ -630,12 +630,9 @@ counts += 1"""
assert "Cannot add non-list value 1 to a list." in str(e)
def test_error_highlights_correct_line_of_code(self):
code = """# Ok this is a very long code
# It has many commented lines
a = 1
code = """a = 1
b = 2
# Here is another piece
counts = [1, 2, 3]
counts += 1
b += 1"""
@ -643,12 +640,22 @@ b += 1"""
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "Code execution failed at line 'counts += 1" in str(e)
def test_error_type_returned_in_function_call(self):
code = """def error_function():
raise ValueError("error")
error_function()"""
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code)
assert "error" in str(e)
assert "ValueError" in str(e)
def test_assert(self):
code = """
assert 1 == 1
assert 1 == 2
"""
with pytest.raises(AssertionError) as e:
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "1 == 2" in str(e) and "1 == 1" not in str(e)
@ -845,6 +852,13 @@ shift_intervals
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
def test_syntax_error_points_error(self):
code = "a = ;"
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code)
assert "SyntaxError" in str(e)
assert " ^" in str(e)
def test_fix_final_answer_code(self):
test_cases = [
(
@ -890,18 +904,16 @@ shift_intervals
# Import of whitelisted modules should succeed but dangerous submodules should not exist
code = "import random;random._os.system('echo bad command passed')"
with pytest.raises(AttributeError) as e:
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code)
assert "module 'random' has no attribute '_os'" in str(e)
assert "AttributeError:module 'random' has no attribute '_os'" in str(e)
code = "import doctest;doctest.inspect.os.system('echo bad command passed')"
with pytest.raises(AttributeError):
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["doctest"])
def test_close_matches_subscript(self):
code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]'
with pytest.raises(Exception) as e:
evaluate_python_code(code)
assert "Maybe you meant one of these indexes instead" in str(
e
) and "['Bhutan']" in str(e).replace("\\", "")
assert "Maybe you meant one of these indexes instead" in str(e) and "['Bhutan']" in str(e).replace("\\", "")

View File

@ -374,6 +374,20 @@ class ToolTests(unittest.TestCase):
GetWeatherTool3()
assert "Nullable" in str(e)
def test_tool_default_parameters_is_nullable(self):
@tool
def get_weather(location: str, celsius: bool = False) -> str:
"""
Get weather in the next days at given location.
Args:
location: the location
celsius: is the temperature given in celsius
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert get_weather.inputs["celsius"]["nullable"]
@pytest.fixture
def mock_server_parameters():