From ec45d6766ace781ea306979e19e81f2b28f9cc1f Mon Sep 17 00:00:00 2001 From: Nicholas Broad Date: Wed, 22 Jan 2025 01:41:05 -0800 Subject: [PATCH] minor fix for console in AgentLogger (#303) * minor fix for console in AgentLogger --- src/smolagents/agents.py | 45 ++++++++++++++-------------------------- src/smolagents/utils.py | 21 +++++++++++++++++-- tests/test_agents.py | 4 +--- tests/test_monitoring.py | 5 ++++- 4 files changed, 39 insertions(+), 36 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 8c13784..d3d3e47 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -17,11 +17,10 @@ import time from collections import deque from dataclasses import dataclass -from enum import IntEnum from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from rich import box -from rich.console import Console, Group +from rich.console import Group from rich.panel import Panel from rich.rule import Rule from rich.syntax import Syntax @@ -62,9 +61,10 @@ from .utils import ( AgentError, AgentExecutionError, AgentGenerationError, + AgentLogger, AgentMaxStepsError, AgentParsingError, - console, + LogLevel, parse_code_blobs, parse_json_tool_call, truncate_content, @@ -158,22 +158,6 @@ def format_prompt_with_managed_agents_descriptions( YELLOW_HEX = "#d4b702" -class LogLevel(IntEnum): - ERROR = 0 # Only errors - INFO = 1 # Normal output (default) - DEBUG = 2 # Detailed output - - -class AgentLogger: - def __init__(self, level: LogLevel = LogLevel.INFO): - self.level = level - self.console = Console() - - def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs): - if level <= self.level: - console.print(*args, **kwargs) - - class MultiStepAgent: """ Agent class that solves the given task step by step, using the ReAct framework: @@ -353,7 +337,8 @@ class MultiStepAgent: ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output except Exception: raise AgentParsingError( - f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" + f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!", + self.logger, ) return rationale.strip(), action.strip() @@ -391,7 +376,7 @@ class MultiStepAgent: available_tools = {**self.tools, **self.managed_agents} if tool_name not in available_tools: error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." - raise AgentExecutionError(error_msg) + raise AgentExecutionError(error_msg, self.logger) try: if isinstance(arguments, str): @@ -409,7 +394,7 @@ class MultiStepAgent: observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True) else: error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." - raise AgentExecutionError(error_msg) + raise AgentExecutionError(error_msg, self.logger, self.logger) return observation except Exception as e: if tool_name in self.tools: @@ -418,13 +403,13 @@ class MultiStepAgent: f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" f"As a reminder, this tool's description is the following:\n{tool_description}" ) - raise AgentExecutionError(error_msg) + raise AgentExecutionError(error_msg, self.logger) elif tool_name in self.managed_agents: error_msg = ( f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" ) - raise AgentExecutionError(error_msg) + raise AgentExecutionError(error_msg, self.logger) def step(self, log_entry: ActionStep) -> Union[None, Any]: """To be implemented in children classes. Should return either None if the step is not final.""" @@ -547,7 +532,7 @@ You have been provided with these additional arguments, that you can access usin if final_answer is None and self.step_number == self.max_steps: error_message = "Reached max steps." - final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) + final_step_log = ActionStep(error=AgentMaxStepsError(error_message, self.logger)) self.logs.append(final_step_log) final_answer = self.provide_final_answer(task) self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO) @@ -715,7 +700,7 @@ class ToolCallingAgent(MultiStepAgent): tool_arguments = tool_call.function.arguments except Exception as e: - raise AgentGenerationError(f"Error in generating tool call with model:\n{e}") + raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger) log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] @@ -796,7 +781,7 @@ class CodeAgent(MultiStepAgent): self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) if "{{authorized_imports}}" not in system_prompt: - raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.") + raise ValueError("Tag '{{authorized_imports}}' should be provided in the prompt.") super().__init__( tools=tools, model=model, @@ -861,7 +846,7 @@ class CodeAgent(MultiStepAgent): ).content log_entry.llm_output = llm_output except Exception as e: - raise AgentGenerationError(f"Error in generating model output:\n{e}") + raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) self.logger.log( Group( @@ -885,7 +870,7 @@ class CodeAgent(MultiStepAgent): code_action = fix_final_answer_code(parse_code_blobs(llm_output)) except Exception as e: error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs." - raise AgentParsingError(error_msg) + raise AgentParsingError(error_msg, self.logger) log_entry.tool_calls = [ ToolCall( @@ -931,7 +916,7 @@ class CodeAgent(MultiStepAgent): "[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.", level=LogLevel.INFO, ) - raise AgentExecutionError(error_msg) + raise AgentExecutionError(error_msg, self.logger) truncated_output = truncate_content(str(output)) observation += "Last output from code snippet:\n" + truncated_output diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index 30bcf61..ad80b9d 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -21,6 +21,7 @@ import inspect import json import re import types +from enum import IntEnum from functools import lru_cache from typing import Dict, Tuple, Union @@ -58,13 +59,29 @@ BASE_BUILTIN_MODULES = [ ] +class LogLevel(IntEnum): + ERROR = 0 # Only errors + INFO = 1 # Normal output (default) + DEBUG = 2 # Detailed output + + +class AgentLogger: + def __init__(self, level: LogLevel = LogLevel.INFO): + self.level = level + self.console = Console() + + def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs): + if level <= self.level: + self.console.print(*args, **kwargs) + + class AgentError(Exception): """Base class for other agent-related exceptions""" - def __init__(self, message): + def __init__(self, message, logger: AgentLogger): super().__init__(message) self.message = message - console.print(f"[bold red]{message}[/bold red]") + logger.log(f"[bold red]{message}[/bold red]", level=LogLevel.ERROR) class AgentParsingError(AgentError): diff --git a/tests/test_agents.py b/tests/test_agents.py index 7bc2704..e76d35c 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -421,9 +421,7 @@ class AgentTests(unittest.TestCase): def test_code_agent_missing_import_triggers_advice_in_error_log(self): agent = CodeAgent(tools=[], model=fake_code_model_import) - from smolagents.agents import console - - with console.capture() as capture: + with agent.logger.console.capture() as capture: agent.run("Count to 3") str_output = capture.get() assert "Consider passing said import under" in str_output.replace("\n", "") diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index cfb566f..9fa30bb 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -27,6 +27,7 @@ from smolagents.models import ( ChatMessageToolCall, ChatMessageToolCallDefinition, ) +from smolagents.utils import AgentLogger, LogLevel class FakeLLMModel: @@ -162,8 +163,10 @@ class MonitoringTester(unittest.TestCase): self.assertEqual(final_message.content["mime_type"], "image/png") def test_streaming_with_agent_error(self): + logger = AgentLogger(level=LogLevel.INFO) + def dummy_model(prompt, **kwargs): - raise AgentError("Simulated agent error") + raise AgentError("Simulated agent error", logger) agent = CodeAgent( tools=[],