Nicer console outputs with rich

This commit is contained in:
Aymeric 2024-12-09 15:42:17 +01:00
parent f3dcf1f013
commit 146ee3dd32
20 changed files with 224 additions and 194 deletions

View File

@ -16,7 +16,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ..utils import (
from transformers.utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,

View File

@ -19,10 +19,11 @@ import uuid
import numpy as np
from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
import logging
logger = logging.get_logger(__name__)
logger = logging.getLogger(__name__)
if is_vision_available():
from PIL import Image
@ -159,7 +160,7 @@ class AgentImage(AgentType, ImageType):
return self._path
def save(self, output_bytes, format, **params):
def save(self, output_bytes, format = None, **params):
"""
Saves the image to a file.
Args:
@ -168,7 +169,7 @@ class AgentImage(AgentType, ImageType):
**params: Additional parameters to pass to PIL.Image.save.
"""
img = self.to_raw()
img.save(output_bytes, format, **params)
img.save(output_bytes, format=format, **params)
class AgentAudio(AgentType, str):

View File

@ -19,10 +19,12 @@ import logging
import re
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import rich
from rich import markdown as rich_markdown
from .. import is_torch_available
from ..utils import logging as transformers_logging
from ..utils.import_utils import is_pygments_available
from transformers.utils import is_torch_available
import logging
from .utils import console
from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole
@ -48,52 +50,6 @@ from .tools import (
)
if is_pygments_available():
from pygments import highlight
from pygments.formatters import Terminal256Formatter
from pygments.lexers import PythonLexer
class CustomFormatter(logging.Formatter):
grey = "\x1b[38;20m"
bold_yellow = "\x1b[33;1m"
red = "\x1b[31;20m"
green = "\x1b[32;20m"
bold_green = "\x1b[32;20;1m"
bold_red = "\x1b[31;1m"
bold_white = "\x1b[37;1m"
orange = "\x1b[38;5;214m"
bold_orange = "\x1b[38;5;214;1m"
reset = "\x1b[0m"
format = "%(message)s"
FORMATS = {
logging.DEBUG: grey + format + reset,
logging.INFO: format,
logging.WARNING: bold_yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset,
31: reset + format + reset,
32: green + format + reset,
33: bold_green + format + reset,
34: bold_white + format + reset,
35: orange + format + reset,
36: bold_orange + format + reset,
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
logger = transformers_logging.get_logger(__name__)
logger.propagate = False
ch = logging.StreamHandler()
ch.setFormatter(CustomFormatter())
logger.addHandler(ch)
def parse_json_blob(json_blob: str) -> Dict[str, str]:
try:
first_accolade_index = json_blob.find("{")
@ -142,9 +98,10 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
elif "action" in tool_call:
return tool_call["action"], None
else:
raise ValueError(
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
)
missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call]
error_msg = f"Missing keys: {missing_keys} in blob {tool_call}"
console.print(f"[bold red]{error_msg}[/bold red]")
raise ValueError(error_msg)
def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]:
@ -197,18 +154,18 @@ class Toolbox:
self._tools = {tool.name: tool for tool in tools}
if add_base_tools:
self.add_base_tools()
self._load_tools_if_needed()
# self._load_tools_if_needed()
def add_base_tools(self, add_python_interpreter: bool = False):
global _tools_are_initialized
global HUGGINGFACE_DEFAULT_TOOLS
if not _tools_are_initialized:
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger)
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
_tools_are_initialized = True
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
if tool.name != "python_interpreter" or add_python_interpreter:
self.add_tool(tool)
self._load_tools_if_needed()
# self._load_tools_if_needed()
@property
def tools(self) -> Dict[str, Tool]:
@ -271,11 +228,11 @@ class Toolbox:
"""Clears the toolbox"""
self._tools = {}
def _load_tools_if_needed(self):
for name, tool in self._tools.items():
if not isinstance(tool, Tool):
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
self._tools[name] = load_tool(task_or_repo_id)
# def _load_tools_if_needed(self):
# for name, tool in self._tools.items():
# if not isinstance(tool, Tool):
# task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
# self._tools[name] = load_tool(task_or_repo_id)
def __repr__(self):
toolbox_description = "Toolbox contents:\n"
@ -290,6 +247,8 @@ class AgentError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message
console.print(f"[bold red]{message}[/bold red]")
class AgentParsingError(AgentError):
@ -362,7 +321,7 @@ class Agent:
max_iterations: int = 6,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
verbose: int = 0,
verbose: bool = False,
grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None,
step_callbacks: Optional[List[Callable]] = None,
@ -380,7 +339,6 @@ class Agent:
)
self.additional_args = additional_args
self.max_iterations = max_iterations
self.logger = logger
self.tool_parser = tool_parser
self.grammar = grammar
@ -406,13 +364,7 @@ class Agent:
self.prompt = None
self.logs = []
self.task = None
if verbose == 0:
logger.setLevel(logging.WARNING)
elif verbose == 1:
logger.setLevel(logging.INFO)
elif verbose == 2:
logger.setLevel(logging.DEBUG)
self.verbose = verbose
# Initialize step callbacks
self.step_callbacks = step_callbacks if step_callbacks is not None else []
@ -441,10 +393,8 @@ class Agent:
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
)
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
self.logger.log(33, "======== New task ========")
self.logger.log(34, self.task)
self.logger.debug("System prompt is as follows:")
self.logger.debug(self.system_prompt)
console.rule("New task", characters='=')
console.print(self.task)
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
"""
@ -521,7 +471,6 @@ class Agent:
split[-1],
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
except Exception as e:
self.logger.error(e, exc_info=1)
raise AgentParsingError(
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
)
@ -541,7 +490,7 @@ class Agent:
available_tools = {**available_tools, **self.managed_agents}
if tool_name not in available_tools:
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
self.logger.error(error_msg, exc_info=1)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
try:
@ -549,38 +498,39 @@ class Agent:
observation = available_tools[tool_name](arguments)
elif isinstance(arguments, dict):
for key, value in arguments.items():
# if the value is the name of a state variable like "image.png", replace it with the actual value
if isinstance(value, str) and value in self.state:
arguments[key] = self.state[value]
observation = available_tools[tool_name](**arguments)
else:
raise AgentExecutionError(
f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
)
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
return observation
except Exception as e:
if tool_name in self.toolbox.tools:
raise AgentExecutionError(
tool_description = get_tool_description_with_args(available_tools[tool_name])
error_msg = (
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}"
f"As a reminder, this tool's description is the following:\n{tool_description}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
elif tool_name in self.managed_agents:
raise AgentExecutionError(
error_msg = (
f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n"
f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}"
)
console.print(f"[bold red]{error_msg}")
raise AgentExecutionError(error_msg)
def log_rationale_code_action(self, rationale: str, code_action: str) -> None:
self.logger.warning("=== Agent thoughts:")
self.logger.log(31, rationale)
self.logger.warning(">>> Agent is executing the code below:")
if is_pygments_available():
self.logger.log(
31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
)
else:
self.logger.log(31, code_action)
self.logger.warning("====")
if self.verbose:
console.rule("Agent thoughts")
console.print(rationale)
console.rule("Agent is executing the code below:", align="left")
console.print(code_action)
console.rule("", align="left")
def run(self, **kwargs):
"""To be implemented in the child class"""
@ -617,13 +567,6 @@ class CodeAgent(Agent):
**kwargs,
)
if not is_pygments_available():
transformers_logging.warning_once(
logger,
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
"CodeAgent.",
)
self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
@ -669,20 +612,21 @@ class CodeAgent(Agent):
}
self.prompt = [prompt_message, task_message]
self.logger.info("====Executing with this prompt====")
self.logger.info(self.prompt)
if self.verbose:
console.rule("Executing with this prompt")
console.print(self.prompt)
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"], **additional_args)
if return_generated_code:
return llm_output
# Parse
try:
rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
self.logger.debug(
if self.verbose:
console.print(
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
)
rationale, code_action = "", llm_output
@ -691,7 +635,7 @@ class CodeAgent(Agent):
code_action = self.parse_code_blob(code_action)
except Exception as e:
error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
self.logger.error(error_msg, exc_info=1)
console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg
# Execute
@ -705,11 +649,12 @@ class CodeAgent(Agent):
state=self.state,
authorized_imports=self.authorized_imports,
)
self.logger.info(self.state["print_outputs"])
if self.verbose:
console.print(self.state["print_outputs"])
return output
except Exception as e:
error_msg = f"Error in execution: {e}. Be sure to provide correct code."
self.logger.error(error_msg, exc_info=1)
console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg
@ -759,7 +704,7 @@ class ReactAgent(Agent):
self.prompt = [
{
"role": MessageRole.SYSTEM,
"content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
"content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
}
]
self.prompt += self.write_inner_memory_from_logs()[1:]
@ -772,7 +717,9 @@ class ReactAgent(Agent):
try:
return self.llm_engine(self.prompt)
except Exception as e:
return f"Error in generating final llm output: {e}."
error_msg = f"Error in generating final LLM output: {e}."
console.print(f"[bold red]{error_msg}[/bold red]")
return error_msg
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
"""
@ -815,7 +762,6 @@ class ReactAgent(Agent):
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
step_log_entry["error"] = e
finally:
step_end_time = time.time()
@ -831,7 +777,7 @@ class ReactAgent(Agent):
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
console.print(f"[bold red]{error_message}")
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
final_step_log["step_duration"] = 0
@ -857,7 +803,6 @@ class ReactAgent(Agent):
if "final_answer" in step_log_entry:
final_answer = step_log_entry["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
step_log_entry["error"] = e
finally:
step_end_time = time.time()
@ -872,7 +817,7 @@ class ReactAgent(Agent):
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
console.print(f"[bold red]{error_message}")
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
final_step_log["step_duration"] = 0
@ -931,8 +876,8 @@ Now begin!""",
{answer_facts}
```""".strip()
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.log(36, "===== Initial plan =====")
self.logger.log(35, final_plan_redaction)
console.rule("[orange]Initial plan")
console.print(final_plan_redaction)
else: # update plan
agent_memory = self.write_inner_memory_from_logs(
summary_mode=False
@ -977,8 +922,8 @@ Now begin!""",
{facts_update}
```"""
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.log(36, "===== Updated plan =====")
self.logger.log(35, final_plan_redaction)
console.rule("Updated plan")
console.print(final_plan_redaction)
class ReactJsonAgent(ReactAgent):
@ -1022,13 +967,14 @@ class ReactJsonAgent(ReactAgent):
agent_memory = self.write_inner_memory_from_logs()
self.prompt = agent_memory
self.logger.debug("===== New step =====")
console.rule("New step")
# Add new step in logs
log_entry["agent_memory"] = agent_memory.copy()
self.logger.info("===== Calling LLM with this last message: =====")
self.logger.info(self.prompt[-1])
if self.verbose:
console.rule("Calling LLM with this last message:")
console.print(self.prompt[-1])
try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
@ -1037,12 +983,12 @@ class ReactJsonAgent(ReactAgent):
)
except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("===== Output message of the LLM: =====")
self.logger.debug(llm_output)
console.rule("===== Output message of the LLM: =====")
console.print(llm_output)
log_entry["llm_output"] = llm_output
# Parse
self.logger.debug("===== Extracting action =====")
console.rule("===== Extracting action =====")
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
try:
@ -1054,9 +1000,9 @@ class ReactJsonAgent(ReactAgent):
log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
# Execute
self.logger.warning("=== Agent thoughts:")
self.logger.log(31, rationale)
self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
console.print("=== Agent thoughts:")
console.print(rationale)
console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}")
if tool_name == "final_answer":
if isinstance(arguments, dict):
if "answer" in arguments:
@ -1087,7 +1033,6 @@ class ReactJsonAgent(ReactAgent):
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
self.logger.info(updated_information)
log_entry["observation"] = updated_information
return log_entry
@ -1126,13 +1071,6 @@ class ReactCodeAgent(ReactAgent):
**kwargs,
)
if not is_pygments_available():
transformers_logging.warning_once(
logger,
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
"ReactCodeAgent.",
)
self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
@ -1147,13 +1085,14 @@ class ReactCodeAgent(ReactAgent):
agent_memory = self.write_inner_memory_from_logs()
self.prompt = agent_memory.copy()
self.logger.debug("===== New step =====")
console.rule("New step")
# Add new step in logs
log_entry["agent_memory"] = agent_memory.copy()
self.logger.info("===== Calling LLM with these last messages: =====")
self.logger.info(self.prompt[-2:])
if self.verbose:
console.print("===== Calling LLM with these last messages: =====")
console.print(self.prompt[-2:])
try:
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
@ -1163,16 +1102,16 @@ class ReactCodeAgent(ReactAgent):
except Exception as e:
raise AgentGenerationError(f"Error in generating llm output: {e}.")
self.logger.debug("=== Output message of the LLM:")
self.logger.debug(llm_output)
if self.verbose:
console.rule("Output message of the LLM:")
console.print(llm_output)
log_entry["llm_output"] = llm_output
# Parse
self.logger.debug("=== Extracting action ===")
try:
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
rationale, raw_code_action = llm_output, llm_output
try:
@ -1200,12 +1139,12 @@ class ReactCodeAgent(ReactAgent):
state=self.state,
authorized_imports=self.authorized_imports,
)
self.logger.warning("Print outputs:")
self.logger.log(32, self.state["print_outputs"])
console.print("Print outputs:")
console.print(self.state["print_outputs"])
observation = "Print outputs:\n" + self.state["print_outputs"]
if result is not None:
self.logger.warning("Last output from code snippet:")
self.logger.log(32, str(result))
console.print("Last output from code snippet:")
console.print(str(result))
observation += "Last output from code snippet:\n" + str(result)[:100000]
log_entry["observation"] = observation
except Exception as e:
@ -1215,8 +1154,8 @@ class ReactCodeAgent(ReactAgent):
raise AgentExecutionError(error_msg)
for line in code_action.split("\n"):
if line[: len("final_answer")] == "final_answer":
self.logger.log(33, "Final answer:")
self.logger.log(32, result)
console.print("Final answer:")
console.print(f"[bold]{result}")
log_entry["final_answer"] = result
return result

