Consolidate pushing Tools to Hub

This commit is contained in:
Aymeric 2024-12-19 16:05:17 +01:00
parent 00b9a71453
commit 584ce8f363
17 changed files with 506 additions and 213 deletions

View File

@ -1,5 +1,5 @@
# Base Python image
FROM python:3.12-slim
FROM python:3.9-slim
# Set working directory
WORKDIR /app
@ -25,4 +25,7 @@ RUN pip install -e .
COPY server.py /app/server.py
# Expose the port your server will run on
EXPOSE 65432
CMD ["python", "/app/server.py"]

View File

@ -27,3 +27,20 @@ limitations under the License.
<h3 align="center">
<p>Run agents!
</h3>
W
<div class="flex justify-center">
<img
class="block dark:hidden"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Agent_ManimCE.gif"
/>
<img
class="hidden dark:block"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Agent_ManimCE.gif"
/>
</div>
To run Docker, run `docker build -t pyrunner:latest .`
This will use the Local Dockerfile to create your Docker image!

View File

@ -23,11 +23,11 @@ Here, we're going to see advanced tool usage.
> If you're new to `agents`, make sure to first read the main [agents documentation](./agents).
### Directly define a tool by subclassing Tool, and share it to the Hub
### Directly define a tool by subclassing Tool
Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.
Let's take again the tool example from the [quicktour](../quicktour), for which we had implemented a `@tool` decorator. The `tool` decorator is the standard format, but sometimes you need more: use several methods in a class for more clarity, or using additional class attributes.
If you need to add variation, like custom attributes for your tool, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
In this case, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name it `model_download_counter`.
@ -67,19 +67,28 @@ tool = HFModelDownloadsTool()
Now the custom `HfModelDownloadsTool` class is ready.
### Share your tool to the Hub
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
```python
tool.push_to_hub("m-ric/hf-model-downloads", token="<YOUR_HUGGINGFACEHUB_API_TOKEN>")
tool.push_to_hub("{your_username}/hf-model-downloads", token="<YOUR_HUGGINGFACEHUB_API_TOKEN>")
```
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
For the push to Hub to work, your tool will need to respect some rules:
- All method are self-contained, e.g. use variables that come either from their args,
- If you subclass the `__init__` method, you can give it no other argument than `self`. This is because arguments set during a specific tool instance's initialization are hard to track, which prevents from sharing them properly to the hub. And anyway, the idea of making a specific class is that you can already set class attributes for anything you need to hard-code (just set `your_variable=(...)` directly under the `class YourTool(Tool):` line). And of course you can still create a class attribute anywhere in your code by assigning stuff to `self.your_variable`.
Once your tool is pushed to Hub, you can load it with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`.
```python
from agents import load_tool, CodeAgent
model_download_tool = load_tool("m-ric/hf-model-downloads", trust_remote_code=True)
model_download_tool = load_tool(
"{your_username}/hf-model-downloads",
trust_remote_code=True
)
```
### Import a Space as a tool 🚀
@ -214,8 +223,4 @@ agent = CodeAgent(tools=[*image_tool_collection.tools], llm_engine=llm_engine, a
agent.run("Please draw me a picture of rivers and lakes.")
```
To speed up the start, tools are loaded only if called by the agent.
This gets you this image:
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png">
To speed up the start, tools are loaded only if called by the agent.

View File

@ -1,9 +1,8 @@
from agents.search import DuckDuckGoSearchTool
from agents.tools.search import DuckDuckGoSearchTool
from agents.docker_alternative import DockerPythonInterpreter
test = """
from agents.tools import Tool
from agents.tool import Tool
class DummyTool(Tool):
name = "echo"
@ -17,7 +16,6 @@ class DummyTool(Tool):
def forward(self, cmd: str) -> str:
return cmd
"""
container = DockerPythonInterpreter()
@ -30,10 +28,8 @@ breakpoint()
print("---------")
output = container.execute(test)
print(output)
output = container.execute("res = DummyTool(cmd='echo this'); print(res)")
output = container.execute("res = DummyTool(cmd='echo this'); print(res())")
print(output)
container.stop()

View File

@ -1,4 +1,4 @@
from agents.tools import Tool
from agents.tool import Tool
class DummyTool(Tool):

View File

@ -30,11 +30,12 @@ if TYPE_CHECKING:
from .local_python_executor import *
from .monitoring import *
from .prompts import *
from .search import *
from .tools import *
from .tools.search import *
from .tool import *
from .types import *
from .utils import *
else:
import sys

