diff --git a/src/smolagents/__init__.py b/src/smolagents/__init__.py index 1507c70..7cbf191 100644 --- a/src/smolagents/__init__.py +++ b/src/smolagents/__init__.py @@ -21,7 +21,6 @@ from .default_tools import * from .e2b_executor import * from .gradio_ui import * from .local_python_executor import * -from .logger import * from .memory import * from .models import * from .monitoring import * diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 2377590..6114601 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -20,18 +20,17 @@ from collections import deque from logging import getLogger from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union -from rich import box from rich.console import Group from rich.panel import Panel from rich.rule import Rule -from rich.syntax import Syntax from rich.text import Text -from smolagents.logger import ( +from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall +from smolagents.monitoring import ( + YELLOW_HEX, AgentLogger, LogLevel, ) -from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall from smolagents.types import AgentAudio, AgentImage, handle_agent_output_types from smolagents.utils import ( AgentError, @@ -123,9 +122,6 @@ def format_prompt_with_managed_agents_descriptions( return prompt_template.replace(agent_descriptions_placeholder, "") -YELLOW_HEX = "#d4b702" - - class MultiStepAgent: """ Agent class that solves the given task step by step, using the ReAct framework: @@ -399,14 +395,9 @@ You have been provided with these additional arguments, that you can access usin self.memory.reset() self.monitor.reset() - self.logger.log( - Panel( - f"\n[bold]{self.task.strip()}\n", - title="[bold]New run", - subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}", - border_style=YELLOW_HEX, - subtitle_align="left", - ), + self.logger.log_task( + content=self.task.strip(), + subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}", level=LogLevel.INFO, ) @@ -451,14 +442,7 @@ You have been provided with these additional arguments, that you can access usin is_first_step=(self.step_number == 0), step=self.step_number, ) - self.logger.log( - Rule( - f"[bold]Step {self.step_number}", - characters="━", - style=YELLOW_HEX, - ), - level=LogLevel.INFO, - ) + self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO) # Run one step! final_answer = self.step(memory_step) @@ -520,8 +504,9 @@ You have been provided with these additional arguments, that you can access usin ``` Now begin!""", } + input_messages = [message_prompt_facts, message_prompt_task] - chat_message_facts: ChatMessage = self.model([message_prompt_facts, message_prompt_task]) + chat_message_facts: ChatMessage = self.model(input_messages) answer_facts = chat_message_facts.content message_system_prompt_plan = { @@ -553,6 +538,7 @@ Now begin!""", ```""".strip() self.memory.steps.append( PlanningStep( + model_input_messages=input_messages, plan=final_plan_redaction, facts=final_facts_redaction, model_output_message_plan=chat_message_plan, @@ -578,9 +564,8 @@ Now begin!""", "role": MessageRole.USER, "content": [{"type": "text", "text": USER_PROMPT_FACTS_UPDATE}], } - chat_message_facts: ChatMessage = self.model( - [facts_update_system_prompt] + memory_messages + [facts_update_message] - ) + input_messages = [facts_update_system_prompt] + memory_messages + [facts_update_message] + chat_message_facts: ChatMessage = self.model(input_messages) facts_update = chat_message_facts.content # Redact updated plan @@ -618,6 +603,7 @@ Now begin!""", ```""" self.memory.steps.append( PlanningStep( + model_input_messages=input_messages, plan=final_plan_redaction, facts=final_facts_redaction, model_output_message_plan=chat_message_plan, @@ -630,6 +616,15 @@ Now begin!""", level=LogLevel.INFO, ) + def replay(self, detailed: bool = False): + """Prints a pretty replay of the agent's steps. + + Args: + detailed (bool, optional): If True, also displays the memory at each step. Defaults to False. + Careful: will increase log length exponentially. Use only for debugging. + """ + self.memory.replay(self.logger, detailed=detailed) + class ToolCallingAgent(MultiStepAgent): """ @@ -851,20 +846,9 @@ class CodeAgent(MultiStepAgent): except Exception as e: raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e - self.logger.log( - Group( - Rule( - "[italic]Output message of the LLM:", - align="left", - style="orange", - ), - Syntax( - model_output, - lexer="markdown", - theme="github-dark", - word_wrap=True, - ), - ), + self.logger.log_markdown( + content=model_output, + title="Output message of the LLM:", level=LogLevel.DEBUG, ) @@ -884,20 +868,7 @@ class CodeAgent(MultiStepAgent): ] # Execute - self.logger.log( - Panel( - Syntax( - code_action, - lexer="python", - theme="monokai", - word_wrap=True, - ), - title="[bold]Executing this code:", - title_align="left", - box=box.HORIZONTALS, - ), - level=LogLevel.INFO, - ) + self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO) observation = "" is_final_answer = False try: diff --git a/src/smolagents/logger.py b/src/smolagents/logger.py deleted file mode 100644 index c74f6b1..0000000 --- a/src/smolagents/logger.py +++ /dev/null @@ -1,85 +0,0 @@ -import json -from enum import IntEnum -from typing import TYPE_CHECKING - -from rich.console import Console -from rich.rule import Rule -from rich.syntax import Syntax - - -if TYPE_CHECKING: - from smolagents.memory import AgentMemory - - -class LogLevel(IntEnum): - ERROR = 0 # Only errors - INFO = 1 # Normal output (default) - DEBUG = 2 # Detailed output - - -class AgentLogger: - def __init__(self, level: LogLevel = LogLevel.INFO): - self.level = level - self.console = Console() - - def log(self, *args, level: str | LogLevel = LogLevel.INFO, **kwargs): - """Logs a message to the console. - - Args: - level (LogLevel, optional): Defaults to LogLevel.INFO. - """ - if isinstance(level, str): - level = LogLevel[level.upper()] - if level <= self.level: - self.console.print(*args, **kwargs) - - def replay(self, agent_memory: "AgentMemory", full: bool = False): - """Prints a pretty replay of the agent's steps. - - Args: - with_memory (bool, optional): If True, also displays the memory at each step. Defaults to False. - Careful: will increase log length exponentially. Use only for debugging. - """ - memory = [] - for step_log in agent_memory.steps: - memory.extend(step_log.to_messages(show_model_input_messages=full)) - - self.console.log("Replaying the agent's steps:") - ix = 0 - for step in memory: - role = step["role"].strip() - if ix > 0 and role == "system": - role == "memory" - theme = "default" - match role: - case "assistant": - theme = "monokai" - ix += 1 - case "system": - theme = "monokai" - case "tool-response": - theme = "github_dark" - - content = step["content"] - try: - content = eval(content) - except Exception: - content = [step["content"]] - - for substep_ix, item in enumerate(content): - self.console.log( - Rule( - f"{role.upper()}, STEP {ix}, SUBSTEP {substep_ix + 1}/{len(content)}", - align="center", - style="orange", - ), - Syntax( - json.dumps(item, indent=4) if isinstance(item, dict) else str(item), - lexer="json", - theme=theme, - word_wrap=True, - ), - ) - - -__all__ = ["AgentLogger"] diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index 4043a5c..d80bceb 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -3,10 +3,12 @@ from logging import getLogger from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union from smolagents.models import ChatMessage, MessageRole +from smolagents.monitoring import AgentLogger from smolagents.utils import AgentError, make_json_serializable if TYPE_CHECKING: + from smolagents.logger import AgentLogger from smolagents.models import ChatMessage @@ -47,7 +49,7 @@ class MemoryStep: @dataclass class ActionStep(MemoryStep): - model_input_messages: List[Dict[str, str]] | None = None + model_input_messages: List[Message] | None = None tool_calls: List[ToolCall] | None = None start_time: float | None = None end_time: float | None = None @@ -76,7 +78,7 @@ class ActionStep(MemoryStep): "action_output": make_json_serializable(self.action_output), } - def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Dict[str, Any]]: + def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Message]: messages = [] if self.model_input_messages is not None and show_model_input_messages: messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages)) @@ -142,17 +144,26 @@ class ActionStep(MemoryStep): @dataclass class PlanningStep(MemoryStep): + model_input_messages: List[Message] model_output_message_facts: ChatMessage facts: str - model_output_message_facts: ChatMessage + model_output_message_plan: ChatMessage plan: str - def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]: + def to_messages(self, summary_mode: bool, **kwargs) -> List[Message]: messages = [] - messages.append(Message(role=MessageRole.ASSISTANT, content=f"[FACTS LIST]:\n{self.facts.strip()}")) + messages.append( + Message( + role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[FACTS LIST]:\n{self.facts.strip()}"}] + ) + ) if not summary_mode: - messages.append(Message(role=MessageRole.ASSISTANT, content=f"[PLAN]:\n{self.plan.strip()}")) + messages.append( + Message( + role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[PLAN]:\n{self.plan.strip()}"}] + ) + ) return messages @@ -161,7 +172,7 @@ class TaskStep(MemoryStep): task: str task_images: List[str] | None = None - def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]: + def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]: content = [{"type": "text", "text": f"New task:\n{self.task}"}] if self.task_images: for image in self.task_images: @@ -174,7 +185,7 @@ class TaskStep(MemoryStep): class SystemPromptStep(MemoryStep): system_prompt: str - def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]: + def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]: if summary_mode: return [] return [Message(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt.strip()}])] @@ -196,5 +207,30 @@ class AgentMemory: def get_full_steps(self) -> list[dict]: return [step.dict() for step in self.steps] + def replay(self, logger: AgentLogger, detailed: bool = False): + """Prints a pretty replay of the agent's steps. + + Args: + logger (AgentLogger): The logger to print replay logs to. + detailed (bool, optional): If True, also displays the memory at each step. Defaults to False. + Careful: will increase log length exponentially. Use only for debugging. + """ + logger.console.log("Replaying the agent's steps:") + for step in self.steps: + if isinstance(step, SystemPromptStep) and detailed: + logger.log_markdown(title="System prompt", content=step.system_prompt) + elif isinstance(step, TaskStep): + logger.log_task(step.task, "", 2) + elif isinstance(step, ActionStep): + logger.log_rule(f"Step {step.step_number}") + if detailed: + logger.log_messages(step.model_input_messages) + logger.log_markdown(title="Agent output:", content=step.model_output) + elif isinstance(step, PlanningStep): + logger.log_rule("Planning step") + if detailed: + logger.log_messages(step.model_input_messages) + logger.log_markdown(title="Agent output:", content=step.facts + "\n" + step.plan) + __all__ = ["AgentMemory"] diff --git a/src/smolagents/monitoring.py b/src/smolagents/monitoring.py index d56eef4..53c433f 100644 --- a/src/smolagents/monitoring.py +++ b/src/smolagents/monitoring.py @@ -14,6 +14,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +from enum import IntEnum +from typing import List, Optional + +from rich import box +from rich.console import Console, Group +from rich.panel import Panel +from rich.rule import Rule +from rich.syntax import Syntax from rich.text import Text @@ -57,4 +66,101 @@ class Monitor: self.logger.log(Text(console_outputs, style="dim"), level=1) -__all__ = ["Monitor"] +class LogLevel(IntEnum): + ERROR = 0 # Only errors + INFO = 1 # Normal output (default) + DEBUG = 2 # Detailed output + + +YELLOW_HEX = "#d4b702" + + +class AgentLogger: + def __init__(self, level: LogLevel = LogLevel.INFO): + self.level = level + self.console = Console() + + def log(self, *args, level: str | LogLevel = LogLevel.INFO, **kwargs) -> None: + """Logs a message to the console. + + Args: + level (LogLevel, optional): Defaults to LogLevel.INFO. + """ + if isinstance(level, str): + level = LogLevel[level.upper()] + if level <= self.level: + self.console.print(*args, **kwargs) + + def log_markdown(self, content: str, title: Optional[str] = None, level=LogLevel.INFO, style=YELLOW_HEX) -> None: + markdown_content = Syntax( + content, + lexer="markdown", + theme="github-dark", + word_wrap=True, + ) + if title: + self.log( + Group( + Rule( + "[bold italic]" + title, + align="left", + style=style, + ), + markdown_content, + ), + level=level, + ) + else: + self.log(markdown_content, level=level) + + def log_code(self, title: str, content: str, level: int = LogLevel.INFO) -> None: + self.log( + Panel( + Syntax( + content, + lexer="python", + theme="monokai", + word_wrap=True, + ), + title="[bold]" + title, + title_align="left", + box=box.HORIZONTALS, + ), + level=level, + ) + + def log_rule(self, title: str, level: int = LogLevel.INFO) -> None: + self.log( + Rule( + "[bold]" + title, + characters="━", + style=YELLOW_HEX, + ), + level=LogLevel.INFO, + ) + + def log_task(self, content: str, subtitle: str, level: int = LogLevel.INFO) -> None: + self.log( + Panel( + f"\n[bold]{content}\n", + title="[bold]New run", + subtitle=subtitle, + border_style=YELLOW_HEX, + subtitle_align="left", + ), + level=level, + ) + + def log_messages(self, messages: List) -> None: + messages_as_string = "\n".join([json.dumps(dict(message), indent=4) for message in messages]) + self.log( + Syntax( + messages_as_string, + lexer="markdown", + theme="github-dark", + word_wrap=True, + ) + ) + + +__all__ = ["AgentLogger", "Monitor"] diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index 1cf8d52..7483214 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -22,12 +22,12 @@ from smolagents import ( ToolCallingAgent, stream_to_gradio, ) -from smolagents.logger import AgentLogger, LogLevel from smolagents.models import ( ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, ) +from smolagents.monitoring import AgentLogger, LogLevel class FakeLLMModel: