Get rid of necessity to declare tools in their own .py file
This commit is contained in:
parent
aef0510e68
commit
0eb582bdba
|
@ -38,10 +38,10 @@ The custom tool needs:
|
|||
|
||||
The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`].
|
||||
|
||||
Also, all imports should be put within the tool's forward function, else you will get an error.
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
from huggingface_hub import list_models
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
|
@ -58,26 +58,27 @@ class HFModelDownloadsTool(Tool):
|
|||
output_type = "string"
|
||||
|
||||
def forward(self, task: str):
|
||||
from huggingface_hub import list_models
|
||||
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
```
|
||||
|
||||
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
|
||||
|
||||
|
||||
```python
|
||||
from model_downloads import HFModelDownloadsTool
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
```
|
||||
|
||||
Now the custom `HfModelDownloadsTool` class is ready.
|
||||
|
||||
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
|
||||
|
||||
```python
|
||||
tool.push_to_hub("{your_username}/hf-model-downloads")
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
tool.push_to_hub("m-ric/hf-model-downloads", token=os.getenv("HF_TOKEN"))
|
||||
```
|
||||
|
||||
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
|
||||
Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`.
|
||||
|
||||
```python
|
||||
from transformers import load_tool, CodeAgent
|
||||
|
@ -159,7 +160,7 @@ We love Langchain and think it has a very compelling suite of tools.
|
|||
To import a tool from LangChain, use the `from_langchain()` method.
|
||||
|
||||
Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
|
||||
This tool will need `pip install google-search-results` to work properly.
|
||||
This tool will need `pip install langchain google-search-results -q` to work properly.
|
||||
```python
|
||||
from langchain.agents import load_tools
|
||||
from agents import Tool, CodeAgent
|
||||
|
@ -191,7 +192,6 @@ agent.run(
|
|||
)
|
||||
```
|
||||
|
||||
|
||||
| **Audio** |
|
||||
|------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
|
||||
|
|
|
@ -44,7 +44,6 @@ from transformers.dynamic_module_utils import (
|
|||
)
|
||||
from transformers import AutoProcessor
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
TypeHintParsingException,
|
||||
cached_file,
|
||||
get_json_schema,
|
||||
|
@ -53,6 +52,8 @@ from transformers.utils import (
|
|||
is_vision_available,
|
||||
)
|
||||
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
|
||||
from .utils import ImportFinder
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -104,7 +105,7 @@ def setup_default_tools():
|
|||
|
||||
# docstyle-ignore
|
||||
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
||||
from {module_name} import {class_name}
|
||||
from tool import {class_name}
|
||||
|
||||
launch_gradio_demo({class_name})
|
||||
"""
|
||||
|
@ -304,28 +305,44 @@ class Tool:
|
|||
output_dir (`str`): The folder in which you want to save your tool.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# Save module file
|
||||
if self.__module__ == "__main__":
|
||||
raise ValueError(
|
||||
f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
|
||||
"have to put this code in a separate module so we can include it in the saved folder."
|
||||
)
|
||||
module_files = custom_object_save(self, output_dir)
|
||||
class_name = self.__class__.__name__
|
||||
|
||||
module_name = self.__class__.__module__
|
||||
last_module = module_name.split(".")[-1]
|
||||
full_name = f"{last_module}.{self.__class__.__name__}"
|
||||
# Save tool file
|
||||
forward_source_code = inspect.getsource(self.forward)
|
||||
tool_code = textwrap.dedent(f"""
|
||||
from agents import Tool
|
||||
|
||||
class {class_name}(Tool):
|
||||
name = "{self.name}"
|
||||
description = "{self.description}"
|
||||
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
|
||||
output_type = "{self.output_type}"
|
||||
""").strip()
|
||||
|
||||
import re
|
||||
def add_self_argument(source_code: str) -> str:
|
||||
"""Add 'self' as first argument to a function definition if not present."""
|
||||
pattern = r'def forward\(((?!self)[^)]*)\)'
|
||||
|
||||
def replacement(match):
|
||||
args = match.group(1).strip()
|
||||
if args: # If there are other arguments
|
||||
return f'def forward(self, {args})'
|
||||
return 'def forward(self)'
|
||||
|
||||
return re.sub(pattern, replacement, source_code)
|
||||
|
||||
forward_source_code = forward_source_code.replace(self.name, "forward")
|
||||
forward_source_code = add_self_argument(forward_source_code)
|
||||
forward_source_code = forward_source_code.replace("@tool", "").strip()
|
||||
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
|
||||
with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f:
|
||||
f.write(tool_code)
|
||||
|
||||
# Save config file
|
||||
config_file = os.path.join(output_dir, "tool_config.json")
|
||||
if os.path.isfile(config_file):
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
tool_config = json.load(f)
|
||||
else:
|
||||
tool_config = {}
|
||||
|
||||
tool_config = {
|
||||
"tool_class": full_name,
|
||||
"tool_class": self.__class__.__name__,
|
||||
"description": self.description,
|
||||
"name": self.name,
|
||||
"inputs": self.inputs,
|
||||
|
@ -339,131 +356,20 @@ class Tool:
|
|||
with open(app_file, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
APP_FILE_TEMPLATE.format(
|
||||
module_name=last_module, class_name=self.__class__.__name__
|
||||
class_name=class_name
|
||||
)
|
||||
)
|
||||
|
||||
# Save requirements file
|
||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||
imports = []
|
||||
for module in module_files:
|
||||
imports.extend(get_imports(module))
|
||||
imports = list(set(imports))
|
||||
|
||||
tree = ast.parse(forward_source_code)
|
||||
import_finder = ImportFinder()
|
||||
import_finder.visit(tree)
|
||||
|
||||
imports = list(set(import_finder.packages))
|
||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(imports) + "\n")
|
||||
|
||||
@classmethod
|
||||
def from_hub(
|
||||
cls,
|
||||
repo_id: str,
|
||||
token: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads a tool defined on the Hub.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Loading a tool from the Hub means that you'll download the tool and execute it locally.
|
||||
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
||||
installing a package using pip/npm/apt.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The name of the repo on the Hub where your tool is defined.
|
||||
token (`str`, *optional*):
|
||||
The token to identify you on hf.co. If unset, will use the token generated when running
|
||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
||||
others will be passed along to its init.
|
||||
"""
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"resume_download",
|
||||
"proxies",
|
||||
"revision",
|
||||
"repo_type",
|
||||
"subfolder",
|
||||
"local_files_only",
|
||||
]
|
||||
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
|
||||
|
||||
# Try to get the tool config first.
|
||||
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
|
||||
resolved_config_file = cached_file(
|
||||
repo_id,
|
||||
TOOL_CONFIG_FILE,
|
||||
token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
is_tool_config = resolved_config_file is not None
|
||||
if resolved_config_file is None:
|
||||
resolved_config_file = cached_file(
|
||||
repo_id,
|
||||
CONFIG_NAME,
|
||||
token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
raise EnvironmentError(
|
||||
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
|
||||
)
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
|
||||
if not is_tool_config:
|
||||
if "custom_tool" not in config:
|
||||
raise EnvironmentError(
|
||||
f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
|
||||
)
|
||||
custom_tool = config["custom_tool"]
|
||||
else:
|
||||
custom_tool = config
|
||||
|
||||
tool_class = custom_tool["tool_class"]
|
||||
tool_class = get_class_from_dynamic_module(
|
||||
tool_class, repo_id, token=token, **hub_kwargs
|
||||
)
|
||||
|
||||
if len(tool_class.name) == 0:
|
||||
tool_class.name = custom_tool["name"]
|
||||
if tool_class.name != custom_tool["name"]:
|
||||
logger.warning(
|
||||
f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool "
|
||||
"configuration name."
|
||||
)
|
||||
tool_class.name = custom_tool["name"]
|
||||
|
||||
if len(tool_class.description) == 0:
|
||||
tool_class.description = custom_tool["description"]
|
||||
if tool_class.description != custom_tool["description"]:
|
||||
logger.warning(
|
||||
f"{tool_class.__name__} implements a different description in its configuration and class. Using the "
|
||||
"tool configuration description."
|
||||
)
|
||||
tool_class.description = custom_tool["description"]
|
||||
|
||||
if tool_class.inputs != custom_tool["inputs"]:
|
||||
tool_class.inputs = custom_tool["inputs"]
|
||||
if tool_class.output_type != custom_tool["output_type"]:
|
||||
tool_class.output_type = custom_tool["output_type"]
|
||||
|
||||
if not isinstance(tool_class.inputs, dict):
|
||||
tool_class.inputs = ast.literal_eval(tool_class.inputs)
|
||||
|
||||
return tool_class(**kwargs)
|
||||
f.write("agents_package\n" + "\n".join(imports) + "\n")
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
|
@ -512,6 +418,9 @@ class Tool:
|
|||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
# Save all files.
|
||||
self.save(work_dir)
|
||||
print(work_dir)
|
||||
with open(work_dir + "/tool.py", "r") as f:
|
||||
print('\n'.join(f.readlines()))
|
||||
logger.info(
|
||||
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
|
||||
)
|
||||
|
@ -524,6 +433,110 @@ class Tool:
|
|||
repo_type="space",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_hub(
|
||||
cls,
|
||||
repo_id: str,
|
||||
token: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads a tool defined on the Hub.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Loading a tool from the Hub means that you'll download the tool and execute it locally.
|
||||
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
||||
installing a package using pip/npm/apt.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The name of the repo on the Hub where your tool is defined.
|
||||
token (`str`, *optional*):
|
||||
The token to identify you on hf.co. If unset, will use the token generated when running
|
||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
||||
others will be passed along to its init.
|
||||
"""
|
||||
assert trust_remote_code, "Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
|
||||
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"resume_download",
|
||||
"proxies",
|
||||
"revision",
|
||||
"repo_type",
|
||||
"subfolder",
|
||||
"local_files_only",
|
||||
]
|
||||
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
|
||||
|
||||
tool_file = "tool.py"
|
||||
|
||||
# Get the tool's tool.py file.
|
||||
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
|
||||
resolved_tool_file = cached_file(
|
||||
repo_id,
|
||||
tool_file,
|
||||
token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
tool_code = resolved_tool_file is not None
|
||||
if resolved_tool_file is None:
|
||||
resolved_tool_file = cached_file(
|
||||
repo_id,
|
||||
tool_file,
|
||||
token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_tool_file is None:
|
||||
raise EnvironmentError(
|
||||
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
|
||||
)
|
||||
|
||||
with open(resolved_tool_file, encoding="utf-8") as reader:
|
||||
tool_code = "".join(reader.readlines())
|
||||
|
||||
# Find the Tool subclass in the namespace
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save the code to a file
|
||||
module_path = os.path.join(temp_dir, "tool.py")
|
||||
with open(module_path, "w") as f:
|
||||
f.write(tool_code)
|
||||
|
||||
# Load module from file path
|
||||
spec = importlib.util.spec_from_file_location("custom_tool", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find and instantiate the Tool class
|
||||
for item_name in dir(module):
|
||||
item = getattr(module, item_name)
|
||||
if isinstance(item, type) and issubclass(item, Tool) and item != Tool:
|
||||
tool_class = item
|
||||
break
|
||||
|
||||
if tool_class is None:
|
||||
raise ValueError("No Tool subclass found in the code")
|
||||
|
||||
if not isinstance(tool_class.inputs, dict):
|
||||
tool_class.inputs = ast.literal_eval(tool_class.inputs)
|
||||
|
||||
return tool_class(**kwargs)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_space(
|
||||
space_id: str,
|
||||
|
@ -967,7 +980,8 @@ def tool(tool_function: Callable) -> Tool:
|
|||
raise TypeHintParsingException(
|
||||
"Tool return type not found: make sure your function has a return type hint!"
|
||||
)
|
||||
class_name = f"{parameters['name'].capitalize()}Tool"
|
||||
class_name = ''.join([el.title() for el in parameters['name'].split('_')])
|
||||
|
||||
if parameters["return"]["type"] == "object":
|
||||
parameters["return"]["type"] = "any"
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
import json
|
||||
import re
|
||||
from typing import Tuple, Dict, Union
|
||||
import ast
|
||||
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
@ -110,4 +111,20 @@ def truncate_content(
|
|||
+ content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
|
||||
)
|
||||
|
||||
class ImportFinder(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.packages = set()
|
||||
|
||||
def visit_Import(self, node):
|
||||
for alias in node.names:
|
||||
# Get the base package name (before any dots)
|
||||
base_package = alias.name.split('.')[0]
|
||||
self.packages.add(base_package)
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
if node.module: # for "from x import y" statements
|
||||
# Get the base package name (before any dots)
|
||||
base_package = node.module.split('.')[0]
|
||||
self.packages.add(base_package)
|
||||
|
||||
__all__ = []
|
Loading…
Reference in New Issue