View File

@ -43,7 +43,7 @@ from .prompts import (
SYSTEM_PROMPT_PLAN,
)
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tools import (
from .tool import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool,
get_tool_description_with_args,

View File

@ -24,7 +24,7 @@ from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
from .tools import TOOL_CONFIG_FILE, Tool
from .tool import TOOL_CONFIG_FILE, Tool
def custom_print(*args):

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import warnings
import socket
from agents.tools import Tool
from agents.tool import Tool
class DockerPythonInterpreter:
def __init__(self):

View File

@ -47,8 +47,10 @@ from transformers.utils import (
is_torch_available,
is_vision_available,
)
from transformers.dynamic_module_utils import get_imports
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
from .utils import ImportFinder
from .utils import instance_to_source
from .tool_validation import validate_tool_attributes, MethodChecker
import logging
@ -97,14 +99,6 @@ def setup_default_tools():
return default_tools
# docstyle-ignore
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
from tool import {class_name}
launch_gradio_demo({class_name})
"""
def validate_after_init(cls, do_validate_forward: bool = True):
original_init = cls.__init__
@ -117,112 +111,6 @@ def validate_after_init(cls, do_validate_forward: bool = True):
return cls
def validate_args_are_self_contained(source_code):
"""Validates that all names in forward method are properly defined.
In particular it will check that all imports are done within the function."""
print("CODDDD", source_code)
tree = ast.parse(textwrap.dedent(source_code))
# Get function arguments
func_node = tree.body[0]
arg_names = {arg.arg for arg in func_node.args.args} | {"kwargs"}
builtin_names = set(vars(builtins))
class NameChecker(ast.NodeVisitor):
def __init__(self):
self.undefined_names = set()
self.imports = {}
self.from_imports = {}
self.assigned_names = set()
def visit_Import(self, node):
"""Handle simple imports like 'import datetime'."""
for name in node.names:
actual_name = name.asname or name.name
self.imports[actual_name] = (name.name, actual_name)
def visit_ImportFrom(self, node):
"""Handle from imports like 'from datetime import datetime'."""
module = node.module or ""
for name in node.names:
actual_name = name.asname or name.name
self.from_imports[actual_name] = (module, name.name, actual_name)
def visit_Assign(self, node):
"""Track variable assignments."""
for target in node.targets:
if isinstance(target, ast.Name):
self.assigned_names.add(target.id)
self.visit(node.value)
def visit_AnnAssign(self, node):
"""Track annotated assignments."""
if isinstance(node.target, ast.Name):
self.assigned_names.add(node.target.id)
if node.value:
self.visit(node.value)
def _handle_for_target(self, target) -> Set[str]:
"""Extract all names from a for loop target."""
names = set()
if isinstance(target, ast.Name):
names.add(target.id)
elif isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name):
names.add(elt.id)
return names
def visit_For(self, node):
"""Track for-loop target variables and handle enumerate specially."""
# Add names from the target
target_names = self._handle_for_target(node.target)
self.assigned_names.update(target_names)
# Special handling for enumerate
if (
isinstance(node.iter, ast.Call)
and isinstance(node.iter.func, ast.Name)
and node.iter.func.id == "enumerate"
):
# For enumerate, if we have "for i, x in enumerate(...)",
# both i and x should be marked as assigned
if isinstance(node.target, ast.Tuple):
for elt in node.target.elts:
if isinstance(elt, ast.Name):
self.assigned_names.add(elt.id)
# Visit the rest of the node
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load) and not (
node.id == "tool"
or node.id in builtin_names
or node.id in arg_names
or node.id == "self"
or node.id in self.assigned_names
):
if node.id not in self.from_imports and node.id not in self.imports:
self.undefined_names.add(node.id)
def visit_Attribute(self, node):
# Skip self.something
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node)
checker = NameChecker()
checker.visit(tree)
if checker.undefined_names:
raise ValueError(
f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}.
Make sure all imports and variables are self-contained within the method.
"""
)
AUTHORIZED_TYPES = [
"string",
"boolean",
@ -339,64 +227,79 @@ class Tool:
"""
os.makedirs(output_dir, exist_ok=True)
class_name = self.__class__.__name__
tool_file = os.path.join(output_dir, "tool.py")
# Save tool file
forward_source_code = inspect.getsource(self.forward)
validate_args_are_self_contained(forward_source_code)
tool_code = f"""
from agents import Tool
if type(self).__name__ == "SimpleTool":
# Check that imports are self-contained
forward_node = ast.parse(textwrap.dedent(inspect.getsource(self.forward)))
# If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
method_checker = MethodChecker(set())
method_checker.visit(forward_node)
if len(method_checker.errors) > 0:
raise(ValueError("\n".join(method_checker.errors)))
class {class_name}(Tool):
name = "{self.name}"
description = \"\"\"{self.description}\"\"\"
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
output_type = "{self.output_type}"
""".strip()
forward_source_code = inspect.getsource(self.forward)
tool_code = textwrap.dedent(f"""
from agents import Tool
def add_self_argument(source_code: str) -> str:
"""Add 'self' as first argument to a function definition if not present."""
pattern = r"def forward\(((?!self)[^)]*)\)"
class {class_name}(Tool):
name = "{self.name}"
description = "{self.description}"
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
output_type = "{self.output_type}"
""").strip()
import re
def add_self_argument(source_code: str) -> str:
"""Add 'self' as first argument to a function definition if not present."""
pattern = r'def forward\(((?!self)[^)]*)\)'
def replacement(match):
args = match.group(1).strip()
if args: # If there are other arguments
return f'def forward(self, {args})'
return 'def forward(self)'
return re.sub(pattern, replacement, source_code)
def replacement(match):
args = match.group(1).strip()
if args: # If there are other arguments
return f"def forward(self, {args})"
return "def forward(self)"
forward_source_code = forward_source_code.replace(self.name, "forward")
forward_source_code = add_self_argument(forward_source_code)
forward_source_code = forward_source_code.replace("@tool", "").strip()
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code)
else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]:
raise ValueError(
f"Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors."
)
return re.sub(pattern, replacement, source_code)
validate_tool_attributes(self.__class__)
forward_source_code = forward_source_code.replace(self.name, "forward")
forward_source_code = add_self_argument(forward_source_code)
forward_source_code = forward_source_code.replace("@tool", "").strip()
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f:
f.write(tool_code)
# Save config file
config_file = os.path.join(output_dir, "tool_config.json")
tool_config = {
"tool_class": self.__class__.__name__,
"description": self.description,
"name": self.name,
"inputs": self.inputs,
"output_type": str(self.output_type),
}
with open(config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
tool_code = instance_to_source(self, base_cls=Tool)
with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code)
# Save app file
app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f:
f.write(APP_FILE_TEMPLATE.format(class_name=class_name))
f.write(textwrap.dedent(f"""
from agents import launch_gradio_demo
from tool import {class_name}
tool = {class_name}()
launch_gradio_demo(tool)
""").lstrip())
# Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt")
tree = ast.parse(forward_source_code)
import_finder = ImportFinder()
import_finder.visit(tree)
imports = list(set(import_finder.packages))
imports = []
for module in [tool_file]:
imports.extend(get_imports(module))
imports = list(set(imports))
with open(requirements_file, "w", encoding="utf-8") as f:
f.write("agents_package\n" + "\n".join(imports) + "\n")
@ -612,7 +515,6 @@ class {class_name}(Tool):
```
"""
from gradio_client import Client, handle_file
from gradio_client.utils import is_http_url_like
class SpaceToolWrapper(Tool):
def __init__(
@ -665,6 +567,7 @@ class {class_name}(Tool):
self.is_initialized = True
def sanitize_argument_for_prediction(self, arg):
from gradio_client.utils import is_http_url_like
if isinstance(arg, ImageType):
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
arg.save(temp_file.name)
@ -793,13 +696,13 @@ def compile_jinja_template(template):
return jinja_env.from_string(template)
def launch_gradio_demo(tool_class: Tool):
def launch_gradio_demo(tool: Tool):
"""
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
`inputs` and `output_type`.
Args:
tool_class (`type`): The class of the tool for which to launch the demo.
tool (`type`): The tool for which to launch the demo.
"""
try:
import gradio as gr
@ -808,11 +711,6 @@ def launch_gradio_demo(tool_class: Tool):
"Gradio should be installed in order to launch a gradio demo."
)
tool = tool_class()
def fn(*args, **kwargs):
return tool(*args, **kwargs)
TYPE_TO_COMPONENT_CLASS_MAPPING = {
"image": gr.Image,
"audio": gr.Audio,
@ -822,7 +720,7 @@ def launch_gradio_demo(tool_class: Tool):
}
gradio_inputs = []
for input_name, input_details in tool_class.inputs.items():
for input_name, input_details in tool.inputs.items():
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
input_details["type"]
]
@ -830,15 +728,15 @@ def launch_gradio_demo(tool_class: Tool):
gradio_inputs.append(new_component)
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
tool_class.output_type
tool.output_type
]
gradio_output = output_gradio_componentclass(label=input_name)
gradio_output = output_gradio_componentclass(label="Output")
gr.Interface(
fn=fn,
fn=tool, # This works because `tool` has a __call__ method
inputs=gradio_inputs,
outputs=gradio_output,
title=tool_class.__name__,
title=tool.name,
article=tool.description,
).launch()
@ -1027,29 +925,34 @@ def tool(tool_function: Callable) -> Tool:
raise TypeHintParsingException(
"Tool return type not found: make sure your function has a return type hint!"
)
class_name = "".join([el.title() for el in parameters["name"].split("_")])
if parameters["return"]["type"] == "object":
parameters["return"]["type"] = "any"
class SpecificTool(Tool):
name = parameters["name"]
description = parameters["description"]
inputs = parameters["parameters"]["properties"]
output_type = parameters["return"]["type"]
@wraps(tool_function)
def forward(self, *args, **kwargs):
return tool_function(*args, **kwargs)
class SimpleTool(Tool):
def __init__(self, name, description, inputs, output_type, function):
self.name = name
self.description = description
self.inputs = inputs
self.output_type = output_type
self.forward = function
self.is_initialized = True
simple_tool = SimpleTool(
parameters["name"],
parameters["description"],
parameters["parameters"]["properties"],
parameters["return"]["type"],
function=tool_function
)
original_signature = inspect.signature(tool_function)
new_parameters = [
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
] + list(original_signature.parameters.values())
new_signature = original_signature.replace(parameters=new_parameters)
SpecificTool.forward.__signature__ = new_signature
SpecificTool.__name__ = class_name
return SpecificTool()
simple_tool.forward.__signature__ = new_signature
# SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")])
return simple_tool
HUGGINGFACE_DEFAULT_TOOLS = {}

