Nicer console outputs with rich
This commit is contained in:
parent
f3dcf1f013
commit
146ee3dd32
|
@ -16,7 +16,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ..utils import (
|
from transformers.utils import (
|
||||||
OptionalDependencyNotAvailable,
|
OptionalDependencyNotAvailable,
|
||||||
_LazyModule,
|
_LazyModule,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
|
|
@ -19,10 +19,11 @@ import uuid
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
|
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -159,7 +160,7 @@ class AgentImage(AgentType, ImageType):
|
||||||
|
|
||||||
return self._path
|
return self._path
|
||||||
|
|
||||||
def save(self, output_bytes, format, **params):
|
def save(self, output_bytes, format = None, **params):
|
||||||
"""
|
"""
|
||||||
Saves the image to a file.
|
Saves the image to a file.
|
||||||
Args:
|
Args:
|
||||||
|
@ -168,7 +169,7 @@ class AgentImage(AgentType, ImageType):
|
||||||
**params: Additional parameters to pass to PIL.Image.save.
|
**params: Additional parameters to pass to PIL.Image.save.
|
||||||
"""
|
"""
|
||||||
img = self.to_raw()
|
img = self.to_raw()
|
||||||
img.save(output_bytes, format, **params)
|
img.save(output_bytes, format=format, **params)
|
||||||
|
|
||||||
|
|
||||||
class AgentAudio(AgentType, str):
|
class AgentAudio(AgentType, str):
|
||||||
|
|
237
agents/agents.py
237
agents/agents.py
|
@ -19,10 +19,12 @@ import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
import rich
|
||||||
|
from rich import markdown as rich_markdown
|
||||||
|
|
||||||
from .. import is_torch_available
|
from transformers.utils import is_torch_available
|
||||||
from ..utils import logging as transformers_logging
|
import logging
|
||||||
from ..utils.import_utils import is_pygments_available
|
from .utils import console
|
||||||
from .agent_types import AgentAudio, AgentImage
|
from .agent_types import AgentAudio, AgentImage
|
||||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||||
from .llm_engine import HfApiEngine, MessageRole
|
from .llm_engine import HfApiEngine, MessageRole
|
||||||
|
@ -48,52 +50,6 @@ from .tools import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_pygments_available():
|
|
||||||
from pygments import highlight
|
|
||||||
from pygments.formatters import Terminal256Formatter
|
|
||||||
from pygments.lexers import PythonLexer
|
|
||||||
|
|
||||||
|
|
||||||
class CustomFormatter(logging.Formatter):
|
|
||||||
grey = "\x1b[38;20m"
|
|
||||||
bold_yellow = "\x1b[33;1m"
|
|
||||||
red = "\x1b[31;20m"
|
|
||||||
green = "\x1b[32;20m"
|
|
||||||
bold_green = "\x1b[32;20;1m"
|
|
||||||
bold_red = "\x1b[31;1m"
|
|
||||||
bold_white = "\x1b[37;1m"
|
|
||||||
orange = "\x1b[38;5;214m"
|
|
||||||
bold_orange = "\x1b[38;5;214;1m"
|
|
||||||
reset = "\x1b[0m"
|
|
||||||
format = "%(message)s"
|
|
||||||
|
|
||||||
FORMATS = {
|
|
||||||
logging.DEBUG: grey + format + reset,
|
|
||||||
logging.INFO: format,
|
|
||||||
logging.WARNING: bold_yellow + format + reset,
|
|
||||||
logging.ERROR: red + format + reset,
|
|
||||||
logging.CRITICAL: bold_red + format + reset,
|
|
||||||
31: reset + format + reset,
|
|
||||||
32: green + format + reset,
|
|
||||||
33: bold_green + format + reset,
|
|
||||||
34: bold_white + format + reset,
|
|
||||||
35: orange + format + reset,
|
|
||||||
36: bold_orange + format + reset,
|
|
||||||
}
|
|
||||||
|
|
||||||
def format(self, record):
|
|
||||||
log_fmt = self.FORMATS.get(record.levelno)
|
|
||||||
formatter = logging.Formatter(log_fmt)
|
|
||||||
return formatter.format(record)
|
|
||||||
|
|
||||||
|
|
||||||
logger = transformers_logging.get_logger(__name__)
|
|
||||||
logger.propagate = False
|
|
||||||
ch = logging.StreamHandler()
|
|
||||||
ch.setFormatter(CustomFormatter())
|
|
||||||
logger.addHandler(ch)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||||
try:
|
try:
|
||||||
first_accolade_index = json_blob.find("{")
|
first_accolade_index = json_blob.find("{")
|
||||||
|
@ -142,9 +98,10 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
||||||
elif "action" in tool_call:
|
elif "action" in tool_call:
|
||||||
return tool_call["action"], None
|
return tool_call["action"], None
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call]
|
||||||
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
|
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
|
||||||
)
|
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]:
|
def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]:
|
||||||
|
@ -197,18 +154,18 @@ class Toolbox:
|
||||||
self._tools = {tool.name: tool for tool in tools}
|
self._tools = {tool.name: tool for tool in tools}
|
||||||
if add_base_tools:
|
if add_base_tools:
|
||||||
self.add_base_tools()
|
self.add_base_tools()
|
||||||
self._load_tools_if_needed()
|
# self._load_tools_if_needed()
|
||||||
|
|
||||||
def add_base_tools(self, add_python_interpreter: bool = False):
|
def add_base_tools(self, add_python_interpreter: bool = False):
|
||||||
global _tools_are_initialized
|
global _tools_are_initialized
|
||||||
global HUGGINGFACE_DEFAULT_TOOLS
|
global HUGGINGFACE_DEFAULT_TOOLS
|
||||||
if not _tools_are_initialized:
|
if not _tools_are_initialized:
|
||||||
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger)
|
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
|
||||||
_tools_are_initialized = True
|
_tools_are_initialized = True
|
||||||
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
||||||
if tool.name != "python_interpreter" or add_python_interpreter:
|
if tool.name != "python_interpreter" or add_python_interpreter:
|
||||||
self.add_tool(tool)
|
self.add_tool(tool)
|
||||||
self._load_tools_if_needed()
|
# self._load_tools_if_needed()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tools(self) -> Dict[str, Tool]:
|
def tools(self) -> Dict[str, Tool]:
|
||||||
|
@ -271,11 +228,11 @@ class Toolbox:
|
||||||
"""Clears the toolbox"""
|
"""Clears the toolbox"""
|
||||||
self._tools = {}
|
self._tools = {}
|
||||||
|
|
||||||
def _load_tools_if_needed(self):
|
# def _load_tools_if_needed(self):
|
||||||
for name, tool in self._tools.items():
|
# for name, tool in self._tools.items():
|
||||||
if not isinstance(tool, Tool):
|
# if not isinstance(tool, Tool):
|
||||||
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
# task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||||
self._tools[name] = load_tool(task_or_repo_id)
|
# self._tools[name] = load_tool(task_or_repo_id)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
toolbox_description = "Toolbox contents:\n"
|
toolbox_description = "Toolbox contents:\n"
|
||||||
|
@ -290,6 +247,8 @@ class AgentError(Exception):
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
console.print(f"[bold red]{message}[/bold red]")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AgentParsingError(AgentError):
|
class AgentParsingError(AgentError):
|
||||||
|
@ -362,7 +321,7 @@ class Agent:
|
||||||
max_iterations: int = 6,
|
max_iterations: int = 6,
|
||||||
tool_parser: Optional[Callable] = None,
|
tool_parser: Optional[Callable] = None,
|
||||||
add_base_tools: bool = False,
|
add_base_tools: bool = False,
|
||||||
verbose: int = 0,
|
verbose: bool = False,
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
grammar: Optional[Dict[str, str]] = None,
|
||||||
managed_agents: Optional[List] = None,
|
managed_agents: Optional[List] = None,
|
||||||
step_callbacks: Optional[List[Callable]] = None,
|
step_callbacks: Optional[List[Callable]] = None,
|
||||||
|
@ -380,7 +339,6 @@ class Agent:
|
||||||
)
|
)
|
||||||
self.additional_args = additional_args
|
self.additional_args = additional_args
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.logger = logger
|
|
||||||
self.tool_parser = tool_parser
|
self.tool_parser = tool_parser
|
||||||
self.grammar = grammar
|
self.grammar = grammar
|
||||||
|
|
||||||
|
@ -406,13 +364,7 @@ class Agent:
|
||||||
self.prompt = None
|
self.prompt = None
|
||||||
self.logs = []
|
self.logs = []
|
||||||
self.task = None
|
self.task = None
|
||||||
|
self.verbose = verbose
|
||||||
if verbose == 0:
|
|
||||||
logger.setLevel(logging.WARNING)
|
|
||||||
elif verbose == 1:
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
elif verbose == 2:
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
# Initialize step callbacks
|
# Initialize step callbacks
|
||||||
self.step_callbacks = step_callbacks if step_callbacks is not None else []
|
self.step_callbacks = step_callbacks if step_callbacks is not None else []
|
||||||
|
@ -441,10 +393,8 @@ class Agent:
|
||||||
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
||||||
)
|
)
|
||||||
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
|
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
|
||||||
self.logger.log(33, "======== New task ========")
|
console.rule("New task", characters='=')
|
||||||
self.logger.log(34, self.task)
|
console.print(self.task)
|
||||||
self.logger.debug("System prompt is as follows:")
|
|
||||||
self.logger.debug(self.system_prompt)
|
|
||||||
|
|
||||||
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
|
@ -521,7 +471,6 @@ class Agent:
|
||||||
split[-1],
|
split[-1],
|
||||||
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(e, exc_info=1)
|
|
||||||
raise AgentParsingError(
|
raise AgentParsingError(
|
||||||
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
||||||
)
|
)
|
||||||
|
@ -541,7 +490,7 @@ class Agent:
|
||||||
available_tools = {**available_tools, **self.managed_agents}
|
available_tools = {**available_tools, **self.managed_agents}
|
||||||
if tool_name not in available_tools:
|
if tool_name not in available_tools:
|
||||||
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||||
self.logger.error(error_msg, exc_info=1)
|
console.print(f"[bold red]{error_msg}")
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -549,38 +498,39 @@ class Agent:
|
||||||
observation = available_tools[tool_name](arguments)
|
observation = available_tools[tool_name](arguments)
|
||||||
elif isinstance(arguments, dict):
|
elif isinstance(arguments, dict):
|
||||||
for key, value in arguments.items():
|
for key, value in arguments.items():
|
||||||
# if the value is the name of a state variable like "image.png", replace it with the actual value
|
|
||||||
if isinstance(value, str) and value in self.state:
|
if isinstance(value, str) and value in self.state:
|
||||||
arguments[key] = self.state[value]
|
arguments[key] = self.state[value]
|
||||||
observation = available_tools[tool_name](**arguments)
|
observation = available_tools[tool_name](**arguments)
|
||||||
else:
|
else:
|
||||||
raise AgentExecutionError(
|
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||||
f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
console.print(f"[bold red]{error_msg}")
|
||||||
)
|
raise AgentExecutionError(error_msg)
|
||||||
return observation
|
return observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if tool_name in self.toolbox.tools:
|
if tool_name in self.toolbox.tools:
|
||||||
raise AgentExecutionError(
|
tool_description = get_tool_description_with_args(available_tools[tool_name])
|
||||||
|
error_msg = (
|
||||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||||
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}"
|
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||||
)
|
)
|
||||||
|
console.print(f"[bold red]{error_msg}")
|
||||||
|
raise AgentExecutionError(error_msg)
|
||||||
elif tool_name in self.managed_agents:
|
elif tool_name in self.managed_agents:
|
||||||
raise AgentExecutionError(
|
error_msg = (
|
||||||
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
|
||||||
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
|
||||||
)
|
)
|
||||||
|
console.print(f"[bold red]{error_msg}")
|
||||||
|
raise AgentExecutionError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
|
def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
|
||||||
self.logger.warning("=== Agent thoughts:")
|
if self.verbose:
|
||||||
self.logger.log(31, rationale)
|
console.rule("Agent thoughts")
|
||||||
self.logger.warning(">>> Agent is executing the code below:")
|
console.print(rationale)
|
||||||
if is_pygments_available():
|
console.rule("Agent is executing the code below:", align="left")
|
||||||
self.logger.log(
|
console.print(code_action)
|
||||||
31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
|
console.rule("", align="left")
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.logger.log(31, code_action)
|
|
||||||
self.logger.warning("====")
|
|
||||||
|
|
||||||
def run(self, **kwargs):
|
def run(self, **kwargs):
|
||||||
"""To be implemented in the child class"""
|
"""To be implemented in the child class"""
|
||||||
|
@ -617,13 +567,6 @@ class CodeAgent(Agent):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_pygments_available():
|
|
||||||
transformers_logging.warning_once(
|
|
||||||
logger,
|
|
||||||
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
|
|
||||||
"CodeAgent.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.python_evaluator = evaluate_python_code
|
self.python_evaluator = evaluate_python_code
|
||||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||||
|
@ -669,29 +612,30 @@ class CodeAgent(Agent):
|
||||||
}
|
}
|
||||||
|
|
||||||
self.prompt = [prompt_message, task_message]
|
self.prompt = [prompt_message, task_message]
|
||||||
self.logger.info("====Executing with this prompt====")
|
|
||||||
self.logger.info(self.prompt)
|
if self.verbose:
|
||||||
|
console.rule("Executing with this prompt")
|
||||||
|
console.print(self.prompt)
|
||||||
|
|
||||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
|
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
|
||||||
|
|
||||||
if return_generated_code:
|
|
||||||
return llm_output
|
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
try:
|
try:
|
||||||
rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.debug(
|
if self.verbose:
|
||||||
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
|
console.print(
|
||||||
)
|
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
|
||||||
|
)
|
||||||
rationale, code_action = "", llm_output
|
rationale, code_action = "", llm_output
|
||||||
|
|
||||||
try:
|
try:
|
||||||
code_action = self.parse_code_blob(code_action)
|
code_action = self.parse_code_blob(code_action)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
|
error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
|
||||||
self.logger.error(error_msg, exc_info=1)
|
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
|
@ -705,11 +649,12 @@ class CodeAgent(Agent):
|
||||||
state=self.state,
|
state=self.state,
|
||||||
authorized_imports=self.authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
)
|
)
|
||||||
self.logger.info(self.state["print_outputs"])
|
if self.verbose:
|
||||||
|
console.print(self.state["print_outputs"])
|
||||||
return output
|
return output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error in execution: {e}. Be sure to provide correct code."
|
error_msg = f"Error in execution: {e}. Be sure to provide correct code."
|
||||||
self.logger.error(error_msg, exc_info=1)
|
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
|
|
||||||
|
@ -759,7 +704,7 @@ class ReactAgent(Agent):
|
||||||
self.prompt = [
|
self.prompt = [
|
||||||
{
|
{
|
||||||
"role": MessageRole.SYSTEM,
|
"role": MessageRole.SYSTEM,
|
||||||
"content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
"content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
self.prompt += self.write_inner_memory_from_logs()[1:]
|
self.prompt += self.write_inner_memory_from_logs()[1:]
|
||||||
|
@ -772,7 +717,9 @@ class ReactAgent(Agent):
|
||||||
try:
|
try:
|
||||||
return self.llm_engine(self.prompt)
|
return self.llm_engine(self.prompt)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error in generating final llm output: {e}."
|
error_msg = f"Error in generating final LLM output: {e}."
|
||||||
|
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||||
|
return error_msg
|
||||||
|
|
||||||
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
|
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -815,7 +762,6 @@ class ReactAgent(Agent):
|
||||||
if "final_answer" in step_log_entry:
|
if "final_answer" in step_log_entry:
|
||||||
final_answer = step_log_entry["final_answer"]
|
final_answer = step_log_entry["final_answer"]
|
||||||
except AgentError as e:
|
except AgentError as e:
|
||||||
self.logger.error(e, exc_info=1)
|
|
||||||
step_log_entry["error"] = e
|
step_log_entry["error"] = e
|
||||||
finally:
|
finally:
|
||||||
step_end_time = time.time()
|
step_end_time = time.time()
|
||||||
|
@ -831,7 +777,7 @@ class ReactAgent(Agent):
|
||||||
error_message = "Reached max iterations."
|
error_message = "Reached max iterations."
|
||||||
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
self.logger.error(error_message, exc_info=1)
|
console.print(f"[bold red]{error_message}")
|
||||||
final_answer = self.provide_final_answer(task)
|
final_answer = self.provide_final_answer(task)
|
||||||
final_step_log["final_answer"] = final_answer
|
final_step_log["final_answer"] = final_answer
|
||||||
final_step_log["step_duration"] = 0
|
final_step_log["step_duration"] = 0
|
||||||
|
@ -857,7 +803,6 @@ class ReactAgent(Agent):
|
||||||
if "final_answer" in step_log_entry:
|
if "final_answer" in step_log_entry:
|
||||||
final_answer = step_log_entry["final_answer"]
|
final_answer = step_log_entry["final_answer"]
|
||||||
except AgentError as e:
|
except AgentError as e:
|
||||||
self.logger.error(e, exc_info=1)
|
|
||||||
step_log_entry["error"] = e
|
step_log_entry["error"] = e
|
||||||
finally:
|
finally:
|
||||||
step_end_time = time.time()
|
step_end_time = time.time()
|
||||||
|
@ -872,7 +817,7 @@ class ReactAgent(Agent):
|
||||||
error_message = "Reached max iterations."
|
error_message = "Reached max iterations."
|
||||||
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
self.logger.error(error_message, exc_info=1)
|
console.print(f"[bold red]{error_message}")
|
||||||
final_answer = self.provide_final_answer(task)
|
final_answer = self.provide_final_answer(task)
|
||||||
final_step_log["final_answer"] = final_answer
|
final_step_log["final_answer"] = final_answer
|
||||||
final_step_log["step_duration"] = 0
|
final_step_log["step_duration"] = 0
|
||||||
|
@ -931,8 +876,8 @@ Now begin!""",
|
||||||
{answer_facts}
|
{answer_facts}
|
||||||
```""".strip()
|
```""".strip()
|
||||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||||
self.logger.log(36, "===== Initial plan =====")
|
console.rule("[orange]Initial plan")
|
||||||
self.logger.log(35, final_plan_redaction)
|
console.print(final_plan_redaction)
|
||||||
else: # update plan
|
else: # update plan
|
||||||
agent_memory = self.write_inner_memory_from_logs(
|
agent_memory = self.write_inner_memory_from_logs(
|
||||||
summary_mode=False
|
summary_mode=False
|
||||||
|
@ -977,8 +922,8 @@ Now begin!""",
|
||||||
{facts_update}
|
{facts_update}
|
||||||
```"""
|
```"""
|
||||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||||
self.logger.log(36, "===== Updated plan =====")
|
console.rule("Updated plan")
|
||||||
self.logger.log(35, final_plan_redaction)
|
console.print(final_plan_redaction)
|
||||||
|
|
||||||
|
|
||||||
class ReactJsonAgent(ReactAgent):
|
class ReactJsonAgent(ReactAgent):
|
||||||
|
@ -1022,13 +967,14 @@ class ReactJsonAgent(ReactAgent):
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
self.prompt = agent_memory
|
self.prompt = agent_memory
|
||||||
self.logger.debug("===== New step =====")
|
console.rule("New step")
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry["agent_memory"] = agent_memory.copy()
|
log_entry["agent_memory"] = agent_memory.copy()
|
||||||
|
|
||||||
self.logger.info("===== Calling LLM with this last message: =====")
|
if self.verbose:
|
||||||
self.logger.info(self.prompt[-1])
|
console.rule("Calling LLM with this last message:")
|
||||||
|
console.print(self.prompt[-1])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
|
@ -1037,12 +983,12 @@ class ReactJsonAgent(ReactAgent):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||||
self.logger.debug("===== Output message of the LLM: =====")
|
console.rule("===== Output message of the LLM: =====")
|
||||||
self.logger.debug(llm_output)
|
console.print(llm_output)
|
||||||
log_entry["llm_output"] = llm_output
|
log_entry["llm_output"] = llm_output
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
self.logger.debug("===== Extracting action =====")
|
console.rule("===== Extracting action =====")
|
||||||
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
|
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1054,9 +1000,9 @@ class ReactJsonAgent(ReactAgent):
|
||||||
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
self.logger.warning("=== Agent thoughts:")
|
console.print("=== Agent thoughts:")
|
||||||
self.logger.log(31, rationale)
|
console.print(rationale)
|
||||||
self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
|
console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
|
||||||
if tool_name == "final_answer":
|
if tool_name == "final_answer":
|
||||||
if isinstance(arguments, dict):
|
if isinstance(arguments, dict):
|
||||||
if "answer" in arguments:
|
if "answer" in arguments:
|
||||||
|
@ -1087,7 +1033,6 @@ class ReactJsonAgent(ReactAgent):
|
||||||
updated_information = f"Stored '{observation_name}' in memory."
|
updated_information = f"Stored '{observation_name}' in memory."
|
||||||
else:
|
else:
|
||||||
updated_information = str(observation).strip()
|
updated_information = str(observation).strip()
|
||||||
self.logger.info(updated_information)
|
|
||||||
log_entry["observation"] = updated_information
|
log_entry["observation"] = updated_information
|
||||||
return log_entry
|
return log_entry
|
||||||
|
|
||||||
|
@ -1126,13 +1071,6 @@ class ReactCodeAgent(ReactAgent):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_pygments_available():
|
|
||||||
transformers_logging.warning_once(
|
|
||||||
logger,
|
|
||||||
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
|
|
||||||
"ReactCodeAgent.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.python_evaluator = evaluate_python_code
|
self.python_evaluator = evaluate_python_code
|
||||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||||
|
@ -1147,13 +1085,14 @@ class ReactCodeAgent(ReactAgent):
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
self.prompt = agent_memory.copy()
|
self.prompt = agent_memory.copy()
|
||||||
self.logger.debug("===== New step =====")
|
console.rule("New step")
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
log_entry["agent_memory"] = agent_memory.copy()
|
log_entry["agent_memory"] = agent_memory.copy()
|
||||||
|
|
||||||
self.logger.info("===== Calling LLM with these last messages: =====")
|
if self.verbose:
|
||||||
self.logger.info(self.prompt[-2:])
|
console.print("===== Calling LLM with these last messages: =====")
|
||||||
|
console.print(self.prompt[-2:])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||||
|
@ -1163,16 +1102,16 @@ class ReactCodeAgent(ReactAgent):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||||
|
|
||||||
self.logger.debug("=== Output message of the LLM:")
|
if self.verbose:
|
||||||
self.logger.debug(llm_output)
|
console.rule("Output message of the LLM:")
|
||||||
|
console.print(llm_output)
|
||||||
log_entry["llm_output"] = llm_output
|
log_entry["llm_output"] = llm_output
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
self.logger.debug("=== Extracting action ===")
|
|
||||||
try:
|
try:
|
||||||
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
|
console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
|
||||||
rationale, raw_code_action = llm_output, llm_output
|
rationale, raw_code_action = llm_output, llm_output
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1200,12 +1139,12 @@ class ReactCodeAgent(ReactAgent):
|
||||||
state=self.state,
|
state=self.state,
|
||||||
authorized_imports=self.authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
)
|
)
|
||||||
self.logger.warning("Print outputs:")
|
console.print("Print outputs:")
|
||||||
self.logger.log(32, self.state["print_outputs"])
|
console.print(self.state["print_outputs"])
|
||||||
observation = "Print outputs:\n" + self.state["print_outputs"]
|
observation = "Print outputs:\n" + self.state["print_outputs"]
|
||||||
if result is not None:
|
if result is not None:
|
||||||
self.logger.warning("Last output from code snippet:")
|
console.print("Last output from code snippet:")
|
||||||
self.logger.log(32, str(result))
|
console.print(str(result))
|
||||||
observation += "Last output from code snippet:\n" + str(result)[:100000]
|
observation += "Last output from code snippet:\n" + str(result)[:100000]
|
||||||
log_entry["observation"] = observation
|
log_entry["observation"] = observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1215,8 +1154,8 @@ class ReactCodeAgent(ReactAgent):
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
for line in code_action.split("\n"):
|
for line in code_action.split("\n"):
|
||||||
if line[: len("final_answer")] == "final_answer":
|
if line[: len("final_answer")] == "final_answer":
|
||||||
self.logger.log(33, "Final answer:")
|
console.print("Final answer:")
|
||||||
self.logger.log(32, result)
|
console.print(f"[bold]{result}")
|
||||||
log_entry["final_answer"] = result
|
log_entry["final_answer"] = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ from typing import Dict
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, list_spaces
|
from huggingface_hub import hf_hub_download, list_spaces
|
||||||
|
|
||||||
from ..utils import is_offline_mode
|
from transformers.utils import is_offline_mode
|
||||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||||
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
|
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ def get_remote_tools(logger, organization="huggingface-tools"):
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
def setup_default_tools(logger):
|
def setup_default_tools():
|
||||||
default_tools = {}
|
default_tools = {}
|
||||||
main_module = importlib.import_module("transformers")
|
main_module = importlib.import_module("transformers")
|
||||||
tools_module = main_module.agents
|
tools_module = main_module.agents
|
||||||
|
|
|
@ -19,9 +19,8 @@ import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..models.auto import AutoProcessor
|
from transformers import AutoProcessor, VisionEncoderDecoderModel
|
||||||
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
|
from transformers.utils import is_vision_available
|
||||||
from ..utils import is_vision_available
|
|
||||||
from .tools import PipelineTool
|
from .tools import PipelineTool
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,8 @@
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
|
from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor
|
||||||
from ..utils import requires_backends
|
from transformers.utils import requires_backends
|
||||||
from .tools import PipelineTool
|
from .tools import PipelineTool
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,12 +20,11 @@ from typing import Dict, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
from .. import AutoTokenizer
|
from transformers import AutoTokenizer, Pipeline
|
||||||
from ..pipelines.base import Pipeline
|
import logging
|
||||||
from ..utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MessageRole(str, Enum):
|
class MessageRole(str, Enum):
|
||||||
|
|
|
@ -14,11 +14,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from ..utils import logging
|
|
||||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||||
|
from .utils import console
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def pull_message(step_log: dict, test_mode: bool = True):
|
def pull_message(step_log: dict, test_mode: bool = True):
|
||||||
|
@ -107,11 +104,12 @@ class Monitor:
|
||||||
def update_metrics(self, step_log):
|
def update_metrics(self, step_log):
|
||||||
step_duration = step_log["step_duration"]
|
step_duration = step_log["step_duration"]
|
||||||
self.step_durations.append(step_duration)
|
self.step_durations.append(step_duration)
|
||||||
logger.info(f"Step {len(self.step_durations)}:")
|
console.print(f"Step {len(self.step_durations)}:")
|
||||||
logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
|
console.print(f"- Time taken: {step_duration:.2f} seconds")
|
||||||
|
|
||||||
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
|
||||||
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
|
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
|
||||||
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
|
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
|
||||||
logger.info(f"- Input tokens: {self.total_input_token_count}")
|
console.print(f"- Input tokens: {self.total_input_token_count}")
|
||||||
logger.info(f"- Output tokens: {self.total_output_token_count}")
|
console.print(f"- Output tokens: {self.total_output_token_count}")
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from ..utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
|
|
|
@ -22,12 +22,7 @@ from importlib import import_module
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
from ..utils import is_pandas_available
|
|
||||||
|
|
||||||
|
|
||||||
if is_pandas_available():
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
|
|
||||||
class InterpreterError(ValueError):
|
class InterpreterError(ValueError):
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
||||||
from .tools import PipelineTool
|
from .tools import PipelineTool
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
||||||
from ..utils import is_datasets_available
|
from transformers.utils import is_datasets_available
|
||||||
from .tools import PipelineTool
|
from .tools import PipelineTool
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,13 +30,13 @@ from huggingface_hub import create_repo, get_collection, hf_hub_download, metada
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from ..dynamic_module_utils import (
|
from transformers.dynamic_module_utils import (
|
||||||
custom_object_save,
|
custom_object_save,
|
||||||
get_class_from_dynamic_module,
|
get_class_from_dynamic_module,
|
||||||
get_imports,
|
get_imports,
|
||||||
)
|
)
|
||||||
from ..models.auto import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from ..utils import (
|
from transformers.utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
TypeHintParsingException,
|
TypeHintParsingException,
|
||||||
cached_file,
|
cached_file,
|
||||||
|
@ -44,12 +44,11 @@ from ..utils import (
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
logging,
|
|
||||||
)
|
)
|
||||||
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
|
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
from .tools import PipelineTool
|
from .tools import PipelineTool
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
|
||||||
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
|
_pygments_available = _is_package_available("pygments")
|
||||||
|
|
||||||
|
def is_pygments_available():
|
||||||
|
return _pygments_available
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
console = Console()
|
|
@ -47,10 +47,10 @@ TRANSFORMERS_NO_ADVISORY_WARNINGS=1 ./myprogram.py
|
||||||
Here is an example of how to use the same logger as the library in your own module or script:
|
Here is an example of how to use the same logger as the library in your own module or script:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from transformers.utils import logging
|
import logging
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
logger = logging.get_logger("transformers")
|
logger = logging.getLogger(__name__)("transformers")
|
||||||
logger.info("INFO")
|
logger.info("INFO")
|
||||||
logger.warning("WARN")
|
logger.warning("WARN")
|
||||||
```
|
```
|
||||||
|
@ -104,7 +104,7 @@ See reference of the `captureWarnings` method below.
|
||||||
|
|
||||||
[[autodoc]] logging.set_verbosity
|
[[autodoc]] logging.set_verbosity
|
||||||
|
|
||||||
[[autodoc]] logging.get_logger
|
[[autodoc]] logging.getLogger(__name__)
|
||||||
|
|
||||||
[[autodoc]] logging.enable_default_handler
|
[[autodoc]] logging.enable_default_handler
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
from agents import load_tool, ReactCodeAgent, HfApiEngine
|
||||||
|
|
||||||
|
# Import tool from Hub
|
||||||
|
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
|
||||||
|
|
||||||
|
# Import tool from LangChain
|
||||||
|
from agents.search import DuckDuckGoSearchTool
|
||||||
|
|
||||||
|
search_tool = DuckDuckGoSearchTool()
|
||||||
|
|
||||||
|
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
|
||||||
|
# Initialize the agent with both tools
|
||||||
|
agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine)
|
||||||
|
|
||||||
|
# Run it!
|
||||||
|
result = agent.run(
|
||||||
|
"Return me a photo of the car that James bond drove in the latest movie.",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result)
|
|
@ -3,3 +3,4 @@ evaluate
|
||||||
datasets==2.3.2
|
datasets==2.3.2
|
||||||
schedulefree
|
schedulefree
|
||||||
huggingface_hub>=0.20.0
|
huggingface_hub>=0.20.0
|
||||||
|
duckduckgo-search
|
|
@ -264,6 +264,41 @@ files = [
|
||||||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "markdown-it-py"
|
||||||
|
version = "3.0.0"
|
||||||
|
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
|
||||||
|
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
mdurl = ">=0.1,<1.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
||||||
|
code-style = ["pre-commit (>=3.0,<4.0)"]
|
||||||
|
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
||||||
|
linkify = ["linkify-it-py (>=1,<3)"]
|
||||||
|
plugins = ["mdit-py-plugins"]
|
||||||
|
profiling = ["gprof2dot"]
|
||||||
|
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
||||||
|
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mdurl"
|
||||||
|
version = "0.1.2"
|
||||||
|
description = "Markdown URL utilities"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
|
||||||
|
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "numpy"
|
name = "numpy"
|
||||||
version = "2.1.3"
|
version = "2.1.3"
|
||||||
|
@ -354,6 +389,20 @@ files = [
|
||||||
dev = ["pre-commit", "tox"]
|
dev = ["pre-commit", "tox"]
|
||||||
testing = ["pytest", "pytest-benchmark"]
|
testing = ["pytest", "pytest-benchmark"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pygments"
|
||||||
|
version = "2.18.0"
|
||||||
|
description = "Pygments is a syntax highlighting package written in Python."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"},
|
||||||
|
{file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
windows-terminal = ["colorama (>=0.4.6)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytest"
|
name = "pytest"
|
||||||
version = "8.3.4"
|
version = "8.3.4"
|
||||||
|
@ -562,6 +611,25 @@ urllib3 = ">=1.21.1,<3"
|
||||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rich"
|
||||||
|
version = "13.9.4"
|
||||||
|
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8.0"
|
||||||
|
files = [
|
||||||
|
{file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"},
|
||||||
|
{file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
markdown-it-py = ">=2.2.0"
|
||||||
|
pygments = ">=2.13.0,<3.0.0"
|
||||||
|
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "safetensors"
|
name = "safetensors"
|
||||||
version = "0.4.5"
|
version = "0.4.5"
|
||||||
|
@ -888,4 +956,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "de2c60fd8f7c54521b2204b3af144b974059361d7892c70eeb325a7fe52bd489"
|
content-hash = "985797a96ac1b58e4892479d96bbd1c696b452d760fa421a28b91ea8b3dbf977"
|
||||||
|
|
|
@ -62,6 +62,7 @@ python = ">=3.10,<3.13"
|
||||||
transformers = ">=4.0.0"
|
transformers = ">=4.0.0"
|
||||||
pytest = {version = ">=8.1.0", optional = true}
|
pytest = {version = ">=8.1.0", optional = true}
|
||||||
requests = "^2.32.3"
|
requests = "^2.32.3"
|
||||||
|
rich = "^13.9.4"
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
Loading…
Reference in New Issue