Untangling Logging (#316)

* Deep refacto that disentangled logging logic between logging (informing the user via console outputs) and memory (storing what happened in the agent).

---------

Co-authored-by: benediktstroebl <50178209+benediktstroebl@users.noreply.github.com>
Co-authored-by: Aymeric <aymeric.roucher@gmail.com>
This commit is contained in:
Clémentine Fourrier 2025-01-29 16:27:04 +01:00 committed by GitHub
parent 882b134f35
commit c6eb9526f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 499 additions and 301 deletions

View File

@ -27,7 +27,7 @@ Initialization: the system prompt is stored in a `SystemPromptStep`, and the use
While loop (ReAct loop): While loop (ReAct loop):
- Use `agent.write_inner_memory_from_logs()` to write the agent logs into a list of LLM-readable [chat messages](https://huggingface.co/docs/transformers/en/chat_templating). - Use `agent.write_memory_to_messages()` to write the agent logs into a list of LLM-readable [chat messages](https://huggingface.co/docs/transformers/en/chat_templating).
- Send these messages to a `Model` object to get its completion. Parse the completion to get the action (a JSON blob for `ToolCallingAgent`, a code snippet for `CodeAgent`). - Send these messages to a `Model` object to get its completion. Parse the completion to get the action (a JSON blob for `ToolCallingAgent`, a code snippet for `CodeAgent`).
- Execute the action and logs result into memory (an `ActionStep`). - Execute the action and logs result into memory (an `ActionStep`).
- At the end of each step, we run all callback functions defined in `agent.step_callbacks` . - At the end of each step, we run all callback functions defined in `agent.step_callbacks` .

View File

@ -189,7 +189,7 @@ agent.run("Could you get me the title of the page at url 'https://huggingface.co
Here are a few useful attributes to inspect what happened after a run: Here are a few useful attributes to inspect what happened after a run:
- `agent.logs` stores the fine-grained logs of the agent. At every step of the agent's run, everything gets stored in a dictionary that then is appended to `agent.logs`. - `agent.logs` stores the fine-grained logs of the agent. At every step of the agent's run, everything gets stored in a dictionary that then is appended to `agent.logs`.
- Running `agent.write_inner_memory_from_logs()` creates an inner memory of the agent's logs for the LLM to view, as a list of chat messages. This method goes over each step of the log and only stores what it's interested in as a message: for instance, it will save the system prompt and task in separate messages, then for each step it will store the LLM output as a message, and the tool call output as another message. Use this if you want a higher-level view of what has happened - but not every log will be transcripted by this method. - Running `agent.write_memory_to_messages()` writes the agent's memory as list of chat messages for the Model to view. This method goes over each step of the log and only stores what it's interested in as a message: for instance, it will save the system prompt and task in separate messages, then for each step it will store the LLM output as a message, and the tool call output as another message. Use this if you want a higher-level view of what has happened - but not every log will be transcripted by this method.
## Tools ## Tools

View File

@ -142,7 +142,7 @@ agent.run("Could you get me the title of the page at url 'https://huggingface.co
रन के बाद क्या हुआ यह जांचने के लिए यहाँ कुछ उपयोगी एट्रिब्यूट्स हैं: रन के बाद क्या हुआ यह जांचने के लिए यहाँ कुछ उपयोगी एट्रिब्यूट्स हैं:
- `agent.logs` एजेंट के फाइन-ग्रेन्ड लॉग्स को स्टोर करता है। एजेंट के रन के हर स्टेप पर, सब कुछ एक डिक्शनरी में स्टोर किया जाता है जो फिर `agent.logs` में जोड़ा जाता है। - `agent.logs` एजेंट के फाइन-ग्रेन्ड लॉग्स को स्टोर करता है। एजेंट के रन के हर स्टेप पर, सब कुछ एक डिक्शनरी में स्टोर किया जाता है जो फिर `agent.logs` में जोड़ा जाता है।
- `agent.write_inner_memory_from_logs()` चलाने से LLM के लिए एजेंट के लॉग्स की एक इनर मेमोरी बनती है, चैट मैसेज की लिस्ट के रूप में। यह मेथड लॉग के प्रत्येक स्टेप पर जाता है और केवल वही स्टोर करता है जिसमें यह एक मैसेज के रूप में रुचि रखता है: उदाहरण के लिए, यह सिस्टम प्रॉम्प्ट और टास्क को अलग-अलग मैसेज के रूप में सेव करेगा, फिर प्रत्येक स्टेप के लिए यह LLM आउटपुट को एक मैसेज के रूप में और टूल कॉल आउटपुट को दूसरे मैसेज के रूप में स्टोर करेगा। - `agent.write_memory_to_messages()` चलाने से LLM के लिए एजेंट के लॉग्स की एक इनर मेमोरी बनती है, चैट मैसेज की लिस्ट के रूप में। यह मेथड लॉग के प्रत्येक स्टेप पर जाता है और केवल वही स्टोर करता है जिसमें यह एक मैसेज के रूप में रुचि रखता है: उदाहरण के लिए, यह सिस्टम प्रॉम्प्ट और टास्क को अलग-अलग मैसेज के रूप में सेव करेगा, फिर प्रत्येक स्टेप के लिए यह LLM आउटपुट को एक मैसेज के रूप में और टूल कॉल आउटपुट को दूसरे मैसेज के रूप में स्टोर करेगा।
## टूल्स ## टूल्स

View File

@ -152,7 +152,7 @@ agent.run("Could you get me the title of the page at url 'https://huggingface.co
以下是一些有用的属性,用于检查运行后发生了什么: 以下是一些有用的属性,用于检查运行后发生了什么:
- `agent.logs` 存储 agent 的细粒度日志。在 agent 运行的每一步,所有内容都会存储在一个字典中,然后附加到 `agent.logs` 中。 - `agent.logs` 存储 agent 的细粒度日志。在 agent 运行的每一步,所有内容都会存储在一个字典中,然后附加到 `agent.logs` 中。
- 运行 `agent.write_inner_memory_from_logs()` 会为 LLM 创建一个 agent 日志的内部内存,作为聊天消息列表。此方法会遍历日志的每一步,并仅存储它感兴趣的内容作为消息:例如,它会将系统提示和任务存储为单独的消息,然后对于每一步,它会将 LLM 输出存储为一条消息,工具调用输出存储为另一条消息。如果您想要更高级别的视图 - 但不是每个日志都会被此方法转录。 - 运行 `agent.write_memory_to_messages()` 会为 LLM 创建一个 agent 日志的内部内存,作为聊天消息列表。此方法会遍历日志的每一步,并仅存储它感兴趣的内容作为消息:例如,它会将系统提示和任务存储为单独的消息,然后对于每一步,它会将 LLM 输出存储为一条消息,工具调用输出存储为另一条消息。如果您想要更高级别的视图 - 但不是每个日志都会被此方法转录。
## 工具 ## 工具

View File

@ -53,11 +53,17 @@ mcp = [
openai = [ openai = [
"openai>=1.58.1" "openai>=1.58.1"
] ]
telemetry = [
"arize-phoenix",
"opentelemetry-sdk",
"opentelemetry-exporter-otlp",
"openinference-instrumentation-smolagents>=0.1.1"
]
quality = [ quality = [
"ruff>=0.9.0", "ruff>=0.9.0",
] ]
all = [ all = [
"smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers]", "smolagents[accelerate,audio,e2b,gradio,litellm,mcp,openai,transformers,telemetry]",
] ]
test = [ test = [
"ipython>=8.31.0", # for interactive environment tests "ipython>=8.31.0", # for interactive environment tests

View File

@ -17,7 +17,7 @@
import inspect import inspect
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass from logging import getLogger
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from rich import box from rich import box
@ -27,6 +27,23 @@ from rich.rule import Rule
from rich.syntax import Syntax from rich.syntax import Syntax
from rich.text import Text from rich.text import Text
from smolagents.logger import (
AgentLogger,
LogLevel,
)
from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from smolagents.types import AgentAudio, AgentImage, handle_agent_output_types
from smolagents.utils import (
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
AgentParsingError,
parse_code_blobs,
parse_json_tool_call,
truncate_content,
)
from .default_tools import TOOL_MAPPING, FinalAnswerTool from .default_tools import TOOL_MAPPING, FinalAnswerTool
from .e2b_executor import E2BExecutor from .e2b_executor import E2BExecutor
from .local_python_executor import ( from .local_python_executor import (
@ -57,62 +74,10 @@ from .tools import (
Tool, Tool,
get_tool_description_with_args, get_tool_description_with_args,
) )
from .types import AgentAudio, AgentImage, AgentType, handle_agent_output_types from .types import AgentType
from .utils import (
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentLogger,
AgentMaxStepsError,
AgentParsingError,
LogLevel,
parse_code_blobs,
parse_json_tool_call,
truncate_content,
)
@dataclass logger = getLogger(__name__)
class ToolCall:
name: str
arguments: Any
id: str
class AgentStepLog:
pass
@dataclass
class ActionStep(AgentStepLog):
agent_memory: List[Dict[str, str]] | None = None
tool_calls: List[ToolCall] | None = None
start_time: float | None = None
end_time: float | None = None
step_number: int | None = None
error: AgentError | None = None
duration: float | None = None
llm_output: str | None = None
observations: str | None = None
observations_images: List[str] | None = None
action_output: Any = None
@dataclass
class PlanningStep(AgentStepLog):
plan: str
facts: str
@dataclass
class TaskStep(AgentStepLog):
task: str
task_images: List[str] | None = None
@dataclass
class SystemPromptStep(AgentStepLog):
system_prompt: str
def get_tool_descriptions(tools: Dict[str, Tool], tool_description_template: str) -> str: def get_tool_descriptions(tools: Dict[str, Tool], tool_description_template: str) -> str:
@ -227,157 +192,60 @@ class MultiStepAgent:
self.system_prompt = self.initialize_system_prompt() self.system_prompt = self.initialize_system_prompt()
self.input_messages = None self.input_messages = None
self.logs = []
self.task = None self.task = None
self.memory = AgentMemory(system_prompt)
self.logger = AgentLogger(level=verbosity_level) self.logger = AgentLogger(level=verbosity_level)
self.monitor = Monitor(self.model, self.logger) self.monitor = Monitor(self.model, self.logger)
self.step_callbacks = step_callbacks if step_callbacks is not None else [] self.step_callbacks = step_callbacks if step_callbacks is not None else []
self.step_callbacks.append(self.monitor.update_metrics) self.step_callbacks.append(self.monitor.update_metrics)
@property
def logs(self):
logger.warning(
"The 'logs' attribute is deprecated and will soon be removed. Please use 'self.memory.steps' instead."
)
return [self.memory.system_prompt] + self.memory.steps
def initialize_system_prompt(self): def initialize_system_prompt(self):
self.system_prompt = format_prompt_with_tools( system_prompt = format_prompt_with_tools(
self.tools, self.tools,
self.system_prompt_template, self.system_prompt_template,
self.tool_description_template, self.tool_description_template,
) )
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents) system_prompt = format_prompt_with_managed_agents_descriptions(system_prompt, self.managed_agents)
return system_prompt
return self.system_prompt def write_memory_to_messages(
self,
def write_inner_memory_from_logs(self, summary_mode: bool = False) -> List[Dict[str, str]]: summary_mode: Optional[bool] = False,
) -> List[Dict[str, str]]:
""" """
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
that can be used as input to the LLM. that can be used as input to the LLM. Adds a number of keywords (such as PLAN, error, etc) to help
the LLM.
Args:
summary_mode (`bool`): Whether to write a summary of the logs or the full logs.
""" """
memory = [] messages = self.memory.system_prompt.to_messages(summary_mode=summary_mode)
for i, step_log in enumerate(self.logs): for step_log in self.memory.steps:
if isinstance(step_log, SystemPromptStep): messages.extend(step_log.to_messages(summary_mode=summary_mode))
if not summary_mode: return messages
thought_message = {
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": step_log.system_prompt.strip()}],
}
memory.append(thought_message)
elif isinstance(step_log, PlanningStep): def extract_action(self, model_output: str, split_token: str) -> Tuple[str, str]:
thought_message = {
"role": MessageRole.ASSISTANT,
"content": [{"type": "text", "text": "[FACTS LIST]:\n" + step_log.facts.strip()}],
}
memory.append(thought_message)
if not summary_mode:
thought_message = {
"role": MessageRole.ASSISTANT,
"content": [{"type": "text", "text": "[PLAN]:\n" + step_log.plan.strip()}],
}
memory.append(thought_message)
elif isinstance(step_log, TaskStep):
task_message = {
"role": MessageRole.USER,
"content": [{"type": "text", "text": f"New task:\n{step_log.task}"}],
}
if step_log.task_images:
for image in step_log.task_images:
task_message["content"].append({"type": "image", "image": image})
memory.append(task_message)
elif isinstance(step_log, ActionStep):
if step_log.llm_output is not None and not summary_mode:
thought_message = {
"role": MessageRole.ASSISTANT,
"content": [{"type": "text", "text": step_log.llm_output.strip()}],
}
memory.append(thought_message)
if step_log.tool_calls is not None:
tool_call_message = {
"role": MessageRole.ASSISTANT,
"content": [
{
"type": "text",
"text": str(
[
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.name,
"arguments": tool_call.arguments,
},
}
for tool_call in step_log.tool_calls
]
),
}
],
}
memory.append(tool_call_message)
if step_log.error is not None:
error_message = {
"role": MessageRole.ASSISTANT,
"content": [
{
"type": "text",
"text": (
"Error:\n"
+ str(step_log.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
),
}
],
}
memory.append(error_message)
if step_log.observations is not None:
if step_log.tool_calls:
tool_call_reference = f"Call id: {(step_log.tool_calls[0].id if getattr(step_log.tool_calls[0], 'id') else 'call_0')}\n"
else:
tool_call_reference = ""
text_observations = f"Observation:\n{step_log.observations}"
tool_response_message = {
"role": MessageRole.TOOL_RESPONSE,
"content": [{"type": "text", "text": tool_call_reference + text_observations}],
}
memory.append(tool_response_message)
if step_log.observations_images:
thought_message_image = {
"role": MessageRole.USER,
"content": [{"type": "text", "text": "Here are the observed images:"}]
+ [
{
"type": "image",
"image": image,
}
for image in step_log.observations_images
],
}
memory.append(thought_message_image)
return memory
def get_succinct_logs(self):
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
""" """
Parse action from the LLM output Parse action from the LLM output
Args: Args:
llm_output (`str`): Output of the LLM model_output (`str`): Output of the LLM
split_token (`str`): Separator for the action. Should match the example in the system prompt. split_token (`str`): Separator for the action. Should match the example in the system prompt.
""" """
try: try:
split = llm_output.split(split_token) split = model_output.split(split_token)
rationale, action = ( rationale, action = (
split[-2], split[-2],
split[-1], split[-1],
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
except Exception: except Exception:
raise AgentParsingError( raise AgentParsingError(
f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!", f"No '{split_token}' token provided in your output.\nYour output:\n{model_output}\n. Be sure to include an action, prefaced with '{split_token}'!",
self.logger, self.logger,
) )
return rationale.strip(), action.strip() return rationale.strip(), action.strip()
@ -393,16 +261,17 @@ class MultiStepAgent:
Returns: Returns:
`str`: Final answer to the task. `str`: Final answer to the task.
""" """
messages = [{"role": MessageRole.SYSTEM, "content": []}]
if images: if images:
self.input_messages[0]["content"] = [ messages[0]["content"] = [
{ {
"type": "text", "type": "text",
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
} }
] ]
self.input_messages[0]["content"].append({"type": "image"}) messages[0]["content"].append({"type": "image"})
self.input_messages += self.write_inner_memory_from_logs()[1:] messages += self.write_memory_to_messages()[1:]
self.input_messages += [ messages += [
{ {
"role": MessageRole.USER, "role": MessageRole.USER,
"content": [ "content": [
@ -414,14 +283,14 @@ class MultiStepAgent:
} }
] ]
else: else:
self.input_messages[0]["content"] = [ messages[0]["content"] = [
{ {
"type": "text", "type": "text",
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", "text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
} }
] ]
self.input_messages += self.write_inner_memory_from_logs()[1:] messages += self.write_memory_to_messages()[1:]
self.input_messages += [ messages += [
{ {
"role": MessageRole.USER, "role": MessageRole.USER,
"content": [ "content": [
@ -433,7 +302,8 @@ class MultiStepAgent:
} }
] ]
try: try:
return self.model(self.input_messages).content chat_message: ChatMessage = self.model(messages)
return chat_message.content
except Exception as e: except Exception as e:
return f"Error in generating final LLM output:\n{e}" return f"Error in generating final LLM output:\n{e}"
@ -523,18 +393,11 @@ class MultiStepAgent:
You have been provided with these additional arguments, that you can access using the keys as variables in your python code: You have been provided with these additional arguments, that you can access using the keys as variables in your python code:
{str(additional_args)}.""" {str(additional_args)}."""
self.initialize_system_prompt() self.system_prompt = self.initialize_system_prompt()
system_prompt_step = SystemPromptStep(system_prompt=self.system_prompt) self.memory.system_prompt = SystemPromptStep(system_prompt=self.system_prompt)
if reset: if reset:
self.logs = [] self.memory.reset()
self.logs.append(system_prompt_step)
self.monitor.reset() self.monitor.reset()
else:
if len(self.logs) > 0:
self.logs[0] = system_prompt_step
else:
self.logs.append(system_prompt_step)
self.logger.log( self.logger.log(
Panel( Panel(
@ -547,7 +410,7 @@ You have been provided with these additional arguments, that you can access usin
level=LogLevel.INFO, level=LogLevel.INFO,
) )
self.logs.append(TaskStep(task=self.task, task_images=images)) self.memory.steps.append(TaskStep(task=self.task, task_images=images))
if single_step: if single_step:
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(start_time=step_start_time, observations_images=images) step_log = ActionStep(start_time=step_start_time, observations_images=images)
@ -604,7 +467,7 @@ You have been provided with these additional arguments, that you can access usin
finally: finally:
step_log.end_time = time.time() step_log.end_time = time.time()
step_log.duration = step_log.end_time - step_start_time step_log.duration = step_log.end_time - step_start_time
self.logs.append(step_log) self.memory.steps.append(step_log)
for callback in self.step_callbacks: for callback in self.step_callbacks:
# For compatibility with old callbacks that don't take the agent as an argument # For compatibility with old callbacks that don't take the agent as an argument
if len(inspect.signature(callback).parameters) == 1: if len(inspect.signature(callback).parameters) == 1:
@ -616,15 +479,16 @@ You have been provided with these additional arguments, that you can access usin
if final_answer is None and self.step_number == self.max_steps: if final_answer is None and self.step_number == self.max_steps:
error_message = "Reached max steps." error_message = "Reached max steps."
final_answer = self.provide_final_answer(task, images)
final_step_log = ActionStep( final_step_log = ActionStep(
step_number=self.step_number, error=AgentMaxStepsError(error_message, self.logger) step_number=self.step_number, error=AgentMaxStepsError(error_message, self.logger)
) )
self.logs.append(final_step_log) self.logger.log(final_step_log)
final_answer = self.provide_final_answer(task, images) final_step_log = ActionStep(error=AgentMaxStepsError(error_message, self.logger))
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
final_step_log.action_output = final_answer final_step_log.action_output = final_answer
final_step_log.end_time = time.time() final_step_log.end_time = time.time()
final_step_log.duration = step_log.end_time - step_start_time final_step_log.duration = step_log.end_time - step_start_time
self.memory.steps.append(final_step_log)
for callback in self.step_callbacks: for callback in self.step_callbacks:
# For compatibility with old callbacks that don't take the agent as an argument # For compatibility with old callbacks that don't take the agent as an argument
if len(inspect.signature(callback).parameters) == 1: if len(inspect.signature(callback).parameters) == 1:
@ -658,7 +522,8 @@ You have been provided with these additional arguments, that you can access usin
Now begin!""", Now begin!""",
} }
answer_facts = self.model([message_prompt_facts, message_prompt_task]).content chat_message_facts: ChatMessage = self.model([message_prompt_facts, message_prompt_task])
answer_facts = chat_message_facts.content
message_system_prompt_plan = { message_system_prompt_plan = {
"role": MessageRole.SYSTEM, "role": MessageRole.SYSTEM,
@ -673,10 +538,11 @@ Now begin!""",
answer_facts=answer_facts, answer_facts=answer_facts,
), ),
} }
answer_plan = self.model( chat_message_plan: ChatMessage = self.model(
[message_system_prompt_plan, message_user_prompt_plan], [message_system_prompt_plan, message_user_prompt_plan],
stop_sequences=["<end_plan>"], stop_sequences=["<end_plan>"],
).content )
answer_plan = chat_message_plan.content
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task: final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
``` ```
@ -686,14 +552,21 @@ Now begin!""",
``` ```
{answer_facts} {answer_facts}
```""".strip() ```""".strip()
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) self.memory.steps.append(
PlanningStep(
plan=final_plan_redaction,
facts=final_facts_redaction,
model_output_message_plan=chat_message_plan,
model_output_message_facts=chat_message_facts,
)
)
self.logger.log( self.logger.log(
Rule("[bold]Initial plan", style="orange"), Rule("[bold]Initial plan", style="orange"),
Text(final_plan_redaction), Text(final_plan_redaction),
level=LogLevel.INFO, level=LogLevel.INFO,
) )
else: # update plan else: # update plan
agent_memory = self.write_inner_memory_from_logs( memory_messages = self.write_memory_to_messages(
summary_mode=False summary_mode=False
) # This will not log the plan but will log facts ) # This will not log the plan but will log facts
@ -706,7 +579,10 @@ Now begin!""",
"role": MessageRole.USER, "role": MessageRole.USER,
"content": [{"type": "text", "text": USER_PROMPT_FACTS_UPDATE}], "content": [{"type": "text", "text": USER_PROMPT_FACTS_UPDATE}],
} }
facts_update = self.model([facts_update_system_prompt] + agent_memory + [facts_update_message]).content chat_message_facts: ChatMessage = self.model(
[facts_update_system_prompt] + memory_messages + [facts_update_message]
)
facts_update = chat_message_facts.content
# Redact updated plan # Redact updated plan
plan_update_message = { plan_update_message = {
@ -728,18 +604,27 @@ Now begin!""",
} }
], ],
} }
plan_update = self.model( chat_message_plan: ChatMessage = self.model(
[plan_update_message] + agent_memory + [plan_update_message_user], [plan_update_message] + memory_messages + [plan_update_message_user],
stop_sequences=["<end_plan>"], stop_sequences=["<end_plan>"],
).content )
# Log final facts and plan # Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update) final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
task=task, plan_update=chat_message_plan.content
)
final_facts_redaction = f"""Here is the updated list of the facts that I know: final_facts_redaction = f"""Here is the updated list of the facts that I know:
``` ```
{facts_update} {facts_update}
```""" ```"""
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)) self.memory.steps.append(
PlanningStep(
plan=final_plan_redaction,
facts=final_facts_redaction,
model_output_message_plan=chat_message_plan,
model_output_message_facts=chat_message_facts,
)
)
self.logger.log( self.logger.log(
Rule("[bold]Updated plan", style="orange"), Rule("[bold]Updated plan", style="orange"),
Text(final_plan_redaction), Text(final_plan_redaction),
@ -783,19 +668,20 @@ class ToolCallingAgent(MultiStepAgent):
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Returns None if the step is not final. Returns None if the step is not final.
""" """
agent_memory = self.write_inner_memory_from_logs() memory_messages = self.write_memory_to_messages()
self.input_messages = agent_memory self.input_messages = memory_messages
# Add new step in logs # Add new step in logs
log_entry.agent_memory = agent_memory.copy() log_entry.model_input_messages = memory_messages.copy()
try: try:
model_message = self.model( model_message: ChatMessage = self.model(
self.input_messages, memory_messages,
tools_to_call_from=list(self.tools.values()), tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"], stop_sequences=["Observation:"],
) )
log_entry.model_output_message = model_message
if model_message.tool_calls is None or len(model_message.tool_calls) == 0: if model_message.tool_calls is None or len(model_message.tool_calls) == 0:
raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.") raise Exception("Model did not call any tools. Call `final_answer` tool to return a final answer.")
tool_call = model_message.tool_calls[0] tool_call = model_message.tool_calls[0]
@ -931,7 +817,7 @@ class CodeAgent(MultiStepAgent):
) )
def initialize_system_prompt(self): def initialize_system_prompt(self):
super().initialize_system_prompt() self.system_prompt = super().initialize_system_prompt()
self.system_prompt = self.system_prompt.replace( self.system_prompt = self.system_prompt.replace(
"{{authorized_imports}}", "{{authorized_imports}}",
( (
@ -947,20 +833,22 @@ class CodeAgent(MultiStepAgent):
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Returns None if the step is not final. Returns None if the step is not final.
""" """
agent_memory = self.write_inner_memory_from_logs() memory_messages = self.write_memory_to_messages()
self.input_messages = agent_memory.copy() self.input_messages = memory_messages.copy()
# Add new step in logs # Add new step in logs
log_entry.agent_memory = agent_memory.copy() log_entry.model_input_messages = memory_messages.copy()
try: try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {} additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.model( chat_message: ChatMessage = self.model(
self.input_messages, self.input_messages,
stop_sequences=["<end_code>", "Observation:"], stop_sequences=["<end_code>", "Observation:"],
**additional_args, **additional_args,
).content )
log_entry.llm_output = llm_output log_entry.model_output_message = chat_message
model_output = chat_message.content
log_entry.model_output = model_output
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e
@ -972,7 +860,7 @@ class CodeAgent(MultiStepAgent):
style="orange", style="orange",
), ),
Syntax( Syntax(
llm_output, model_output,
lexer="markdown", lexer="markdown",
theme="github-dark", theme="github-dark",
word_wrap=True, word_wrap=True,
@ -983,7 +871,7 @@ class CodeAgent(MultiStepAgent):
# Parse # Parse
try: try:
code_action = fix_final_answer_code(parse_code_blobs(llm_output)) code_action = fix_final_answer_code(parse_code_blobs(model_output))
except Exception as e: except Exception as e:
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
raise AgentParsingError(error_msg, self.logger) raise AgentParsingError(error_msg, self.logger)
@ -992,7 +880,7 @@ class CodeAgent(MultiStepAgent):
ToolCall( ToolCall(
name="python_interpreter", name="python_interpreter",
arguments=code_action, arguments=code_action,
id=f"call_{len(self.logs)}", id=f"call_{len(self.memory.steps)}",
) )
] ]
@ -1095,7 +983,7 @@ class ManagedAgent:
answer = f"Here is the final answer from your managed agent '{self.name}':\n" answer = f"Here is the final answer from your managed agent '{self.name}':\n"
answer += str(output) answer += str(output)
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True): for message in self.agent.write_memory_to_messages(summary_mode=True):
content = message["content"] content = message["content"]
answer += "\n" + truncate_content(str(content)) + "\n---" answer += "\n" + truncate_content(str(content)) + "\n---"
answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'." answer += f"\nEND OF SUMMARY OF WORK FROM AGENT '{self.name}'."
@ -1104,9 +992,4 @@ class ManagedAgent:
return output return output
__all__ = [ __all__ = ["ManagedAgent", "MultiStepAgent", "CodeAgent", "ToolCallingAgent", "AgentMemory"]
"ManagedAgent",
"MultiStepAgent",
"CodeAgent",
"ToolCallingAgent",
]

View File

@ -74,14 +74,16 @@ class E2BExecutor:
tool_codes.append(tool_code) tool_codes.append(tool_code)
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES]) tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
tool_definition_code += textwrap.dedent(""" tool_definition_code += textwrap.dedent(
"""
class Tool: class Tool:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
pass # to be implemented in child class pass # to be implemented in child class
""") """
)
tool_definition_code += "\n\n".join(tool_codes) tool_definition_code += "\n\n".join(tool_codes)
tool_definition_execution = self.run_code_raise_errors(tool_definition_code) tool_definition_execution = self.run_code_raise_errors(tool_definition_code)

View File

@ -19,13 +19,14 @@ import re
import shutil import shutil
from typing import Optional from typing import Optional
from .agents import ActionStep, AgentStepLog, MultiStepAgent from smolagents.agents import ActionStep, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from smolagents.memory import MemoryStep
from .utils import _is_package_available from smolagents.types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.utils import _is_package_available
def pull_messages_from_step( def pull_messages_from_step(
step_log: AgentStepLog, step_log: MemoryStep,
): ):
"""Extract ChatMessage objects from agent steps with proper nesting""" """Extract ChatMessage objects from agent steps with proper nesting"""
import gradio as gr import gradio as gr
@ -36,15 +37,15 @@ def pull_messages_from_step(
yield gr.ChatMessage(role="assistant", content=f"**{step_number}**") yield gr.ChatMessage(role="assistant", content=f"**{step_number}**")
# First yield the thought/reasoning from the LLM # First yield the thought/reasoning from the LLM
if hasattr(step_log, "llm_output") and step_log.llm_output is not None: if hasattr(step_log, "model_output") and step_log.model_output is not None:
# Clean up the LLM output # Clean up the LLM output
llm_output = step_log.llm_output.strip() model_output = step_log.model_output.strip()
# Remove any trailing <end_code> and extra backticks, handling multiple possible formats # Remove any trailing <end_code> and extra backticks, handling multiple possible formats
llm_output = re.sub(r"```\s*<end_code>", "```", llm_output) # handles ```<end_code> model_output = re.sub(r"```\s*<end_code>", "```", model_output) # handles ```<end_code>
llm_output = re.sub(r"<end_code>\s*```", "```", llm_output) # handles <end_code>``` model_output = re.sub(r"<end_code>\s*```", "```", model_output) # handles <end_code>```
llm_output = re.sub(r"```\s*\n\s*<end_code>", "```", llm_output) # handles ```\n<end_code> model_output = re.sub(r"```\s*\n\s*<end_code>", "```", model_output) # handles ```\n<end_code>
llm_output = llm_output.strip() model_output = model_output.strip()
yield gr.ChatMessage(role="assistant", content=llm_output) yield gr.ChatMessage(role="assistant", content=model_output)
# For tool calls, create a parent message # For tool calls, create a parent message
if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None: if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None:

85
src/smolagents/logger.py Normal file
View File

@ -0,0 +1,85 @@
import json
from enum import IntEnum
from typing import TYPE_CHECKING
from rich.console import Console
from rich.rule import Rule
from rich.syntax import Syntax
if TYPE_CHECKING:
from smolagents.memory import AgentMemory
class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
def log(self, *args, level: str | LogLevel = LogLevel.INFO, **kwargs):
"""Logs a message to the console.
Args:
level (LogLevel, optional): Defaults to LogLevel.INFO.
"""
if isinstance(level, str):
level = LogLevel[level.upper()]
if level <= self.level:
self.console.print(*args, **kwargs)
def replay(self, agent_memory: "AgentMemory", full: bool = False):
"""Prints a pretty replay of the agent's steps.
Args:
with_memory (bool, optional): If True, also displays the memory at each step. Defaults to False.
Careful: will increase log length exponentially. Use only for debugging.
"""
memory = []
for step_log in agent_memory.steps:
memory.extend(step_log.to_messages(show_model_input_messages=full))
self.console.log("Replaying the agent's steps:")
ix = 0
for step in memory:
role = step["role"].strip()
if ix > 0 and role == "system":
role == "memory"
theme = "default"
match role:
case "assistant":
theme = "monokai"
ix += 1
case "system":
theme = "monokai"
case "tool-response":
theme = "github_dark"
content = step["content"]
try:
content = eval(content)
except Exception:
content = [step["content"]]
for substep_ix, item in enumerate(content):
self.console.log(
Rule(
f"{role.upper()}, STEP {ix}, SUBSTEP {substep_ix + 1}/{len(content)}",
align="center",
style="orange",
),
Syntax(
json.dumps(item, indent=4) if isinstance(item, dict) else str(item),
lexer="json",
theme=theme,
word_wrap=True,
),
)
__all__ = ["AgentLogger"]

194
src/smolagents/memory.py Normal file
View File

@ -0,0 +1,194 @@
from dataclasses import asdict, dataclass
from logging import getLogger
from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union
from smolagents.models import ChatMessage, MessageRole
from smolagents.utils import AgentError, make_json_serializable
if TYPE_CHECKING:
from smolagents.models import ChatMessage
logger = getLogger(__name__)
class Message(TypedDict):
role: MessageRole
content: str | list[dict]
@dataclass
class ToolCall:
name: str
arguments: Any
id: str
def dict(self):
return {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": make_json_serializable(self.arguments),
},
}
class MemoryStep:
raw: Any # This is a placeholder for the raw data that the agent logs
def dict(self):
return asdict(self)
def to_messages(self, **kwargs) -> List[Dict[str, Any]]:
raise NotImplementedError
@dataclass
class ActionStep(MemoryStep):
model_input_messages: List[Dict[str, str]] | None = None
tool_calls: List[ToolCall] | None = None
start_time: float | None = None
end_time: float | None = None
step_number: int | None = None
error: AgentError | None = None
duration: float | None = None
model_output_message: ChatMessage = None
model_output: str | None = None
observations: str | None = None
observations_images: List[str] | None = None
action_output: Any = None
def dict(self):
# We overwrite the method to parse the tool_calls and action_output manually
return {
"model_input_messages": self.model_input_messages,
"tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [],
"start_time": self.start_time,
"end_time": self.end_time,
"step": self.step_number,
"error": self.error.dict() if self.error else None,
"duration": self.duration,
"model_output_message": self.model_output_message,
"model_output": self.model_output,
"observations": self.observations,
"action_output": make_json_serializable(self.action_output),
}
def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Dict[str, Any]]:
messages = []
if self.model_input_messages is not None and show_model_input_messages:
messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages))
if self.model_output is not None and not summary_mode:
messages.append(
Message(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": self.model_output.strip()}])
)
if self.tool_calls is not None:
messages.append(
Message(
role=MessageRole.ASSISTANT,
content=[{"type": "text", "text": str([tc.dict() for tc in self.tool_calls])}],
)
)
if self.error is not None:
message_content = (
"Error:\n"
+ str(self.error)
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
if self.tool_calls is None:
tool_response_message = Message(
role=MessageRole.ASSISTANT, content=[{"type": "text", "text": message_content}]
)
else:
tool_response_message = Message(
role=MessageRole.TOOL_RESPONSE, content=f"Call id: {self.tool_calls[0].id}\n{message_content}"
)
messages.append(tool_response_message)
else:
if self.observations is not None and self.tool_calls is not None:
messages.append(
Message(
role=MessageRole.TOOL_RESPONSE,
content=f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}",
)
)
if self.observations_images:
messages.append(
Message(
role=MessageRole.USER,
content=[{"type": "text", "text": "Here are the observed images:"}]
+ [
{
"type": "image",
"image": image,
}
for image in self.observations_images
],
)
)
return messages
@dataclass
class PlanningStep(MemoryStep):
model_output_message_facts: ChatMessage
facts: str
model_output_message_facts: ChatMessage
plan: str
def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
messages = []
messages.append(Message(role=MessageRole.ASSISTANT, content=f"[FACTS LIST]:\n{self.facts.strip()}"))
if not summary_mode:
messages.append(Message(role=MessageRole.ASSISTANT, content=f"[PLAN]:\n{self.plan.strip()}"))
return messages
@dataclass
class TaskStep(MemoryStep):
task: str
task_images: List[str] | None = None
def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
content = [{"type": "text", "text": f"New task:\n{self.task}"}]
if self.task_images:
for image in self.task_images:
content.append({"type": "image", "image": image})
return [Message(role=MessageRole.USER, content=content)]
@dataclass
class SystemPromptStep(MemoryStep):
system_prompt: str
def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
if summary_mode:
return []
return [Message(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt.strip()}])]
class AgentMemory:
def __init__(self, system_prompt: str):
self.system_prompt = SystemPromptStep(system_prompt=system_prompt)
self.steps: List[Union[TaskStep, ActionStep, PlanningStep]] = []
def reset(self):
self.steps = []
def get_succinct_steps(self) -> list[dict]:
return [
{key: value for key, value in step.dict().items() if key != "model_input_messages"} for step in self.steps
]
def get_full_steps(self) -> list[dict]:
return [step.dict() for step in self.steps]
__all__ = ["AgentMemory"]

View File

@ -47,10 +47,10 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
} }
def get_dict_from_nested_dataclasses(obj): def get_dict_from_nested_dataclasses(obj, ignore_key=None):
def convert(obj): def convert(obj):
if hasattr(obj, "__dataclass_fields__"): if hasattr(obj, "__dataclass_fields__"):
return {k: convert(v) for k, v in asdict(obj).items()} return {k: convert(v) for k, v in asdict(obj).items() if k != ignore_key}
return obj return obj
return convert(obj) return convert(obj)
@ -78,7 +78,7 @@ class ChatMessageToolCall:
type: str type: str
@classmethod @classmethod
def from_hf_api(cls, tool_call) -> "ChatMessageToolCall": def from_hf_api(cls, tool_call, raw) -> "ChatMessageToolCall":
return cls( return cls(
function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function), function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function),
id=tool_call.id, id=tool_call.id,
@ -91,16 +91,17 @@ class ChatMessage:
role: str role: str
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[List[ChatMessageToolCall]] = None tool_calls: Optional[List[ChatMessageToolCall]] = None
raw: Optional[Any] = None # Stores the raw output from the API
def model_dump_json(self): def model_dump_json(self):
return json.dumps(get_dict_from_nested_dataclasses(self)) return json.dumps(get_dict_from_nested_dataclasses(self, ignore_key="raw"))
@classmethod @classmethod
def from_hf_api(cls, message) -> "ChatMessage": def from_hf_api(cls, message, raw) -> "ChatMessage":
tool_calls = None tool_calls = None
if getattr(message, "tool_calls", None) is not None: if getattr(message, "tool_calls", None) is not None:
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls] tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
return cls(role=message.role, content=message.content, tool_calls=tool_calls) return cls(role=message.role, content=message.content, tool_calls=tool_calls, raw=raw)
@classmethod @classmethod
def from_dict(cls, data: dict) -> "ChatMessage": def from_dict(cls, data: dict) -> "ChatMessage":
@ -114,6 +115,9 @@ class ChatMessage:
data["tool_calls"] = tool_calls data["tool_calls"] = tool_calls
return cls(**data) return cls(**data)
def dict(self):
return json.dumps(get_dict_from_nested_dataclasses(self))
def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]: def parse_json_if_needed(arguments: Union[str, dict]) -> Union[str, dict]:
if isinstance(arguments, dict): if isinstance(arguments, dict):
@ -395,7 +399,7 @@ class HfApiModel(Model):
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
message = ChatMessage.from_hf_api(response.choices[0].message) message = ChatMessage.from_hf_api(response.choices[0].message, raw=response)
if tools_to_call_from is not None: if tools_to_call_from is not None:
return parse_tool_args_if_needed(message) return parse_tool_args_if_needed(message)
return message return message
@ -587,7 +591,11 @@ class TransformersModel(Model):
output = remove_stop_sequences(output, stop_sequences) output = remove_stop_sequences(output, stop_sequences)
if tools_to_call_from is None: if tools_to_call_from is None:
return ChatMessage(role="assistant", content=output) return ChatMessage(
role="assistant",
content=output,
raw={"out": out, "completion_kwargs": completion_kwargs},
)
else: else:
if "Action:" in output: if "Action:" in output:
output = output.split("Action:", 1)[1].strip() output = output.split("Action:", 1)[1].strip()
@ -614,6 +622,7 @@ class TransformersModel(Model):
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments), function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
) )
], ],
raw={"out": out, "completion_kwargs": completion_kwargs},
) )
@ -678,10 +687,10 @@ class LiteLLMModel(Model):
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = response.usage.completion_tokens
message = ChatMessage.from_dict( message = ChatMessage.from_dict(
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
) )
message.raw = response
if tools_to_call_from is not None: if tools_to_call_from is not None:
return parse_tool_args_if_needed(message) return parse_tool_args_if_needed(message)
@ -762,6 +771,7 @@ class OpenAIServerModel(Model):
message = ChatMessage.from_dict( message = ChatMessage.from_dict(
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}) response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
) )
message.raw = response
if tools_to_call_from is not None: if tools_to_call_from is not None:
return parse_tool_args_if_needed(message) return parse_tool_args_if_needed(message)
return message return message