View File

@ -23,7 +23,7 @@ from typing import Dict
from huggingface_hub import hf_hub_download, list_spaces
from ..utils import is_offline_mode
from transformers.utils import is_offline_mode
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
@ -128,7 +128,7 @@ def get_remote_tools(logger, organization="huggingface-tools"):
return tools
def setup_default_tools(logger):
def setup_default_tools():
default_tools = {}
main_module = importlib.import_module("transformers")
tools_module = main_module.agents

View File

@ -19,9 +19,8 @@ import re
import numpy as np
import torch
from ..models.auto import AutoProcessor
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
from ..utils import is_vision_available
from transformers import AutoProcessor, VisionEncoderDecoderModel
from transformers.utils import is_vision_available
from .tools import PipelineTool

View File

@ -18,8 +18,8 @@
import torch
from PIL import Image
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
from ..utils import requires_backends
from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor
from transformers.utils import requires_backends
from .tools import PipelineTool

View File

@ -20,12 +20,11 @@ from typing import Dict, List, Optional
from huggingface_hub import InferenceClient
from .. import AutoTokenizer
from ..pipelines.base import Pipeline
from ..utils import logging
from transformers import AutoTokenizer, Pipeline
import logging
logger = logging.get_logger(__name__)
logger = logging.getLogger(__name__)
class MessageRole(str, Enum):

