Nicer console outputs with rich

This commit is contained in:
Aymeric 2024-12-09 15:42:17 +01:00
parent f3dcf1f013
commit 146ee3dd32
20 changed files with 224 additions and 194 deletions

View File

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ..utils import ( from transformers.utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_torch_available, is_torch_available,

View File

@ -19,10 +19,11 @@ import uuid
import numpy as np import numpy as np
from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
import logging
logger = logging.get_logger(__name__) logger = logging.getLogger(__name__)
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
@ -159,7 +160,7 @@ class AgentImage(AgentType, ImageType):
return self._path return self._path
def save(self, output_bytes, format, **params): def save(self, output_bytes, format = None, **params):
""" """
Saves the image to a file. Saves the image to a file.
Args: Args:
@ -168,7 +169,7 @@ class AgentImage(AgentType, ImageType):
**params: Additional parameters to pass to PIL.Image.save. **params: Additional parameters to pass to PIL.Image.save.
""" """
img = self.to_raw() img = self.to_raw()
img.save(output_bytes, format, **params) img.save(output_bytes, format=format, **params)
class AgentAudio(AgentType, str): class AgentAudio(AgentType, str):

View File

@ -19,10 +19,12 @@ import logging
import re import re
import time import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import rich
from rich import markdown as rich_markdown
from .. import is_torch_available from transformers.utils import is_torch_available
from ..utils import logging as transformers_logging import logging
from ..utils.import_utils import is_pygments_available from .utils import console
from .agent_types import AgentAudio, AgentImage from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole from .llm_engine import HfApiEngine, MessageRole
@ -48,52 +50,6 @@ from .tools import (
) )
if is_pygments_available():
from pygments import highlight
from pygments.formatters import Terminal256Formatter
from pygments.lexers import PythonLexer
class CustomFormatter(logging.Formatter):
grey = "\x1b[38;20m"
bold_yellow = "\x1b[33;1m"
red = "\x1b[31;20m"
green = "\x1b[32;20m"
bold_green = "\x1b[32;20;1m"
bold_red = "\x1b[31;1m"
bold_white = "\x1b[37;1m"
orange = "\x1b[38;5;214m"
bold_orange = "\x1b[38;5;214;1m"
reset = "\x1b[0m"
format = "%(message)s"
FORMATS = {
logging.DEBUG: grey + format + reset,
logging.INFO: format,
logging.WARNING: bold_yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset,
31: reset + format + reset,
32: green + format + reset,
33: bold_green + format + reset,
34: bold_white + format + reset,
35: orange + format + reset,
36: bold_orange + format + reset,
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
logger = transformers_logging.get_logger(__name__)
logger.propagate = False
ch = logging.StreamHandler()
ch.setFormatter(CustomFormatter())
logger.addHandler(ch)
def parse_json_blob(json_blob: str) -> Dict[str, str]: def parse_json_blob(json_blob: str) -> Dict[str, str]:
try: try:
first_accolade_index = json_blob.find("{") first_accolade_index = json_blob.find("{")
@ -142,9 +98,10 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
elif "action" in tool_call: elif "action" in tool_call:
return tool_call["action"], None return tool_call["action"], None
else: else:
raise ValueError( missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call]
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}" error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
) console.print(f"[bold red]{error_msg}[/bold red]")
raise ValueError(error_msg)
def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]: def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]:
@ -197,18 +154,18 @@ class Toolbox:
self._tools = {tool.name: tool for tool in tools} self._tools = {tool.name: tool for tool in tools}
if add_base_tools: if add_base_tools:
self.add_base_tools() self.add_base_tools()
self._load_tools_if_needed() # self._load_tools_if_needed()
def add_base_tools(self, add_python_interpreter: bool = False): def add_base_tools(self, add_python_interpreter: bool = False):
global _tools_are_initialized global _tools_are_initialized
global HUGGINGFACE_DEFAULT_TOOLS global HUGGINGFACE_DEFAULT_TOOLS
if not _tools_are_initialized: if not _tools_are_initialized:
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger) HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
_tools_are_initialized = True _tools_are_initialized = True
for tool in HUGGINGFACE_DEFAULT_TOOLS.values(): for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
if tool.name != "python_interpreter" or add_python_interpreter: if tool.name != "python_interpreter" or add_python_interpreter:
self.add_tool(tool) self.add_tool(tool)
self._load_tools_if_needed() # self._load_tools_if_needed()
@property @property
def tools(self) -> Dict[str, Tool]: def tools(self) -> Dict[str, Tool]:
@ -271,11 +228,11 @@ class Toolbox:
"""Clears the toolbox""" """Clears the toolbox"""
self._tools = {} self._tools = {}
def _load_tools_if_needed(self): # def _load_tools_if_needed(self):
for name, tool in self._tools.items(): # for name, tool in self._tools.items():
if not isinstance(tool, Tool): # if not isinstance(tool, Tool):
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id # task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
self._tools[name] = load_tool(task_or_repo_id) # self._tools[name] = load_tool(task_or_repo_id)
def __repr__(self): def __repr__(self):
toolbox_description = "Toolbox contents:\n" toolbox_description = "Toolbox contents:\n"
@ -290,6 +247,8 @@ class AgentError(Exception):
def __init__(self, message): def __init__(self, message):
super().__init__(message) super().__init__(message)
self.message = message self.message = message
console.print(f"[bold red]{message}[/bold red]")
class AgentParsingError(AgentError): class AgentParsingError(AgentError):
@ -362,7 +321,7 @@ class Agent:
max_iterations: int = 6, max_iterations: int = 6,
tool_parser: Optional[Callable] = None, tool_parser: Optional[Callable] = None,
add_base_tools: bool = False, add_base_tools: bool = False,
verbose: int = 0, verbose: bool = False,
grammar: Optional[Dict[str, str]] = None, grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None, managed_agents: Optional[List] = None,
step_callbacks: Optional[List[Callable]] = None, step_callbacks: Optional[List[Callable]] = None,
@ -380,7 +339,6 @@ class Agent:
) )
self.additional_args = additional_args self.additional_args = additional_args
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.logger = logger
self.tool_parser = tool_parser self.tool_parser = tool_parser
self.grammar = grammar self.grammar = grammar
@ -406,13 +364,7 @@ class Agent:
self.prompt = None self.prompt = None
self.logs = [] self.logs = []
self.task = None self.task = None
self.verbose = verbose
if verbose == 0:
logger.setLevel(logging.WARNING)
elif verbose == 1:
logger.setLevel(logging.INFO)
elif verbose == 2:
logger.setLevel(logging.DEBUG)
# Initialize step callbacks # Initialize step callbacks
self.step_callbacks = step_callbacks if step_callbacks is not None else [] self.step_callbacks = step_callbacks if step_callbacks is not None else []
@ -441,10 +393,8 @@ class Agent:
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
) )
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}] self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
self.logger.log(33, "======== New task ========") console.rule("New task", characters='=')
self.logger.log(34, self.task) console.print(self.task)
self.logger.debug("System prompt is as follows:")
self.logger.debug(self.system_prompt)
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
""" """
@ -521,7 +471,6 @@ class Agent:
split[-1], split[-1],
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
except Exception as e: except Exception as e:
self.logger.error(e, exc_info=1)
raise AgentParsingError( raise AgentParsingError(
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
) )
@ -541,7 +490,7 @@ class Agent:
available_tools = {**available_tools, **self.managed_agents} available_tools = {**available_tools, **self.managed_agents}
if tool_name not in available_tools: if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
self.logger.error(error_msg, exc_info=1) console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
try: try:
@ -549,38 +498,39 @@ class Agent:
observation = available_tools[tool_name](arguments) observation = available_tools[tool_name](arguments)
elif isinstance(arguments, dict): elif isinstance(arguments, dict):
for key, value in arguments.items(): for key, value in arguments.items():
# if the value is the name of a state variable like "image.png", replace it with the actual value
if isinstance(value, str) and value in self.state: if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value] arguments[key] = self.state[value]
observation = available_tools[tool_name](**arguments) observation = available_tools[tool_name](**arguments)
else: else:
raise AgentExecutionError( error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." console.print(f"[bold red]{error_msg}")
) raise AgentExecutionError(error_msg)
return observation return observation
except Exception as e: except Exception as e:
if tool_name in self.toolbox.tools: if tool_name in self.toolbox.tools:
raise AgentExecutionError( tool_description = get_tool_description_with_args(available_tools[tool_name])
error_msg = (
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}" f"As a reminder, this tool's description is the following:\n{tool_description}"
) )
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
elif tool_name in self.managed_agents: elif tool_name in self.managed_agents:
raise AgentExecutionError( error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
) )
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
def log_rationale_code_action(self, rationale: str, code_action: str) -> None: def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
self.logger.warning("=== Agent thoughts:") if self.verbose:
self.logger.log(31, rationale) console.rule("Agent thoughts")
self.logger.warning(">>> Agent is executing the code below:") console.print(rationale)
if is_pygments_available(): console.rule("Agent is executing the code below:", align="left")
self.logger.log( console.print(code_action)
31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord")) console.rule("", align="left")
)
else:
self.logger.log(31, code_action)
self.logger.warning("====")
def run(self, **kwargs): def run(self, **kwargs):
"""To be implemented in the child class""" """To be implemented in the child class"""
@ -617,13 +567,6 @@ class CodeAgent(Agent):
**kwargs, **kwargs,
) )
if not is_pygments_available():
transformers_logging.warning_once(
logger,
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
"CodeAgent.",
)
self.python_evaluator = evaluate_python_code self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
@ -669,29 +612,30 @@ class CodeAgent(Agent):
} }
self.prompt = [prompt_message, task_message] self.prompt = [prompt_message, task_message]
self.logger.info("====Executing with this prompt====")
self.logger.info(self.prompt) if self.verbose:
console.rule("Executing with this prompt")
console.print(self.prompt)
additional_args = {"grammar": self.grammar} if self.grammar is not None else {} additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args) llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
if return_generated_code:
return llm_output
# Parse # Parse
try: try:
rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e: except Exception as e:
self.logger.debug( if self.verbose:
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}" console.print(
) f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
)
rationale, code_action = "", llm_output rationale, code_action = "", llm_output
try: try:
code_action = self.parse_code_blob(code_action) code_action = self.parse_code_blob(code_action)
except Exception as e: except Exception as e:
error_msg = f"Error in code parsing: {e}. Be sure to provide correct code" error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
self.logger.error(error_msg, exc_info=1) console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg return error_msg
# Execute # Execute
@ -705,11 +649,12 @@ class CodeAgent(Agent):
state=self.state, state=self.state,
authorized_imports=self.authorized_imports, authorized_imports=self.authorized_imports,
) )
self.logger.info(self.state["print_outputs"]) if self.verbose:
console.print(self.state["print_outputs"])
return output return output
except Exception as e: except Exception as e:
error_msg = f"Error in execution: {e}. Be sure to provide correct code." error_msg = f"Error in execution: {e}. Be sure to provide correct code."
self.logger.error(error_msg, exc_info=1) console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg return error_msg
@ -759,7 +704,7 @@ class ReactAgent(Agent):
self.prompt = [ self.prompt = [
{ {
"role": MessageRole.SYSTEM, "role": MessageRole.SYSTEM,
"content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", "content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
} }
] ]
self.prompt += self.write_inner_memory_from_logs()[1:] self.prompt += self.write_inner_memory_from_logs()[1:]
@ -772,7 +717,9 @@ class ReactAgent(Agent):
try: try:
return self.llm_engine(self.prompt) return self.llm_engine(self.prompt)
except Exception as e: except Exception as e:
return f"Error in generating final llm output: {e}." error_msg = f"Error in generating final LLM output: {e}."
console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
""" """
@ -815,7 +762,6 @@ class ReactAgent(Agent):
if "final_answer" in step_log_entry: if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"] final_answer = step_log_entry["final_answer"]
except AgentError as e: except AgentError as e:
self.logger.error(e, exc_info=1)
step_log_entry["error"] = e step_log_entry["error"] = e
finally: finally:
step_end_time = time.time() step_end_time = time.time()
@ -831,7 +777,7 @@ class ReactAgent(Agent):
error_message = "Reached max iterations." error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)} final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log) self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1) console.print(f"[bold red]{error_message}")
final_answer = self.provide_final_answer(task) final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer final_step_log["final_answer"] = final_answer
final_step_log["step_duration"] = 0 final_step_log["step_duration"] = 0
@ -857,7 +803,6 @@ class ReactAgent(Agent):
if "final_answer" in step_log_entry: if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"] final_answer = step_log_entry["final_answer"]
except AgentError as e: except AgentError as e:
self.logger.error(e, exc_info=1)
step_log_entry["error"] = e step_log_entry["error"] = e
finally: finally:
step_end_time = time.time() step_end_time = time.time()
@ -872,7 +817,7 @@ class ReactAgent(Agent):
error_message = "Reached max iterations." error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)} final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log) self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1) console.print(f"[bold red]{error_message}")
final_answer = self.provide_final_answer(task) final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer final_step_log["final_answer"] = final_answer
final_step_log["step_duration"] = 0 final_step_log["step_duration"] = 0
@ -931,8 +876,8 @@ Now begin!""",
{answer_facts} {answer_facts}
```""".strip() ```""".strip()
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.log(36, "===== Initial plan =====") console.rule("[orange]Initial plan")
self.logger.log(35, final_plan_redaction) console.print(final_plan_redaction)
else: # update plan else: # update plan
agent_memory = self.write_inner_memory_from_logs( agent_memory = self.write_inner_memory_from_logs(
summary_mode=False summary_mode=False
@ -977,8 +922,8 @@ Now begin!""",
{facts_update} {facts_update}
```""" ```"""
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.log(36, "===== Updated plan =====") console.rule("Updated plan")
self.logger.log(35, final_plan_redaction) console.print(final_plan_redaction)
class ReactJsonAgent(ReactAgent): class ReactJsonAgent(ReactAgent):
@ -1022,13 +967,14 @@ class ReactJsonAgent(ReactAgent):
agent_memory = self.write_inner_memory_from_logs() agent_memory = self.write_inner_memory_from_logs()
self.prompt = agent_memory self.prompt = agent_memory
self.logger.debug("===== New step =====") console.rule("New step")
# Add new step in logs # Add new step in logs
log_entry["agent_memory"] = agent_memory.copy() log_entry["agent_memory"] = agent_memory.copy()
self.logger.info("===== Calling LLM with this last message: =====") if self.verbose:
self.logger.info(self.prompt[-1]) console.rule("Calling LLM with this last message:")
console.print(self.prompt[-1])
try: try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {} additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
@ -1037,12 +983,12 @@ class ReactJsonAgent(ReactAgent):
) )
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.") raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("===== Output message of the LLM: =====") console.rule("===== Output message of the LLM: =====")
self.logger.debug(llm_output) console.print(llm_output)
log_entry["llm_output"] = llm_output log_entry["llm_output"] = llm_output
# Parse # Parse
self.logger.debug("===== Extracting action =====") console.rule("===== Extracting action =====")
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
try: try:
@ -1054,9 +1000,9 @@ class ReactJsonAgent(ReactAgent):
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
# Execute # Execute
self.logger.warning("=== Agent thoughts:") console.print("=== Agent thoughts:")
self.logger.log(31, rationale) console.print(rationale)
self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}") console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
if tool_name == "final_answer": if tool_name == "final_answer":
if isinstance(arguments, dict): if isinstance(arguments, dict):
if "answer" in arguments: if "answer" in arguments:
@ -1087,7 +1033,6 @@ class ReactJsonAgent(ReactAgent):
updated_information = f"Stored '{observation_name}' in memory." updated_information = f"Stored '{observation_name}' in memory."
else: else:
updated_information = str(observation).strip() updated_information = str(observation).strip()
self.logger.info(updated_information)
log_entry["observation"] = updated_information log_entry["observation"] = updated_information
return log_entry return log_entry
@ -1126,13 +1071,6 @@ class ReactCodeAgent(ReactAgent):
**kwargs, **kwargs,
) )
if not is_pygments_available():
transformers_logging.warning_once(
logger,
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
"ReactCodeAgent.",
)
self.python_evaluator = evaluate_python_code self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
@ -1147,13 +1085,14 @@ class ReactCodeAgent(ReactAgent):
agent_memory = self.write_inner_memory_from_logs() agent_memory = self.write_inner_memory_from_logs()
self.prompt = agent_memory.copy() self.prompt = agent_memory.copy()
self.logger.debug("===== New step =====") console.rule("New step")
# Add new step in logs # Add new step in logs
log_entry["agent_memory"] = agent_memory.copy() log_entry["agent_memory"] = agent_memory.copy()
self.logger.info("===== Calling LLM with these last messages: =====") if self.verbose:
self.logger.info(self.prompt[-2:]) console.print("===== Calling LLM with these last messages: =====")
console.print(self.prompt[-2:])
try: try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {} additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
@ -1163,16 +1102,16 @@ class ReactCodeAgent(ReactAgent):
except Exception as e: except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.") raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("=== Output message of the LLM:") if self.verbose:
self.logger.debug(llm_output) console.rule("Output message of the LLM:")
console.print(llm_output)
log_entry["llm_output"] = llm_output log_entry["llm_output"] = llm_output
# Parse # Parse
self.logger.debug("=== Extracting action ===")
try: try:
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e: except Exception as e:
self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}") console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
rationale, raw_code_action = llm_output, llm_output rationale, raw_code_action = llm_output, llm_output
try: try:
@ -1200,12 +1139,12 @@ class ReactCodeAgent(ReactAgent):
state=self.state, state=self.state,
authorized_imports=self.authorized_imports, authorized_imports=self.authorized_imports,
) )
self.logger.warning("Print outputs:") console.print("Print outputs:")
self.logger.log(32, self.state["print_outputs"]) console.print(self.state["print_outputs"])
observation = "Print outputs:\n" + self.state["print_outputs"] observation = "Print outputs:\n" + self.state["print_outputs"]
if result is not None: if result is not None:
self.logger.warning("Last output from code snippet:") console.print("Last output from code snippet:")
self.logger.log(32, str(result)) console.print(str(result))
observation += "Last output from code snippet:\n" + str(result)[:100000] observation += "Last output from code snippet:\n" + str(result)[:100000]
log_entry["observation"] = observation log_entry["observation"] = observation
except Exception as e: except Exception as e:
@ -1215,8 +1154,8 @@ class ReactCodeAgent(ReactAgent):
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
for line in code_action.split("\n"): for line in code_action.split("\n"):
if line[: len("final_answer")] == "final_answer": if line[: len("final_answer")] == "final_answer":
self.logger.log(33, "Final answer:") console.print("Final answer:")
self.logger.log(32, result) console.print(f"[bold]{result}")
log_entry["final_answer"] = result log_entry["final_answer"] = result
return result return result

View File

@ -23,7 +23,7 @@ from typing import Dict
from huggingface_hub import hf_hub_download, list_spaces from huggingface_hub import hf_hub_download, list_spaces
from ..utils import is_offline_mode from transformers.utils import is_offline_mode
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
@ -128,7 +128,7 @@ def get_remote_tools(logger, organization="huggingface-tools"):
return tools return tools
def setup_default_tools(logger): def setup_default_tools():
default_tools = {} default_tools = {}
main_module = importlib.import_module("transformers") main_module = importlib.import_module("transformers")
tools_module = main_module.agents tools_module = main_module.agents

View File

@ -19,9 +19,8 @@ import re
import numpy as np import numpy as np
import torch import torch
from ..models.auto import AutoProcessor from transformers import AutoProcessor, VisionEncoderDecoderModel
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel from transformers.utils import is_vision_available
from ..utils import is_vision_available
from .tools import PipelineTool from .tools import PipelineTool

View File

@ -18,8 +18,8 @@
import torch import torch
from PIL import Image from PIL import Image
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor
from ..utils import requires_backends from transformers.utils import requires_backends
from .tools import PipelineTool from .tools import PipelineTool

View File

@ -20,12 +20,11 @@ from typing import Dict, List, Optional
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
from .. import AutoTokenizer from transformers import AutoTokenizer, Pipeline
from ..pipelines.base import Pipeline import logging
from ..utils import logging
logger = logging.get_logger(__name__) logger = logging.getLogger(__name__)
class MessageRole(str, Enum): class MessageRole(str, Enum):

