minor fix for console in AgentLogger (#303)
* minor fix for console in AgentLogger
This commit is contained in:
parent
2c43546d3c
commit
ec45d6766a
|
@ -17,11 +17,10 @@
|
|||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from rich import box
|
||||
from rich.console import Console, Group
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.rule import Rule
|
||||
from rich.syntax import Syntax
|
||||
|
@ -62,9 +61,10 @@ from .utils import (
|
|||
AgentError,
|
||||
AgentExecutionError,
|
||||
AgentGenerationError,
|
||||
AgentLogger,
|
||||
AgentMaxStepsError,
|
||||
AgentParsingError,
|
||||
console,
|
||||
LogLevel,
|
||||
parse_code_blobs,
|
||||
parse_json_tool_call,
|
||||
truncate_content,
|
||||
|
@ -158,22 +158,6 @@ def format_prompt_with_managed_agents_descriptions(
|
|||
YELLOW_HEX = "#d4b702"
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
ERROR = 0 # Only errors
|
||||
INFO = 1 # Normal output (default)
|
||||
DEBUG = 2 # Detailed output
|
||||
|
||||
|
||||
class AgentLogger:
|
||||
def __init__(self, level: LogLevel = LogLevel.INFO):
|
||||
self.level = level
|
||||
self.console = Console()
|
||||
|
||||
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
|
||||
if level <= self.level:
|
||||
console.print(*args, **kwargs)
|
||||
|
||||
|
||||
class MultiStepAgent:
|
||||
"""
|
||||
Agent class that solves the given task step by step, using the ReAct framework:
|
||||
|
@ -353,7 +337,8 @@ class MultiStepAgent:
|
|||
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
||||
except Exception:
|
||||
raise AgentParsingError(
|
||||
f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
||||
f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!",
|
||||
self.logger,
|
||||
)
|
||||
return rationale.strip(), action.strip()
|
||||
|
||||
|
@ -391,7 +376,7 @@ class MultiStepAgent:
|
|||
available_tools = {**self.tools, **self.managed_agents}
|
||||
if tool_name not in available_tools:
|
||||
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||
raise AgentExecutionError(error_msg)
|
||||
raise AgentExecutionError(error_msg, self.logger)
|
||||
|
||||
try:
|
||||
if isinstance(arguments, str):
|
||||
|
@ -409,7 +394,7 @@ class MultiStepAgent:
|
|||
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
|
||||
else:
|
||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||
raise AgentExecutionError(error_msg)
|
||||
raise AgentExecutionError(error_msg, self.logger, self.logger)
|
||||
return observation
|
||||
except Exception as e:
|
||||
if tool_name in self.tools:
|
||||
|
@ -418,13 +403,13 @@ class MultiStepAgent:
|
|||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||
)
|
||||
raise AgentExecutionError(error_msg)
|
||||
raise AgentExecutionError(error_msg, self.logger)
|
||||
elif tool_name in self.managed_agents:
|
||||
error_msg = (
|
||||
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
||||
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
||||
)
|
||||
raise AgentExecutionError(error_msg)
|
||||
raise AgentExecutionError(error_msg, self.logger)
|
||||
|
||||
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||
"""To be implemented in children classes. Should return either None if the step is not final."""
|
||||
|
@ -547,7 +532,7 @@ You have been provided with these additional arguments, that you can access usin
|
|||
|
||||
if final_answer is None and self.step_number == self.max_steps:
|
||||
error_message = "Reached max steps."
|
||||
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
||||
final_step_log = ActionStep(error=AgentMaxStepsError(error_message, self.logger))
|
||||
self.logs.append(final_step_log)
|
||||
final_answer = self.provide_final_answer(task)
|
||||
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
|
||||
|
@ -715,7 +700,7 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
tool_arguments = tool_call.function.arguments
|
||||
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}")
|
||||
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger)
|
||||
|
||||
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
|
||||
|
||||
|
@ -796,7 +781,7 @@ class CodeAgent(MultiStepAgent):
|
|||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
||||
if "{{authorized_imports}}" not in system_prompt:
|
||||
raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.")
|
||||
raise ValueError("Tag '{{authorized_imports}}' should be provided in the prompt.")
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
model=model,
|
||||
|
@ -861,7 +846,7 @@ class CodeAgent(MultiStepAgent):
|
|||
).content
|
||||
log_entry.llm_output = llm_output
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating model output:\n{e}")
|
||||
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger)
|
||||
|
||||
self.logger.log(
|
||||
Group(
|
||||
|
@ -885,7 +870,7 @@ class CodeAgent(MultiStepAgent):
|
|||
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
|
||||
except Exception as e:
|
||||
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
||||
raise AgentParsingError(error_msg)
|
||||
raise AgentParsingError(error_msg, self.logger)
|
||||
|
||||
log_entry.tool_calls = [
|
||||
ToolCall(
|
||||
|
@ -931,7 +916,7 @@ class CodeAgent(MultiStepAgent):
|
|||
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
|
||||
level=LogLevel.INFO,
|
||||
)
|
||||
raise AgentExecutionError(error_msg)
|
||||
raise AgentExecutionError(error_msg, self.logger)
|
||||
|
||||
truncated_output = truncate_content(str(output))
|
||||
observation += "Last output from code snippet:\n" + truncated_output
|
||||
|
|
|
@ -21,6 +21,7 @@ import inspect
|
|||
import json
|
||||
import re
|
||||
import types
|
||||
from enum import IntEnum
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
|
@ -58,13 +59,29 @@ BASE_BUILTIN_MODULES = [
|
|||
]
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
ERROR = 0 # Only errors
|
||||
INFO = 1 # Normal output (default)
|
||||
DEBUG = 2 # Detailed output
|
||||
|
||||
|
||||
class AgentLogger:
|
||||
def __init__(self, level: LogLevel = LogLevel.INFO):
|
||||
self.level = level
|
||||
self.console = Console()
|
||||
|
||||
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
|
||||
if level <= self.level:
|
||||
self.console.print(*args, **kwargs)
|
||||
|
||||
|
||||
class AgentError(Exception):
|
||||
"""Base class for other agent-related exceptions"""
|
||||
|
||||
def __init__(self, message):
|
||||
def __init__(self, message, logger: AgentLogger):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
console.print(f"[bold red]{message}[/bold red]")
|
||||
logger.log(f"[bold red]{message}[/bold red]", level=LogLevel.ERROR)
|
||||
|
||||
|
||||
class AgentParsingError(AgentError):
|
||||
|
|
|
@ -421,9 +421,7 @@ class AgentTests(unittest.TestCase):
|
|||
def test_code_agent_missing_import_triggers_advice_in_error_log(self):
|
||||
agent = CodeAgent(tools=[], model=fake_code_model_import)
|
||||
|
||||
from smolagents.agents import console
|
||||
|
||||
with console.capture() as capture:
|
||||
with agent.logger.console.capture() as capture:
|
||||
agent.run("Count to 3")
|
||||
str_output = capture.get()
|
||||
assert "Consider passing said import under" in str_output.replace("\n", "")
|
||||
|
|
|
@ -27,6 +27,7 @@ from smolagents.models import (
|
|||
ChatMessageToolCall,
|
||||
ChatMessageToolCallDefinition,
|
||||
)
|
||||
from smolagents.utils import AgentLogger, LogLevel
|
||||
|
||||
|
||||
class FakeLLMModel:
|
||||
|
@ -162,8 +163,10 @@ class MonitoringTester(unittest.TestCase):
|
|||
self.assertEqual(final_message.content["mime_type"], "image/png")
|
||||
|
||||
def test_streaming_with_agent_error(self):
|
||||
logger = AgentLogger(level=LogLevel.INFO)
|
||||
|
||||
def dummy_model(prompt, **kwargs):
|
||||
raise AgentError("Simulated agent error")
|
||||
raise AgentError("Simulated agent error", logger)
|
||||
|
||||
agent = CodeAgent(
|
||||
tools=[],
|
||||
|
|
Loading…
Reference in New Issue