More flexible verbosity level (#150)

This commit is contained in:
Aymeric Roucher 2025-01-10 23:46:22 +01:00 committed by GitHub
parent 82a2fe5bb4
commit fec65e154a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 86 additions and 71 deletions

View File

@ -137,7 +137,7 @@ _Note:_ The Inference API hosts models based on various criteria, and deployed m
from smolagents import HfApiModel, CodeAgent
agent = CodeAgent(
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbose=True
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbosity_level=2
)
```

View File

@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -172,7 +172,7 @@
"[132 rows x 4 columns]"
]
},
"execution_count": 21,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@ -195,7 +195,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -398,23 +398,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 27836.90it/s]\n",
" 16%|█▌ | 21/132 [02:18<07:35, 4.11s/it]"
]
}
],
"outputs": [],
"source": [
"open_model_ids = [\n",
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
@ -423,6 +407,7 @@
" \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n",
" \"meta-llama/Llama-3.2-3B-Instruct\",\n",
" \"meta-llama/Llama-3.1-8B-Instruct\",\n",
" \"mistralai/Mistral-Nemo-Instruct-2407\",\n",
" # \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n",
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
"]\n",
@ -1010,7 +995,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "test",
"display_name": "compare-agents",
"language": "python",
"name": "python3"
},

View File

@ -60,7 +60,7 @@ from smolagents import HfApiModel, CodeAgent
retriever_tool = RetrieverTool(docs_processed)
agent = CodeAgent(
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbose=True
tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbosity_level=2
)
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")

View File

@ -18,12 +18,14 @@ import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from enum import IntEnum
from rich import box
from rich.console import Group
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.text import Text
from rich.console import Console
from .default_tools import FinalAnswerTool, TOOL_MAPPING
from .e2b_executor import E2BExecutor
@ -164,6 +166,22 @@ def format_prompt_with_managed_agents_descriptions(
YELLOW_HEX = "#d4b702"
class LogLevel(IntEnum):
ERROR = 0 # Only errors
INFO = 1 # Normal output (default)
DEBUG = 2 # Detailed output
class AgentLogger:
def __init__(self, level: LogLevel = LogLevel.INFO):
self.level = level
self.console = Console()
def log(self, *args, level: LogLevel = LogLevel.INFO, **kwargs):
if level <= self.level:
console.print(*args, **kwargs)
class MultiStepAgent:
"""
Agent class that solves the given task step by step, using the ReAct framework:
@ -179,7 +197,7 @@ class MultiStepAgent:
max_steps: int = 6,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
verbose: bool = False,
verbosity_level: int = 1,
grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None,
step_callbacks: Optional[List[Callable]] = None,
@ -205,7 +223,6 @@ class MultiStepAgent:
self.managed_agents = {}
if managed_agents is not None:
print("NOTNONE")
self.managed_agents = {agent.name: agent for agent in managed_agents}
self.tools = {tool.name: tool for tool in tools}
@ -222,8 +239,8 @@ class MultiStepAgent:
self.input_messages = None
self.logs = []
self.task = None
self.verbose = verbose
self.monitor = Monitor(self.model)
self.logger = AgentLogger(level=verbosity_level)
self.monitor = Monitor(self.model, self.logger)
self.step_callbacks = step_callbacks if step_callbacks is not None else []
self.step_callbacks.append(self.monitor.update_metrics)
@ -485,14 +502,15 @@ You have been provided with these additional arguments, that you can access usin
else:
self.logs.append(system_prompt_step)
console.print(
self.logger.log(
Panel(
f"\n[bold]{self.task.strip()}\n",
title="[bold]New run",
subtitle=f"{type(self.model).__name__} - {(self.model.model_id if hasattr(self.model, 'model_id') else '')}",
border_style=YELLOW_HEX,
subtitle_align="left",
)
),
level=LogLevel.INFO,
)
self.logs.append(TaskStep(task=self.task))
@ -531,12 +549,13 @@ You have been provided with these additional arguments, that you can access usin
is_first_step=(self.step_number == 0),
step=self.step_number,
)
console.print(
self.logger.log(
Rule(
f"[bold]Step {self.step_number}",
characters="",
style=YELLOW_HEX,
)
),
level=LogLevel.INFO,
)
# Run one step!
@ -557,7 +576,7 @@ You have been provided with these additional arguments, that you can access usin
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task)
console.print(Text(f"Final answer: {final_answer}"))
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
final_step_log.action_output = final_answer
final_step_log.end_time = time.time()
final_step_log.duration = step_log.end_time - step_start_time
@ -586,12 +605,13 @@ You have been provided with these additional arguments, that you can access usin
is_first_step=(self.step_number == 0),
step=self.step_number,
)
console.print(
self.logger.log(
Rule(
f"[bold]Step {self.step_number}",
characters="",
style=YELLOW_HEX,
)
),
level=LogLevel.INFO,
)
# Run one step!
@ -613,7 +633,7 @@ You have been provided with these additional arguments, that you can access usin
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
self.logs.append(final_step_log)
final_answer = self.provide_final_answer(task)
console.print(Text(f"Final answer: {final_answer}"))
self.logger.log(Text(f"Final answer: {final_answer}"), level=LogLevel.INFO)
final_step_log.action_output = final_answer
final_step_log.duration = 0
for callback in self.step_callbacks:
@ -679,8 +699,10 @@ Now begin!""",
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.print(
Rule("[bold]Initial plan", style="orange"), Text(final_plan_redaction)
self.logger.log(
Rule("[bold]Initial plan", style="orange"),
Text(final_plan_redaction),
level=LogLevel.INFO,
)
else: # update plan
agent_memory = self.write_inner_memory_from_logs(
@ -735,8 +757,10 @@ Now begin!""",
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
console.print(
Rule("[bold]Updated plan", style="orange"), Text(final_plan_redaction)
self.logger.log(
Rule("[bold]Updated plan", style="orange"),
Text(final_plan_redaction),
level=LogLevel.INFO,
)
@ -795,8 +819,11 @@ class ToolCallingAgent(MultiStepAgent):
)
# Execute
console.print(
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}"))
self.logger.log(
Panel(
Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")
),
level=LogLevel.INFO,
)
if tool_name == "final_answer":
if isinstance(tool_arguments, dict):
@ -810,13 +837,15 @@ class ToolCallingAgent(MultiStepAgent):
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
final_answer = self.state[answer]
console.print(
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'."
self.logger.log(
f"[bold {YELLOW_HEX}]Final answer:[/bold {YELLOW_HEX}] Extracting key '{answer}' from state to return value '{final_answer}'.",
level=LogLevel.INFO,
)
else:
final_answer = answer
console.print(
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}")
self.logger.log(
Text(f"Final answer: {final_answer}", style=f"bold {YELLOW_HEX}"),
level=LogLevel.INFO,
)
log_entry.action_output = final_answer
@ -837,7 +866,7 @@ class ToolCallingAgent(MultiStepAgent):
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
console.print(f"Observations: {updated_information}")
self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO)
log_entry.observations = updated_information
return None
@ -922,8 +951,7 @@ class CodeAgent(MultiStepAgent):
except Exception as e:
raise AgentGenerationError(f"Error in generating model output:\n{e}")
if self.verbose:
console.print(
self.logger.log(
Group(
Rule(
"[italic]Output message of the LLM:",
@ -936,7 +964,8 @@ class CodeAgent(MultiStepAgent):
theme="github-dark",
word_wrap=True,
),
)
),
level=LogLevel.DEBUG,
)
# Parse
@ -955,7 +984,7 @@ class CodeAgent(MultiStepAgent):
)
# Execute
console.print(
self.logger.log(
Panel(
Syntax(
code_action,
@ -966,7 +995,8 @@ class CodeAgent(MultiStepAgent):
title="[bold]Executing this code:",
title_align="left",
box=box.HORIZONTALS,
)
),
level=LogLevel.INFO,
)
observation = ""
is_final_answer = False
@ -993,8 +1023,9 @@ class CodeAgent(MultiStepAgent):
else:
error_msg = str(e)
if "Import of " in str(e) and " is not allowed" in str(e):
console.print(
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent."
self.logger.log(
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
level=LogLevel.INFO,
)
raise AgentExecutionError(error_msg)
@ -1008,7 +1039,7 @@ class CodeAgent(MultiStepAgent):
style=(f"bold {YELLOW_HEX}" if is_final_answer else ""),
),
]
console.print(Group(*execution_outputs_console))
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
log_entry.action_output = output
return output if is_final_answer else None

View File

@ -16,13 +16,12 @@
# limitations under the License.
from rich.text import Text
from .utils import console
class Monitor:
def __init__(self, tracked_model):
def __init__(self, tracked_model, logger):
self.step_durations = []
self.tracked_model = tracked_model
self.logger = logger
if (
getattr(self.tracked_model, "last_input_token_count", "Not found")
!= "Not found"
@ -53,7 +52,7 @@ class Monitor:
self.total_output_token_count += self.tracked_model.last_output_token_count
console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
console_outputs += "]"
console.print(Text(console_outputs, style="dim"))
self.logger.log(Text(console_outputs, style="dim"), level=1)
__all__ = ["Monitor"]