View File

@ -0,0 +1,191 @@
import ast
import inspect
import importlib.util
import builtins
from pathlib import Path
from typing import List, Set, Dict
import textwrap
_BUILTIN_NAMES = set(vars(builtins))
def is_local_import(module_name: str) -> bool:
"""
Check if an import is from a local file or a package.
Returns True if it's a local file import.
"""
try:
spec = importlib.util.find_spec(module_name)
if spec is None:
return True # If we can't find the module, assume it's local
# If the module is found and has a file path, check if it's in site-packages
if spec.origin and 'site-packages' not in spec.origin:
# Check if it's a .py file in the current directory or subdirectories
return spec.origin.endswith('.py')
return False
except ImportError:
return True # If there's an import error, assume it's local
class MethodChecker(ast.NodeVisitor):
"""
Checks that a method
- only uses defined names
- contains no local imports (e.g. numpy is ok but local_script is not)
"""
def __init__(self, class_attributes: Set[str]):
self.undefined_names = set()
self.imports = {}
self.from_imports = {}
self.assigned_names = set()
self.arg_names = set()
self.class_attributes = class_attributes
self.errors = []
def visit_arguments(self, node):
"""Collect function arguments"""
self.arg_names = {arg.arg for arg in node.args}
if node.kwarg:
self.arg_names.add(node.kwarg.arg)
if node.vararg:
self.arg_names.add(node.vararg.arg)
def visit_Import(self, node):
for name in node.names:
actual_name = name.asname or name.name
if is_local_import(actual_name):
self.errors.append(f"Local import '{actual_name}'")
self.imports[actual_name] = name.name
def visit_ImportFrom(self, node):
module = node.module or ""
for name in node.names:
actual_name = name.asname or name.name
if is_local_import(module):
self.errors.append(f"Local import '{module}'")
self.from_imports[actual_name] = (module, name.name)
def visit_Assign(self, node):
for target in node.targets:
if isinstance(target, ast.Name):
self.assigned_names.add(target.id)
self.visit(node.value)
def visit_AnnAssign(self, node):
"""Track annotated assignments."""
if isinstance(node.target, ast.Name):
self.assigned_names.add(node.target.id)
if node.value:
self.visit(node.value)
def visit_For(self, node):
target = node.target
if isinstance(target, ast.Name):
self.assigned_names.add(target.id)
elif isinstance(target, ast.Tuple):
for elt in target.elts:
if isinstance(elt, ast.Name):
self.assigned_names.add(elt.id)
self.generic_visit(node)
def visit_Attribute(self, node):
# Skip self.something
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
if not (
node.id in _BUILTIN_NAMES
or node.id in self.arg_names
or node.id == "self"
or node.id in self.class_attributes
or node.id in self.imports
or node.id in self.from_imports
or node.id in self.assigned_names
):
self.errors.append(f"Name '{node.id}' is undefined.")
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
if not (
node.func.id in _BUILTIN_NAMES
or node.func.id in self.arg_names
or node.func.id == "self"
or node.func.id in self.class_attributes
or node.func.id in self.imports
or node.func.id in self.from_imports
or node.func.id in self.assigned_names
):
self.errors.append(f"Name '{node.func.id}' is undefined.")
self.generic_visit(node)
def validate_tool_attributes(cls) -> None:
"""
Validates that a Tool class follows the proper patterns:
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
1. About the class:
- Class attributes should only be strings or dicts
- Class attributes cannot be complex attributes
2. About all class methods:
- Imports must be from packages, not local files
- All methods must be self-contained
Raises all errors encountered, if no error returns None.
"""
errors = []
source = textwrap.dedent(inspect.getsource(cls))
tree = ast.parse(source)
if not isinstance(tree.body[0], ast.ClassDef):
raise ValueError("Source code must define a class")
# Check that __init__ method takes no arguments
if not cls.__init__.__qualname__ == 'Tool.__init__':
sig = inspect.signature(cls.__init__)
non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
if len(non_self_params) > 0:
errors.append(f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!")
class_node = tree.body[0]
class ClassLevelChecker(ast.NodeVisitor):
def __init__(self):
self.imported_names = set()
self.complex_attributes = set()
self.class_attributes = set()
def visit_Assign(self, node):
# Track class attributes
for target in node.targets:
if isinstance(target, ast.Name):
self.class_attributes.add(target.id)
# Check if the assignment is more complex than simple literals
if not all(isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
for val in ast.walk(node.value)):
for target in node.targets:
if isinstance(target, ast.Name):
self.complex_attributes.add(target.id)
class_level_checker = ClassLevelChecker()
class_level_checker.visit(class_node)
if class_level_checker.complex_attributes:
errors.append(
f"Complex attributes should be defined in __init__, not as class attributes: "
f"{', '.join(class_level_checker.complex_attributes)}"
)
# Run checks on all methods
for node in class_node.body:
if isinstance(node, ast.FunctionDef):
method_checker = MethodChecker(class_level_checker.class_attributes)
method_checker.visit(node)
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
if errors:
raise ValueError("Tool validation failed:\n" + "\n".join(errors))
return

View File

View File

@ -19,7 +19,7 @@ import re
import requests
from requests.exceptions import RequestException
from .tools import Tool
from ..tools import Tool
class DuckDuckGoSearchTool(Tool):

View File

@ -19,6 +19,9 @@ import re
from typing import Tuple, Dict, Union
import ast
from rich.console import Console
import ast
import inspect
import types
from transformers.utils.import_utils import _is_package_available
@ -127,5 +130,114 @@ class ImportFinder(ast.NodeVisitor):
base_package = node.module.split(".")[0]
self.packages.add(base_package)
import ast
import builtins
from typing import Set, Dict, List
def get_method_source(method):
"""Get source code for a method, including bound methods."""
if isinstance(method, types.MethodType):
method = method.__func__
return inspect.getsource(method).strip()
def is_same_method(method1, method2):
"""Compare two methods by their source code."""
try:
source1 = get_method_source(method1)
source2 = get_method_source(method2)
# Remove method decorators if any
source1 = '\n'.join(line for line in source1.split('\n')
if not line.strip().startswith('@'))
source2 = '\n'.join(line for line in source2.split('\n')
if not line.strip().startswith('@'))
return source1 == source2
except (TypeError, OSError):
return False
def is_same_item(item1, item2):
"""Compare two class items (methods or attributes) for equality."""
if callable(item1) and callable(item2):
return is_same_method(item1, item2)
else:
return item1 == item2
def instance_to_source(instance, base_cls=None):
"""Convert an instance to its class source code representation."""
cls = instance.__class__
class_name = cls.__name__
# Start building class lines
class_lines = []
if base_cls:
class_lines.append(f"class {class_name}({base_cls.__name__}):")
else:
class_lines.append(f"class {class_name}:")
# Add docstring if it exists and differs from base
if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__):
class_lines.append(f' """{cls.__doc__}"""')
# Add class-level attributes
class_attrs = {
name: value for name, value in cls.__dict__.items()
if not name.startswith('__') and not callable(value) and
not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
}
for name, value in class_attrs.items():
if isinstance(value, str):
class_lines.append(f' {name} = "{value}"')
else:
class_lines.append(f' {name} = {repr(value)}')
if class_attrs:
class_lines.append("")
# Add methods
methods = {
name: func for name, func in cls.__dict__.items()
if callable(func) and
not (base_cls and hasattr(base_cls, name) and
getattr(base_cls, name).__code__.co_code == func.__code__.co_code)
}
for name, method in methods.items():
method_source = inspect.getsource(method)
# Clean up the indentation
method_lines = method_source.split('\n')
first_line = method_lines[0]
indent = len(first_line) - len(first_line.lstrip())
method_lines = [line[indent:] for line in method_lines]
method_source = '\n'.join([' ' + line if line.strip() else line
for line in method_lines])
class_lines.append(method_source)
class_lines.append("")
# Find required imports using ImportFinder
import_finder = ImportFinder()
import_finder.visit(ast.parse('\n'.join(class_lines)))
required_imports = import_finder.packages
# Build final code with imports
final_lines = []
# Add base class import if needed
if base_cls:
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
# Add discovered imports
final_lines.extend(required_imports)
if final_lines: # Add empty line after imports
final_lines.append("")
# Add the class code
final_lines.extend(class_lines)
return '\n'.join(final_lines)
__all__ = []

