Consolidate pushing Tools to Hub
This commit is contained in:
parent
00b9a71453
commit
584ce8f363
|
@ -1,5 +1,5 @@
|
|||
# Base Python image
|
||||
FROM python:3.12-slim
|
||||
FROM python:3.9-slim
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
@ -25,4 +25,7 @@ RUN pip install -e .
|
|||
|
||||
COPY server.py /app/server.py
|
||||
|
||||
# Expose the port your server will run on
|
||||
EXPOSE 65432
|
||||
|
||||
CMD ["python", "/app/server.py"]
|
||||
|
|
17
README.md
17
README.md
|
@ -27,3 +27,20 @@ limitations under the License.
|
|||
<h3 align="center">
|
||||
<p>Run agents!
|
||||
</h3>
|
||||
|
||||
W
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
class="block dark:hidden"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Agent_ManimCE.gif"
|
||||
/>
|
||||
<img
|
||||
class="hidden dark:block"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Agent_ManimCE.gif"
|
||||
/>
|
||||
</div>
|
||||
|
||||
To run Docker, run `docker build -t pyrunner:latest .`
|
||||
|
||||
This will use the Local Dockerfile to create your Docker image!
|
|
@ -23,11 +23,11 @@ Here, we're going to see advanced tool usage.
|
|||
> If you're new to `agents`, make sure to first read the main [agents documentation](./agents).
|
||||
|
||||
|
||||
### Directly define a tool by subclassing Tool, and share it to the Hub
|
||||
### Directly define a tool by subclassing Tool
|
||||
|
||||
Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.
|
||||
Let's take again the tool example from the [quicktour](../quicktour), for which we had implemented a `@tool` decorator. The `tool` decorator is the standard format, but sometimes you need more: use several methods in a class for more clarity, or using additional class attributes.
|
||||
|
||||
If you need to add variation, like custom attributes for your tool, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
|
||||
In this case, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
|
||||
|
||||
The custom tool needs:
|
||||
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name it `model_download_counter`.
|
||||
|
@ -67,19 +67,28 @@ tool = HFModelDownloadsTool()
|
|||
|
||||
Now the custom `HfModelDownloadsTool` class is ready.
|
||||
|
||||
### Share your tool to the Hub
|
||||
|
||||
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
|
||||
|
||||
```python
|
||||
tool.push_to_hub("m-ric/hf-model-downloads", token="<YOUR_HUGGINGFACEHUB_API_TOKEN>")
|
||||
tool.push_to_hub("{your_username}/hf-model-downloads", token="<YOUR_HUGGINGFACEHUB_API_TOKEN>")
|
||||
```
|
||||
|
||||
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
|
||||
For the push to Hub to work, your tool will need to respect some rules:
|
||||
- All method are self-contained, e.g. use variables that come either from their args,
|
||||
- If you subclass the `__init__` method, you can give it no other argument than `self`. This is because arguments set during a specific tool instance's initialization are hard to track, which prevents from sharing them properly to the hub. And anyway, the idea of making a specific class is that you can already set class attributes for anything you need to hard-code (just set `your_variable=(...)` directly under the `class YourTool(Tool):` line). And of course you can still create a class attribute anywhere in your code by assigning stuff to `self.your_variable`.
|
||||
|
||||
Once your tool is pushed to Hub, you can load it with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
|
||||
Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`.
|
||||
|
||||
```python
|
||||
from agents import load_tool, CodeAgent
|
||||
|
||||
model_download_tool = load_tool("m-ric/hf-model-downloads", trust_remote_code=True)
|
||||
model_download_tool = load_tool(
|
||||
"{your_username}/hf-model-downloads",
|
||||
trust_remote_code=True
|
||||
)
|
||||
```
|
||||
|
||||
### Import a Space as a tool 🚀
|
||||
|
@ -214,8 +223,4 @@ agent = CodeAgent(tools=[*image_tool_collection.tools], llm_engine=llm_engine, a
|
|||
agent.run("Please draw me a picture of rivers and lakes.")
|
||||
```
|
||||
|
||||
To speed up the start, tools are loaded only if called by the agent.
|
||||
|
||||
This gets you this image:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png">
|
||||
To speed up the start, tools are loaded only if called by the agent.
|
|
@ -1,9 +1,8 @@
|
|||
from agents.search import DuckDuckGoSearchTool
|
||||
from agents.tools.search import DuckDuckGoSearchTool
|
||||
from agents.docker_alternative import DockerPythonInterpreter
|
||||
|
||||
|
||||
test = """
|
||||
from agents.tools import Tool
|
||||
from agents.tool import Tool
|
||||
|
||||
class DummyTool(Tool):
|
||||
name = "echo"
|
||||
|
@ -17,7 +16,6 @@ class DummyTool(Tool):
|
|||
def forward(self, cmd: str) -> str:
|
||||
return cmd
|
||||
|
||||
"""
|
||||
|
||||
container = DockerPythonInterpreter()
|
||||
|
||||
|
@ -30,10 +28,8 @@ breakpoint()
|
|||
|
||||
print("---------")
|
||||
|
||||
output = container.execute(test)
|
||||
print(output)
|
||||
|
||||
output = container.execute("res = DummyTool(cmd='echo this'); print(res)")
|
||||
output = container.execute("res = DummyTool(cmd='echo this'); print(res())")
|
||||
print(output)
|
||||
|
||||
container.stop()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from agents.tools import Tool
|
||||
from agents.tool import Tool
|
||||
|
||||
|
||||
class DummyTool(Tool):
|
||||
|
|
|
@ -30,11 +30,12 @@ if TYPE_CHECKING:
|
|||
from .local_python_executor import *
|
||||
from .monitoring import *
|
||||
from .prompts import *
|
||||
from .search import *
|
||||
from .tools import *
|
||||
from .tools.search import *
|
||||
from .tool import *
|
||||
from .types import *
|
||||
from .utils import *
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ from .prompts import (
|
|||
SYSTEM_PROMPT_PLAN,
|
||||
)
|
||||
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import (
|
||||
from .tool import (
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
Tool,
|
||||
get_tool_description_with_args,
|
||||
|
|
|
@ -24,7 +24,7 @@ from huggingface_hub import hf_hub_download, list_spaces
|
|||
|
||||
from transformers.utils import is_offline_mode
|
||||
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import TOOL_CONFIG_FILE, Tool
|
||||
from .tool import TOOL_CONFIG_FILE, Tool
|
||||
|
||||
|
||||
def custom_print(*args):
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
|||
import warnings
|
||||
import socket
|
||||
|
||||
from agents.tools import Tool
|
||||
from agents.tool import Tool
|
||||
|
||||
class DockerPythonInterpreter:
|
||||
def __init__(self):
|
||||
|
|
|
@ -47,8 +47,10 @@ from transformers.utils import (
|
|||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
|
||||
from .utils import ImportFinder
|
||||
from .utils import instance_to_source
|
||||
from .tool_validation import validate_tool_attributes, MethodChecker
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -97,14 +99,6 @@ def setup_default_tools():
|
|||
return default_tools
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
||||
from tool import {class_name}
|
||||
|
||||
launch_gradio_demo({class_name})
|
||||
"""
|
||||
|
||||
|
||||
def validate_after_init(cls, do_validate_forward: bool = True):
|
||||
original_init = cls.__init__
|
||||
|
||||
|
@ -117,112 +111,6 @@ def validate_after_init(cls, do_validate_forward: bool = True):
|
|||
return cls
|
||||
|
||||
|
||||
def validate_args_are_self_contained(source_code):
|
||||
"""Validates that all names in forward method are properly defined.
|
||||
In particular it will check that all imports are done within the function."""
|
||||
print("CODDDD", source_code)
|
||||
tree = ast.parse(textwrap.dedent(source_code))
|
||||
|
||||
# Get function arguments
|
||||
func_node = tree.body[0]
|
||||
arg_names = {arg.arg for arg in func_node.args.args} | {"kwargs"}
|
||||
|
||||
builtin_names = set(vars(builtins))
|
||||
|
||||
class NameChecker(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.undefined_names = set()
|
||||
self.imports = {}
|
||||
self.from_imports = {}
|
||||
self.assigned_names = set()
|
||||
|
||||
def visit_Import(self, node):
|
||||
"""Handle simple imports like 'import datetime'."""
|
||||
for name in node.names:
|
||||
actual_name = name.asname or name.name
|
||||
self.imports[actual_name] = (name.name, actual_name)
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
"""Handle from imports like 'from datetime import datetime'."""
|
||||
module = node.module or ""
|
||||
for name in node.names:
|
||||
actual_name = name.asname or name.name
|
||||
self.from_imports[actual_name] = (module, name.name, actual_name)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
"""Track variable assignments."""
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.assigned_names.add(target.id)
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_AnnAssign(self, node):
|
||||
"""Track annotated assignments."""
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.assigned_names.add(node.target.id)
|
||||
if node.value:
|
||||
self.visit(node.value)
|
||||
|
||||
def _handle_for_target(self, target) -> Set[str]:
|
||||
"""Extract all names from a for loop target."""
|
||||
names = set()
|
||||
if isinstance(target, ast.Name):
|
||||
names.add(target.id)
|
||||
elif isinstance(target, ast.Tuple):
|
||||
for elt in target.elts:
|
||||
if isinstance(elt, ast.Name):
|
||||
names.add(elt.id)
|
||||
return names
|
||||
|
||||
def visit_For(self, node):
|
||||
"""Track for-loop target variables and handle enumerate specially."""
|
||||
# Add names from the target
|
||||
target_names = self._handle_for_target(node.target)
|
||||
self.assigned_names.update(target_names)
|
||||
|
||||
# Special handling for enumerate
|
||||
if (
|
||||
isinstance(node.iter, ast.Call)
|
||||
and isinstance(node.iter.func, ast.Name)
|
||||
and node.iter.func.id == "enumerate"
|
||||
):
|
||||
# For enumerate, if we have "for i, x in enumerate(...)",
|
||||
# both i and x should be marked as assigned
|
||||
if isinstance(node.target, ast.Tuple):
|
||||
for elt in node.target.elts:
|
||||
if isinstance(elt, ast.Name):
|
||||
self.assigned_names.add(elt.id)
|
||||
|
||||
# Visit the rest of the node
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Load) and not (
|
||||
node.id == "tool"
|
||||
or node.id in builtin_names
|
||||
or node.id in arg_names
|
||||
or node.id == "self"
|
||||
or node.id in self.assigned_names
|
||||
):
|
||||
if node.id not in self.from_imports and node.id not in self.imports:
|
||||
self.undefined_names.add(node.id)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
# Skip self.something
|
||||
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
|
||||
self.generic_visit(node)
|
||||
|
||||
checker = NameChecker()
|
||||
checker.visit(tree)
|
||||
|
||||
if checker.undefined_names:
|
||||
raise ValueError(
|
||||
f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}.
|
||||
Make sure all imports and variables are self-contained within the method.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
AUTHORIZED_TYPES = [
|
||||
"string",
|
||||
"boolean",
|
||||
|
@ -339,64 +227,79 @@ class Tool:
|
|||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
class_name = self.__class__.__name__
|
||||
tool_file = os.path.join(output_dir, "tool.py")
|
||||
|
||||
# Save tool file
|
||||
forward_source_code = inspect.getsource(self.forward)
|
||||
validate_args_are_self_contained(forward_source_code)
|
||||
tool_code = f"""
|
||||
from agents import Tool
|
||||
if type(self).__name__ == "SimpleTool":
|
||||
# Check that imports are self-contained
|
||||
forward_node = ast.parse(textwrap.dedent(inspect.getsource(self.forward)))
|
||||
# If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
|
||||
method_checker = MethodChecker(set())
|
||||
method_checker.visit(forward_node)
|
||||
if len(method_checker.errors) > 0:
|
||||
raise(ValueError("\n".join(method_checker.errors)))
|
||||
|
||||
class {class_name}(Tool):
|
||||
name = "{self.name}"
|
||||
description = \"\"\"{self.description}\"\"\"
|
||||
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
|
||||
output_type = "{self.output_type}"
|
||||
""".strip()
|
||||
forward_source_code = inspect.getsource(self.forward)
|
||||
tool_code = textwrap.dedent(f"""
|
||||
from agents import Tool
|
||||
|
||||
def add_self_argument(source_code: str) -> str:
|
||||
"""Add 'self' as first argument to a function definition if not present."""
|
||||
pattern = r"def forward\(((?!self)[^)]*)\)"
|
||||
class {class_name}(Tool):
|
||||
name = "{self.name}"
|
||||
description = "{self.description}"
|
||||
inputs = {json.dumps(self.inputs, separators=(',', ':'))}
|
||||
output_type = "{self.output_type}"
|
||||
""").strip()
|
||||
import re
|
||||
def add_self_argument(source_code: str) -> str:
|
||||
"""Add 'self' as first argument to a function definition if not present."""
|
||||
pattern = r'def forward\(((?!self)[^)]*)\)'
|
||||
|
||||
def replacement(match):
|
||||
args = match.group(1).strip()
|
||||
if args: # If there are other arguments
|
||||
return f'def forward(self, {args})'
|
||||
return 'def forward(self)'
|
||||
|
||||
return re.sub(pattern, replacement, source_code)
|
||||
|
||||
def replacement(match):
|
||||
args = match.group(1).strip()
|
||||
if args: # If there are other arguments
|
||||
return f"def forward(self, {args})"
|
||||
return "def forward(self)"
|
||||
forward_source_code = forward_source_code.replace(self.name, "forward")
|
||||
forward_source_code = add_self_argument(forward_source_code)
|
||||
forward_source_code = forward_source_code.replace("@tool", "").strip()
|
||||
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
|
||||
|
||||
with open(tool_file, "w", encoding="utf-8") as f:
|
||||
f.write(tool_code)
|
||||
else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
|
||||
if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]:
|
||||
raise ValueError(
|
||||
f"Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors."
|
||||
)
|
||||
|
||||
return re.sub(pattern, replacement, source_code)
|
||||
validate_tool_attributes(self.__class__)
|
||||
|
||||
forward_source_code = forward_source_code.replace(self.name, "forward")
|
||||
forward_source_code = add_self_argument(forward_source_code)
|
||||
forward_source_code = forward_source_code.replace("@tool", "").strip()
|
||||
tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
|
||||
with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f:
|
||||
f.write(tool_code)
|
||||
|
||||
# Save config file
|
||||
config_file = os.path.join(output_dir, "tool_config.json")
|
||||
tool_config = {
|
||||
"tool_class": self.__class__.__name__,
|
||||
"description": self.description,
|
||||
"name": self.name,
|
||||
"inputs": self.inputs,
|
||||
"output_type": str(self.output_type),
|
||||
}
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
|
||||
tool_code = instance_to_source(self, base_cls=Tool)
|
||||
with open(tool_file, "w", encoding="utf-8") as f:
|
||||
f.write(tool_code)
|
||||
|
||||
# Save app file
|
||||
app_file = os.path.join(output_dir, "app.py")
|
||||
with open(app_file, "w", encoding="utf-8") as f:
|
||||
f.write(APP_FILE_TEMPLATE.format(class_name=class_name))
|
||||
f.write(textwrap.dedent(f"""
|
||||
from agents import launch_gradio_demo
|
||||
from tool import {class_name}
|
||||
|
||||
tool = {class_name}()
|
||||
|
||||
launch_gradio_demo(tool)
|
||||
""").lstrip())
|
||||
|
||||
# Save requirements file
|
||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||
|
||||
tree = ast.parse(forward_source_code)
|
||||
import_finder = ImportFinder()
|
||||
import_finder.visit(tree)
|
||||
|
||||
imports = list(set(import_finder.packages))
|
||||
imports = []
|
||||
for module in [tool_file]:
|
||||
imports.extend(get_imports(module))
|
||||
imports = list(set(imports))
|
||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||
f.write("agents_package\n" + "\n".join(imports) + "\n")
|
||||
|
||||
|
@ -612,7 +515,6 @@ class {class_name}(Tool):
|
|||
```
|
||||
"""
|
||||
from gradio_client import Client, handle_file
|
||||
from gradio_client.utils import is_http_url_like
|
||||
|
||||
class SpaceToolWrapper(Tool):
|
||||
def __init__(
|
||||
|
@ -665,6 +567,7 @@ class {class_name}(Tool):
|
|||
self.is_initialized = True
|
||||
|
||||
def sanitize_argument_for_prediction(self, arg):
|
||||
from gradio_client.utils import is_http_url_like
|
||||
if isinstance(arg, ImageType):
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||
arg.save(temp_file.name)
|
||||
|
@ -793,13 +696,13 @@ def compile_jinja_template(template):
|
|||
return jinja_env.from_string(template)
|
||||
|
||||
|
||||
def launch_gradio_demo(tool_class: Tool):
|
||||
def launch_gradio_demo(tool: Tool):
|
||||
"""
|
||||
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
|
||||
`inputs` and `output_type`.
|
||||
|
||||
Args:
|
||||
tool_class (`type`): The class of the tool for which to launch the demo.
|
||||
tool (`type`): The tool for which to launch the demo.
|
||||
"""
|
||||
try:
|
||||
import gradio as gr
|
||||
|
@ -808,11 +711,6 @@ def launch_gradio_demo(tool_class: Tool):
|
|||
"Gradio should be installed in order to launch a gradio demo."
|
||||
)
|
||||
|
||||
tool = tool_class()
|
||||
|
||||
def fn(*args, **kwargs):
|
||||
return tool(*args, **kwargs)
|
||||
|
||||
TYPE_TO_COMPONENT_CLASS_MAPPING = {
|
||||
"image": gr.Image,
|
||||
"audio": gr.Audio,
|
||||
|
@ -822,7 +720,7 @@ def launch_gradio_demo(tool_class: Tool):
|
|||
}
|
||||
|
||||
gradio_inputs = []
|
||||
for input_name, input_details in tool_class.inputs.items():
|
||||
for input_name, input_details in tool.inputs.items():
|
||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
||||
input_details["type"]
|
||||
]
|
||||
|
@ -830,15 +728,15 @@ def launch_gradio_demo(tool_class: Tool):
|
|||
gradio_inputs.append(new_component)
|
||||
|
||||
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
||||
tool_class.output_type
|
||||
tool.output_type
|
||||
]
|
||||
gradio_output = output_gradio_componentclass(label=input_name)
|
||||
gradio_output = output_gradio_componentclass(label="Output")
|
||||
|
||||
gr.Interface(
|
||||
fn=fn,
|
||||
fn=tool, # This works because `tool` has a __call__ method
|
||||
inputs=gradio_inputs,
|
||||
outputs=gradio_output,
|
||||
title=tool_class.__name__,
|
||||
title=tool.name,
|
||||
article=tool.description,
|
||||
).launch()
|
||||
|
||||
|
@ -1027,29 +925,34 @@ def tool(tool_function: Callable) -> Tool:
|
|||
raise TypeHintParsingException(
|
||||
"Tool return type not found: make sure your function has a return type hint!"
|
||||
)
|
||||
class_name = "".join([el.title() for el in parameters["name"].split("_")])
|
||||
|
||||
if parameters["return"]["type"] == "object":
|
||||
parameters["return"]["type"] = "any"
|
||||
|
||||
class SpecificTool(Tool):
|
||||
name = parameters["name"]
|
||||
description = parameters["description"]
|
||||
inputs = parameters["parameters"]["properties"]
|
||||
output_type = parameters["return"]["type"]
|
||||
|
||||
@wraps(tool_function)
|
||||
def forward(self, *args, **kwargs):
|
||||
return tool_function(*args, **kwargs)
|
||||
class SimpleTool(Tool):
|
||||
def __init__(self, name, description, inputs, output_type, function):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.inputs = inputs
|
||||
self.output_type = output_type
|
||||
self.forward = function
|
||||
self.is_initialized = True
|
||||
|
||||
simple_tool = SimpleTool(
|
||||
parameters["name"],
|
||||
parameters["description"],
|
||||
parameters["parameters"]["properties"],
|
||||
parameters["return"]["type"],
|
||||
function=tool_function
|
||||
)
|
||||
original_signature = inspect.signature(tool_function)
|
||||
new_parameters = [
|
||||
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||
] + list(original_signature.parameters.values())
|
||||
new_signature = original_signature.replace(parameters=new_parameters)
|
||||
SpecificTool.forward.__signature__ = new_signature
|
||||
SpecificTool.__name__ = class_name
|
||||
return SpecificTool()
|
||||
simple_tool.forward.__signature__ = new_signature
|
||||
# SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")])
|
||||
return simple_tool
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
|
@ -0,0 +1,191 @@
|
|||
import ast
|
||||
import inspect
|
||||
import importlib.util
|
||||
import builtins
|
||||
from pathlib import Path
|
||||
from typing import List, Set, Dict
|
||||
import textwrap
|
||||
|
||||
_BUILTIN_NAMES = set(vars(builtins))
|
||||
|
||||
def is_local_import(module_name: str) -> bool:
|
||||
"""
|
||||
Check if an import is from a local file or a package.
|
||||
Returns True if it's a local file import.
|
||||
"""
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None:
|
||||
return True # If we can't find the module, assume it's local
|
||||
|
||||
# If the module is found and has a file path, check if it's in site-packages
|
||||
if spec.origin and 'site-packages' not in spec.origin:
|
||||
# Check if it's a .py file in the current directory or subdirectories
|
||||
return spec.origin.endswith('.py')
|
||||
|
||||
return False
|
||||
except ImportError:
|
||||
return True # If there's an import error, assume it's local
|
||||
|
||||
class MethodChecker(ast.NodeVisitor):
|
||||
"""
|
||||
Checks that a method
|
||||
- only uses defined names
|
||||
- contains no local imports (e.g. numpy is ok but local_script is not)
|
||||
"""
|
||||
def __init__(self, class_attributes: Set[str]):
|
||||
self.undefined_names = set()
|
||||
self.imports = {}
|
||||
self.from_imports = {}
|
||||
self.assigned_names = set()
|
||||
self.arg_names = set()
|
||||
self.class_attributes = class_attributes
|
||||
self.errors = []
|
||||
|
||||
def visit_arguments(self, node):
|
||||
"""Collect function arguments"""
|
||||
self.arg_names = {arg.arg for arg in node.args}
|
||||
if node.kwarg:
|
||||
self.arg_names.add(node.kwarg.arg)
|
||||
if node.vararg:
|
||||
self.arg_names.add(node.vararg.arg)
|
||||
|
||||
def visit_Import(self, node):
|
||||
for name in node.names:
|
||||
actual_name = name.asname or name.name
|
||||
if is_local_import(actual_name):
|
||||
self.errors.append(f"Local import '{actual_name}'")
|
||||
self.imports[actual_name] = name.name
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
module = node.module or ""
|
||||
for name in node.names:
|
||||
actual_name = name.asname or name.name
|
||||
if is_local_import(module):
|
||||
self.errors.append(f"Local import '{module}'")
|
||||
self.from_imports[actual_name] = (module, name.name)
|
||||
|
||||
def visit_Assign(self, node):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.assigned_names.add(target.id)
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_AnnAssign(self, node):
|
||||
"""Track annotated assignments."""
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.assigned_names.add(node.target.id)
|
||||
if node.value:
|
||||
self.visit(node.value)
|
||||
|
||||
def visit_For(self, node):
|
||||
target = node.target
|
||||
if isinstance(target, ast.Name):
|
||||
self.assigned_names.add(target.id)
|
||||
elif isinstance(target, ast.Tuple):
|
||||
for elt in target.elts:
|
||||
if isinstance(elt, ast.Name):
|
||||
self.assigned_names.add(elt.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
# Skip self.something
|
||||
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Load):
|
||||
if not (
|
||||
node.id in _BUILTIN_NAMES
|
||||
or node.id in self.arg_names
|
||||
or node.id == "self"
|
||||
or node.id in self.class_attributes
|
||||
or node.id in self.imports
|
||||
or node.id in self.from_imports
|
||||
or node.id in self.assigned_names
|
||||
):
|
||||
self.errors.append(f"Name '{node.id}' is undefined.")
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Name):
|
||||
if not (
|
||||
node.func.id in _BUILTIN_NAMES
|
||||
or node.func.id in self.arg_names
|
||||
or node.func.id == "self"
|
||||
or node.func.id in self.class_attributes
|
||||
or node.func.id in self.imports
|
||||
or node.func.id in self.from_imports
|
||||
or node.func.id in self.assigned_names
|
||||
):
|
||||
self.errors.append(f"Name '{node.func.id}' is undefined.")
|
||||
self.generic_visit(node)
|
||||
|
||||
def validate_tool_attributes(cls) -> None:
|
||||
"""
|
||||
Validates that a Tool class follows the proper patterns:
|
||||
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
|
||||
1. About the class:
|
||||
- Class attributes should only be strings or dicts
|
||||
- Class attributes cannot be complex attributes
|
||||
2. About all class methods:
|
||||
- Imports must be from packages, not local files
|
||||
- All methods must be self-contained
|
||||
|
||||
Raises all errors encountered, if no error returns None.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
source = textwrap.dedent(inspect.getsource(cls))
|
||||
|
||||
tree = ast.parse(source)
|
||||
|
||||
if not isinstance(tree.body[0], ast.ClassDef):
|
||||
raise ValueError("Source code must define a class")
|
||||
|
||||
# Check that __init__ method takes no arguments
|
||||
if not cls.__init__.__qualname__ == 'Tool.__init__':
|
||||
sig = inspect.signature(cls.__init__)
|
||||
non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
|
||||
if len(non_self_params) > 0:
|
||||
errors.append(f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!")
|
||||
|
||||
class_node = tree.body[0]
|
||||
|
||||
class ClassLevelChecker(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.imported_names = set()
|
||||
self.complex_attributes = set()
|
||||
self.class_attributes = set()
|
||||
|
||||
def visit_Assign(self, node):
|
||||
# Track class attributes
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.class_attributes.add(target.id)
|
||||
|
||||
# Check if the assignment is more complex than simple literals
|
||||
if not all(isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
|
||||
for val in ast.walk(node.value)):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
self.complex_attributes.add(target.id)
|
||||
|
||||
class_level_checker = ClassLevelChecker()
|
||||
class_level_checker.visit(class_node)
|
||||
|
||||
if class_level_checker.complex_attributes:
|
||||
errors.append(
|
||||
f"Complex attributes should be defined in __init__, not as class attributes: "
|
||||
f"{', '.join(class_level_checker.complex_attributes)}"
|
||||
)
|
||||
|
||||
# Run checks on all methods
|
||||
for node in class_node.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
method_checker = MethodChecker(class_level_checker.class_attributes)
|
||||
method_checker.visit(node)
|
||||
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
|
||||
|
||||
if errors:
|
||||
raise ValueError("Tool validation failed:\n" + "\n".join(errors))
|
||||
return
|
|
@ -19,7 +19,7 @@ import re
|
|||
import requests
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from .tools import Tool
|
||||
from ..tools import Tool
|
||||
|
||||
|
||||
class DuckDuckGoSearchTool(Tool):
|
|
@ -19,6 +19,9 @@ import re
|
|||
from typing import Tuple, Dict, Union
|
||||
import ast
|
||||
from rich.console import Console
|
||||
import ast
|
||||
import inspect
|
||||
import types
|
||||
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
@ -127,5 +130,114 @@ class ImportFinder(ast.NodeVisitor):
|
|||
base_package = node.module.split(".")[0]
|
||||
self.packages.add(base_package)
|
||||
|
||||
import ast
|
||||
import builtins
|
||||
from typing import Set, Dict, List
|
||||
|
||||
|
||||
def get_method_source(method):
|
||||
"""Get source code for a method, including bound methods."""
|
||||
if isinstance(method, types.MethodType):
|
||||
method = method.__func__
|
||||
return inspect.getsource(method).strip()
|
||||
|
||||
|
||||
def is_same_method(method1, method2):
|
||||
"""Compare two methods by their source code."""
|
||||
try:
|
||||
source1 = get_method_source(method1)
|
||||
source2 = get_method_source(method2)
|
||||
|
||||
# Remove method decorators if any
|
||||
source1 = '\n'.join(line for line in source1.split('\n')
|
||||
if not line.strip().startswith('@'))
|
||||
source2 = '\n'.join(line for line in source2.split('\n')
|
||||
if not line.strip().startswith('@'))
|
||||
|
||||
return source1 == source2
|
||||
except (TypeError, OSError):
|
||||
return False
|
||||
|
||||
def is_same_item(item1, item2):
|
||||
"""Compare two class items (methods or attributes) for equality."""
|
||||
if callable(item1) and callable(item2):
|
||||
return is_same_method(item1, item2)
|
||||
else:
|
||||
return item1 == item2
|
||||
|
||||
def instance_to_source(instance, base_cls=None):
|
||||
"""Convert an instance to its class source code representation."""
|
||||
cls = instance.__class__
|
||||
class_name = cls.__name__
|
||||
|
||||
# Start building class lines
|
||||
class_lines = []
|
||||
if base_cls:
|
||||
class_lines.append(f"class {class_name}({base_cls.__name__}):")
|
||||
else:
|
||||
class_lines.append(f"class {class_name}:")
|
||||
|
||||
# Add docstring if it exists and differs from base
|
||||
if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__):
|
||||
class_lines.append(f' """{cls.__doc__}"""')
|
||||
|
||||
# Add class-level attributes
|
||||
class_attrs = {
|
||||
name: value for name, value in cls.__dict__.items()
|
||||
if not name.startswith('__') and not callable(value) and
|
||||
not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
|
||||
}
|
||||
|
||||
for name, value in class_attrs.items():
|
||||
if isinstance(value, str):
|
||||
class_lines.append(f' {name} = "{value}"')
|
||||
else:
|
||||
class_lines.append(f' {name} = {repr(value)}')
|
||||
|
||||
if class_attrs:
|
||||
class_lines.append("")
|
||||
|
||||
# Add methods
|
||||
methods = {
|
||||
name: func for name, func in cls.__dict__.items()
|
||||
if callable(func) and
|
||||
not (base_cls and hasattr(base_cls, name) and
|
||||
getattr(base_cls, name).__code__.co_code == func.__code__.co_code)
|
||||
}
|
||||
|
||||
for name, method in methods.items():
|
||||
method_source = inspect.getsource(method)
|
||||
# Clean up the indentation
|
||||
method_lines = method_source.split('\n')
|
||||
first_line = method_lines[0]
|
||||
indent = len(first_line) - len(first_line.lstrip())
|
||||
method_lines = [line[indent:] for line in method_lines]
|
||||
method_source = '\n'.join([' ' + line if line.strip() else line
|
||||
for line in method_lines])
|
||||
class_lines.append(method_source)
|
||||
class_lines.append("")
|
||||
|
||||
# Find required imports using ImportFinder
|
||||
import_finder = ImportFinder()
|
||||
import_finder.visit(ast.parse('\n'.join(class_lines)))
|
||||
required_imports = import_finder.packages
|
||||
|
||||
# Build final code with imports
|
||||
final_lines = []
|
||||
|
||||
# Add base class import if needed
|
||||
if base_cls:
|
||||
final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
|
||||
|
||||
# Add discovered imports
|
||||
final_lines.extend(required_imports)
|
||||
|
||||
if final_lines: # Add empty line after imports
|
||||
final_lines.append("")
|
||||
|
||||
# Add the class code
|
||||
final_lines.extend(class_lines)
|
||||
|
||||
return '\n'.join(final_lines)
|
||||
|
||||
__all__ = []
|
||||
|
|
|
@ -29,7 +29,7 @@ from agents.agents import (
|
|||
Toolbox,
|
||||
ToolCall,
|
||||
)
|
||||
from agents.tools import tool
|
||||
from agents.tool import tool
|
||||
from agents.default_tools import PythonInterpreterTool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ class TestDocs:
|
|||
"from_langchain",
|
||||
]
|
||||
code_blocks = [
|
||||
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token)
|
||||
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", self.hf_token).replace("{your_username}", "m-ric")
|
||||
for block in code_blocks
|
||||
if not any(
|
||||
[snippet in block for snippet in excluded_snippets]
|
||||
|
|
|
@ -26,7 +26,7 @@ from agents.types import (
|
|||
AgentImage,
|
||||
AgentText,
|
||||
)
|
||||
from agents.tools import Tool, tool, AUTHORIZED_TYPES
|
||||
from agents.tool import Tool, tool, AUTHORIZED_TYPES
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
||||
|
@ -174,6 +174,8 @@ class ToolTests(unittest.TestCase):
|
|||
Gets the current time.
|
||||
"""
|
||||
return str(datetime.now())
|
||||
|
||||
get_current_time.save("output")
|
||||
|
||||
assert "datetime" in str(e)
|
||||
|
||||
|
@ -188,6 +190,9 @@ class ToolTests(unittest.TestCase):
|
|||
|
||||
def forward(self):
|
||||
return str(datetime.now())
|
||||
|
||||
get_current_time = GetCurrentTimeTool()
|
||||
get_current_time.save("output")
|
||||
|
||||
assert "datetime" in str(e)
|
||||
|
||||
|
@ -210,3 +215,63 @@ class ToolTests(unittest.TestCase):
|
|||
def forward(self):
|
||||
from datetime import datetime
|
||||
return str(datetime.now())
|
||||
|
||||
def test_saving_tool_allows_no_arg_in_init(self):
|
||||
# Test one cannot save tool with additional args in init
|
||||
class FailTool(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {"input_str": {"type": "string", "description": "input description"}}
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, url):
|
||||
super().__init__(self)
|
||||
self.url = "none"
|
||||
|
||||
def forward(self, string_input):
|
||||
return self.url + string_input
|
||||
|
||||
fail_tool = FailTool("dummy_url")
|
||||
with pytest.raises(Exception) as e:
|
||||
fail_tool.save('output')
|
||||
assert '__init__' in str(e)
|
||||
|
||||
def test_saving_tool_allows_no_imports_from_outside_methods(self):
|
||||
# Test that using imports from outside functions fails
|
||||
from numpy import random
|
||||
class FailTool2(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {"input_str": {"type": "string", "description": "input description"}}
|
||||
output_type = "string"
|
||||
|
||||
def useless_method(self):
|
||||
self.client = random.random()
|
||||
return ""
|
||||
|
||||
def forward(self, string_input):
|
||||
return self.useless_method() + string_input
|
||||
|
||||
fail_tool_2 = FailTool2()
|
||||
with pytest.raises(Exception) as e:
|
||||
fail_tool_2.save('output')
|
||||
assert 'random' in str(e)
|
||||
|
||||
# Test that putting these imports inside functions works
|
||||
|
||||
class FailTool3(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {"input_str": {"type": "string", "description": "input description"}}
|
||||
output_type = "string"
|
||||
|
||||
def useless_method(self):
|
||||
from numpy import random
|
||||
self.client = random.random()
|
||||
return ""
|
||||
|
||||
def forward(self, string_input):
|
||||
return self.useless_method() + string_input
|
||||
|
||||
fail_tool_3 = FailTool3()
|
||||
fail_tool_3.save('output')
|
||||
|
|
Loading…
Reference in New Issue