View File

@ -14,11 +14,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..utils import logging
from .agent_types import AgentAudio, AgentImage, AgentText from .agent_types import AgentAudio, AgentImage, AgentText
from .utils import console
logger = logging.get_logger(__name__)
def pull_message(step_log: dict, test_mode: bool = True): def pull_message(step_log: dict, test_mode: bool = True):
@ -107,11 +104,12 @@ class Monitor:
def update_metrics(self, step_log): def update_metrics(self, step_log):
step_duration = step_log["step_duration"] step_duration = step_log["step_duration"]
self.step_durations.append(step_duration) self.step_durations.append(step_duration)
logger.info(f"Step {len(self.step_durations)}:") console.print(f"Step {len(self.step_durations)}:")
logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)") console.print(f"- Time taken: {step_duration:.2f} seconds")
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
logger.info(f"- Input tokens: {self.total_input_token_count}") console.print(f"- Input tokens: {self.total_input_token_count}")
logger.info(f"- Output tokens: {self.total_output_token_count}") console.print(f"- Output tokens: {self.total_output_token_count}")

View File

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import re import re
from ..utils import cached_file from transformers.utils import cached_file
# docstyle-ignore # docstyle-ignore

View File

@ -22,12 +22,7 @@ from importlib import import_module
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import numpy as np import numpy as np
import pandas as pd
from ..utils import is_pandas_available
if is_pandas_available():
import pandas as pd
class InterpreterError(ValueError): class InterpreterError(ValueError):

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor from transformers import WhisperForConditionalGeneration, WhisperProcessor
from .tools import PipelineTool from .tools import PipelineTool

