Nicer console outputs with rich
This commit is contained in:
parent
f3dcf1f013
commit
146ee3dd32
|
@ -16,7 +16,7 @@
|
|||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import (
|
||||
from transformers.utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
|
|
|
@ -19,10 +19,11 @@ import uuid
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
|
||||
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
@ -159,7 +160,7 @@ class AgentImage(AgentType, ImageType):
|
|||
|
||||
return self._path
|
||||
|
||||
def save(self, output_bytes, format, **params):
|
||||
def save(self, output_bytes, format = None, **params):
|
||||
"""
|
||||
Saves the image to a file.
|
||||
Args:
|
||||
|
@ -168,7 +169,7 @@ class AgentImage(AgentType, ImageType):
|
|||
**params: Additional parameters to pass to PIL.Image.save.
|
||||
"""
|
||||
img = self.to_raw()
|
||||
img.save(output_bytes, format, **params)
|
||||
img.save(output_bytes, format=format, **params)
|
||||
|
||||
|
||||
class AgentAudio(AgentType, str):
|
||||
|
|
233
agents/agents.py
233
agents/agents.py
|
@ -19,10 +19,12 @@ import logging
|
|||
import re
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import rich
|
||||
from rich import markdown as rich_markdown
|
||||
|
||||
from .. import is_torch_available
|
||||
from ..utils import logging as transformers_logging
|
||||
from ..utils.import_utils import is_pygments_available
|
||||
from transformers.utils import is_torch_available
|
||||
import logging
|
||||
from .utils import console
|
||||
from .agent_types import AgentAudio, AgentImage
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||
from .llm_engine import HfApiEngine, MessageRole
|
||||
|
@ -48,52 +50,6 @@ from .tools import (
|
|||
)
|
||||
|
||||
|
||||
if is_pygments_available():
|
||||
from pygments import highlight
|
||||
from pygments.formatters import Terminal256Formatter
|
||||
from pygments.lexers import PythonLexer
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
grey = "\x1b[38;20m"
|
||||
bold_yellow = "\x1b[33;1m"
|
||||
red = "\x1b[31;20m"
|
||||
green = "\x1b[32;20m"
|
||||
bold_green = "\x1b[32;20;1m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
bold_white = "\x1b[37;1m"
|
||||
orange = "\x1b[38;5;214m"
|
||||
bold_orange = "\x1b[38;5;214;1m"
|
||||
reset = "\x1b[0m"
|
||||
format = "%(message)s"
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: grey + format + reset,
|
||||
logging.INFO: format,
|
||||
logging.WARNING: bold_yellow + format + reset,
|
||||
logging.ERROR: red + format + reset,
|
||||
logging.CRITICAL: bold_red + format + reset,
|
||||
31: reset + format + reset,
|
||||
32: green + format + reset,
|
||||
33: bold_green + format + reset,
|
||||
34: bold_white + format + reset,
|
||||
35: orange + format + reset,
|
||||
36: bold_orange + format + reset,
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt)
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
logger = transformers_logging.get_logger(__name__)
|
||||
logger.propagate = False
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(CustomFormatter())
|
||||
logger.addHandler(ch)
|
||||
|
||||
|
||||
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||
try:
|
||||
first_accolade_index = json_blob.find("{")
|
||||
|
@ -142,9 +98,10 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
|||
elif "action" in tool_call:
|
||||
return tool_call["action"], None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
|
||||
)
|
||||
missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call]
|
||||
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
|
||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]:
|
||||
|
@ -197,18 +154,18 @@ class Toolbox:
|
|||
self._tools = {tool.name: tool for tool in tools}
|
||||
if add_base_tools:
|
||||
self.add_base_tools()
|
||||
self._load_tools_if_needed()
|
||||
# self._load_tools_if_needed()
|
||||
|
||||
def add_base_tools(self, add_python_interpreter: bool = False):
|
||||
global _tools_are_initialized
|
||||
global HUGGINGFACE_DEFAULT_TOOLS
|
||||
if not _tools_are_initialized:
|
||||
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger)
|
||||
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
|
||||
_tools_are_initialized = True
|
||||
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
||||
if tool.name != "python_interpreter" or add_python_interpreter:
|
||||
self.add_tool(tool)
|
||||
self._load_tools_if_needed()
|
||||
# self._load_tools_if_needed()
|
||||
|
||||
@property
|
||||
def tools(self) -> Dict[str, Tool]:
|
||||
|
@ -271,11 +228,11 @@ class Toolbox:
|
|||
"""Clears the toolbox"""
|
||||
self._tools = {}
|
||||
|
||||
def _load_tools_if_needed(self):
|
||||
for name, tool in self._tools.items():
|
||||
if not isinstance(tool, Tool):
|
||||
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||
self._tools[name] = load_tool(task_or_repo_id)
|
||||
# def _load_tools_if_needed(self):
|
||||
# for name, tool in self._tools.items():
|
||||
# if not isinstance(tool, Tool):
|
||||
# task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||
# self._tools[name] = load_tool(task_or_repo_id)
|
||||
|
||||
def __repr__(self):
|
||||
toolbox_description = "Toolbox contents:\n"
|
||||
|
@ -290,6 +247,8 @@ class AgentError(Exception):
|
|||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
console.print(f"[bold red]{message}[/bold red]")
|
||||
|
||||
|
||||
|
||||
class AgentParsingError(AgentError):
|
||||
|
@ -362,7 +321,7 @@ class Agent:
|
|||
max_iterations: int = 6,
|
||||
tool_parser: Optional[Callable] = None,
|
||||
add_base_tools: bool = False,
|
||||
verbose: int = 0,
|
||||
verbose: bool = False,
|
||||
grammar: Optional[Dict[str, str]] = None,
|
||||
managed_agents: Optional[List] = None,
|
||||
step_callbacks: Optional[List[Callable]] = None,
|
||||
|
@ -380,7 +339,6 @@ class Agent:
|
|||
)
|
||||
self.additional_args = additional_args
|
||||
self.max_iterations = max_iterations
|
||||
self.logger = logger
|
||||
self.tool_parser = tool_parser
|
||||
self.grammar = grammar
|
||||
|
||||
|
@ -406,13 +364,7 @@ class Agent:
|
|||
self.prompt = None
|
||||
self.logs = []
|
||||
self.task = None
|
||||
|
||||
if verbose == 0:
|
||||
logger.setLevel(logging.WARNING)
|
||||
elif verbose == 1:
|
||||
logger.setLevel(logging.INFO)
|
||||
elif verbose == 2:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
self.verbose = verbose
|
||||
|
||||
# Initialize step callbacks
|
||||
self.step_callbacks = step_callbacks if step_callbacks is not None else []
|
||||
|
@ -441,10 +393,8 @@ class Agent:
|
|||
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
||||
)
|
||||
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
|
||||
self.logger.log(33, "======== New task ========")
|
||||
self.logger.log(34, self.task)
|
||||
self.logger.debug("System prompt is as follows:")
|
||||
self.logger.debug(self.system_prompt)
|
||||
console.rule("New task", characters='=')
|
||||
console.print(self.task)
|
||||
|
||||
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||
"""
|
||||
|
@ -521,7 +471,6 @@ class Agent:
|
|||
split[-1],
|
||||
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
||||
except Exception as e:
|
||||
self.logger.error(e, exc_info=1)
|
||||
raise AgentParsingError(
|
||||
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
||||
)
|
||||
|
@ -541,7 +490,7 @@ class Agent:
|
|||
available_tools = {**available_tools, **self.managed_agents}
|
||||
if tool_name not in available_tools:
|
||||
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||
self.logger.error(error_msg, exc_info=1)
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
||||
try:
|
||||
|
@ -549,38 +498,39 @@ class Agent:
|
|||
observation = available_tools[tool_name](arguments)
|
||||
elif isinstance(arguments, dict):
|
||||
for key, value in arguments.items():
|
||||
# if the value is the name of a state variable like "image.png", replace it with the actual value
|
||||
if isinstance(value, str) and value in self.state:
|
||||
arguments[key] = self.state[value]
|
||||
observation = available_tools[tool_name](**arguments)
|
||||
else:
|
||||
raise AgentExecutionError(
|
||||
f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||
)
|
||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
return observation
|
||||
except Exception as e:
|
||||
if tool_name in self.toolbox.tools:
|
||||
raise AgentExecutionError(
|
||||
tool_description = get_tool_description_with_args(available_tools[tool_name])
|
||||
error_msg = (
|
||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}"
|
||||
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||
)
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
elif tool_name in self.managed_agents:
|
||||
raise AgentExecutionError(
|
||||
error_msg = (
|
||||
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
||||
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
||||
)
|
||||
console.print(f"[bold red]{error_msg}")
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
||||
|
||||
def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
|
||||
self.logger.warning("=== Agent thoughts:")
|
||||
self.logger.log(31, rationale)
|
||||
self.logger.warning(">>> Agent is executing the code below:")
|
||||
if is_pygments_available():
|
||||
self.logger.log(
|
||||
31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
|
||||
)
|
||||
else:
|
||||
self.logger.log(31, code_action)
|
||||
self.logger.warning("====")
|
||||
if self.verbose:
|
||||
console.rule("Agent thoughts")
|
||||
console.print(rationale)
|
||||
console.rule("Agent is executing the code below:", align="left")
|
||||
console.print(code_action)
|
||||
console.rule("", align="left")
|
||||
|
||||
def run(self, **kwargs):
|
||||
"""To be implemented in the child class"""
|
||||
|
@ -617,13 +567,6 @@ class CodeAgent(Agent):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
if not is_pygments_available():
|
||||
transformers_logging.warning_once(
|
||||
logger,
|
||||
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
|
||||
"CodeAgent.",
|
||||
)
|
||||
|
||||
self.python_evaluator = evaluate_python_code
|
||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||
|
@ -669,20 +612,21 @@ class CodeAgent(Agent):
|
|||
}
|
||||
|
||||
self.prompt = [prompt_message, task_message]
|
||||
self.logger.info("====Executing with this prompt====")
|
||||
self.logger.info(self.prompt)
|
||||
|
||||
if self.verbose:
|
||||
console.rule("Executing with this prompt")
|
||||
console.print(self.prompt)
|
||||
|
||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
|
||||
|
||||
if return_generated_code:
|
||||
return llm_output
|
||||
|
||||
# Parse
|
||||
try:
|
||||
rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||
except Exception as e:
|
||||
self.logger.debug(
|
||||
if self.verbose:
|
||||
console.print(
|
||||
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
|
||||
)
|
||||
rationale, code_action = "", llm_output
|
||||
|
@ -691,7 +635,7 @@ class CodeAgent(Agent):
|
|||
code_action = self.parse_code_blob(code_action)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
|
||||
self.logger.error(error_msg, exc_info=1)
|
||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||
return error_msg
|
||||
|
||||
# Execute
|
||||
|
@ -705,11 +649,12 @@ class CodeAgent(Agent):
|
|||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
self.logger.info(self.state["print_outputs"])
|
||||
if self.verbose:
|
||||
console.print(self.state["print_outputs"])
|
||||
return output
|
||||
except Exception as e:
|
||||
error_msg = f"Error in execution: {e}. Be sure to provide correct code."
|
||||
self.logger.error(error_msg, exc_info=1)
|
||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||
return error_msg
|
||||
|
||||
|
||||
|
@ -759,7 +704,7 @@ class ReactAgent(Agent):
|
|||
self.prompt = [
|
||||
{
|
||||
"role": MessageRole.SYSTEM,
|
||||
"content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
||||
"content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
||||
}
|
||||
]
|
||||
self.prompt += self.write_inner_memory_from_logs()[1:]
|
||||
|
@ -772,7 +717,9 @@ class ReactAgent(Agent):
|
|||
try:
|
||||
return self.llm_engine(self.prompt)
|
||||
except Exception as e:
|
||||
return f"Error in generating final llm output: {e}."
|
||||
error_msg = f"Error in generating final LLM output: {e}."
|
||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||
return error_msg
|
||||
|
||||
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
|
||||
"""
|
||||
|
@ -815,7 +762,6 @@ class ReactAgent(Agent):
|
|||
if "final_answer" in step_log_entry:
|
||||
final_answer = step_log_entry["final_answer"]
|
||||
except AgentError as e:
|
||||
self.logger.error(e, exc_info=1)
|
||||
step_log_entry["error"] = e
|
||||
finally:
|
||||
step_end_time = time.time()
|
||||
|
@ -831,7 +777,7 @@ class ReactAgent(Agent):
|
|||
error_message = "Reached max iterations."
|
||||
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
||||
self.logs.append(final_step_log)
|
||||
self.logger.error(error_message, exc_info=1)
|
||||
console.print(f"[bold red]{error_message}")
|
||||
final_answer = self.provide_final_answer(task)
|
||||
final_step_log["final_answer"] = final_answer
|
||||
final_step_log["step_duration"] = 0
|
||||
|
@ -857,7 +803,6 @@ class ReactAgent(Agent):
|
|||
if "final_answer" in step_log_entry:
|
||||
final_answer = step_log_entry["final_answer"]
|
||||
except AgentError as e:
|
||||
self.logger.error(e, exc_info=1)
|
||||
step_log_entry["error"] = e
|
||||
finally:
|
||||
step_end_time = time.time()
|
||||
|
@ -872,7 +817,7 @@ class ReactAgent(Agent):
|
|||
error_message = "Reached max iterations."
|
||||
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
||||
self.logs.append(final_step_log)
|
||||
self.logger.error(error_message, exc_info=1)
|
||||
console.print(f"[bold red]{error_message}")
|
||||
final_answer = self.provide_final_answer(task)
|
||||
final_step_log["final_answer"] = final_answer
|
||||
final_step_log["step_duration"] = 0
|
||||
|
@ -931,8 +876,8 @@ Now begin!""",
|
|||
{answer_facts}
|
||||
```""".strip()
|
||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||
self.logger.log(36, "===== Initial plan =====")
|
||||
self.logger.log(35, final_plan_redaction)
|
||||
console.rule("[orange]Initial plan")
|
||||
console.print(final_plan_redaction)
|
||||
else: # update plan
|
||||
agent_memory = self.write_inner_memory_from_logs(
|
||||
summary_mode=False
|
||||
|
@ -977,8 +922,8 @@ Now begin!""",
|
|||
{facts_update}
|
||||
```"""
|
||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||
self.logger.log(36, "===== Updated plan =====")
|
||||
self.logger.log(35, final_plan_redaction)
|
||||
console.rule("Updated plan")
|
||||
console.print(final_plan_redaction)
|
||||
|
||||
|
||||
class ReactJsonAgent(ReactAgent):
|
||||
|
@ -1022,13 +967,14 @@ class ReactJsonAgent(ReactAgent):
|
|||
agent_memory = self.write_inner_memory_from_logs()
|
||||
|
||||
self.prompt = agent_memory
|
||||
self.logger.debug("===== New step =====")
|
||||
console.rule("New step")
|
||||
|
||||
# Add new step in logs
|
||||
log_entry["agent_memory"] = agent_memory.copy()
|
||||
|
||||
self.logger.info("===== Calling LLM with this last message: =====")
|
||||
self.logger.info(self.prompt[-1])
|
||||
if self.verbose:
|
||||
console.rule("Calling LLM with this last message:")
|
||||
console.print(self.prompt[-1])
|
||||
|
||||
try:
|
||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||
|
@ -1037,12 +983,12 @@ class ReactJsonAgent(ReactAgent):
|
|||
)
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||
self.logger.debug("===== Output message of the LLM: =====")
|
||||
self.logger.debug(llm_output)
|
||||
console.rule("===== Output message of the LLM: =====")
|
||||
console.print(llm_output)
|
||||
log_entry["llm_output"] = llm_output
|
||||
|
||||
# Parse
|
||||
self.logger.debug("===== Extracting action =====")
|
||||
console.rule("===== Extracting action =====")
|
||||
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
|
||||
|
||||
try:
|
||||
|
@ -1054,9 +1000,9 @@ class ReactJsonAgent(ReactAgent):
|
|||
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
||||
|
||||
# Execute
|
||||
self.logger.warning("=== Agent thoughts:")
|
||||
self.logger.log(31, rationale)
|
||||
self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
|
||||
console.print("=== Agent thoughts:")
|
||||
console.print(rationale)
|
||||
console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
|
||||
if tool_name == "final_answer":
|
||||
if isinstance(arguments, dict):
|
||||
if "answer" in arguments:
|
||||
|
@ -1087,7 +1033,6 @@ class ReactJsonAgent(ReactAgent):
|
|||
updated_information = f"Stored '{observation_name}' in memory."
|
||||
else:
|
||||
updated_information = str(observation).strip()
|
||||
self.logger.info(updated_information)
|
||||
log_entry["observation"] = updated_information
|
||||
return log_entry
|
||||
|
||||
|
@ -1126,13 +1071,6 @@ class ReactCodeAgent(ReactAgent):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
if not is_pygments_available():
|
||||
transformers_logging.warning_once(
|
||||
logger,
|
||||
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
|
||||
"ReactCodeAgent.",
|
||||
)
|
||||
|
||||
self.python_evaluator = evaluate_python_code
|
||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||
|
@ -1147,13 +1085,14 @@ class ReactCodeAgent(ReactAgent):
|
|||
agent_memory = self.write_inner_memory_from_logs()
|
||||
|
||||
self.prompt = agent_memory.copy()
|
||||
self.logger.debug("===== New step =====")
|
||||
console.rule("New step")
|
||||
|
||||
# Add new step in logs
|
||||
log_entry["agent_memory"] = agent_memory.copy()
|
||||
|
||||
self.logger.info("===== Calling LLM with these last messages: =====")
|
||||
self.logger.info(self.prompt[-2:])
|
||||
if self.verbose:
|
||||
console.print("===== Calling LLM with these last messages: =====")
|
||||
console.print(self.prompt[-2:])
|
||||
|
||||
try:
|
||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||
|
@ -1163,16 +1102,16 @@ class ReactCodeAgent(ReactAgent):
|
|||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||
|
||||
self.logger.debug("=== Output message of the LLM:")
|
||||
self.logger.debug(llm_output)
|
||||
if self.verbose:
|
||||
console.rule("Output message of the LLM:")
|
||||
console.print(llm_output)
|
||||
log_entry["llm_output"] = llm_output
|
||||
|
||||
# Parse
|
||||
self.logger.debug("=== Extracting action ===")
|
||||
try:
|
||||
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
|
||||
console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
|
||||
rationale, raw_code_action = llm_output, llm_output
|
||||
|
||||
try:
|
||||
|
@ -1200,12 +1139,12 @@ class ReactCodeAgent(ReactAgent):
|
|||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
self.logger.warning("Print outputs:")
|
||||
self.logger.log(32, self.state["print_outputs"])
|
||||
console.print("Print outputs:")
|
||||
console.print(self.state["print_outputs"])
|
||||
observation = "Print outputs:\n" + self.state["print_outputs"]
|
||||
if result is not None:
|
||||
self.logger.warning("Last output from code snippet:")
|
||||
self.logger.log(32, str(result))
|
||||
console.print("Last output from code snippet:")
|
||||
console.print(str(result))
|
||||
observation += "Last output from code snippet:\n" + str(result)[:100000]
|
||||
log_entry["observation"] = observation
|
||||
except Exception as e:
|
||||
|
@ -1215,8 +1154,8 @@ class ReactCodeAgent(ReactAgent):
|
|||
raise AgentExecutionError(error_msg)
|
||||
for line in code_action.split("\n"):
|
||||
if line[: len("final_answer")] == "final_answer":
|
||||
self.logger.log(33, "Final answer:")
|
||||
self.logger.log(32, result)
|
||||
console.print("Final answer:")
|
||||
console.print(f"[bold]{result}")
|
||||
log_entry["final_answer"] = result
|
||||
return result
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from typing import Dict
|
|||
|
||||
from huggingface_hub import hf_hub_download, list_spaces
|
||||
|
||||
from ..utils import is_offline_mode
|
||||
from transformers.utils import is_offline_mode
|
||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
|
||||
|
||||
|
@ -128,7 +128,7 @@ def get_remote_tools(logger, organization="huggingface-tools"):
|
|||
return tools
|
||||
|
||||
|
||||
def setup_default_tools(logger):
|
||||
def setup_default_tools():
|
||||
default_tools = {}
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.agents
|
||||
|
|
|
@ -19,9 +19,8 @@ import re
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
|
||||
from ..utils import is_vision_available
|
||||
from transformers import AutoProcessor, VisionEncoderDecoderModel
|
||||
from transformers.utils import is_vision_available
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
|
||||
from ..utils import requires_backends
|
||||
from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor
|
||||
from transformers.utils import requires_backends
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
|
|
|
@ -20,12 +20,11 @@ from typing import Dict, List, Optional
|
|||
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from .. import AutoTokenizer
|
||||
from ..pipelines.base import Pipeline
|
||||
from ..utils import logging
|
||||
from transformers import AutoTokenizer, Pipeline
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
|
|
|
@ -14,11 +14,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..utils import logging
|
||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
from .utils import console
|
||||
|
||||
|
||||
def pull_message(step_log: dict, test_mode: bool = True):
|
||||
|
@ -107,11 +104,12 @@ class Monitor:
|
|||
def update_metrics(self, step_log):
|
||||
step_duration = step_log["step_duration"]
|
||||
self.step_durations.append(step_duration)
|
||||
logger.info(f"Step {len(self.step_durations)}:")
|
||||
logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
|
||||
console.print(f"Step {len(self.step_durations)}:")
|
||||
console.print(f"- Time taken: {step_duration:.2f} seconds")
|
||||
|
||||
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
||||
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
|
||||
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
|
||||
logger.info(f"- Input tokens: {self.total_input_token_count}")
|
||||
logger.info(f"- Output tokens: {self.total_output_token_count}")
|
||||
console.print(f"- Input tokens: {self.total_input_token_count}")
|
||||
console.print(f"- Output tokens: {self.total_output_token_count}")
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# limitations under the License.
|
||||
import re
|
||||
|
||||
from ..utils import cached_file
|
||||
from transformers.utils import cached_file
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
|
|
|
@ -22,11 +22,6 @@ from importlib import import_module
|
|||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import is_pandas_available
|
||||
|
||||
|
||||
if is_pandas_available():
|
||||
import pandas as pd
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
||||
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
import torch
|
||||
|
||||
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
||||
from ..utils import is_datasets_available
|
||||
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
||||
from transformers.utils import is_datasets_available
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
|
|
|
@ -30,13 +30,13 @@ from huggingface_hub import create_repo, get_collection, hf_hub_download, metada
|
|||
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
||||
from packaging import version
|
||||
|
||||
from ..dynamic_module_utils import (
|
||||
from transformers.dynamic_module_utils import (
|
||||
custom_object_save,
|
||||
get_class_from_dynamic_module,
|
||||
get_imports,
|
||||
)
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..utils import (
|
||||
from transformers import AutoProcessor
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
TypeHintParsingException,
|
||||
cached_file,
|
||||
|
@ -44,12 +44,11 @@ from ..utils import (
|
|||
is_accelerate_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
_pygments_available = _is_package_available("pygments")
|
||||
|
||||
def is_pygments_available():
|
||||
return _pygments_available
|
||||
|
||||
from rich.console import Console
|
||||
console = Console()
|
|
@ -47,10 +47,10 @@ TRANSFORMERS_NO_ADVISORY_WARNINGS=1 ./myprogram.py
|
|||
Here is an example of how to use the same logger as the library in your own module or script:
|
||||
|
||||
```python
|
||||
from transformers.utils import logging
|
||||
import logging
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger("transformers")
|
||||
logger = logging.getLogger(__name__)("transformers")
|
||||
logger.info("INFO")
|
||||
logger.warning("WARN")
|
||||
```
|
||||
|
@ -104,7 +104,7 @@ See reference of the `captureWarnings` method below.
|
|||
|
||||
[[autodoc]] logging.set_verbosity
|
||||
|
||||
[[autodoc]] logging.get_logger
|
||||
[[autodoc]] logging.getLogger(__name__)
|
||||
|
||||
[[autodoc]] logging.enable_default_handler
|
||||
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
from agents import load_tool, ReactCodeAgent, HfApiEngine
|
||||
|
||||
# Import tool from Hub
|
||||
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
|
||||
|
||||
# Import tool from LangChain
|
||||
from agents.search import DuckDuckGoSearchTool
|
||||
|
||||
search_tool = DuckDuckGoSearchTool()
|
||||
|
||||
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
|
||||
# Initialize the agent with both tools
|
||||
agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine)
|
||||
|
||||
# Run it!
|
||||
result = agent.run(
|
||||
"Return me a photo of the car that James bond drove in the latest movie.",
|
||||
)
|
||||
|
||||
print(result)
|
|
@ -3,3 +3,4 @@ evaluate
|
|||
datasets==2.3.2
|
||||
schedulefree
|
||||
huggingface_hub>=0.20.0
|
||||
duckduckgo-search
|
|
@ -264,6 +264,41 @@ files = [
|
|||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
|
||||
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mdurl = ">=0.1,<1.0"
|
||||
|
||||
[package.extras]
|
||||
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
||||
code-style = ["pre-commit (>=3.0,<4.0)"]
|
||||
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
||||
linkify = ["linkify-it-py (>=1,<3)"]
|
||||
plugins = ["mdit-py-plugins"]
|
||||
profiling = ["gprof2dot"]
|
||||
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
||||
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
description = "Markdown URL utilities"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
|
||||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "2.1.3"
|
||||
|
@ -354,6 +389,20 @@ files = [
|
|||
dev = ["pre-commit", "tox"]
|
||||
testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.18.0"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"},
|
||||
{file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
windows-terminal = ["colorama (>=0.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.3.4"
|
||||
|
@ -562,6 +611,25 @@ urllib3 = ">=1.21.1,<3"
|
|||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "13.9.4"
|
||||
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"},
|
||||
{file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
markdown-it-py = ">=2.2.0"
|
||||
pygments = ">=2.13.0,<3.0.0"
|
||||
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||
|
||||
[[package]]
|
||||
name = "safetensors"
|
||||
version = "0.4.5"
|
||||
|
@ -888,4 +956,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "de2c60fd8f7c54521b2204b3af144b974059361d7892c70eeb325a7fe52bd489"
|
||||
content-hash = "985797a96ac1b58e4892479d96bbd1c696b452d760fa421a28b91ea8b3dbf977"
|
||||
|
|
|
@ -62,6 +62,7 @@ python = ">=3.10,<3.13"
|
|||
transformers = ">=4.0.0"
|
||||
pytest = {version = ">=8.1.0", optional = true}
|
||||
requests = "^2.32.3"
|
||||
rich = "^13.9.4"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
|
Loading…
Reference in New Issue