Get rid of necessity to declare tools in their own .py file

This commit is contained in:
Aymeric 2024-12-15 19:45:41 +01:00
parent aef0510e68
commit 0eb582bdba
3 changed files with 182 additions and 151 deletions

View File

@ -38,10 +38,10 @@ The custom tool needs:
The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`]. The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`].
Also, all imports should be put within the tool's forward function, else you will get an error.
```python ```python
from transformers import Tool from transformers import Tool
from huggingface_hub import list_models
class HFModelDownloadsTool(Tool): class HFModelDownloadsTool(Tool):
name = "model_download_counter" name = "model_download_counter"
@ -58,26 +58,27 @@ class HFModelDownloadsTool(Tool):
output_type = "string" output_type = "string"
def forward(self, task: str): def forward(self, task: str):
from huggingface_hub import list_models
model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id return model.id
```
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
```python
from model_downloads import HFModelDownloadsTool
tool = HFModelDownloadsTool() tool = HFModelDownloadsTool()
``` ```
Now the custom `HfModelDownloadsTool` class is ready.
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
```python ```python
tool.push_to_hub("{your_username}/hf-model-downloads") from dotenv import load_dotenv
load_dotenv()
tool.push_to_hub("m-ric/hf-model-downloads", token=os.getenv("HF_TOKEN"))
``` ```
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`.
```python ```python
from transformers import load_tool, CodeAgent from transformers import load_tool, CodeAgent
@ -159,7 +160,7 @@ We love Langchain and think it has a very compelling suite of tools.
To import a tool from LangChain, use the `from_langchain()` method. To import a tool from LangChain, use the `from_langchain()` method.
Here is how you can use it to recreate the intro's search result using a LangChain web search tool. Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
This tool will need `pip install google-search-results` to work properly. This tool will need `pip install langchain google-search-results -q` to work properly.
```python ```python
from langchain.agents import load_tools from langchain.agents import load_tools
from agents import Tool, CodeAgent from agents import Tool, CodeAgent
@ -191,7 +192,6 @@ agent.run(
) )
``` ```
| **Audio** | | **Audio** |
|------------------------------------------------------------------------------------------------------------------------------------------------------| |------------------------------------------------------------------------------------------------------------------------------------------------------|
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> | | <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |

View File

@ -44,7 +44,6 @@ from transformers.dynamic_module_utils import (
) )
from transformers import AutoProcessor from transformers import AutoProcessor
from transformers.utils import ( from transformers.utils import (
CONFIG_NAME,
TypeHintParsingException, TypeHintParsingException,
cached_file, cached_file,
get_json_schema, get_json_schema,
@ -53,6 +52,8 @@ from transformers.utils import (
is_vision_available, is_vision_available,
) )
from .types import ImageType, handle_agent_inputs, handle_agent_outputs from .types import ImageType, handle_agent_inputs, handle_agent_outputs
from .utils import ImportFinder
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -104,7 +105,7 @@ def setup_default_tools():
# docstyle-ignore # docstyle-ignore
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
from {module_name} import {class_name} from tool import {class_name}
launch_gradio_demo({class_name}) launch_gradio_demo({class_name})
""" """
@ -304,28 +305,44 @@ class Tool:
output_dir (`str`): The folder in which you want to save your tool. output_dir (`str`): The folder in which you want to save your tool.
""" """
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# Save module file class_name = self.__class__.__name__
if self.__module__ == "__main__":
raise ValueError(
f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
"have to put this code in a separate module so we can include it in the saved folder."
)
module_files = custom_object_save(self, output_dir)
module_name = self.__class__.__module__ # Save tool file
last_module = module_name.split(".")[-1] forward_source_code = inspect.getsource(self.forward)
full_name = f"{last_module}.{self.__class__.__name__}" tool_code = textwrap.dedent(f"""
from agents import Tool
class {class_name}(Tool):
name = "{self.name}"
description = "{self.description}"
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
output_type = "{self.output_type}"
""").strip()
import re
def add_self_argument(source_code: str) -> str:
"""Add 'self' as first argument to a function definition if not present."""
pattern = r'def forward\(((?!self)[^)]*)\)'
def replacement(match):
args = match.group(1).strip()
if args: # If there are other arguments
return f'def forward(self, {args})'
return 'def forward(self)'
return re.sub(pattern, replacement, source_code)
forward_source_code = forward_source_code.replace(self.name, "forward")
forward_source_code = add_self_argument(forward_source_code)
forward_source_code = forward_source_code.replace("@tool", "").strip()
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f:
f.write(tool_code)
# Save config file # Save config file
config_file = os.path.join(output_dir, "tool_config.json") config_file = os.path.join(output_dir, "tool_config.json")
if os.path.isfile(config_file):
with open(config_file, "r", encoding="utf-8") as f:
tool_config = json.load(f)
else:
tool_config = {}
tool_config = { tool_config = {
"tool_class": full_name, "tool_class": self.__class__.__name__,
"description": self.description, "description": self.description,
"name": self.name, "name": self.name,
"inputs": self.inputs, "inputs": self.inputs,
@ -339,131 +356,20 @@ class Tool:
with open(app_file, "w", encoding="utf-8") as f: with open(app_file, "w", encoding="utf-8") as f:
f.write( f.write(
APP_FILE_TEMPLATE.format( APP_FILE_TEMPLATE.format(
module_name=last_module, class_name=self.__class__.__name__ class_name=class_name
) )
) )
# Save requirements file # Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt") requirements_file = os.path.join(output_dir, "requirements.txt")
imports = []
for module in module_files: tree = ast.parse(forward_source_code)
imports.extend(get_imports(module)) import_finder = ImportFinder()
imports = list(set(imports)) import_finder.visit(tree)
imports = list(set(import_finder.packages))
with open(requirements_file, "w", encoding="utf-8") as f: with open(requirements_file, "w", encoding="utf-8") as f:
f.write("\n".join(imports) + "\n") f.write("agents_package\n" + "\n".join(imports) + "\n")
@classmethod
def from_hub(
cls,
repo_id: str,
token: Optional[str] = None,
**kwargs,
):
"""
Loads a tool defined on the Hub.
<Tip warning={true}>
Loading a tool from the Hub means that you'll download the tool and execute it locally.
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
installing a package using pip/npm/apt.
</Tip>
Args:
repo_id (`str`):
The name of the repo on the Hub where your tool is defined.
token (`str`, *optional*):
The token to identify you on hf.co. If unset, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
others will be passed along to its init.
"""
hub_kwargs_names = [
"cache_dir",
"force_download",
"resume_download",
"proxies",
"revision",
"repo_type",
"subfolder",
"local_files_only",
]
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
# Try to get the tool config first.
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
resolved_config_file = cached_file(
repo_id,
TOOL_CONFIG_FILE,
token=token,
**hub_kwargs,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
is_tool_config = resolved_config_file is not None
if resolved_config_file is None:
resolved_config_file = cached_file(
repo_id,
CONFIG_NAME,
token=token,
**hub_kwargs,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if resolved_config_file is None:
raise EnvironmentError(
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
)
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
if not is_tool_config:
if "custom_tool" not in config:
raise EnvironmentError(
f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
)
custom_tool = config["custom_tool"]
else:
custom_tool = config
tool_class = custom_tool["tool_class"]
tool_class = get_class_from_dynamic_module(
tool_class, repo_id, token=token, **hub_kwargs
)
if len(tool_class.name) == 0:
tool_class.name = custom_tool["name"]
if tool_class.name != custom_tool["name"]:
logger.warning(
f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool "
"configuration name."
)
tool_class.name = custom_tool["name"]
if len(tool_class.description) == 0:
tool_class.description = custom_tool["description"]
if tool_class.description != custom_tool["description"]:
logger.warning(
f"{tool_class.__name__} implements a different description in its configuration and class. Using the "
"tool configuration description."
)
tool_class.description = custom_tool["description"]
if tool_class.inputs != custom_tool["inputs"]:
tool_class.inputs = custom_tool["inputs"]
if tool_class.output_type != custom_tool["output_type"]:
tool_class.output_type = custom_tool["output_type"]
if not isinstance(tool_class.inputs, dict):
tool_class.inputs = ast.literal_eval(tool_class.inputs)
return tool_class(**kwargs)
def push_to_hub( def push_to_hub(
self, self,
@ -512,6 +418,9 @@ class Tool:
with tempfile.TemporaryDirectory() as work_dir: with tempfile.TemporaryDirectory() as work_dir:
# Save all files. # Save all files.
self.save(work_dir) self.save(work_dir)
print(work_dir)
with open(work_dir + "/tool.py", "r") as f:
print('\n'.join(f.readlines()))
logger.info( logger.info(
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}" f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
) )
@ -524,6 +433,110 @@ class Tool:
repo_type="space", repo_type="space",
) )
@classmethod
def from_hub(
cls,
repo_id: str,
token: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs,
):
"""
Loads a tool defined on the Hub.
<Tip warning={true}>
Loading a tool from the Hub means that you'll download the tool and execute it locally.
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
installing a package using pip/npm/apt.
</Tip>
Args:
repo_id (`str`):
The name of the repo on the Hub where your tool is defined.
token (`str`, *optional*):
The token to identify you on hf.co. If unset, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
others will be passed along to its init.
"""
assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
hub_kwargs_names = [
"cache_dir",
"force_download",
"resume_download",
"proxies",
"revision",
"repo_type",
"subfolder",
"local_files_only",
]
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
tool_file = "tool.py"
# Get the tool's tool.py file.
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
resolved_tool_file = cached_file(
repo_id,
tool_file,
token=token,
**hub_kwargs,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
tool_code = resolved_tool_file is not None
if resolved_tool_file is None:
resolved_tool_file = cached_file(
repo_id,
tool_file,
token=token,
**hub_kwargs,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if resolved_tool_file is None:
raise EnvironmentError(
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
)
with open(resolved_tool_file, encoding="utf-8") as reader:
tool_code = "".join(reader.readlines())
# Find the Tool subclass in the namespace
with tempfile.TemporaryDirectory() as temp_dir:
# Save the code to a file
module_path = os.path.join(temp_dir, "tool.py")
with open(module_path, "w") as f:
f.write(tool_code)
# Load module from file path
spec = importlib.util.spec_from_file_location("custom_tool", module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find and instantiate the Tool class
for item_name in dir(module):
item = getattr(module, item_name)
if isinstance(item, type) and issubclass(item, Tool) and item != Tool:
tool_class = item
break
if tool_class is None:
raise ValueError("No Tool subclass found in the code")
if not isinstance(tool_class.inputs, dict):
tool_class.inputs = ast.literal_eval(tool_class.inputs)
return tool_class(**kwargs)
@staticmethod @staticmethod
def from_space( def from_space(
space_id: str, space_id: str,
@ -967,7 +980,8 @@ def tool(tool_function: Callable) -> Tool:
raise TypeHintParsingException( raise TypeHintParsingException(
"Tool return type not found: make sure your function has a return type hint!" "Tool return type not found: make sure your function has a return type hint!"
) )
class_name = f"{parameters['name'].capitalize()}Tool" class_name = ''.join([el.title() for el in parameters['name'].split('_')])
if parameters["return"]["type"] == "object": if parameters["return"]["type"] == "object":
parameters["return"]["type"] = "any" parameters["return"]["type"] = "any"

View File

@ -17,6 +17,7 @@
import json import json
import re import re
from typing import Tuple, Dict, Union from typing import Tuple, Dict, Union
import ast
from transformers.utils.import_utils import _is_package_available from transformers.utils.import_utils import _is_package_available
@ -109,5 +110,21 @@ def truncate_content(
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n" + f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :] + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
) )
class ImportFinder(ast.NodeVisitor):
def __init__(self):
self.packages = set()
def visit_Import(self, node):
for alias in node.names:
# Get the base package name (before any dots)
base_package = alias.name.split('.')[0]
self.packages.add(base_package)
def visit_ImportFrom(self, node):
if node.module: # for "from x import y" statements
# Get the base package name (before any dots)
base_package = node.module.split('.')[0]
self.packages.add(base_package)
__all__ = [] __all__ = []