View File

@ -17,8 +17,8 @@
import torch import torch
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
from ..utils import is_datasets_available from transformers.utils import is_datasets_available
from .tools import PipelineTool from .tools import PipelineTool

View File

@ -30,13 +30,13 @@ from huggingface_hub import create_repo, get_collection, hf_hub_download, metada
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from packaging import version from packaging import version
from ..dynamic_module_utils import ( from transformers.dynamic_module_utils import (
custom_object_save, custom_object_save,
get_class_from_dynamic_module, get_class_from_dynamic_module,
get_imports, get_imports,
) )
from ..models.auto import AutoProcessor from transformers import AutoProcessor
from ..utils import ( from transformers.utils import (
CONFIG_NAME, CONFIG_NAME,
TypeHintParsingException, TypeHintParsingException,
cached_file, cached_file,
@ -44,12 +44,11 @@ from ..utils import (
is_accelerate_available, is_accelerate_available,
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
logging,
) )
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
if is_torch_available(): if is_torch_available():

View File

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from .tools import PipelineTool from .tools import PipelineTool

10
agents/utils.py Normal file
View File

@ -0,0 +1,10 @@
from transformers.utils.import_utils import _is_package_available
_pygments_available = _is_package_available("pygments")
def is_pygments_available():
return _pygments_available
from rich.console import Console
console = Console()

View File

@ -47,10 +47,10 @@ TRANSFORMERS_NO_ADVISORY_WARNINGS=1 ./myprogram.py
Here is an example of how to use the same logger as the library in your own module or script: Here is an example of how to use the same logger as the library in your own module or script:
```python ```python
from transformers.utils import logging import logging
logging.set_verbosity_info() logging.set_verbosity_info()
logger = logging.get_logger("transformers") logger = logging.getLogger(__name__)("transformers")
logger.info("INFO") logger.info("INFO")
logger.warning("WARN") logger.warning("WARN")
``` ```
@ -104,7 +104,7 @@ See reference of the `captureWarnings` method below.
[[autodoc]] logging.set_verbosity [[autodoc]] logging.set_verbosity
[[autodoc]] logging.get_logger [[autodoc]] logging.getLogger(__name__)
[[autodoc]] logging.enable_default_handler [[autodoc]] logging.enable_default_handler

View File

@ -0,0 +1,20 @@
from agents import load_tool, ReactCodeAgent, HfApiEngine
# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
# Import tool from LangChain
from agents.search import DuckDuckGoSearchTool
search_tool = DuckDuckGoSearchTool()
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
# Initialize the agent with both tools
agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine)
# Run it!
result = agent.run(
"Return me a photo of the car that James bond drove in the latest movie.",
)
print(result)

