Add transcriber tool and PipelineTool

This commit is contained in:
Aymeric 2024-12-23 22:49:32 +01:00
parent cb7e68f2f0
commit d389f11e37
13 changed files with 218 additions and 256 deletions

View File

@ -108,4 +108,8 @@ Code is just a better way to express actions on a computer. It has better:
- **Generality:** code is built to express simply anything you can do have a computer do. - **Generality:** code is built to express simply anything you can do have a computer do.
- **Representation in LLM training corpuses:** why not leverage this benediction of the sky that plenty of quality actions have already been included in LLM training corpuses? - **Representation in LLM training corpuses:** why not leverage this benediction of the sky that plenty of quality actions have already been included in LLM training corpuses?
This is why we put emphasis on proposing code agents, in this case python agents, which meant putting higher effort on building python interpreters. This is illustrated on the figure below, taken from [Executable Code Actions Elicit Better LLM Agents](https://huggingface.co/papers/2402.01030).
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/code_vs_json_actions.png">
This is why we put emphasis on proposing code agents, in this case python agents, which meant putting higher effort on building secure python interpreters.

View File

@ -17,18 +17,16 @@ rendered properly in your Markdown viewer.
[[open-in-colab]] [[open-in-colab]]
This visit of your framework should take about 15 minutes. It will show you how to build an agent, how to run it, and how to customize it to make it work better for your use-case. For more in-depth usage, you will then want to check out our tutorials like [Building good agents](./tutorials/building_good_agents). In this guided visit, you will learn how to build an agent, how to run it, and how to customize it to make it work better for your use-case.
### Building your agent ### Building your agent
To initialize an agent, you need these arguments: To initialize a minimal agent, you need at least these two arguments:
- An LLM to power your agent - because the agent is different from a simple LLM, it is a system that uses a LLM as its engine. - An LLM to power your agent - because the agent is different from a simple LLM, it is a system that uses a LLM as its engine.
- A toolbox from which the agent pick tools to execute - A list of tools from which the agent pick tools to execute
Upon initialization of the agent system, a system prompt (attribute `system_prompt`) is built automatically by turning the description extracted from the tools into a predefined system prompt template. But you can customize it! For defining your llm, you can make a `llm_engine` method which accepts a list of [messages](./chat_templating) and returns text. This callable also needs to accept a `stop_sequences` argument that indicates when to stop generating.
For defining your llm, you can make a `llm_engine` method which accepts a list of [messages](./chat_templating) and returns text. This callable also needs to accept a `stop` argument that indicates when to stop generating.
```python ```python
from huggingface_hub import login, InferenceClient from huggingface_hub import login, InferenceClient
@ -51,10 +49,14 @@ You could use any `llm_engine` method as long as:
Additionally, `llm_engine` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to llm_engine, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs. Additionally, `llm_engine` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to llm_engine, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs.
For convenience, we provide pre-built classes for your llm engine:
- [`TransformersEngine`] takes a pre-initialized `transformers` pipeline to run inference on your local machine using `transformers`.
- [`HfApiEngine`] leverages a `huggingface_hub.InferenceClient` under the hood.
- We also provide [`OpenAIEngine`] and [`AnthropicEngine`] but you could use anything!
You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`. You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
Now you can create an agent, like [`CodeAgent`], and run it. You can also create a [`TransformersEngine`] with a pre-initialized pipeline to run inference on your local machine using `transformers`. Once you have these two arguments, `tools` and `llm_engine`, you can create an agent and run it.
For convenience, since agentic behaviours generally require strong models that are harder to run locally for now, we also provide the [`HfApiEngine`] class that initializes a `huggingface_hub.InferenceClient` under the hood.
```python ```python
from agents import CodeAgent, HfApiEngine from agents import CodeAgent, HfApiEngine
@ -63,12 +65,10 @@ llm_engine = HfApiEngine(model=model_id)
agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True) agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
agent.run( agent.run(
"Could you translate this sentence from French, say it out loud and return the audio.", "Could you give me the 118th number in the Fibonacci sequence?",
sentence="Où est la boulangerie la plus proche?",
) )
``` ```
This will be handy in case of emergency baguette need!
You can even leave the argument `llm_engine` undefined, and an [`HfApiEngine`] will be created by default. You can even leave the argument `llm_engine` undefined, and an [`HfApiEngine`] will be created by default.
```python ```python
@ -77,24 +77,23 @@ from agents import CodeAgent
agent = CodeAgent(tools=[], add_base_tools=True) agent = CodeAgent(tools=[], add_base_tools=True)
agent.run( agent.run(
"Could you translate this sentence from French, say it out loud and give me the audio.", "Could you give me the 118th number in the Fibonacci sequence?",
sentence="Où est la boulangerie la plus proche?", additional_detail="We adopt the convention where the first two numbers are 0 and 1."
) )
``` ```
Note that we used an additional `sentence` argument: you can pass text as additional arguments to the model. Note that we used an additional `additional_detail` argument: you can additional kwargs to `agent.run()`, they will be baked into the prompt as text.
You can also use this to indicate the path to local or remote files for the model to use: You can use this to indicate the path to local or remote files for the model to use:
```py ```py
from agents import CodeAgent from agents import CodeAgent, Tool, SpeechToTextTool
agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True) agent = CodeAgent(tools=[SpeechToTextTool()], add_base_tools=True)
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3") agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
``` ```
It's important to explain as clearly as possible the task you want to perform. It's important to explain as clearly as possible the task you want to perform.
Since an agent is powered by an LLM, minor variations in your task formulation might yield completely different results. Since an agent is powered by an LLM, minor variations in your task formulation might yield completely different results.
You can also run an agent consecutively for different tasks: if you leave the default option of `True` for the flag `reset` when calling `agent.run(task)`, the agent's memory will be erased before starting the new task. You can also run an agent consecutively for different tasks: if you leave the default option of `True` for the flag `reset` when calling `agent.run(task)`, the agent's memory will be erased before starting the new task.
@ -117,14 +116,18 @@ This gives you at the end of the agent run:
```text ```text
'Hugging Face Blog' 'Hugging Face Blog'
``` ```
The execution will stop at any code trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent. The execution will stop at any code trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent. You can also use E2B code executor instead of a local Python interpreter by passing `use_e2b_executor=True` upon agent initialization.
> [!WARNING] > [!WARNING]
> The LLM can generate arbitrary code that will then be executed: do not add any unsafe imports! > The LLM can generate arbitrary code that will then be executed: do not add any unsafe imports!
### The system prompt ### The system prompt
An agent, or rather the LLM that drives the agent, generates an output based on the system prompt. The system prompt can be customized and tailored to the intended task. For example, check the system prompt for the [`CodeAgent`] (below version is slightly simplified). Upon initialization of the agent system, a system prompt (attribute `system_prompt`) is built automatically by turning the description extracted from the tools into a predefined system prompt template.
But you can customize it!
Let's see how it works. For example, check the system prompt for the [`CodeAgent`] (below version is slightly simplified).
The prompt and output parser were automatically defined, but you can easily inspect them by calling the `system_prompt_template` on your agent. The prompt and output parser were automatically defined, but you can easily inspect them by calling the `system_prompt_template` on your agent.
@ -207,6 +210,7 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
- **DuckDuckGo web search***: performs a web search using DuckDuckGo browser. - **DuckDuckGo web search***: performs a web search using DuckDuckGo browser.
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`JsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code - **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`JsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
- **Transcriber**: a speech-to-text pipeline built on Whisper-Turbo that transcribes an audio to text.
You can manually use a tool by calling the [`load_tool`] function and a task to perform. You can manually use a tool by calling the [`load_tool`] function and a task to perform.
@ -217,7 +221,6 @@ search_tool = load_tool("web_search")
print(search_tool("Who's the current president of Russia?")) print(search_tool("Who's the current president of Russia?"))
``` ```
### Create a new tool ### Create a new tool
You can create your own tool for use cases not covered by the default tools from Hugging Face. You can create your own tool for use cases not covered by the default tools from Hugging Face.
@ -320,7 +323,7 @@ manager_agent.run("Who is the CEO of Hugging Face?")
> For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia). > For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia).
## Talk with your agent in a cool Gradio interface ## Talk with your agent and visualize its thoughts in a cool Gradio interface
You can use `GradioUI` to interactively submit tasks to your agent and observe its thought and execution process, here is an example: You can use `GradioUI` to interactively submit tasks to your agent and observe its thought and execution process, here is an example:
@ -344,4 +347,11 @@ GradioUI(agent).launch()
``` ```
Under the hood, when the user types a new answer, the agent is launched with `agent.run(user_request, reset=False)`. Under the hood, when the user types a new answer, the agent is launched with `agent.run(user_request, reset=False)`.
The `reset=False` flag means the agent's memory is not flushed before launching this new task, which lets the conversation go on. The `reset=False` flag means the agent's memory is not flushed before launching this new task, which lets the conversation go on.
## Next steps
For more in-depth usage, you will then want to check out our tutorials:
- [the explanation of how our code agents work](./tutorials/secure_code_execution)
- [this guide on how to build good agents](./tutorials/building_good_agents).
- [the in-depth guide for tool usage](./tutorials/building_good_agents).

View File

@ -29,6 +29,10 @@ Code is just a better way to express actions on a computer. It has better:
- **Generality:** code is built to express simply anything you can do have a computer do. - **Generality:** code is built to express simply anything you can do have a computer do.
- **Representation in LLM training corpuses:** why not leverage this benediction of the sky that plenty of quality actions have already been included in LLM training corpuses? - **Representation in LLM training corpuses:** why not leverage this benediction of the sky that plenty of quality actions have already been included in LLM training corpuses?
This is illustrated on the figure below, taken from [Executable Code Actions Elicit Better LLM Agents](https://huggingface.co/papers/2402.01030).
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/code_vs_json_actions.png">
This is why we put emphasis on proposing code agents, in this case python agents, which meant putting higher effort on building secure python interpreters. This is why we put emphasis on proposing code agents, in this case python agents, which meant putting higher effort on building secure python interpreters.
### Local python interpreter ### Local python interpreter

View File

@ -91,7 +91,7 @@ model_download_tool = load_tool(
) )
``` ```
### Import a Space as a tool 🚀 ### Import a Space as a tool
You can directly import a Space from the Hub as a tool using the [`Tool.from_space`] method! You can directly import a Space from the Hub as a tool using the [`Tool.from_space`] method!
@ -103,7 +103,8 @@ For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-
image_generation_tool = Tool.from_space( image_generation_tool = Tool.from_space(
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell",
name="image_generator", name="image_generator",
description="Generate an image from a prompt") description="Generate an image from a prompt"
)
image_generation_tool("A sunny beach") image_generation_tool("A sunny beach")
``` ```

View File

@ -1,9 +1,8 @@
from agents import Tool, CodeAgent from agents import Tool, CodeAgent
from agents.default_tools.search import VisitWebpageTool from agents.default_tools.search import VisitWebpageTool
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv()
LAUNCH_GRADIO = False load_dotenv()
class GetCatImageTool(Tool): class GetCatImageTool(Tool):
name="get_cat_image" name="get_cat_image"
@ -24,6 +23,8 @@ class GetCatImageTool(Tool):
return Image.open(BytesIO(response.content)) return Image.open(BytesIO(response.content))
LAUNCH_GRADIO = False
get_cat_image = GetCatImageTool() get_cat_image = GetCatImageTool()

View File

@ -24,17 +24,16 @@ from transformers.utils.import_utils import define_import_structure
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import * from .agents import *
from .default_tools.base import * from .default_tools import *
from .default_tools.search import *
from .gradio_ui import * from .gradio_ui import *
from .llm_engines import * from .llm_engines import *
from .local_python_executor import * from .local_python_executor import *
from .e2b_executor import *
from .monitoring import * from .monitoring import *
from .prompts import * from .prompts import *
from .tools import * from .tools import *
from .types import * from .types import *
from .utils import * from .utils import *
from .default_tools.search import *
else: else:

View File

@ -25,7 +25,7 @@ from rich.text import Text
from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content
from .types import AgentAudio, AgentImage from .types import AgentAudio, AgentImage
from .default_tools.base import FinalAnswerTool from .default_tools import FinalAnswerTool
from .llm_engines import HfApiEngine, MessageRole from .llm_engines import HfApiEngine, MessageRole
from .monitoring import Monitor from .monitoring import Monitor
from .prompts import ( from .prompts import (

View File

@ -1,136 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from dataclasses import dataclass
from typing import Dict
from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode
from ..local_python_executor import (
BASE_BUILTIN_MODULES,
BASE_PYTHON_TOOLS,
evaluate_python_code,
)
from ..tools import TOOL_CONFIG_FILE, Tool
@dataclass
class PreTool:
name: str
inputs: Dict[str, str]
output_type: type
task: str
description: str
repo_id: str
def get_remote_tools(logger, organization="huggingface-tools"):
if is_offline_mode():
logger.info("You are in offline mode, so remote tools are not available.")
return {}
spaces = list_spaces(author=organization)
tools = {}
for space_info in spaces:
repo_id = space_info.id
resolved_config_file = hf_hub_download(
repo_id, TOOL_CONFIG_FILE, repo_type="space"
)
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
task = repo_id.split("/")[-1]
tools[config["name"]] = PreTool(
task=task,
description=config["description"],
repo_id=repo_id,
name=task,
inputs=config["inputs"],
output_type=config["output_type"],
)
return tools
class PythonInterpreterTool(Tool):
name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations."
inputs = {
"code": {
"type": "string",
"description": "The python code to run in interpreter",
}
}
output_type = "string"
def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None:
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
else:
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(authorized_imports)
)
self.inputs = {
"code": {
"type": "string",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
),
}
}
self.base_python_tool = BASE_PYTHON_TOOLS
self.python_evaluator = evaluate_python_code
super().__init__(*args, **kwargs)
def forward(self, code: str) -> str:
output = str(
self.python_evaluator(
code,
static_tools=self.base_python_tool,
authorized_imports=self.authorized_imports,
)
)
return output
class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {
"answer": {"type": "object", "description": "The final answer to the problem"}
}
output_type = "object"
def forward(self, answer):
return answer
class UserInputTool(Tool):
name = "user_input"
description = "Asks for user's input on a specific question"
inputs = {
"question": {"type": "string", "description": "The question to ask the user"}
}
output_type = "string"
def forward(self, question):
user_input = input(f"{question} => ")
return user_input
__all__ = ["PythonInterpreterTool", "FinalAnswerTool", "UserInputTool"]

View File

@ -1,81 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from ..tools import Tool
class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Performs a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'."""
inputs = {
"query": {"type": "string", "description": "The search query to perform."}
}
output_type = "any"
def forward(self, query: str) -> str:
try:
from duckduckgo_search import DDGS
except ImportError:
raise ImportError(
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
)
results = DDGS().text(query, max_results=7)
return results
class VisitWebpageTool(Tool):
name = "visit_webpage"
description = "Visits a webpage at the given url and returns its content as a markdown string."
inputs = {
"url": {
"type": "string",
"description": "The url of the webpage to visit.",
}
}
output_type = "string"
def forward(self, url: str) -> str:
try:
from markdownify import markdownify
import requests
from requests.exceptions import RequestException
except ImportError:
raise ImportError(
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
)
try:
# Send a GET request to the URL
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
# Convert the HTML content to Markdown
markdown_content = markdownify(response.text).strip()
# Remove multiple line breaks
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
return markdown_content
except RequestException as e:
return f"Error fetching the webpage: {str(e)}"
except Exception as e:
return f"An unexpected error occurred: {str(e)}"
__all__ = ["DuckDuckGoSearchTool", "VisitWebpageTool"]

View File

@ -237,20 +237,20 @@ class HfApiEngine(HfEngine):
# Send messages to the Hugging Face Inference API # Send messages to the Hugging Face Inference API
if grammar is not None: if grammar is not None:
response = self.client.chat_completion( output = self.client.chat_completion(
messages, messages,
stop=stop_sequences, stop=stop_sequences,
response_format=grammar, response_format=grammar,
max_tokens=max_tokens, max_tokens=max_tokens,
) )
else: else:
response = self.client.chat.completions.create( output = self.client.chat.completions.create(
messages, stop=stop_sequences, max_tokens=max_tokens messages, stop=stop_sequences, max_tokens=max_tokens
) )
response = response.choices[0].message.content response = output.choices[0].message.content
self.last_input_token_count = response.usage.prompt_tokens self.last_input_token_count = output.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens self.last_output_token_count = output.usage.completion_tokens
return response return response
def get_tool_call( def get_tool_call(

View File

@ -753,6 +753,7 @@ def launch_gradio_demo(tool: Tool):
TOOL_MAPPING = { TOOL_MAPPING = {
"python_interpreter": "PythonInterpreterTool", "python_interpreter": "PythonInterpreterTool",
"web_search": "DuckDuckGoSearchTool", "web_search": "DuckDuckGoSearchTool",
"transcriber": "SpeechToTextTool"
} }
@ -1003,6 +1004,160 @@ class Toolbox:
toolbox_description += f"\t{tool.name}: {tool.description}\n" toolbox_description += f"\t{tool.name}: {tool.description}\n"
return toolbox_description return toolbox_description
from transformers import AutoProcessor
from .types import handle_agent_input_types, handle_agent_output_types
class PipelineTool(Tool):
"""
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
need to specify:
- **model_class** (`type`) -- The class to use to load the model in this tool.
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
pre-processor
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
post-processor (when different from the pre-processor).
Args:
model (`str` or [`PreTrainedModel`], *optional*):
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
value of the class attribute `default_checkpoint`.
pre_processor (`str` or `Any`, *optional*):
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
unset.
post_processor (`str` or `Any`, *optional*):
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
unset.
device (`int`, `str` or `torch.device`, *optional*):
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
CPU otherwise.
device_map (`str` or `dict`, *optional*):
If passed along, will be used to instantiate the model.
model_kwargs (`dict`, *optional*):
Any keyword argument to send to the model instantiation.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
running `huggingface-cli login` (stored in `~/.huggingface`).
hub_kwargs (additional keyword arguments, *optional*):
Any additional keyword argument to send to the methods that will load the data from the Hub.
"""
pre_processor_class = AutoProcessor
model_class = None
post_processor_class = AutoProcessor
default_checkpoint = None
description = "This is a pipeline tool"
name = "pipeline"
inputs = {"prompt": str}
output_type = str
def __init__(
self,
model=None,
pre_processor=None,
post_processor=None,
device=None,
device_map=None,
model_kwargs=None,
token=None,
**hub_kwargs,
):
if not is_torch_available():
raise ImportError("Please install torch in order to use this tool.")
if not is_accelerate_available():
raise ImportError("Please install accelerate in order to use this tool.")
if model is None:
if self.default_checkpoint is None:
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
model = self.default_checkpoint
if pre_processor is None:
pre_processor = model
self.model = model
self.pre_processor = pre_processor
self.post_processor = post_processor
self.device = device
self.device_map = device_map
self.model_kwargs = {} if model_kwargs is None else model_kwargs
if device_map is not None:
self.model_kwargs["device_map"] = device_map
self.hub_kwargs = hub_kwargs
self.hub_kwargs["token"] = token
super().__init__()
def setup(self):
"""
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
"""
from accelerate import PartialState
if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
if self.post_processor is None:
self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
if self.device is None:
if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0]
else:
self.device = PartialState().default_device
if self.device_map is None:
self.model.to(self.device)
super().setup()
def encode(self, raw_inputs):
"""
Uses the `pre_processor` to prepare the inputs for the `model`.
"""
return self.pre_processor(raw_inputs)
def forward(self, inputs):
"""
Sends the inputs through the `model`.
"""
with torch.no_grad():
return self.model(**inputs)
def decode(self, outputs):
"""
Uses the `post_processor` to decode the model output.
"""
return self.post_processor(outputs)
def __call__(self, *args, **kwargs):
args, kwargs = handle_agent_input_types(*args, **kwargs)
if not self.is_initialized:
self.setup()
encoded_inputs = self.encode(*args, **kwargs)
import torch
from accelerate.utils import send_to_device
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
encoded_inputs = send_to_device(tensor_inputs, self.device)
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
outputs = send_to_device(outputs, "cpu")
decoded_outputs = self.decode(outputs)
return handle_agent_output_types(decoded_outputs, self.output_type)
__all__ = [ __all__ = [
"AUTHORIZED_TYPES", "AUTHORIZED_TYPES",

View File

@ -17,7 +17,7 @@ import pathlib
import tempfile import tempfile
import uuid import uuid
from io import BytesIO from io import BytesIO
import requests
import numpy as np import numpy as np
from transformers.utils import ( from transformers.utils import (
@ -224,7 +224,12 @@ class AgentAudio(AgentType, str):
return self._tensor return self._tensor
if self._path is not None: if self._path is not None:
tensor, self.samplerate = sf.read(self._path) if "://" in str(self._path):
response = requests.get(self._path)
response.raise_for_status()
tensor, self.samplerate = sf.read(BytesIO(response.content))
else:
tensor, self.samplerate = sf.read(self._path)
self._tensor = torch.tensor(tensor) self._tensor = torch.tensor(tensor)
return self._tensor return self._tensor