View File

@ -41,7 +41,7 @@ class Monitor:
"""Update the metrics of the monitor. """Update the metrics of the monitor.
Args: Args:
step_log ([`AgentStepLog`]): Step log to update the monitor with. step_log ([`MemoryStep`]): Step log to update the monitor with.
""" """
step_duration = step_log.duration step_duration = step_log.duration
self.step_durations.append(step_duration) self.step_durations.append(step_duration)

View File

@ -23,12 +23,13 @@ import json
import re import re
import textwrap import textwrap
import types import types
from enum import IntEnum
from functools import lru_cache from functools import lru_cache
from io import BytesIO from io import BytesIO
from typing import Dict, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Tuple, Union
from rich.console import Console
if TYPE_CHECKING:
from smolagents.memory import AgentLogger
__all__ = ["AgentError"] __all__ = ["AgentError"]
@ -48,8 +49,6 @@ def _is_pillow_available():
return importlib.util.find_spec("PIL") is not None return importlib.util.find_spec("PIL") is not None
console = Console()
BASE_BUILTIN_MODULES = [ BASE_BUILTIN_MODULES = [
"collections", "collections",
"datetime", "datetime",
@ -65,29 +64,16 @@ BASE_BUILTIN_MODULES = [
] ]
class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
if level <= self.level:
self.console.print(*args, **kwargs)
class AgentError(Exception): class AgentError(Exception):
"""Base class for other agent-related exceptions""" """Base class for other agent-related exceptions"""
def __init__(self, message, logger: AgentLogger): def __init__(self, message, logger: "AgentLogger"):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
logger.log(f"[bold red]{message}[/bold red]", level=LogLevel.ERROR) logger.log(f"[bold red]{message}[/bold red]", level="ERROR")
def dict(self) -> Dict[str, str]:
return {"type": self.__class__.__name__, "message": str(self.message)}
class AgentParsingError(AgentError): class AgentParsingError(AgentError):
@ -114,6 +100,32 @@ class AgentGenerationError(AgentError):
pass pass
def make_json_serializable(obj: Any) -> Any:
"""Recursive function to make objects JSON serializable"""
if obj is None:
return None
elif isinstance(obj, (str, int, float, bool)):
# Try to parse string as JSON if it looks like a JSON object/array
if isinstance(obj, str):
try:
if (obj.startswith("{") and obj.endswith("}")) or (obj.startswith("[") and obj.endswith("]")):
parsed = json.loads(obj)
return make_json_serializable(parsed)
except json.JSONDecodeError:
pass
return obj
elif isinstance(obj, (list, tuple)):
return [make_json_serializable(item) for item in obj]
elif isinstance(obj, dict):
return {str(k): make_json_serializable(v) for k, v in obj.items()}
elif hasattr(obj, "__dict__"):
# For custom objects, convert their __dict__ to a serializable format
return {"_type": obj.__class__.__name__, **{k: make_json_serializable(v) for k, v in obj.__dict__.items()}}
else:
# For any other type, convert to string
return str(obj)
def parse_json_blob(json_blob: str) -> Dict[str, str]: def parse_json_blob(json_blob: str) -> Dict[str, str]:
try: try:
first_accolade_index = json_blob.find("{") first_accolade_index = json_blob.find("{")

View File

@ -303,9 +303,9 @@ class AgentTests(unittest.TestCase):
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str) assert isinstance(output, str)
assert "7.2904" in output assert "7.2904" in output
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" assert agent.memory.steps[0].task == "What is 2 multiplied by 3.6452?"
assert "7.2904" in agent.logs[2].observations assert "7.2904" in agent.memory.steps[1].observations
assert agent.logs[3].llm_output is None assert agent.memory.steps[2].model_output is None
def test_toolcalling_agent_handles_image_tool_outputs(self): def test_toolcalling_agent_handles_image_tool_outputs(self):
from PIL import Image from PIL import Image
@ -348,9 +348,9 @@ class AgentTests(unittest.TestCase):
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, float) assert isinstance(output, float)
assert output == 7.2904 assert output == 7.2904
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" assert agent.memory.steps[0].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[3].tool_calls == [ assert agent.memory.steps[2].tool_calls == [
ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_3") ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_2")
] ]
def test_additional_args_added_to_task(self): def test_additional_args_added_to_task(self):
@ -366,30 +366,30 @@ class AgentTests(unittest.TestCase):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model) agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model)
output = agent.run("What is 2 multiplied by 3.6452?", reset=True) output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
assert output == 7.2904 assert output == 7.2904
assert len(agent.logs) == 4 assert len(agent.memory.steps) == 3
output = agent.run("What is 2 multiplied by 3.6452?", reset=False) output = agent.run("What is 2 multiplied by 3.6452?", reset=False)
assert output == 7.2904 assert output == 7.2904
assert len(agent.logs) == 6 assert len(agent.memory.steps) == 5
output = agent.run("What is 2 multiplied by 3.6452?", reset=True) output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
assert output == 7.2904 assert output == 7.2904
assert len(agent.logs) == 4 assert len(agent.memory.steps) == 3
def test_code_agent_code_errors_show_offending_line_and_error(self): def test_code_agent_code_errors_show_offending_line_and_error(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error) agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText) assert isinstance(output, AgentText)
assert output == "got an error" assert output == "got an error"
assert "Code execution failed at line 'error_function()'" in str(agent.logs[2].error) assert "Code execution failed at line 'error_function()'" in str(agent.memory.steps[1].error)
assert "ValueError" in str(agent.logs) assert "ValueError" in str(agent.memory.steps)
def test_code_agent_syntax_error_show_offending_lines(self): def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error) agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText) assert isinstance(output, AgentText)
assert output == "got an error" assert output == "got an error"
assert ' print("Failing due to unexpected indent")' in str(agent.logs) assert ' print("Failing due to unexpected indent")' in str(agent.memory.steps)
def test_setup_agent_with_empty_toolbox(self): def test_setup_agent_with_empty_toolbox(self):
ToolCallingAgent(model=FakeToolCallModel(), tools=[]) ToolCallingAgent(model=FakeToolCallModel(), tools=[])
@ -401,8 +401,8 @@ class AgentTests(unittest.TestCase):
max_steps=5, max_steps=5,
) )
answer = agent.run("What is 2 multiplied by 3.6452?") answer = agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.logs) == 8 assert len(agent.memory.steps) == 7
assert type(agent.logs[-1].error) is AgentMaxStepsError assert type(agent.memory.steps[-1].error) is AgentMaxStepsError
assert isinstance(answer, str) assert isinstance(answer, str)
def test_tool_descriptions_get_baked_in_system_prompt(self): def test_tool_descriptions_get_baked_in_system_prompt(self):
@ -638,4 +638,9 @@ nested_answer()
) )
agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1) agent = ToolCallingAgent(model=model, tools=[get_weather], max_steps=1)
agent.run("What's the weather in Paris?") agent.run("What's the weather in Paris?")
assert agent.logs[2].tool_calls[0].name == "get_weather" assert agent.memory.steps[0].task == "What's the weather in Paris?"
assert agent.memory.steps[1].tool_calls[0].name == "get_weather"
step_memory_dict = agent.memory.get_succinct_steps()[1]
assert step_memory_dict["model_output_message"].tool_calls[0].function.name == "get_weather"
assert step_memory_dict["model_output_message"].raw["completion_kwargs"]["max_new_tokens"] == 100
assert "model_input_messages" in agent.memory.get_full_steps()[1]

View File

@ -22,12 +22,12 @@ from smolagents import (
ToolCallingAgent, ToolCallingAgent,
stream_to_gradio, stream_to_gradio,
) )
from smolagents.logger import AgentLogger, LogLevel
from smolagents.models import ( from smolagents.models import (
ChatMessage, ChatMessage,
ChatMessageToolCall, ChatMessageToolCall,
ChatMessageToolCallDefinition, ChatMessageToolCallDefinition,
) )
from smolagents.utils import AgentLogger, LogLevel
class FakeLLMModel: class FakeLLMModel: