Improve code execution error logging

This commit is contained in:
Aymeric 2024-12-25 23:28:57 +01:00
parent 8005d6f21d
commit c4f38850b2
8 changed files with 107 additions and 90 deletions

View File

@ -61,7 +61,7 @@ Once you have these two arguments, `tools` and `model`, you can create an agent
```python
from smolagents import CodeAgent, HfApiModel
model = HfApiModel(model=model_id)
model = HfApiModel(model_id=model_id)
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
agent.run(
@ -74,9 +74,7 @@ Note that we used an additional `additional_detail` argument: you can additional
You can use this to indicate the path to local or remote files for the model to use:
```py
from smolagents import CodeAgent, Tool, SpeechToTextTool
agent = CodeAgent(tools=[SpeechToTextTool()], model=model, add_base_tools=True)
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
```
@ -123,35 +121,39 @@ print(agent.system_prompt_template)
```
Here is what you get:
```text
You will be given a task to solve as best you can.
You have access to the following tools:
{{tool_descriptions}}
You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task, then the tools that you want to use.
Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '/End code' sequence.
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '<end_code>' sequence.
During each intermediate step, you can use 'print()' to save whatever important information you will then need.
These print outputs will then be available in the 'Observation:' field, for using this information as input for the next step.
These print outputs will then appear in the 'Observation:' field, which will be available as input for the next step.
In the end you have to return a final answer using the `final_answer` tool.
Here are a few examples using notional tools:
---
{examples}
Above example were using notional tools that might not exist for you. You only have acces to those tools:
{{tool_names}}
You also can perform computations in the python code you generate.
Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you only have access to these tools:
Always provide a 'Thought:' and a 'Code:\n```py' sequence ending with '```<end_code>' sequence. You MUST provide at least the 'Code:' sequence to move forward.
{{tool_descriptions}}
Remember to not perform too many operations in a single code block! You should split the task into intermediate code blocks.
Print results at the end of each step to save the intermediate results. Then use final_answer() to return the final result.
{{managed_agents_descriptions}}
Remember to make sure that variables you use are all defined.
Here are the rules you should always follow to solve your task:
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_code>' sequence, else you will fail.
2. Use only variables that you have defined!
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'.
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}}
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
Now Begin!
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
```
The system prompt includes:

View File