View File

@ -29,7 +29,7 @@ from agents.agents import (
Toolbox,
ToolCall,
)
from agents.tools import tool
from agents.tool import tool
from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir

View File

@ -125,7 +125,7 @@ class TestDocs:
"from_langchain",
]
code_blocks = [
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token)
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace("{your_username}", "m-ric")
for block in code_blocks
if not any(
[snippet in block for snippet in excluded_snippets]

View File

@ -26,7 +26,7 @@ from agents.types import (
AgentImage,
AgentText,
)
from agents.tools import Tool, tool, AUTHORIZED_TYPES
from agents.tool import Tool, tool, AUTHORIZED_TYPES
from transformers.testing_utils import get_tests_dir
@ -174,6 +174,8 @@ class ToolTests(unittest.TestCase):
Gets the current time.
"""
return str(datetime.now())
get_current_time.save("output")
assert "datetime" in str(e)
@ -188,6 +190,9 @@ class ToolTests(unittest.TestCase):
def forward(self):
return str(datetime.now())
get_current_time = GetCurrentTimeTool()
get_current_time.save("output")
assert "datetime" in str(e)
@ -210,3 +215,63 @@ class ToolTests(unittest.TestCase):
def forward(self):
from datetime import datetime
return str(datetime.now())
def test_saving_tool_allows_no_arg_in_init(self):
# Test one cannot save tool with additional args in init
class FailTool(Tool):
name = "specific"
description = "test description"
inputs = {"input_str": {"type": "string", "description": "input description"}}
output_type = "string"
def __init__(self, url):
super().__init__(self)
self.url = "none"
def forward(self, string_input):
return self.url + string_input
fail_tool = FailTool("dummy_url")
with pytest.raises(Exception) as e:
fail_tool.save('output')
assert '__init__' in str(e)
def test_saving_tool_allows_no_imports_from_outside_methods(self):
# Test that using imports from outside functions fails
from numpy import random
class FailTool2(Tool):
name = "specific"
description = "test description"
inputs = {"input_str": {"type": "string", "description": "input description"}}
output_type = "string"
def useless_method(self):
self.client = random.random()
return ""
def forward(self, string_input):
return self.useless_method() + string_input
fail_tool_2 = FailTool2()
with pytest.raises(Exception) as e:
fail_tool_2.save('output')
assert 'random' in str(e)
# Test that putting these imports inside functions works
class FailTool3(Tool):
name = "specific"
description = "test description"
inputs = {"input_str": {"type": "string", "description": "input description"}}
output_type = "string"
def useless_method(self):
from numpy import random
self.client = random.random()
return ""
def forward(self, string_input):
return self.useless_method() + string_input
fail_tool_3 = FailTool3()
fail_tool_3.save('output')