minor fix for console in AgentLogger (#303)
* minor fix for console in AgentLogger
This commit is contained in:
parent
2c43546d3c
commit
ec45d6766a
|
@ -17,11 +17,10 @@
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from rich import box
|
from rich import box
|
||||||
from rich.console import Console, Group
|
from rich.console import Group
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.rule import Rule
|
from rich.rule import Rule
|
||||||
from rich.syntax import Syntax
|
from rich.syntax import Syntax
|
||||||
|
@ -62,9 +61,10 @@ from .utils import (
|
||||||
AgentError,
|
AgentError,
|
||||||
AgentExecutionError,
|
AgentExecutionError,
|
||||||
AgentGenerationError,
|
AgentGenerationError,
|
||||||
|
AgentLogger,
|
||||||
AgentMaxStepsError,
|
AgentMaxStepsError,
|
||||||
AgentParsingError,
|
AgentParsingError,
|
||||||
console,
|
LogLevel,
|
||||||
parse_code_blobs,
|
parse_code_blobs,
|
||||||
parse_json_tool_call,
|
parse_json_tool_call,
|
||||||
truncate_content,
|
truncate_content,
|
||||||
|
@ -158,22 +158,6 @@ def format_prompt_with_managed_agents_descriptions(
|
||||||
YELLOW_HEX = "#d4b702"
|
YELLOW_HEX = "#d4b702"
|
||||||
|
|
||||||
|
|
||||||
class LogLevel(IntEnum):
|
|
||||||
ERROR = 0 # Only errors
|
|
||||||
INFO = 1 # Normal output (default)
|
|
||||||
DEBUG = 2 # Detailed output
|
|
||||||
|
|
||||||
|
|
||||||
class AgentLogger:
|
|
||||||
def __init__(self, level: LogLevel = LogLevel.INFO):
|
|
||||||
self.level = level
|
|
||||||
self.console = Console()
|
|
||||||
|
|
||||||
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
|
|
||||||
if level <= self.level:
|
|
||||||
console.print(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiStepAgent:
|
class MultiStepAgent:
|
||||||
"""
|
"""
|
||||||
Agent class that solves the given task step by step, using the ReAct framework:
|
Agent class that solves the given task step by step, using the ReAct framework:
|
||||||
|
@ -353,7 +337,8 @@ class MultiStepAgent:
|
||||||
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
||||||
except Exception:
|
except Exception:
|
||||||
raise AgentParsingError(
|
raise AgentParsingError(
|
||||||
f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
f"No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!",
|
||||||
|
self.logger,
|
||||||
)
|
)
|
||||||
return rationale.strip(), action.strip()
|
return rationale.strip(), action.strip()
|
||||||
|
|
||||||
|
@ -391,7 +376,7 @@ class MultiStepAgent:
|
||||||
available_tools = {**self.tools, **self.managed_agents}
|
available_tools = {**self.tools, **self.managed_agents}
|
||||||
if tool_name not in available_tools:
|
if tool_name not in available_tools:
|
||||||
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg, self.logger)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(arguments, str):
|
if isinstance(arguments, str):
|
||||||
|
@ -409,7 +394,7 @@ class MultiStepAgent:
|
||||||
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
|
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
|
||||||
else:
|
else:
|
||||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg, self.logger, self.logger)
|
||||||
return observation
|
return observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if tool_name in self.tools:
|
if tool_name in self.tools:
|
||||||
|
@ -418,13 +403,13 @@ class MultiStepAgent:
|
||||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||||
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||||
)
|
)
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg, self.logger)
|
||||||
elif tool_name in self.managed_agents:
|
elif tool_name in self.managed_agents:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
||||||
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
||||||
)
|
)
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg, self.logger)
|
||||||
|
|
||||||
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
def step(self, log_entry: ActionStep) -> Union[None, Any]:
|
||||||
"""To be implemented in children classes. Should return either None if the step is not final."""
|
"""To be implemented in children classes. Should return either None if the step is not final."""
|
||||||
|
@ -547,7 +532,7 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
|
|
||||||
if final_answer is None and self.step_number == self.max_steps:
|
if final_answer is None and self.step_number == self.max_steps:
|
||||||
error_message = "Reached max steps."
|
error_message = "Reached max steps."
|
||||||
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
final_step_log = ActionStep(error=AgentMaxStepsError(error_message, self.logger))
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
final_answer = self.provide_final_answer(task)
|
final_answer = self.provide_final_answer(task)
|
||||||
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
|
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
|
||||||
|
@ -715,7 +700,7 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
tool_arguments = tool_call.function.arguments
|
tool_arguments = tool_call.function.arguments
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}")
|
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}", self.logger)
|
||||||
|
|
||||||
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
|
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
|
||||||
|
|
||||||
|
@ -796,7 +781,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
||||||
if "{{authorized_imports}}" not in system_prompt:
|
if "{{authorized_imports}}" not in system_prompt:
|
||||||
raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.")
|
raise ValueError("Tag '{{authorized_imports}}' should be provided in the prompt.")
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tools=tools,
|
tools=tools,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -861,7 +846,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
).content
|
).content
|
||||||
log_entry.llm_output = llm_output
|
log_entry.llm_output = llm_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating model output:\n{e}")
|
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger)
|
||||||
|
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
Group(
|
Group(
|
||||||
|
@ -885,7 +870,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
|
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
||||||
raise AgentParsingError(error_msg)
|
raise AgentParsingError(error_msg, self.logger)
|
||||||
|
|
||||||
log_entry.tool_calls = [
|
log_entry.tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
|
@ -931,7 +916,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
|
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
|
||||||
level=LogLevel.INFO,
|
level=LogLevel.INFO,
|
||||||
)
|
)
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg, self.logger)
|
||||||
|
|
||||||
truncated_output = truncate_content(str(output))
|
truncated_output = truncate_content(str(output))
|
||||||
observation += "Last output from code snippet:\n" + truncated_output
|
observation += "Last output from code snippet:\n" + truncated_output
|
||||||
|
|
|
@ -21,6 +21,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
|
from enum import IntEnum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
|
@ -58,13 +59,29 @@ BASE_BUILTIN_MODULES = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LogLevel(IntEnum):
|
||||||
|
ERROR = 0 # Only errors
|
||||||
|
INFO = 1 # Normal output (default)
|
||||||
|
DEBUG = 2 # Detailed output
|
||||||
|
|
||||||
|
|
||||||
|
class AgentLogger:
|
||||||
|
def __init__(self, level: LogLevel = LogLevel.INFO):
|
||||||
|
self.level = level
|
||||||
|
self.console = Console()
|
||||||
|
|
||||||
|
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
|
||||||
|
if level <= self.level:
|
||||||
|
self.console.print(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AgentError(Exception):
|
class AgentError(Exception):
|
||||||
"""Base class for other agent-related exceptions"""
|
"""Base class for other agent-related exceptions"""
|
||||||
|
|
||||||
def __init__(self, message):
|
def __init__(self, message, logger: AgentLogger):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
console.print(f"[bold red]{message}[/bold red]")
|
logger.log(f"[bold red]{message}[/bold red]", level=LogLevel.ERROR)
|
||||||
|
|
||||||
|
|
||||||
class AgentParsingError(AgentError):
|
class AgentParsingError(AgentError):
|
||||||
|
|
|
@ -421,9 +421,7 @@ class AgentTests(unittest.TestCase):
|
||||||
def test_code_agent_missing_import_triggers_advice_in_error_log(self):
|
def test_code_agent_missing_import_triggers_advice_in_error_log(self):
|
||||||
agent = CodeAgent(tools=[], model=fake_code_model_import)
|
agent = CodeAgent(tools=[], model=fake_code_model_import)
|
||||||
|
|
||||||
from smolagents.agents import console
|
with agent.logger.console.capture() as capture:
|
||||||
|
|
||||||
with console.capture() as capture:
|
|
||||||
agent.run("Count to 3")
|
agent.run("Count to 3")
|
||||||
str_output = capture.get()
|
str_output = capture.get()
|
||||||
assert "Consider passing said import under" in str_output.replace("\n", "")
|
assert "Consider passing said import under" in str_output.replace("\n", "")
|
||||||
|
|
|
@ -27,6 +27,7 @@ from smolagents.models import (
|
||||||
ChatMessageToolCall,
|
ChatMessageToolCall,
|
||||||
ChatMessageToolCallDefinition,
|
ChatMessageToolCallDefinition,
|
||||||
)
|
)
|
||||||
|
from smolagents.utils import AgentLogger, LogLevel
|
||||||
|
|
||||||
|
|
||||||
class FakeLLMModel:
|
class FakeLLMModel:
|
||||||
|
@ -162,8 +163,10 @@ class MonitoringTester(unittest.TestCase):
|
||||||
self.assertEqual(final_message.content["mime_type"], "image/png")
|
self.assertEqual(final_message.content["mime_type"], "image/png")
|
||||||
|
|
||||||
def test_streaming_with_agent_error(self):
|
def test_streaming_with_agent_error(self):
|
||||||
|
logger = AgentLogger(level=LogLevel.INFO)
|
||||||
|
|
||||||
def dummy_model(prompt, **kwargs):
|
def dummy_model(prompt, **kwargs):
|
||||||
raise AgentError("Simulated agent error")
|
raise AgentError("Simulated agent error", logger)
|
||||||
|
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
|
|
Loading…
Reference in New Issue