@ -51,6 +51,9 @@ We provide two types of agents, based on the main [`Agent`] class.
[[autodoc]] stream_to_gradio
### GradioUI
[[autodoc]] GradioUI
## Models
@ -61,19 +64,17 @@ These engines have the following specification:
### TransformersModel
For convenience, we have added a `TransformersModel` that implements the points above, taking a pre-initialized `Pipeline` as input.
For convenience, we have added a `TransformersModel` that implements the points above by building a local `transformers` pipeline for the model_id given at initialization.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TransformersModel
from smolagents import TransformersModel
model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = TransformersModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
engine = TransformersModel(pipe)
engine([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])
print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
```
```text
>>> What a
```
[[autodoc]] TransformersModel
@ -83,7 +84,7 @@ engine([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])
The `HfApiModel` is an engine that wraps an [HF Inference API](https://huggingface.co/docs/api-inference/index) client for the execution of the LLM.
```python
from transformers import HfApiModel
from smolagents import HfApiModel
messages = [
{"role": "user", "content": "Hello, how are you?"},
@ -91,7 +92,10 @@ messages = [
{"role": "user", "content": "No need to help, take it easy."},
]
HfApiModel()(messages)
model = HfApiModel()
print(model(messages))
```
```text
>>> Of course! If you change your mind, feel free to reach out. Take care!
```
[[autodoc]] HfApiModel

View File

@ -34,7 +34,7 @@ from .utils import (
AgentGenerationError,
AgentMaxIterationsError,
)
from .types import AgentAudio, AgentImage
from .types import AgentAudio, AgentImage, handle_agent_output_types
from .default_tools import FinalAnswerTool
from .models import MessageRole
from .monitoring import Monitor
@ -345,7 +345,7 @@ class MultiStepAgent:
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
except Exception:
raise AgentParsingError(
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
)
return rationale.strip(), action.strip()
@ -384,7 +384,7 @@ class MultiStepAgent:
"""
available_tools = {**self.toolbox.tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
@ -546,7 +546,7 @@ class MultiStepAgent:
callback(final_step_log)
yield final_step_log
yield final_answer
yield handle_agent_output_types(final_answer)
def direct_run(self, task: str):
"""
@ -593,7 +593,7 @@ class MultiStepAgent:
for callback in self.step_callbacks:
callback(final_step_log)
return final_answer
return handle_agent_output_types(final_answer)
def planning_step(self, task, is_first_step: bool, iteration: int):
"""
@ -934,7 +934,7 @@ class CodeAgent(MultiStepAgent):
title_align="left",
)
)
observation = ""
try:
output, execution_logs = self.python_executor(
code_action,
@ -945,12 +945,17 @@ class CodeAgent(MultiStepAgent):
Text("Execution logs:", style="bold"),
Text(execution_logs),
]
observation = "Execution logs:\n" + execution_logs
observation += "Execution logs:\n" + execution_logs
except Exception as e:
console.print_exception()
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e):
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
if isinstance(e, SyntaxError):
error_msg = (
f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n"
f"{e.text}"
f"{' ' * (e.offset or 0)}^\n"
f"Error: {str(e)}"
)
else:
error_msg = f"Code execution failed: {str(e)}"
raise AgentExecutionError(error_msg)
truncated_output = truncate_content(str(output))

View File

@ -1016,10 +1016,7 @@ def evaluate_python_code(
updated by this function to contain all variables as they are evaluated.
The print outputs will be stored in the state under the key 'print_outputs'.
"""
try:
expression = ast.parse(code)
except SyntaxError as e:
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
if state is None:
state = {}
if static_tools is None:
@ -1042,7 +1039,7 @@ def evaluate_python_code(
return result
except InterpreterError as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg)

View File

@ -297,19 +297,23 @@ class TransformersModel(HfModel):
self.model = AutoModelForCausalLM.from_pretrained(default_model_id)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_token_ids):
self.stop_token_ids = stop_token_ids
class StopOnStrings(StoppingCriteria):
def __init__(self, stop_strings: List[str], tokenizer):
self.stop_strings = stop_strings
self.tokenizer = tokenizer
self.stream = ""
def __call__(self, input_ids, scores):
for stop_ids in self.stop_token_ids:
if input_ids[0][-len(stop_ids) :].tolist() == stop_ids:
def reset(self):
self.stream = ""
def __call__(self, input_ids, scores, **kwargs):
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
self.stream += generated
if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
return True
return False
stop_token_ids = [self.tokenizer.encode("Observation:")[1:]] # Remove BOS token
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)])
return stopping_criteria
return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)])
def generate(
self,

View File

@ -11,26 +11,6 @@ _BUILTIN_NAMES = set(vars(builtins))
IMPORTED_PACKAGES = BASE_BUILTIN_MODULES
def is_installed_package(module_name: str) -> bool:
"""
Check if an import is from an installed package.
Returns False if it's not found or a local file import.
"""
try:
spec = importlib.util.find_spec(module_name)
if spec is None:
return False # If we can't find the module, assume it's local
# If the module is found and has a file path, check if it's in site-packages
if spec.origin and "site-packages" not in spec.origin:
# Check if it's a .py file in the current directory or subdirectories
return not spec.origin.endswith(".py")
return False
except ImportError:
return False # If there's an import error, assume it's local
class MethodChecker(ast.NodeVisitor):
"""
Checks that a method
@ -59,20 +39,12 @@ class MethodChecker(ast.NodeVisitor):
def visit_Import(self, node):
for name in node.names:
actual_name = name.asname or name.name
if not is_installed_package(actual_name) and self.check_imports:
self.errors.append(
f"Package not found in importlib, might be a local install: '{actual_name}'"
)
self.imports[actual_name] = name.name
def visit_ImportFrom(self, node):
module = node.module or ""
for name in node.names:
actual_name = name.asname or name.name
if not is_installed_package(module) and self.check_imports:
self.errors.append(
f"Package not found in importlib, might be a local install: '{module}'"
)
self.from_imports[actual_name] = (module, name.name)
def visit_Assign(self, node):

View File

@ -20,7 +20,7 @@ import pytest
from pathlib import Path
from smolagents.types import AgentText
from smolagents.types import AgentText, AgentImage
from smolagents.agents import (
AgentMaxIterationsError,
ManagedAgent,
@ -91,7 +91,32 @@ def fake_code_model_error(messages, stop_sequences=None) -> str:
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
a = 2
b = a * 2
print = 2
print("Ok, calculation done!")
```<end_code>
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Code:
```py
final_answer("got an error")
```<end_code>
"""
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
a = 2
b = a * 2
print("Failing due to unexpected indent")
print("Ok, calculation done!")
```<end_code>
"""
else: # We're at step 2
@ -187,7 +212,7 @@ class AgentTests(unittest.TestCase):
tools=[fake_image_generation_tool], model=FakeToolCallModelImage()
)
output = agent.run("Make me an image.")
assert isinstance(output, Image.Image)
assert isinstance(output, AgentImage)
assert isinstance(agent.state["image.png"], Image.Image)
def test_fake_code_agent(self):
@ -227,7 +252,15 @@ class AgentTests(unittest.TestCase):
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "got an error"
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "got an error"
assert " print(\"Failing due to unexpected indent\")" in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self):
ToolCallingAgent(model=FakeToolCallModel(), tools=[])

View File

@ -702,7 +702,7 @@ counts += 1
b += 1"""
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "Evaluation stopped at line 'counts += 1" in str(e)
assert "Code execution failed at line 'counts += 1" in str(e)
def test_assert(self):
code = """