minor fix for console in AgentLogger (#303)

* minor fix for console in AgentLogger
This commit is contained in:
Nicholas Broad 2025-01-22 01:41:05 -08:00 committed by GitHub
parent 2c43546d3c
commit ec45d6766a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 36 deletions

View File

@ -17,11 +17,10 @@
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from rich import box from rich import box
from rich.console import Console, Group from rich.console import Group
from rich.panel import Panel from rich.panel import Panel
from rich.rule import Rule from rich.rule import Rule
from rich.syntax import Syntax from rich.syntax import Syntax
@ -62,9 +61,10 @@ from .utils import (
AgentError, AgentError,
AgentExecutionError, AgentExecutionError,
AgentGenerationError, AgentGenerationError,
AgentLogger,
AgentMaxStepsError, AgentMaxStepsError,
AgentParsingError, AgentParsingError,
console, LogLevel,
parse_code_blobs, parse_code_blobs,
parse_json_tool_call, parse_json_tool_call,
truncate_content, truncate_content,
@ -158,22 +158,6 @@ def format_prompt_with_managed_agents_descriptions(
YELLOW_HEX = "#d4b702" YELLOW_HEX = "#d4b702"
class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
if level <= self.level:
console.print(*args, **kwargs)
class MultiStepAgent: class MultiStepAgent:
""" """
Agent class that solves the given task step by step, using the ReAct framework: Agent class that solves the given task step by step, using the ReAct framework:
@ -353,7 +337,8 @@ class MultiStepAgent:
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
except Exception: except Exception:
raise AgentParsingError( raise AgentParsingError(
f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!",
self.logger,
) )
return rationale.strip(), action.strip() return rationale.strip(), action.strip()
@ -391,7 +376,7 @@ class MultiStepAgent:
available_tools = {**self.tools, **self.managed_agents} available_tools = {**self.tools, **self.managed_agents}
if tool_name not in available_tools: if tool_name not in available_tools:
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg, self.logger)
try: try:
if isinstance(arguments, str): if isinstance(arguments, str):
@ -409,7 +394,7 @@ class MultiStepAgent:
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
else: else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg, self.logger, self.logger)
return observation return observation
except Exception as e: except Exception as e:
if tool_name in self.tools: if tool_name in self.tools:
@ -418,13 +403,13 @@ class MultiStepAgent:
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}" f"As a reminder, this tool's description is the following:\n{tool_description}"
) )
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg, self.logger)
elif tool_name in self.managed_agents: elif tool_name in self.managed_agents:
error_msg = ( error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
) )
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg, self.logger)
def step(self, log_entry: ActionStep) -> Union[None, Any]: def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""To be implemented in children classes. Should return either None if the step is not final.""" """To be implemented in children classes. Should return either None if the step is not final."""
@ -547,7 +532,7 @@ You have been provided with these additional arguments, that you can access usin
if final_answer is None and self.step_number == self.max_steps: if final_answer is None and self.step_number == self.max_steps:
error_message = "Reached max steps." error_message = "Reached max steps."
final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) final_step_log = ActionStep(error=AgentMaxStepsError(error_message, self.logger))
self.logs.append(final_step_log) self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task) final_answer = self.provide_final_answer(task)
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO) self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
@ -715,7 +700,7 @@ class ToolCallingAgent(MultiStepAgent):
tool_arguments = tool_call.function.arguments tool_arguments = tool_call.function.arguments
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}") raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger)
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
@ -796,7 +781,7 @@ class CodeAgent(MultiStepAgent):
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
if "{{authorized_imports}}" not in system_prompt: if "{{authorized_imports}}" not in system_prompt:
raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.") raise ValueError("Tag '{{authorized_imports}}' should be provided in the prompt.")
super().__init__( super().__init__(
tools=tools, tools=tools,
model=model, model=model,
@ -861,7 +846,7 @@ class CodeAgent(MultiStepAgent):
).content ).content
log_entry.llm_output = llm_output log_entry.llm_output = llm_output
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}") raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger)
self.logger.log( self.logger.log(
Group( Group(
@ -885,7 +870,7 @@ class CodeAgent(MultiStepAgent):
code_action = fix_final_answer_code(parse_code_blobs(llm_output)) code_action = fix_final_answer_code(parse_code_blobs(llm_output))
except Exception as e: except Exception as e:
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
raise AgentParsingError(error_msg) raise AgentParsingError(error_msg, self.logger)
log_entry.tool_calls = [ log_entry.tool_calls = [
ToolCall( ToolCall(
@ -931,7 +916,7 @@ class CodeAgent(MultiStepAgent):
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.", "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
level=LogLevel.INFO, level=LogLevel.INFO,
) )
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg, self.logger)
truncated_output = truncate_content(str(output)) truncated_output = truncate_content(str(output))
observation += "Last output from code snippet:\n" + truncated_output observation += "Last output from code snippet:\n" + truncated_output

View File

@ -21,6 +21,7 @@ import inspect
import json import json
import re import re
import types import types
from enum import IntEnum
from functools import lru_cache from functools import lru_cache
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
@ -58,13 +59,29 @@ BASE_BUILTIN_MODULES = [
] ]
class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
if level <= self.level:
self.console.print(*args, **kwargs)
class AgentError(Exception): class AgentError(Exception):
"""Base class for other agent-related exceptions""" """Base class for other agent-related exceptions"""
def __init__(self, message): def __init__(self, message, logger: AgentLogger):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
console.print(f"[bold red]{message}[/bold red]") logger.log(f"[bold red]{message}[/bold red]", level=LogLevel.ERROR)
class AgentParsingError(AgentError): class AgentParsingError(AgentError):

View File

@ -421,9 +421,7 @@ class AgentTests(unittest.TestCase):
def test_code_agent_missing_import_triggers_advice_in_error_log(self): def test_code_agent_missing_import_triggers_advice_in_error_log(self):
agent = CodeAgent(tools=[], model=fake_code_model_import) agent = CodeAgent(tools=[], model=fake_code_model_import)
from smolagents.agents import console with agent.logger.console.capture() as capture:
with console.capture() as capture:
agent.run("Count to 3") agent.run("Count to 3")
str_output = capture.get() str_output = capture.get()
assert "Consider passing said import under" in str_output.replace("\n", "") assert "Consider passing said import under" in str_output.replace("\n", "")

View File

@ -27,6 +27,7 @@ from smolagents.models import (
ChatMessageToolCall, ChatMessageToolCall,
ChatMessageToolCallDefinition, ChatMessageToolCallDefinition,
) )
from smolagents.utils import AgentLogger, LogLevel
class FakeLLMModel: class FakeLLMModel:
@ -162,8 +163,10 @@ class MonitoringTester(unittest.TestCase):
self.assertEqual(final_message.content["mime_type"], "image/png") self.assertEqual(final_message.content["mime_type"], "image/png")
def test_streaming_with_agent_error(self): def test_streaming_with_agent_error(self):
logger = AgentLogger(level=LogLevel.INFO)
def dummy_model(prompt, **kwargs): def dummy_model(prompt, **kwargs):
raise AgentError("Simulated agent error") raise AgentError("Simulated agent error", logger)
agent = CodeAgent( agent = CodeAgent(
tools=[], tools=[],