minor fix for console in AgentLogger (#303)

* minor fix for console in AgentLogger
This commit is contained in:
Nicholas Broad 2025-01-22 01:41:05 -08:00 committed by GitHub
parent 2c43546d3c
commit ec45d6766a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 36 deletions

View File

@ -17,11 +17,10 @@
import time
from collections import deque
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from rich import box
from rich.console import Console, Group
from rich.console import Group
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
@ -62,9 +61,10 @@ from .utils import (
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentLogger,
AgentMaxStepsError,
AgentParsingError,
console,
LogLevel,
parse_code_blobs,
parse_json_tool_call,
truncate_content,
@ -158,22 +158,6 @@ def format_prompt_with_managed_agents_descriptions(
YELLOW_HEX = "#d4b702"
class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
if level <= self.level:
console.print(*args, **kwargs)
class MultiStepAgent:
"""
Agent class that solves the given task step by step, using the ReAct framework:
@ -353,7 +337,8 @@ class MultiStepAgent:
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
except Exception:
raise AgentParsingError(
f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!",
self.logger,
)
return rationale.strip(), action.strip()
@ -391,7 +376,7 @@ class MultiStepAgent:
available_tools = {**self.tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
raise AgentExecutionError(error_msg)
raise AgentExecutionError(error_msg, self.logger)
try:
if isinstance(arguments, str):
@ -409,7 +394,7 @@ class MultiStepAgent:
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
raise AgentExecutionError(error_msg)
raise AgentExecutionError(error_msg, self.logger, self.logger)
return observation
except Exception as e:
if tool_name in self.tools:
@ -418,13 +403,13 @@ class MultiStepAgent:
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}"
)
raise AgentExecutionError(error_msg)
raise AgentExecutionError(error_msg, self.logger)
elif tool_name in self.managed_agents:
error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
)
raise AgentExecutionError(error_msg)
raise AgentExecutionError(error_msg, self.logger)
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""To be implemented in children classes. Should return either None if the step is not final."""
@ -547,7 +532,7 @@ You have been provided with these additional arguments, that you can access usin
if final_answer is None and self.step_number == self.max_steps:
error_message = "Reached max steps."
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
final_step_log = ActionStep(error=AgentMaxStepsError(error_message, self.logger))
self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task)
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
@ -715,7 +700,7 @@ class ToolCallingAgent(MultiStepAgent):
tool_arguments = tool_call.function.arguments
except Exception as e:
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}")
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger)
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
@ -796,7 +781,7 @@ class CodeAgent(MultiStepAgent):
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
if "{{authorized_imports}}" not in system_prompt:
raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.")
raise ValueError("Tag '{{authorized_imports}}' should be provided in the prompt.")
super().__init__(
tools=tools,
model=model,
@ -861,7 +846,7 @@ class CodeAgent(MultiStepAgent):
).content
log_entry.llm_output = llm_output
except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}")
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger)
self.logger.log(
Group(
@ -885,7 +870,7 @@ class CodeAgent(MultiStepAgent):
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
except Exception as e:
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
raise AgentParsingError(error_msg)
raise AgentParsingError(error_msg, self.logger)
log_entry.tool_calls = [
ToolCall(
@ -931,7 +916,7 @@ class CodeAgent(MultiStepAgent):
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
level=LogLevel.INFO,
)
raise AgentExecutionError(error_msg)
raise AgentExecutionError(error_msg, self.logger)
truncated_output = truncate_content(str(output))
observation += "Last output from code snippet:\n" + truncated_output

View File

@ -21,6 +21,7 @@ import inspect
import json
import re
import types
from enum import IntEnum
from functools import lru_cache
from typing import Dict, Tuple, Union
@ -58,13 +59,29 @@ BASE_BUILTIN_MODULES = [
]
class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
if level <= self.level:
self.console.print(*args, **kwargs)
class AgentError(Exception):
"""Base class for other agent-related exceptions"""
def __init__(self, message):
def __init__(self, message, logger: AgentLogger):
super().__init__(message)
self.message = message
console.print(f"[bold red]{message}[/bold red]")
logger.log(f"[bold red]{message}[/bold red]", level=LogLevel.ERROR)
class AgentParsingError(AgentError):

View File

@ -421,9 +421,7 @@ class AgentTests(unittest.TestCase):
def test_code_agent_missing_import_triggers_advice_in_error_log(self):
agent = CodeAgent(tools=[], model=fake_code_model_import)
from smolagents.agents import console
with console.capture() as capture:
with agent.logger.console.capture() as capture:
agent.run("Count to 3")
str_output = capture.get()
assert "Consider passing said import under" in str_output.replace("\n", "")

View File

@ -27,6 +27,7 @@ from smolagents.models import (
ChatMessageToolCall,
ChatMessageToolCallDefinition,
)
from smolagents.utils import AgentLogger, LogLevel
class FakeLLMModel:
@ -162,8 +163,10 @@ class MonitoringTester(unittest.TestCase):
self.assertEqual(final_message.content["mime_type"], "image/png")
def test_streaming_with_agent_error(self):
logger = AgentLogger(level=LogLevel.INFO)
def dummy_model(prompt, **kwargs):
raise AgentError("Simulated agent error")
raise AgentError("Simulated agent error", logger)
agent = CodeAgent(
tools=[],