diff --git a/agents/__init__.py b/agents/__init__.py index 70762c2..99db95d 100644 --- a/agents/__init__.py +++ b/agents/__init__.py @@ -16,7 +16,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ..utils import ( +from transformers.utils import ( OptionalDependencyNotAvailable, _LazyModule, is_torch_available, diff --git a/agents/agent_types.py b/agents/agent_types.py index f5be746..5e3169c 100644 --- a/agents/agent_types.py +++ b/agents/agent_types.py @@ -19,10 +19,11 @@ import uuid import numpy as np -from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging +from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available +import logging -logger = logging.get_logger(__name__) +logger = logging.getLogger(__name__) if is_vision_available(): from PIL import Image @@ -159,7 +160,7 @@ class AgentImage(AgentType, ImageType): return self._path - def save(self, output_bytes, format, **params): + def save(self, output_bytes, format = None, **params): """ Saves the image to a file. Args: @@ -168,7 +169,7 @@ class AgentImage(AgentType, ImageType): **params: Additional parameters to pass to PIL.Image.save. """ img = self.to_raw() - img.save(output_bytes, format, **params) + img.save(output_bytes, format=format, **params) class AgentAudio(AgentType, str): diff --git a/agents/agents.py b/agents/agents.py index 08c30d5..68db965 100644 --- a/agents/agents.py +++ b/agents/agents.py @@ -19,10 +19,12 @@ import logging import re import time from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import rich +from rich import markdown as rich_markdown -from .. import is_torch_available -from ..utils import logging as transformers_logging -from ..utils.import_utils import is_pygments_available +from transformers.utils import is_torch_available +import logging +from .utils import console from .agent_types import AgentAudio, AgentImage from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .llm_engine import HfApiEngine, MessageRole @@ -48,52 +50,6 @@ from .tools import ( ) -if is_pygments_available(): - from pygments import highlight - from pygments.formatters import Terminal256Formatter - from pygments.lexers import PythonLexer - - -class CustomFormatter(logging.Formatter): - grey = "\x1b[38;20m" - bold_yellow = "\x1b[33;1m" - red = "\x1b[31;20m" - green = "\x1b[32;20m" - bold_green = "\x1b[32;20;1m" - bold_red = "\x1b[31;1m" - bold_white = "\x1b[37;1m" - orange = "\x1b[38;5;214m" - bold_orange = "\x1b[38;5;214;1m" - reset = "\x1b[0m" - format = "%(message)s" - - FORMATS = { - logging.DEBUG: grey + format + reset, - logging.INFO: format, - logging.WARNING: bold_yellow + format + reset, - logging.ERROR: red + format + reset, - logging.CRITICAL: bold_red + format + reset, - 31: reset + format + reset, - 32: green + format + reset, - 33: bold_green + format + reset, - 34: bold_white + format + reset, - 35: orange + format + reset, - 36: bold_orange + format + reset, - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt) - return formatter.format(record) - - -logger = transformers_logging.get_logger(__name__) -logger.propagate = False -ch = logging.StreamHandler() -ch.setFormatter(CustomFormatter()) -logger.addHandler(ch) - - def parse_json_blob(json_blob: str) -> Dict[str, str]: try: first_accolade_index = json_blob.find("{") @@ -142,9 +98,10 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]: elif "action" in tool_call: return tool_call["action"], None else: - raise ValueError( - f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}" - ) + missing_keys = [key for key in ['action', 'action_input'] if key not in tool_call] + error_msg = f"Missing keys: {missing_keys} in blob {tool_call}" + console.print(f"[bold red]{error_msg}[/bold red]") + raise ValueError(error_msg) def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]: @@ -197,18 +154,18 @@ class Toolbox: self._tools = {tool.name: tool for tool in tools} if add_base_tools: self.add_base_tools() - self._load_tools_if_needed() + # self._load_tools_if_needed() def add_base_tools(self, add_python_interpreter: bool = False): global _tools_are_initialized global HUGGINGFACE_DEFAULT_TOOLS if not _tools_are_initialized: - HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger) + HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools() _tools_are_initialized = True for tool in HUGGINGFACE_DEFAULT_TOOLS.values(): if tool.name != "python_interpreter" or add_python_interpreter: self.add_tool(tool) - self._load_tools_if_needed() + # self._load_tools_if_needed() @property def tools(self) -> Dict[str, Tool]: @@ -271,11 +228,11 @@ class Toolbox: """Clears the toolbox""" self._tools = {} - def _load_tools_if_needed(self): - for name, tool in self._tools.items(): - if not isinstance(tool, Tool): - task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id - self._tools[name] = load_tool(task_or_repo_id) + # def _load_tools_if_needed(self): + # for name, tool in self._tools.items(): + # if not isinstance(tool, Tool): + # task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id + # self._tools[name] = load_tool(task_or_repo_id) def __repr__(self): toolbox_description = "Toolbox contents:\n" @@ -290,6 +247,8 @@ class AgentError(Exception): def __init__(self, message): super().__init__(message) self.message = message + console.print(f"[bold red]{message}[/bold red]") + class AgentParsingError(AgentError): @@ -362,7 +321,7 @@ class Agent: max_iterations: int = 6, tool_parser: Optional[Callable] = None, add_base_tools: bool = False, - verbose: int = 0, + verbose: bool = False, grammar: Optional[Dict[str, str]] = None, managed_agents: Optional[List] = None, step_callbacks: Optional[List[Callable]] = None, @@ -380,7 +339,6 @@ class Agent: ) self.additional_args = additional_args self.max_iterations = max_iterations - self.logger = logger self.tool_parser = tool_parser self.grammar = grammar @@ -406,13 +364,7 @@ class Agent: self.prompt = None self.logs = [] self.task = None - - if verbose == 0: - logger.setLevel(logging.WARNING) - elif verbose == 1: - logger.setLevel(logging.INFO) - elif verbose == 2: - logger.setLevel(logging.DEBUG) + self.verbose = verbose # Initialize step callbacks self.step_callbacks = step_callbacks if step_callbacks is not None else [] @@ -441,10 +393,8 @@ class Agent: self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) ) self.logs = [{"system_prompt": self.system_prompt, "task": self.task}] - self.logger.log(33, "======== New task ========") - self.logger.log(34, self.task) - self.logger.debug("System prompt is as follows:") - self.logger.debug(self.system_prompt) + console.rule("New task", characters='=') + console.print(self.task) def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]: """ @@ -521,7 +471,6 @@ class Agent: split[-1], ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output except Exception as e: - self.logger.error(e, exc_info=1) raise AgentParsingError( f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" ) @@ -541,7 +490,7 @@ class Agent: available_tools = {**available_tools, **self.managed_agents} if tool_name not in available_tools: error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." - self.logger.error(error_msg, exc_info=1) + console.print(f"[bold red]{error_msg}") raise AgentExecutionError(error_msg) try: @@ -549,38 +498,39 @@ class Agent: observation = available_tools[tool_name](arguments) elif isinstance(arguments, dict): for key, value in arguments.items(): - # if the value is the name of a state variable like "image.png", replace it with the actual value if isinstance(value, str) and value in self.state: arguments[key] = self.state[value] observation = available_tools[tool_name](**arguments) else: - raise AgentExecutionError( - f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." - ) + error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." + console.print(f"[bold red]{error_msg}") + raise AgentExecutionError(error_msg) return observation except Exception as e: if tool_name in self.toolbox.tools: - raise AgentExecutionError( + tool_description = get_tool_description_with_args(available_tools[tool_name]) + error_msg = ( f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" - f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(available_tools[tool_name])}" + f"As a reminder, this tool's description is the following:\n{tool_description}" ) + console.print(f"[bold red]{error_msg}") + raise AgentExecutionError(error_msg) elif tool_name in self.managed_agents: - raise AgentExecutionError( + error_msg = ( f"Error in calling team member: {e}\nYou should only ask this team member with a correct request.\n" f"As a reminder, this team member's description is the following:\n{available_tools[tool_name]}" ) + console.print(f"[bold red]{error_msg}") + raise AgentExecutionError(error_msg) + def log_rationale_code_action(self, rationale: str, code_action: str) -> None: - self.logger.warning("=== Agent thoughts:") - self.logger.log(31, rationale) - self.logger.warning(">>> Agent is executing the code below:") - if is_pygments_available(): - self.logger.log( - 31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord")) - ) - else: - self.logger.log(31, code_action) - self.logger.warning("====") + if self.verbose: + console.rule("Agent thoughts") + console.print(rationale) + console.rule("Agent is executing the code below:", align="left") + console.print(code_action) + console.rule("", align="left") def run(self, **kwargs): """To be implemented in the child class""" @@ -617,13 +567,6 @@ class CodeAgent(Agent): **kwargs, ) - if not is_pygments_available(): - transformers_logging.warning_once( - logger, - "pygments isn't installed. Installing pygments will enable color syntax highlighting in the " - "CodeAgent.", - ) - self.python_evaluator = evaluate_python_code self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) @@ -669,29 +612,30 @@ class CodeAgent(Agent): } self.prompt = [prompt_message, task_message] - self.logger.info("====Executing with this prompt====") - self.logger.info(self.prompt) + + if self.verbose: + console.rule("Executing with this prompt") + console.print(self.prompt) additional_args = {"grammar": self.grammar} if self.grammar is not None else {} llm_output = self.llm_engine(self.prompt, stop_sequences=[""], **additional_args) - if return_generated_code: - return llm_output # Parse try: rationale, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") except Exception as e: - self.logger.debug( - f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}" - ) + if self.verbose: + console.print( + f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}" + ) rationale, code_action = "", llm_output try: code_action = self.parse_code_blob(code_action) except Exception as e: error_msg = f"Error in code parsing: {e}. Be sure to provide correct code" - self.logger.error(error_msg, exc_info=1) + console.print(f"[bold red]{error_msg}[/bold red]") return error_msg # Execute @@ -705,11 +649,12 @@ class CodeAgent(Agent): state=self.state, authorized_imports=self.authorized_imports, ) - self.logger.info(self.state["print_outputs"]) + if self.verbose: + console.print(self.state["print_outputs"]) return output except Exception as e: error_msg = f"Error in execution: {e}. Be sure to provide correct code." - self.logger.error(error_msg, exc_info=1) + console.print(f"[bold red]{error_msg}[/bold red]") return error_msg @@ -759,7 +704,7 @@ class ReactAgent(Agent): self.prompt = [ { "role": MessageRole.SYSTEM, - "content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", + "content": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", } ] self.prompt += self.write_inner_memory_from_logs()[1:] @@ -772,7 +717,9 @@ class ReactAgent(Agent): try: return self.llm_engine(self.prompt) except Exception as e: - return f"Error in generating final llm output: {e}." + error_msg = f"Error in generating final LLM output: {e}." + console.print(f"[bold red]{error_msg}[/bold red]") + return error_msg def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): """ @@ -815,7 +762,6 @@ class ReactAgent(Agent): if "final_answer" in step_log_entry: final_answer = step_log_entry["final_answer"] except AgentError as e: - self.logger.error(e, exc_info=1) step_log_entry["error"] = e finally: step_end_time = time.time() @@ -831,7 +777,7 @@ class ReactAgent(Agent): error_message = "Reached max iterations." final_step_log = {"error": AgentMaxIterationsError(error_message)} self.logs.append(final_step_log) - self.logger.error(error_message, exc_info=1) + console.print(f"[bold red]{error_message}") final_answer = self.provide_final_answer(task) final_step_log["final_answer"] = final_answer final_step_log["step_duration"] = 0 @@ -857,7 +803,6 @@ class ReactAgent(Agent): if "final_answer" in step_log_entry: final_answer = step_log_entry["final_answer"] except AgentError as e: - self.logger.error(e, exc_info=1) step_log_entry["error"] = e finally: step_end_time = time.time() @@ -872,7 +817,7 @@ class ReactAgent(Agent): error_message = "Reached max iterations." final_step_log = {"error": AgentMaxIterationsError(error_message)} self.logs.append(final_step_log) - self.logger.error(error_message, exc_info=1) + console.print(f"[bold red]{error_message}") final_answer = self.provide_final_answer(task) final_step_log["final_answer"] = final_answer final_step_log["step_duration"] = 0 @@ -931,8 +876,8 @@ Now begin!""", {answer_facts} ```""".strip() self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) - self.logger.log(36, "===== Initial plan =====") - self.logger.log(35, final_plan_redaction) + console.rule("[orange]Initial plan") + console.print(final_plan_redaction) else: # update plan agent_memory = self.write_inner_memory_from_logs( summary_mode=False @@ -977,8 +922,8 @@ Now begin!""", {facts_update} ```""" self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction}) - self.logger.log(36, "===== Updated plan =====") - self.logger.log(35, final_plan_redaction) + console.rule("Updated plan") + console.print(final_plan_redaction) class ReactJsonAgent(ReactAgent): @@ -1022,13 +967,14 @@ class ReactJsonAgent(ReactAgent): agent_memory = self.write_inner_memory_from_logs() self.prompt = agent_memory - self.logger.debug("===== New step =====") + console.rule("New step") # Add new step in logs log_entry["agent_memory"] = agent_memory.copy() - self.logger.info("===== Calling LLM with this last message: =====") - self.logger.info(self.prompt[-1]) + if self.verbose: + console.rule("Calling LLM with this last message:") + console.print(self.prompt[-1]) try: additional_args = {"grammar": self.grammar} if self.grammar is not None else {} @@ -1037,12 +983,12 @@ class ReactJsonAgent(ReactAgent): ) except Exception as e: raise AgentGenerationError(f"Error in generating llm output: {e}.") - self.logger.debug("===== Output message of the LLM: =====") - self.logger.debug(llm_output) + console.rule("===== Output message of the LLM: =====") + console.print(llm_output) log_entry["llm_output"] = llm_output # Parse - self.logger.debug("===== Extracting action =====") + console.rule("===== Extracting action =====") rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") try: @@ -1054,9 +1000,9 @@ class ReactJsonAgent(ReactAgent): log_entry["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} # Execute - self.logger.warning("=== Agent thoughts:") - self.logger.log(31, rationale) - self.logger.warning(f">>> Calling tool: '{tool_name}' with arguments: {arguments}") + console.print("=== Agent thoughts:") + console.print(rationale) + console.print(f">>> Calling tool: '{tool_name}' with arguments: {arguments}") if tool_name == "final_answer": if isinstance(arguments, dict): if "answer" in arguments: @@ -1087,7 +1033,6 @@ class ReactJsonAgent(ReactAgent): updated_information = f"Stored '{observation_name}' in memory." else: updated_information = str(observation).strip() - self.logger.info(updated_information) log_entry["observation"] = updated_information return log_entry @@ -1126,13 +1071,6 @@ class ReactCodeAgent(ReactAgent): **kwargs, ) - if not is_pygments_available(): - transformers_logging.warning_once( - logger, - "pygments isn't installed. Installing pygments will enable color syntax highlighting in the " - "ReactCodeAgent.", - ) - self.python_evaluator = evaluate_python_code self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) @@ -1147,13 +1085,14 @@ class ReactCodeAgent(ReactAgent): agent_memory = self.write_inner_memory_from_logs() self.prompt = agent_memory.copy() - self.logger.debug("===== New step =====") + console.rule("New step") # Add new step in logs log_entry["agent_memory"] = agent_memory.copy() - self.logger.info("===== Calling LLM with these last messages: =====") - self.logger.info(self.prompt[-2:]) + if self.verbose: + console.print("===== Calling LLM with these last messages: =====") + console.print(self.prompt[-2:]) try: additional_args = {"grammar": self.grammar} if self.grammar is not None else {} @@ -1163,16 +1102,16 @@ class ReactCodeAgent(ReactAgent): except Exception as e: raise AgentGenerationError(f"Error in generating llm output: {e}.") - self.logger.debug("=== Output message of the LLM:") - self.logger.debug(llm_output) + if self.verbose: + console.rule("Output message of the LLM:") + console.print(llm_output) log_entry["llm_output"] = llm_output # Parse - self.logger.debug("=== Extracting action ===") try: rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") except Exception as e: - self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}") + console.print(f"Error in extracting action, trying to parse the whole output. Error trace: {e}") rationale, raw_code_action = llm_output, llm_output try: @@ -1200,12 +1139,12 @@ class ReactCodeAgent(ReactAgent): state=self.state, authorized_imports=self.authorized_imports, ) - self.logger.warning("Print outputs:") - self.logger.log(32, self.state["print_outputs"]) + console.print("Print outputs:") + console.print(self.state["print_outputs"]) observation = "Print outputs:\n" + self.state["print_outputs"] if result is not None: - self.logger.warning("Last output from code snippet:") - self.logger.log(32, str(result)) + console.print("Last output from code snippet:") + console.print(str(result)) observation += "Last output from code snippet:\n" + str(result)[:100000] log_entry["observation"] = observation except Exception as e: @@ -1215,8 +1154,8 @@ class ReactCodeAgent(ReactAgent): raise AgentExecutionError(error_msg) for line in code_action.split("\n"): if line[: len("final_answer")] == "final_answer": - self.logger.log(33, "Final answer:") - self.logger.log(32, result) + console.print("Final answer:") + console.print(f"[bold]{result}") log_entry["final_answer"] = result return result diff --git a/agents/default_tools.py b/agents/default_tools.py index 3946aa9..5a267b9 100644 --- a/agents/default_tools.py +++ b/agents/default_tools.py @@ -23,7 +23,7 @@ from typing import Dict from huggingface_hub import hf_hub_download, list_spaces -from ..utils import is_offline_mode +from transformers.utils import is_offline_mode from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool @@ -128,7 +128,7 @@ def get_remote_tools(logger, organization="huggingface-tools"): return tools -def setup_default_tools(logger): +def setup_default_tools(): default_tools = {} main_module = importlib.import_module("transformers") tools_module = main_module.agents diff --git a/agents/document_question_answering.py b/agents/document_question_answering.py index 23ae5b0..fd77e07 100644 --- a/agents/document_question_answering.py +++ b/agents/document_question_answering.py @@ -19,9 +19,8 @@ import re import numpy as np import torch -from ..models.auto import AutoProcessor -from ..models.vision_encoder_decoder import VisionEncoderDecoderModel -from ..utils import is_vision_available +from transformers import AutoProcessor, VisionEncoderDecoderModel +from transformers.utils import is_vision_available from .tools import PipelineTool diff --git a/agents/image_question_answering.py b/agents/image_question_answering.py index de0efb7..5101684 100644 --- a/agents/image_question_answering.py +++ b/agents/image_question_answering.py @@ -18,8 +18,8 @@ import torch from PIL import Image -from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor -from ..utils import requires_backends +from transformers import AutoModelForVisualQuestionAnswering, AutoProcessor +from transformers.utils import requires_backends from .tools import PipelineTool diff --git a/agents/llm_engine.py b/agents/llm_engine.py index afa4d62..cb4fb0b 100644 --- a/agents/llm_engine.py +++ b/agents/llm_engine.py @@ -20,12 +20,11 @@ from typing import Dict, List, Optional from huggingface_hub import InferenceClient -from .. import AutoTokenizer -from ..pipelines.base import Pipeline -from ..utils import logging +from transformers import AutoTokenizer, Pipeline +import logging -logger = logging.get_logger(__name__) +logger = logging.getLogger(__name__) class MessageRole(str, Enum): diff --git a/agents/monitoring.py b/agents/monitoring.py index 7126e72..8675c2d 100644 --- a/agents/monitoring.py +++ b/agents/monitoring.py @@ -14,11 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import logging from .agent_types import AgentAudio, AgentImage, AgentText - - -logger = logging.get_logger(__name__) +from .utils import console def pull_message(step_log: dict, test_mode: bool = True): @@ -107,11 +104,12 @@ class Monitor: def update_metrics(self, step_log): step_duration = step_log["step_duration"] self.step_durations.append(step_duration) - logger.info(f"Step {len(self.step_durations)}:") - logger.info(f"- Time taken: {step_duration:.2f} seconds (valid only if step succeeded)") + console.print(f"Step {len(self.step_durations)}:") + console.print(f"- Time taken: {step_duration:.2f} seconds") if getattr(self.tracked_llm_engine, "last_input_token_count", None) is not None: self.total_input_token_count += self.tracked_llm_engine.last_input_token_count self.total_output_token_count += self.tracked_llm_engine.last_output_token_count - logger.info(f"- Input tokens: {self.total_input_token_count}") - logger.info(f"- Output tokens: {self.total_output_token_count}") + console.print(f"- Input tokens: {self.total_input_token_count}") + console.print(f"- Output tokens: {self.total_output_token_count}") + diff --git a/agents/prompts.py b/agents/prompts.py index 7a84b1d..74849dc 100644 --- a/agents/prompts.py +++ b/agents/prompts.py @@ -16,7 +16,7 @@ # limitations under the License. import re -from ..utils import cached_file +from transformers.utils import cached_file # docstyle-ignore diff --git a/agents/python_interpreter.py b/agents/python_interpreter.py index 6e90f35..9844774 100644 --- a/agents/python_interpreter.py +++ b/agents/python_interpreter.py @@ -22,12 +22,7 @@ from importlib import import_module from typing import Any, Callable, Dict, List, Optional import numpy as np - -from ..utils import is_pandas_available - - -if is_pandas_available(): - import pandas as pd +import pandas as pd class InterpreterError(ValueError): diff --git a/agents/speech_to_text.py b/agents/speech_to_text.py index 8061651..bfa8aeb 100644 --- a/agents/speech_to_text.py +++ b/agents/speech_to_text.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor +from transformers import WhisperForConditionalGeneration, WhisperProcessor from .tools import PipelineTool diff --git a/agents/text_to_speech.py b/agents/text_to_speech.py index ed41ef6..415f5df 100644 --- a/agents/text_to_speech.py +++ b/agents/text_to_speech.py @@ -17,8 +17,8 @@ import torch -from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor -from ..utils import is_datasets_available +from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor +from transformers.utils import is_datasets_available from .tools import PipelineTool diff --git a/agents/tools.py b/agents/tools.py index 7597046..1c2fab7 100644 --- a/agents/tools.py +++ b/agents/tools.py @@ -30,13 +30,13 @@ from huggingface_hub import create_repo, get_collection, hf_hub_download, metada from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session from packaging import version -from ..dynamic_module_utils import ( +from transformers.dynamic_module_utils import ( custom_object_save, get_class_from_dynamic_module, get_imports, ) -from ..models.auto import AutoProcessor -from ..utils import ( +from transformers import AutoProcessor +from transformers.utils import ( CONFIG_NAME, TypeHintParsingException, cached_file, @@ -44,12 +44,11 @@ from ..utils import ( is_accelerate_available, is_torch_available, is_vision_available, - logging, ) from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs +import logging - -logger = logging.get_logger(__name__) +logger = logging.getLogger(__name__) if is_torch_available(): diff --git a/agents/translation.py b/agents/translation.py index 7ae61f9..b03c804 100644 --- a/agents/translation.py +++ b/agents/translation.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from .tools import PipelineTool diff --git a/agents/utils.py b/agents/utils.py new file mode 100644 index 0000000..38f2f78 --- /dev/null +++ b/agents/utils.py @@ -0,0 +1,10 @@ + +from transformers.utils.import_utils import _is_package_available + +_pygments_available = _is_package_available("pygments") + +def is_pygments_available(): + return _pygments_available + +from rich.console import Console +console = Console() diff --git a/docs/source/main_classes/logging.md b/docs/source/main_classes/logging.md index 5cbdf9a..e71eaf7 100644 --- a/docs/source/main_classes/logging.md +++ b/docs/source/main_classes/logging.md @@ -47,10 +47,10 @@ TRANSFORMERS_NO_ADVISORY_WARNINGS=1 ./myprogram.py Here is an example of how to use the same logger as the library in your own module or script: ```python -from transformers.utils import logging +import logging logging.set_verbosity_info() -logger = logging.get_logger("transformers") +logger = logging.getLogger(__name__)("transformers") logger.info("INFO") logger.warning("WARN") ``` @@ -104,7 +104,7 @@ See reference of the `captureWarnings` method below. [[autodoc]] logging.set_verbosity -[[autodoc]] logging.get_logger +[[autodoc]] logging.getLogger(__name__) [[autodoc]] logging.enable_default_handler diff --git a/examples/agent_with_tools.py b/examples/agent_with_tools.py new file mode 100644 index 0000000..dbb7a24 --- /dev/null +++ b/examples/agent_with_tools.py @@ -0,0 +1,20 @@ +from agents import load_tool, ReactCodeAgent, HfApiEngine + +# Import tool from Hub +image_generation_tool = load_tool("m-ric/text-to-image", cache=False) + +# Import tool from LangChain +from agents.search import DuckDuckGoSearchTool + +search_tool = DuckDuckGoSearchTool() + +llm_engine = HfApiEngine("Qwen/Qwen2.5-72B-Instruct") +# Initialize the agent with both tools +agent = ReactCodeAgent(tools=[image_generation_tool, search_tool], llm_engine=llm_engine) + +# Run it! +result = agent.run( + "Return me a photo of the car that James bond drove in the latest movie.", +) + +print(result) \ No newline at end of file diff --git a/examples/requirements.txt b/examples/requirements.txt index fd571f2..b43f5ed 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -3,3 +3,4 @@ evaluate datasets==2.3.2 schedulefree huggingface_hub>=0.20.0 +duckduckgo-search \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index facbf6a..2350048 100644 --- a/poetry.lock +++ b/poetry.lock @@ -264,6 +264,41 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "numpy" version = "2.1.3" @@ -354,6 +389,20 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pygments" +version = "2.18.0" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, + {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, +] + +[package.extras] +windows-terminal = ["colorama (>=0.4.6)"] + [[package]] name = "pytest" version = "8.3.4" @@ -562,6 +611,25 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rich" +version = "13.9.4" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"}, + {file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "safetensors" version = "0.4.5" @@ -888,4 +956,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "de2c60fd8f7c54521b2204b3af144b974059361d7892c70eeb325a7fe52bd489" +content-hash = "985797a96ac1b58e4892479d96bbd1c696b452d760fa421a28b91ea8b3dbf977" diff --git a/pyproject.toml b/pyproject.toml index 6e26980..f3e6d91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ python = ">=3.10,<3.13" transformers = ">=4.0.0" pytest = {version = ">=8.1.0", optional = true} requests = "^2.32.3" +rich = "^13.9.4" [build-system]