View File

@ -14,11 +14,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import logging
from .agent_types import AgentAudio, AgentImage, AgentText
logger = logging.get_logger(__name__)
from .utils import console
def pull_message(step_log: dict, test_mode: bool = True):
@ -107,11 +104,12 @@ class Monitor:
def update_metrics(self, step_log):
step_duration = step_log["step_duration"]
self.step_durations.append(step_duration)
logger.info(f"Step {len(self.step_durations)}:")
logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)")
console.print(f"Step {len(self.step_durations)}:")
console.print(f"- Time taken: {step_duration:.2f} seconds")
if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_llm_engine.last_input_token_count
self.total_output_token_count += self.tracked_llm_engine.last_output_token_count
logger.info(f"- Input tokens: {self.total_input_token_count}")
logger.info(f"- Output tokens: {self.total_output_token_count}")
console.print(f"- Input tokens: {self.total_input_token_count}")
console.print(f"- Output tokens: {self.total_output_token_count}")

View File

@ -16,7 +16,7 @@
# limitations under the License.
import re
from ..utils import cached_file
from transformers.utils import cached_file
# docstyle-ignore

View File

@ -22,12 +22,7 @@ from importlib import import_module
from typing import Any, Callable, Dict, List, Optional
import numpy as np
from ..utils import is_pandas_available
if is_pandas_available():
import pandas as pd
import pandas as pd
class InterpreterError(ValueError):

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from .tools import PipelineTool

View File

@ -17,8 +17,8 @@
import torch
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
from ..utils import is_datasets_available
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
from transformers.utils import is_datasets_available
from .tools import PipelineTool

View File

@ -30,13 +30,13 @@ from huggingface_hub import create_repo, get_collection, hf_hub_download, metada
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from packaging import version
from ..dynamic_module_utils import (
from transformers.dynamic_module_utils import (
custom_object_save,
get_class_from_dynamic_module,
get_imports,
)
from ..models.auto import AutoProcessor
from ..utils import (
from transformers import AutoProcessor
from transformers.utils import (
CONFIG_NAME,
TypeHintParsingException,
cached_file,
@ -44,12 +44,11 @@ from ..utils import (
is_accelerate_available,
is_torch_available,
is_vision_available,
logging,
)
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs
import logging
logger = logging.get_logger(__name__)
logger = logging.getLogger(__name__)
if is_torch_available():

View File

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from .tools import PipelineTool

10
agents/utils.py Normal file
View File

@ -0,0 +1,10 @@
from transformers.utils.import_utils import _is_package_available
_pygments_available = _is_package_available("pygments")
def is_pygments_available():
return _pygments_available
from rich.console import Console
console = Console()

View File

@ -47,10 +47,10 @@ TRANSFORMERS_NO_ADVISORY_WARNINGS=1 ./myprogram.py
Here is an example of how to use the same logger as the library in your own module or script:
```python
from transformers.utils import logging
import logging
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
logger = logging.getLogger(__name__)("transformers")
logger.info("INFO")
logger.warning("WARN")
```
@ -104,7 +104,7 @@ See reference of the `captureWarnings` method below.
[[autodoc]] logging.set_verbosity
[[autodoc]] logging.get_logger
[[autodoc]] logging.getLogger(__name__)
[[autodoc]] logging.enable_default_handler

View File

@ -0,0 +1,20 @@
from agents import load_tool, ReactCodeAgent, HfApiEngine
# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
# Import tool from LangChain
from agents.search import DuckDuckGoSearchTool
search_tool = DuckDuckGoSearchTool()
llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct")
# Initialize the agent with both tools
agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine)
# Run it!
result = agent.run(
"Return me a photo of the car that James bond drove in the latest movie.",
)
print(result)

View File

@ -3,3 +3,4 @@ evaluate
datasets==2.3.2
schedulefree
huggingface_hub>=0.20.0
duckduckgo-search

70
poetry.lock generated
View File

@ -264,6 +264,41 @@ files = [
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
optional = false
python-versions = ">=3.8"
files = [
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
]
[package.dependencies]
mdurl = ">=0.1,<1.0"
[package.extras]
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
code-style = ["pre-commit (>=3.0,<4.0)"]
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
linkify = ["linkify-it-py (>=1,<3)"]
plugins = ["mdit-py-plugins"]
profiling = ["gprof2dot"]
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
[[package]]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
optional = false
python-versions = ">=3.7"
files = [
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]]
name = "numpy"
version = "2.1.3"
@ -354,6 +389,20 @@ files = [
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "pygments"
version = "2.18.0"
description = "Pygments is a syntax highlighting package written in Python."
optional = false
python-versions = ">=3.8"
files = [
{file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"},
{file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"},
]
[package.extras]
windows-terminal = ["colorama (>=0.4.6)"]
[[package]]
name = "pytest"
version = "8.3.4"
@ -562,6 +611,25 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rich"
version = "13.9.4"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"},
{file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"},
]
[package.dependencies]
markdown-it-py = ">=2.2.0"
pygments = ">=2.13.0,<3.0.0"
typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""}
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]]
name = "safetensors"
version = "0.4.5"
@ -888,4 +956,4 @@ zstd = ["zstandard (>=0.18.0)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "de2c60fd8f7c54521b2204b3af144b974059361d7892c70eeb325a7fe52bd489"
content-hash = "985797a96ac1b58e4892479d96bbd1c696b452d760fa421a28b91ea8b3dbf977"

View File

@ -62,6 +62,7 @@ python = ">=3.10,<3.13"
transformers = ">=4.0.0"
pytest = {version = ">=8.1.0", optional = true}
requests = "^2.32.3"
rich = "^13.9.4"
[build-system]