View File

@ -3,3 +3,4 @@ evaluate
datasets==2.3.2 datasets==2.3.2
schedulefree schedulefree
huggingface_hub>=0.20.0 huggingface_hub>=0.20.0
duckduckgo-search

70
poetry.lock generated
View File

@ -264,6 +264,41 @@ files = [
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
] ]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
optional = false
python-versions = ">=3.8"
files = [
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
]
[package.dependencies]
mdurl = ">=0.1,<1.0"
[package.extras]
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
code-style = ["pre-commit (>=3.0,<4.0)"]
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
linkify = ["linkify-it-py (>=1,<3)"]
plugins = ["mdit-py-plugins"]
profiling = ["gprof2dot"]
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
[[package]]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
optional = false
python-versions = ">=3.7"
files = [
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]] [[package]]
name = "numpy" name = "numpy"
version = "2.1.3" version = "2.1.3"
@ -354,6 +389,20 @@ files = [
dev = ["pre-commit", "tox"] dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"] testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pygments"
version = "2.18.0"
description = "Pygments is a syntax highlighting package written in Python."
optional = false
python-versions = ">=3.8"
files = [
{file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"},
{file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"},
]
[package.extras]
windows-terminal = ["colorama (>=0.4.6)"]
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "8.3.4" version = "8.3.4"
@ -562,6 +611,25 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rich"
version = "13.9.4"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"},
{file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"},
]
[package.dependencies]
markdown-it-py = ">=2.2.0"
pygments = ">=2.13.0,<3.0.0"
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""}
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]] [[package]]
name = "safetensors" name = "safetensors"
version = "0.4.5" version = "0.4.5"
@ -888,4 +956,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "de2c60fd8f7c54521b2204b3af144b974059361d7892c70eeb325a7fe52bd489" content-hash = "985797a96ac1b58e4892479d96bbd1c696b452d760fa421a28b91ea8b3dbf977"

View File

@ -62,6 +62,7 @@ python = ">=3.10,<3.13"
transformers = ">=4.0.0" transformers = ">=4.0.0"
pytest = {version = ">=8.1.0", optional = true} pytest = {version = ">=8.1.0", optional = true}
requests = "^2.32.3" requests = "^2.32.3"
rich = "^13.9.4"
[build-system] [build-system]