Fix regressions, refactor logging and improve replay function (#419)

* Refactor logging and improve replay function

* Fix test

* Remode unnecessary dict method

* Remove logger import

---------

Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
This commit is contained in:
Aymeric Roucher 2025-01-30 11:32:28 +01:00 committed by GitHub
parent c0abd2134e
commit 20383f0914
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 178 additions and 151 deletions

View File

@ -21,7 +21,6 @@ from .default_tools import *
from .e2b_executor import *
from .gradio_ui import *
from .local_python_executor import *
from .logger import *
from .memory import *
from .models import *
from .monitoring import *

View File

@ -20,18 +20,17 @@ from collections import deque
from logging import getLogger
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from rich import box
from rich.console import Group
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.text import Text
from smolagents.logger import (
from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from smolagents.monitoring import (
YELLOW_HEX,
AgentLogger,
LogLevel,
)
from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from smolagents.types import AgentAudio, AgentImage, handle_agent_output_types
from smolagents.utils import (
AgentError,
@ -123,9 +122,6 @@ def format_prompt_with_managed_agents_descriptions(
return prompt_template.replace(agent_descriptions_placeholder, "")
YELLOW_HEX = "#d4b702"
class MultiStepAgent:
"""
Agent class that solves the given task step by step, using the ReAct framework:
@ -399,14 +395,9 @@ You have been provided with these additional arguments, that you can access usin
self.memory.reset()
self.monitor.reset()
self.logger.log(
Panel(
f"\n[bold]{self.task.strip()}\n",
title="[bold]New run",
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
border_style=YELLOW_HEX,
subtitle_align="left",
),
self.logger.log_task(
content=self.task.strip(),
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
level=LogLevel.INFO,
)
@ -451,14 +442,7 @@ You have been provided with these additional arguments, that you can access usin
is_first_step=(self.step_number == 0),
step=self.step_number,
)
self.logger.log(
Rule(
f"[bold]Step {self.step_number}",
characters="",
style=YELLOW_HEX,
),
level=LogLevel.INFO,
)
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
# Run one step!
final_answer = self.step(memory_step)
@ -520,8 +504,9 @@ You have been provided with these additional arguments, that you can access usin
```
Now begin!""",
}
input_messages = [message_prompt_facts, message_prompt_task]
chat_message_facts: ChatMessage = self.model([message_prompt_facts, message_prompt_task])
chat_message_facts: ChatMessage = self.model(input_messages)
answer_facts = chat_message_facts.content
message_system_prompt_plan = {
@ -553,6 +538,7 @@ Now begin!""",
```""".strip()
self.memory.steps.append(
PlanningStep(
model_input_messages=input_messages,
plan=final_plan_redaction,
facts=final_facts_redaction,
model_output_message_plan=chat_message_plan,
@ -578,9 +564,8 @@ Now begin!""",
"role": MessageRole.USER,
"content": [{"type": "text", "text": USER_PROMPT_FACTS_UPDATE}],
}
chat_message_facts: ChatMessage = self.model(
[facts_update_system_prompt] + memory_messages + [facts_update_message]
)
input_messages = [facts_update_system_prompt] + memory_messages + [facts_update_message]
chat_message_facts: ChatMessage = self.model(input_messages)
facts_update = chat_message_facts.content
# Redact updated plan
@ -618,6 +603,7 @@ Now begin!""",
```"""
self.memory.steps.append(
PlanningStep(
model_input_messages=input_messages,
plan=final_plan_redaction,
facts=final_facts_redaction,
model_output_message_plan=chat_message_plan,
@ -630,6 +616,15 @@ Now begin!""",
level=LogLevel.INFO,
)
def replay(self, detailed: bool = False):
"""Prints a pretty replay of the agent's steps.
Args:
detailed (bool, optional): If True, also displays the memory at each step. Defaults to False.
Careful: will increase log length exponentially. Use only for debugging.
"""
self.memory.replay(self.logger, detailed=detailed)
class ToolCallingAgent(MultiStepAgent):
"""
@ -851,20 +846,9 @@ class CodeAgent(MultiStepAgent):
except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e
self.logger.log(
Group(
Rule(
"[italic]Output message of the LLM:",
align="left",
style="orange",
),
Syntax(
model_output,
lexer="markdown",
theme="github-dark",
word_wrap=True,
),
),
self.logger.log_markdown(
content=model_output,
title="Output message of the LLM:",
level=LogLevel.DEBUG,
)
@ -884,20 +868,7 @@ class CodeAgent(MultiStepAgent):
]
# Execute
self.logger.log(
Panel(
Syntax(
code_action,
lexer="python",
theme="monokai",
word_wrap=True,
),
title="[bold]Executing this code:",
title_align="left",
box=box.HORIZONTALS,
),
level=LogLevel.INFO,
)
self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO)
observation = ""
is_final_answer = False
try:

View File

@ -1,85 +0,0 @@
import json
from enum import IntEnum
from typing import TYPE_CHECKING
from rich.console import Console
from rich.rule import Rule
from rich.syntax import Syntax
if TYPE_CHECKING:
from smolagents.memory import AgentMemory
class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
def log(self, *args, level: str | LogLevel = LogLevel.INFO, **kwargs):
"""Logs a message to the console.
Args:
level (LogLevel, optional): Defaults to LogLevel.INFO.
"""
if isinstance(level, str):
level = LogLevel[level.upper()]
if level <= self.level:
self.console.print(*args, **kwargs)
def replay(self, agent_memory: "AgentMemory", full: bool = False):
"""Prints a pretty replay of the agent's steps.
Args:
with_memory (bool, optional): If True, also displays the memory at each step. Defaults to False.
Careful: will increase log length exponentially. Use only for debugging.
"""
memory = []
for step_log in agent_memory.steps:
memory.extend(step_log.to_messages(show_model_input_messages=full))
self.console.log("Replaying the agent's steps:")
ix = 0
for step in memory:
role = step["role"].strip()
if ix > 0 and role == "system":
role == "memory"
theme = "default"
match role:
case "assistant":
theme = "monokai"
ix += 1
case "system":
theme = "monokai"
case "tool-response":
theme = "github_dark"
content = step["content"]
try:
content = eval(content)
except Exception:
content = [step["content"]]
for substep_ix, item in enumerate(content):
self.console.log(
Rule(
f"{role.upper()}, STEP {ix}, SUBSTEP {substep_ix + 1}/{len(content)}",
align="center",
style="orange",
),
Syntax(
json.dumps(item, indent=4) if isinstance(item, dict) else str(item),
lexer="json",
theme=theme,
word_wrap=True,
),
)
__all__ = ["AgentLogger"]

View File

@ -3,10 +3,12 @@ from logging import getLogger
from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union
from smolagents.models import ChatMessage, MessageRole
from smolagents.monitoring import AgentLogger
from smolagents.utils import AgentError, make_json_serializable
if TYPE_CHECKING:
from smolagents.logger import AgentLogger
from smolagents.models import ChatMessage
@ -47,7 +49,7 @@ class MemoryStep:
@dataclass
class ActionStep(MemoryStep):
model_input_messages: List[Dict[str, str]] | None = None
model_input_messages: List[Message] | None = None
tool_calls: List[ToolCall] | None = None
start_time: float | None = None
end_time: float | None = None
@ -76,7 +78,7 @@ class ActionStep(MemoryStep):
"action_output": make_json_serializable(self.action_output),
}
def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Dict[str, Any]]:
def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Message]:
messages = []
if self.model_input_messages is not None and show_model_input_messages:
messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages))
@ -142,17 +144,26 @@ class ActionStep(MemoryStep):
@dataclass
class PlanningStep(MemoryStep):
model_input_messages: List[Message]
model_output_message_facts: ChatMessage
facts: str
model_output_message_facts: ChatMessage
model_output_message_plan: ChatMessage
plan: str
def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
def to_messages(self, summary_mode: bool, **kwargs) -> List[Message]:
messages = []
messages.append(Message(role=MessageRole.ASSISTANT, content=f"[FACTS LIST]:\n{self.facts.strip()}"))
messages.append(
Message(
role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[FACTS LIST]:\n{self.facts.strip()}"}]
)
)
if not summary_mode:
messages.append(Message(role=MessageRole.ASSISTANT, content=f"[PLAN]:\n{self.plan.strip()}"))
messages.append(
Message(
role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[PLAN]:\n{self.plan.strip()}"}]
)
)
return messages
@ -161,7 +172,7 @@ class TaskStep(MemoryStep):
task: str
task_images: List[str] | None = None
def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]:
content = [{"type": "text", "text": f"New task:\n{self.task}"}]
if self.task_images:
for image in self.task_images:
@ -174,7 +185,7 @@ class TaskStep(MemoryStep):
class SystemPromptStep(MemoryStep):
system_prompt: str
def to_messages(self, summary_mode: bool, **kwargs) -> List[Dict[str, str]]:
def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]:
if summary_mode:
return []
return [Message(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt.strip()}])]
@ -196,5 +207,30 @@ class AgentMemory:
def get_full_steps(self) -> list[dict]:
return [step.dict() for step in self.steps]
def replay(self, logger: AgentLogger, detailed: bool = False):
"""Prints a pretty replay of the agent's steps.
Args:
logger (AgentLogger): The logger to print replay logs to.
detailed (bool, optional): If True, also displays the memory at each step. Defaults to False.
Careful: will increase log length exponentially. Use only for debugging.
"""
logger.console.log("Replaying the agent's steps:")
for step in self.steps:
if isinstance(step, SystemPromptStep) and detailed:
logger.log_markdown(title="System prompt", content=step.system_prompt)
elif isinstance(step, TaskStep):
logger.log_task(step.task, "", 2)
elif isinstance(step, ActionStep):
logger.log_rule(f"Step {step.step_number}")
if detailed:
logger.log_messages(step.model_input_messages)
logger.log_markdown(title="Agent output:", content=step.model_output)
elif isinstance(step, PlanningStep):
logger.log_rule("Planning step")
if detailed:
logger.log_messages(step.model_input_messages)
logger.log_markdown(title="Agent output:", content=step.facts + "\n" + step.plan)
__all__ = ["AgentMemory"]

View File

@ -14,6 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from enum import IntEnum
from typing import List, Optional
from rich import box
from rich.console import Console, Group
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.text import Text
@ -57,4 +66,101 @@ class Monitor:
self.logger.log(Text(console_outputs, style="dim"), level=1)
__all__ = ["Monitor"]
class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output
YELLOW_HEX = "#d4b702"
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
def log(self, *args, level: str | LogLevel = LogLevel.INFO, **kwargs) -> None:
"""Logs a message to the console.
Args:
level (LogLevel, optional): Defaults to LogLevel.INFO.
"""
if isinstance(level, str):
level = LogLevel[level.upper()]
if level <= self.level:
self.console.print(*args, **kwargs)
def log_markdown(self, content: str, title: Optional[str] = None, level=LogLevel.INFO, style=YELLOW_HEX) -> None:
markdown_content = Syntax(
content,
lexer="markdown",
theme="github-dark",
word_wrap=True,
)
if title:
self.log(
Group(
Rule(
"[bold italic]" + title,
align="left",
style=style,
),
markdown_content,
),
level=level,
)
else:
self.log(markdown_content, level=level)
def log_code(self, title: str, content: str, level: int = LogLevel.INFO) -> None:
self.log(
Panel(
Syntax(
content,
lexer="python",
theme="monokai",
word_wrap=True,
),
title="[bold]" + title,
title_align="left",
box=box.HORIZONTALS,
),
level=level,
)
def log_rule(self, title: str, level: int = LogLevel.INFO) -> None:
self.log(
Rule(
"[bold]" + title,
characters="",
style=YELLOW_HEX,
),
level=LogLevel.INFO,
)
def log_task(self, content: str, subtitle: str, level: int = LogLevel.INFO) -> None:
self.log(
Panel(
f"\n[bold]{content}\n",
title="[bold]New run",
subtitle=subtitle,
border_style=YELLOW_HEX,
subtitle_align="left",
),
level=level,
)
def log_messages(self, messages: List) -> None:
messages_as_string = "\n".join([json.dumps(dict(message), indent=4) for message in messages])
self.log(
Syntax(
messages_as_string,
lexer="markdown",
theme="github-dark",
word_wrap=True,
)
)
__all__ = ["AgentLogger", "Monitor"]

View File

@ -22,12 +22,12 @@ from smolagents import (
ToolCallingAgent,
stream_to_gradio,
)
from smolagents.logger import AgentLogger, LogLevel
from smolagents.models import (
ChatMessage,
ChatMessageToolCall,
ChatMessageToolCallDefinition,
)
from smolagents.monitoring import AgentLogger, LogLevel
class FakeLLMModel: