diff --git a/Dockerfile b/Dockerfile
index 321f455..67d2b42 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,5 +1,5 @@
# Base Python image
-FROM python:3.12-slim
+FROM python:3.9-slim
# Set working directory
WORKDIR /app
@@ -25,4 +25,7 @@ RUN pip install -e .
COPY server.py /app/server.py
+# Expose the port your server will run on
+EXPOSE 65432
+
CMD ["python", "/app/server.py"]
diff --git a/README.md b/README.md
index 86b7adf..f7abcbc 100644
--- a/README.md
+++ b/README.md
@@ -27,3 +27,20 @@ limitations under the License.
Run agents!
+
+W
+
+
+

+

+
+
+To run Docker, run `docker build -t pyrunner:latest .`
+
+This will use the Local Dockerfile to create your Docker image!
\ No newline at end of file
diff --git a/docs/source/tutorials/tools.md b/docs/source/tutorials/tools.md
index 113fb71..7ea6619 100644
--- a/docs/source/tutorials/tools.md
+++ b/docs/source/tutorials/tools.md
@@ -23,11 +23,11 @@ Here, we're going to see advanced tool usage.
> If you're new to `agents`, make sure to first read the main [agents documentation](./agents).
-### Directly define a tool by subclassing Tool, and share it to the Hub
+### Directly define a tool by subclassing Tool
-Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.
+Let's take again the tool example from the [quicktour](../quicktour), for which we had implemented a `@tool` decorator. The `tool` decorator is the standard format, but sometimes you need more: use several methods in a class for more clarity, or using additional class attributes.
-If you need to add variation, like custom attributes for your tool, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
+In this case, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name it `model_download_counter`.
@@ -67,19 +67,28 @@ tool = HFModelDownloadsTool()
Now the custom `HfModelDownloadsTool` class is ready.
+### Share your tool to the Hub
+
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
```python
-tool.push_to_hub("m-ric/hf-model-downloads", token="")
+tool.push_to_hub("{your_username}/hf-model-downloads", token="")
```
-Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
+For the push to Hub to work, your tool will need to respect some rules:
+- All method are self-contained, e.g. use variables that come either from their args,
+- If you subclass the `__init__` method, you can give it no other argument than `self`. This is because arguments set during a specific tool instance's initialization are hard to track, which prevents from sharing them properly to the hub. And anyway, the idea of making a specific class is that you can already set class attributes for anything you need to hard-code (just set `your_variable=(...)` directly under the `class YourTool(Tool):` line). And of course you can still create a class attribute anywhere in your code by assigning stuff to `self.your_variable`.
+
+Once your tool is pushed to Hub, you can load it with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`.
```python
from agents import load_tool, CodeAgent
-model_download_tool = load_tool("m-ric/hf-model-downloads", trust_remote_code=True)
+model_download_tool = load_tool(
+ "{your_username}/hf-model-downloads",
+ trust_remote_code=True
+)
```
### Import a Space as a tool 🚀
@@ -214,8 +223,4 @@ agent = CodeAgent(tools=[*image_tool_collection.tools], llm_engine=llm_engine, a
agent.run("Please draw me a picture of rivers and lakes.")
```
-To speed up the start, tools are loaded only if called by the agent.
-
-This gets you this image:
-
-
+To speed up the start, tools are loaded only if called by the agent.
\ No newline at end of file
diff --git a/examples/docker_example.py b/examples/docker_example.py
index f229bb9..3938409 100644
--- a/examples/docker_example.py
+++ b/examples/docker_example.py
@@ -1,9 +1,8 @@
-from agents.search import DuckDuckGoSearchTool
+from agents.tools.search import DuckDuckGoSearchTool
from agents.docker_alternative import DockerPythonInterpreter
-test = """
-from agents.tools import Tool
+from agents.tool import Tool
class DummyTool(Tool):
name = "echo"
@@ -17,7 +16,6 @@ class DummyTool(Tool):
def forward(self, cmd: str) -> str:
return cmd
-"""
container = DockerPythonInterpreter()
@@ -30,10 +28,8 @@ breakpoint()
print("---------")
-output = container.execute(test)
-print(output)
-output = container.execute("res = DummyTool(cmd='echo this'); print(res)")
+output = container.execute("res = DummyTool(cmd='echo this'); print(res())")
print(output)
container.stop()
diff --git a/examples/dummytool.py b/examples/dummytool.py
index 75edfd9..f917629 100644
--- a/examples/dummytool.py
+++ b/examples/dummytool.py
@@ -1,4 +1,4 @@
-from agents.tools import Tool
+from agents.tool import Tool
class DummyTool(Tool):
diff --git a/src/agents/__init__.py b/src/agents/__init__.py
index f496ba8..40932d6 100644
--- a/src/agents/__init__.py
+++ b/src/agents/__init__.py
@@ -30,11 +30,12 @@ if TYPE_CHECKING:
from .local_python_executor import *
from .monitoring import *
from .prompts import *
- from .search import *
- from .tools import *
+ from .tools.search import *
+ from .tool import *
from .types import *
from .utils import *
+
else:
import sys
diff --git a/src/agents/agents.py b/src/agents/agents.py
index ea11390..acfdc01 100644
--- a/src/agents/agents.py
+++ b/src/agents/agents.py
@@ -43,7 +43,7 @@ from .prompts import (
SYSTEM_PROMPT_PLAN,
)
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
-from .tools import (
+from .tool import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
Tool,
get_tool_description_with_args,
diff --git a/src/agents/default_tools.py b/src/agents/default_tools.py
index 60704bd..ca41136 100644
--- a/src/agents/default_tools.py
+++ b/src/agents/default_tools.py
@@ -24,7 +24,7 @@ from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode
from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code
-from .tools import TOOL_CONFIG_FILE, Tool
+from .tool import TOOL_CONFIG_FILE, Tool
def custom_print(*args):
diff --git a/src/agents/docker_alternative.py b/src/agents/docker_alternative.py
index b035c7e..711d2c8 100644
--- a/src/agents/docker_alternative.py
+++ b/src/agents/docker_alternative.py
@@ -3,7 +3,7 @@ from typing import List, Optional
import warnings
import socket
-from agents.tools import Tool
+from agents.tool import Tool
class DockerPythonInterpreter:
def __init__(self):
diff --git a/src/agents/tools.py b/src/agents/tool.py
similarity index 82%
rename from src/agents/tools.py
rename to src/agents/tool.py
index fe589f0..7671394 100644
--- a/src/agents/tools.py
+++ b/src/agents/tool.py
@@ -47,8 +47,10 @@ from transformers.utils import (
is_torch_available,
is_vision_available,
)
+from transformers.dynamic_module_utils import get_imports
from .types import ImageType, handle_agent_inputs, handle_agent_outputs
-from .utils import ImportFinder
+from .utils import instance_to_source
+from .tool_validation import validate_tool_attributes, MethodChecker
import logging
@@ -97,14 +99,6 @@ def setup_default_tools():
return default_tools
-# docstyle-ignore
-APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
-from tool import {class_name}
-
-launch_gradio_demo({class_name})
-"""
-
-
def validate_after_init(cls, do_validate_forward: bool = True):
original_init = cls.__init__
@@ -117,112 +111,6 @@ def validate_after_init(cls, do_validate_forward: bool = True):
return cls
-def validate_args_are_self_contained(source_code):
- """Validates that all names in forward method are properly defined.
- In particular it will check that all imports are done within the function."""
- print("CODDDD", source_code)
- tree = ast.parse(textwrap.dedent(source_code))
-
- # Get function arguments
- func_node = tree.body[0]
- arg_names = {arg.arg for arg in func_node.args.args} | {"kwargs"}
-
- builtin_names = set(vars(builtins))
-
- class NameChecker(ast.NodeVisitor):
- def __init__(self):
- self.undefined_names = set()
- self.imports = {}
- self.from_imports = {}
- self.assigned_names = set()
-
- def visit_Import(self, node):
- """Handle simple imports like 'import datetime'."""
- for name in node.names:
- actual_name = name.asname or name.name
- self.imports[actual_name] = (name.name, actual_name)
-
- def visit_ImportFrom(self, node):
- """Handle from imports like 'from datetime import datetime'."""
- module = node.module or ""
- for name in node.names:
- actual_name = name.asname or name.name
- self.from_imports[actual_name] = (module, name.name, actual_name)
-
- def visit_Assign(self, node):
- """Track variable assignments."""
- for target in node.targets:
- if isinstance(target, ast.Name):
- self.assigned_names.add(target.id)
- self.visit(node.value)
-
- def visit_AnnAssign(self, node):
- """Track annotated assignments."""
- if isinstance(node.target, ast.Name):
- self.assigned_names.add(node.target.id)
- if node.value:
- self.visit(node.value)
-
- def _handle_for_target(self, target) -> Set[str]:
- """Extract all names from a for loop target."""
- names = set()
- if isinstance(target, ast.Name):
- names.add(target.id)
- elif isinstance(target, ast.Tuple):
- for elt in target.elts:
- if isinstance(elt, ast.Name):
- names.add(elt.id)
- return names
-
- def visit_For(self, node):
- """Track for-loop target variables and handle enumerate specially."""
- # Add names from the target
- target_names = self._handle_for_target(node.target)
- self.assigned_names.update(target_names)
-
- # Special handling for enumerate
- if (
- isinstance(node.iter, ast.Call)
- and isinstance(node.iter.func, ast.Name)
- and node.iter.func.id == "enumerate"
- ):
- # For enumerate, if we have "for i, x in enumerate(...)",
- # both i and x should be marked as assigned
- if isinstance(node.target, ast.Tuple):
- for elt in node.target.elts:
- if isinstance(elt, ast.Name):
- self.assigned_names.add(elt.id)
-
- # Visit the rest of the node
- self.generic_visit(node)
-
- def visit_Name(self, node):
- if isinstance(node.ctx, ast.Load) and not (
- node.id == "tool"
- or node.id in builtin_names
- or node.id in arg_names
- or node.id == "self"
- or node.id in self.assigned_names
- ):
- if node.id not in self.from_imports and node.id not in self.imports:
- self.undefined_names.add(node.id)
-
- def visit_Attribute(self, node):
- # Skip self.something
- if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
- self.generic_visit(node)
-
- checker = NameChecker()
- checker.visit(tree)
-
- if checker.undefined_names:
- raise ValueError(
- f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}.
- Make sure all imports and variables are self-contained within the method.
- """
- )
-
-
AUTHORIZED_TYPES = [
"string",
"boolean",
@@ -339,64 +227,79 @@ class Tool:
"""
os.makedirs(output_dir, exist_ok=True)
class_name = self.__class__.__name__
+ tool_file = os.path.join(output_dir, "tool.py")
# Save tool file
- forward_source_code = inspect.getsource(self.forward)
- validate_args_are_self_contained(forward_source_code)
- tool_code = f"""
-from agents import Tool
+ if type(self).__name__ == "SimpleTool":
+ # Check that imports are self-contained
+ forward_node = ast.parse(textwrap.dedent(inspect.getsource(self.forward)))
+ # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code
+ method_checker = MethodChecker(set())
+ method_checker.visit(forward_node)
+ if len(method_checker.errors) > 0:
+ raise(ValueError("\n".join(method_checker.errors)))
-class {class_name}(Tool):
- name = "{self.name}"
- description = \"\"\"{self.description}\"\"\"
- inputs = {json.dumps(self.inputs, separators=(',', ':'))}
- output_type = "{self.output_type}"
-""".strip()
+ forward_source_code = inspect.getsource(self.forward)
+ tool_code = textwrap.dedent(f"""
+ from agents import Tool
- def add_self_argument(source_code: str) -> str:
- """Add 'self' as first argument to a function definition if not present."""
- pattern = r"def forward\(((?!self)[^)]*)\)"
+ class {class_name}(Tool):
+ name = "{self.name}"
+ description = "{self.description}"
+ inputs = {json.dumps(self.inputs, separators=(',', ':'))}
+ output_type = "{self.output_type}"
+ """).strip()
+ import re
+ def add_self_argument(source_code: str) -> str:
+ """Add 'self' as first argument to a function definition if not present."""
+ pattern = r'def forward\(((?!self)[^)]*)\)'
+
+ def replacement(match):
+ args = match.group(1).strip()
+ if args: # If there are other arguments
+ return f'def forward(self, {args})'
+ return 'def forward(self)'
+
+ return re.sub(pattern, replacement, source_code)
- def replacement(match):
- args = match.group(1).strip()
- if args: # If there are other arguments
- return f"def forward(self, {args})"
- return "def forward(self)"
+ forward_source_code = forward_source_code.replace(self.name, "forward")
+ forward_source_code = add_self_argument(forward_source_code)
+ forward_source_code = forward_source_code.replace("@tool", "").strip()
+ tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
+
+ with open(tool_file, "w", encoding="utf-8") as f:
+ f.write(tool_code)
+ else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool
+ if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]:
+ raise ValueError(
+ f"Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors."
+ )
- return re.sub(pattern, replacement, source_code)
+ validate_tool_attributes(self.__class__)
- forward_source_code = forward_source_code.replace(self.name, "forward")
- forward_source_code = add_self_argument(forward_source_code)
- forward_source_code = forward_source_code.replace("@tool", "").strip()
- tool_code += "\n\n" + textwrap.indent(forward_source_code, " ")
- with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f:
- f.write(tool_code)
-
- # Save config file
- config_file = os.path.join(output_dir, "tool_config.json")
- tool_config = {
- "tool_class": self.__class__.__name__,
- "description": self.description,
- "name": self.name,
- "inputs": self.inputs,
- "output_type": str(self.output_type),
- }
- with open(config_file, "w", encoding="utf-8") as f:
- f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
+ tool_code = instance_to_source(self, base_cls=Tool)
+ with open(tool_file, "w", encoding="utf-8") as f:
+ f.write(tool_code)
# Save app file
app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f:
- f.write(APP_FILE_TEMPLATE.format(class_name=class_name))
+ f.write(textwrap.dedent(f"""
+ from agents import launch_gradio_demo
+ from tool import {class_name}
+
+ tool = {class_name}()
+
+ launch_gradio_demo(tool)
+ """).lstrip())
# Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt")
- tree = ast.parse(forward_source_code)
- import_finder = ImportFinder()
- import_finder.visit(tree)
-
- imports = list(set(import_finder.packages))
+ imports = []
+ for module in [tool_file]:
+ imports.extend(get_imports(module))
+ imports = list(set(imports))
with open(requirements_file, "w", encoding="utf-8") as f:
f.write("agents_package\n" + "\n".join(imports) + "\n")
@@ -612,7 +515,6 @@ class {class_name}(Tool):
```
"""
from gradio_client import Client, handle_file
- from gradio_client.utils import is_http_url_like
class SpaceToolWrapper(Tool):
def __init__(
@@ -665,6 +567,7 @@ class {class_name}(Tool):
self.is_initialized = True
def sanitize_argument_for_prediction(self, arg):
+ from gradio_client.utils import is_http_url_like
if isinstance(arg, ImageType):
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
arg.save(temp_file.name)
@@ -793,13 +696,13 @@ def compile_jinja_template(template):
return jinja_env.from_string(template)
-def launch_gradio_demo(tool_class: Tool):
+def launch_gradio_demo(tool: Tool):
"""
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
`inputs` and `output_type`.
Args:
- tool_class (`type`): The class of the tool for which to launch the demo.
+ tool (`type`): The tool for which to launch the demo.
"""
try:
import gradio as gr
@@ -808,11 +711,6 @@ def launch_gradio_demo(tool_class: Tool):
"Gradio should be installed in order to launch a gradio demo."
)
- tool = tool_class()
-
- def fn(*args, **kwargs):
- return tool(*args, **kwargs)
-
TYPE_TO_COMPONENT_CLASS_MAPPING = {
"image": gr.Image,
"audio": gr.Audio,
@@ -822,7 +720,7 @@ def launch_gradio_demo(tool_class: Tool):
}
gradio_inputs = []
- for input_name, input_details in tool_class.inputs.items():
+ for input_name, input_details in tool.inputs.items():
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
input_details["type"]
]
@@ -830,15 +728,15 @@ def launch_gradio_demo(tool_class: Tool):
gradio_inputs.append(new_component)
output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
- tool_class.output_type
+ tool.output_type
]
- gradio_output = output_gradio_componentclass(label=input_name)
+ gradio_output = output_gradio_componentclass(label="Output")
gr.Interface(
- fn=fn,
+ fn=tool, # This works because `tool` has a __call__ method
inputs=gradio_inputs,
outputs=gradio_output,
- title=tool_class.__name__,
+ title=tool.name,
article=tool.description,
).launch()
@@ -1027,29 +925,34 @@ def tool(tool_function: Callable) -> Tool:
raise TypeHintParsingException(
"Tool return type not found: make sure your function has a return type hint!"
)
- class_name = "".join([el.title() for el in parameters["name"].split("_")])
if parameters["return"]["type"] == "object":
parameters["return"]["type"] = "any"
- class SpecificTool(Tool):
- name = parameters["name"]
- description = parameters["description"]
- inputs = parameters["parameters"]["properties"]
- output_type = parameters["return"]["type"]
-
- @wraps(tool_function)
- def forward(self, *args, **kwargs):
- return tool_function(*args, **kwargs)
+ class SimpleTool(Tool):
+ def __init__(self, name, description, inputs, output_type, function):
+ self.name = name
+ self.description = description
+ self.inputs = inputs
+ self.output_type = output_type
+ self.forward = function
+ self.is_initialized = True
+ simple_tool = SimpleTool(
+ parameters["name"],
+ parameters["description"],
+ parameters["parameters"]["properties"],
+ parameters["return"]["type"],
+ function=tool_function
+ )
original_signature = inspect.signature(tool_function)
new_parameters = [
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
] + list(original_signature.parameters.values())
new_signature = original_signature.replace(parameters=new_parameters)
- SpecificTool.forward.__signature__ = new_signature
- SpecificTool.__name__ = class_name
- return SpecificTool()
+ simple_tool.forward.__signature__ = new_signature
+ # SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")])
+ return simple_tool
HUGGINGFACE_DEFAULT_TOOLS = {}
diff --git a/src/agents/tool_validation.py b/src/agents/tool_validation.py
new file mode 100644
index 0000000..7e25afe
--- /dev/null
+++ b/src/agents/tool_validation.py
@@ -0,0 +1,191 @@
+import ast
+import inspect
+import importlib.util
+import builtins
+from pathlib import Path
+from typing import List, Set, Dict
+import textwrap
+
+_BUILTIN_NAMES = set(vars(builtins))
+
+def is_local_import(module_name: str) -> bool:
+ """
+ Check if an import is from a local file or a package.
+ Returns True if it's a local file import.
+ """
+ try:
+ spec = importlib.util.find_spec(module_name)
+ if spec is None:
+ return True # If we can't find the module, assume it's local
+
+ # If the module is found and has a file path, check if it's in site-packages
+ if spec.origin and 'site-packages' not in spec.origin:
+ # Check if it's a .py file in the current directory or subdirectories
+ return spec.origin.endswith('.py')
+
+ return False
+ except ImportError:
+ return True # If there's an import error, assume it's local
+
+class MethodChecker(ast.NodeVisitor):
+ """
+ Checks that a method
+ - only uses defined names
+ - contains no local imports (e.g. numpy is ok but local_script is not)
+ """
+ def __init__(self, class_attributes: Set[str]):
+ self.undefined_names = set()
+ self.imports = {}
+ self.from_imports = {}
+ self.assigned_names = set()
+ self.arg_names = set()
+ self.class_attributes = class_attributes
+ self.errors = []
+
+ def visit_arguments(self, node):
+ """Collect function arguments"""
+ self.arg_names = {arg.arg for arg in node.args}
+ if node.kwarg:
+ self.arg_names.add(node.kwarg.arg)
+ if node.vararg:
+ self.arg_names.add(node.vararg.arg)
+
+ def visit_Import(self, node):
+ for name in node.names:
+ actual_name = name.asname or name.name
+ if is_local_import(actual_name):
+ self.errors.append(f"Local import '{actual_name}'")
+ self.imports[actual_name] = name.name
+
+ def visit_ImportFrom(self, node):
+ module = node.module or ""
+ for name in node.names:
+ actual_name = name.asname or name.name
+ if is_local_import(module):
+ self.errors.append(f"Local import '{module}'")
+ self.from_imports[actual_name] = (module, name.name)
+
+ def visit_Assign(self, node):
+ for target in node.targets:
+ if isinstance(target, ast.Name):
+ self.assigned_names.add(target.id)
+ self.visit(node.value)
+
+ def visit_AnnAssign(self, node):
+ """Track annotated assignments."""
+ if isinstance(node.target, ast.Name):
+ self.assigned_names.add(node.target.id)
+ if node.value:
+ self.visit(node.value)
+
+ def visit_For(self, node):
+ target = node.target
+ if isinstance(target, ast.Name):
+ self.assigned_names.add(target.id)
+ elif isinstance(target, ast.Tuple):
+ for elt in target.elts:
+ if isinstance(elt, ast.Name):
+ self.assigned_names.add(elt.id)
+ self.generic_visit(node)
+
+ def visit_Attribute(self, node):
+ # Skip self.something
+ if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
+ self.generic_visit(node)
+
+ def visit_Name(self, node):
+ if isinstance(node.ctx, ast.Load):
+ if not (
+ node.id in _BUILTIN_NAMES
+ or node.id in self.arg_names
+ or node.id == "self"
+ or node.id in self.class_attributes
+ or node.id in self.imports
+ or node.id in self.from_imports
+ or node.id in self.assigned_names
+ ):
+ self.errors.append(f"Name '{node.id}' is undefined.")
+
+ def visit_Call(self, node):
+ if isinstance(node.func, ast.Name):
+ if not (
+ node.func.id in _BUILTIN_NAMES
+ or node.func.id in self.arg_names
+ or node.func.id == "self"
+ or node.func.id in self.class_attributes
+ or node.func.id in self.imports
+ or node.func.id in self.from_imports
+ or node.func.id in self.assigned_names
+ ):
+ self.errors.append(f"Name '{node.func.id}' is undefined.")
+ self.generic_visit(node)
+
+def validate_tool_attributes(cls) -> None:
+ """
+ Validates that a Tool class follows the proper patterns:
+ 0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
+ 1. About the class:
+ - Class attributes should only be strings or dicts
+ - Class attributes cannot be complex attributes
+ 2. About all class methods:
+ - Imports must be from packages, not local files
+ - All methods must be self-contained
+
+ Raises all errors encountered, if no error returns None.
+ """
+ errors = []
+
+ source = textwrap.dedent(inspect.getsource(cls))
+
+ tree = ast.parse(source)
+
+ if not isinstance(tree.body[0], ast.ClassDef):
+ raise ValueError("Source code must define a class")
+
+ # Check that __init__ method takes no arguments
+ if not cls.__init__.__qualname__ == 'Tool.__init__':
+ sig = inspect.signature(cls.__init__)
+ non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
+ if len(non_self_params) > 0:
+ errors.append(f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!")
+
+ class_node = tree.body[0]
+
+ class ClassLevelChecker(ast.NodeVisitor):
+ def __init__(self):
+ self.imported_names = set()
+ self.complex_attributes = set()
+ self.class_attributes = set()
+
+ def visit_Assign(self, node):
+ # Track class attributes
+ for target in node.targets:
+ if isinstance(target, ast.Name):
+ self.class_attributes.add(target.id)
+
+ # Check if the assignment is more complex than simple literals
+ if not all(isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
+ for val in ast.walk(node.value)):
+ for target in node.targets:
+ if isinstance(target, ast.Name):
+ self.complex_attributes.add(target.id)
+
+ class_level_checker = ClassLevelChecker()
+ class_level_checker.visit(class_node)
+
+ if class_level_checker.complex_attributes:
+ errors.append(
+ f"Complex attributes should be defined in __init__, not as class attributes: "
+ f"{', '.join(class_level_checker.complex_attributes)}"
+ )
+
+ # Run checks on all methods
+ for node in class_node.body:
+ if isinstance(node, ast.FunctionDef):
+ method_checker = MethodChecker(class_level_checker.class_attributes)
+ method_checker.visit(node)
+ errors += [f"- {node.name}: {error}" for error in method_checker.errors]
+
+ if errors:
+ raise ValueError("Tool validation failed:\n" + "\n".join(errors))
+ return
diff --git a/src/agents/tools/__init__.py b/src/agents/tools/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/agents/search.py b/src/agents/tools/search.py
similarity index 99%
rename from src/agents/search.py
rename to src/agents/tools/search.py
index c989bc5..01af0cd 100644
--- a/src/agents/search.py
+++ b/src/agents/tools/search.py
@@ -19,7 +19,7 @@ import re
import requests
from requests.exceptions import RequestException
-from .tools import Tool
+from ..tools import Tool
class DuckDuckGoSearchTool(Tool):
diff --git a/src/agents/utils.py b/src/agents/utils.py
index e7d286a..a5fc8ca 100644
--- a/src/agents/utils.py
+++ b/src/agents/utils.py
@@ -19,6 +19,9 @@ import re
from typing import Tuple, Dict, Union
import ast
from rich.console import Console
+import ast
+import inspect
+import types
from transformers.utils.import_utils import _is_package_available
@@ -127,5 +130,114 @@ class ImportFinder(ast.NodeVisitor):
base_package = node.module.split(".")[0]
self.packages.add(base_package)
+import ast
+import builtins
+from typing import Set, Dict, List
+
+
+def get_method_source(method):
+ """Get source code for a method, including bound methods."""
+ if isinstance(method, types.MethodType):
+ method = method.__func__
+ return inspect.getsource(method).strip()
+
+
+def is_same_method(method1, method2):
+ """Compare two methods by their source code."""
+ try:
+ source1 = get_method_source(method1)
+ source2 = get_method_source(method2)
+
+ # Remove method decorators if any
+ source1 = '\n'.join(line for line in source1.split('\n')
+ if not line.strip().startswith('@'))
+ source2 = '\n'.join(line for line in source2.split('\n')
+ if not line.strip().startswith('@'))
+
+ return source1 == source2
+ except (TypeError, OSError):
+ return False
+
+def is_same_item(item1, item2):
+ """Compare two class items (methods or attributes) for equality."""
+ if callable(item1) and callable(item2):
+ return is_same_method(item1, item2)
+ else:
+ return item1 == item2
+
+def instance_to_source(instance, base_cls=None):
+ """Convert an instance to its class source code representation."""
+ cls = instance.__class__
+ class_name = cls.__name__
+
+ # Start building class lines
+ class_lines = []
+ if base_cls:
+ class_lines.append(f"class {class_name}({base_cls.__name__}):")
+ else:
+ class_lines.append(f"class {class_name}:")
+
+ # Add docstring if it exists and differs from base
+ if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__):
+ class_lines.append(f' """{cls.__doc__}"""')
+
+ # Add class-level attributes
+ class_attrs = {
+ name: value for name, value in cls.__dict__.items()
+ if not name.startswith('__') and not callable(value) and
+ not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
+ }
+
+ for name, value in class_attrs.items():
+ if isinstance(value, str):
+ class_lines.append(f' {name} = "{value}"')
+ else:
+ class_lines.append(f' {name} = {repr(value)}')
+
+ if class_attrs:
+ class_lines.append("")
+
+ # Add methods
+ methods = {
+ name: func for name, func in cls.__dict__.items()
+ if callable(func) and
+ not (base_cls and hasattr(base_cls, name) and
+ getattr(base_cls, name).__code__.co_code == func.__code__.co_code)
+ }
+
+ for name, method in methods.items():
+ method_source = inspect.getsource(method)
+ # Clean up the indentation
+ method_lines = method_source.split('\n')
+ first_line = method_lines[0]
+ indent = len(first_line) - len(first_line.lstrip())
+ method_lines = [line[indent:] for line in method_lines]
+ method_source = '\n'.join([' ' + line if line.strip() else line
+ for line in method_lines])
+ class_lines.append(method_source)
+ class_lines.append("")
+
+ # Find required imports using ImportFinder
+ import_finder = ImportFinder()
+ import_finder.visit(ast.parse('\n'.join(class_lines)))
+ required_imports = import_finder.packages
+
+ # Build final code with imports
+ final_lines = []
+
+ # Add base class import if needed
+ if base_cls:
+ final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
+
+ # Add discovered imports
+ final_lines.extend(required_imports)
+
+ if final_lines: # Add empty line after imports
+ final_lines.append("")
+
+ # Add the class code
+ final_lines.extend(class_lines)
+
+ return '\n'.join(final_lines)
__all__ = []
diff --git a/tests/test_agents.py b/tests/test_agents.py
index 539f4cf..93154c0 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -29,7 +29,7 @@ from agents.agents import (
Toolbox,
ToolCall,
)
-from agents.tools import tool
+from agents.tool import tool
from agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import get_tests_dir
diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py
index e5b46ff..a3feb98 100644
--- a/tests/test_all_docs.py
+++ b/tests/test_all_docs.py
@@ -125,7 +125,7 @@ class TestDocs:
"from_langchain",
]
code_blocks = [
- block.replace("", self.hf_token)
+ block.replace("", self.hf_token).replace("{your_username}", "m-ric")
for block in code_blocks
if not any(
[snippet in block for snippet in excluded_snippets]
diff --git a/tests/test_tools_common.py b/tests/test_tools_common.py
index ae3b622..7d3b730 100644
--- a/tests/test_tools_common.py
+++ b/tests/test_tools_common.py
@@ -26,7 +26,7 @@ from agents.types import (
AgentImage,
AgentText,
)
-from agents.tools import Tool, tool, AUTHORIZED_TYPES
+from agents.tool import Tool, tool, AUTHORIZED_TYPES
from transformers.testing_utils import get_tests_dir
@@ -174,6 +174,8 @@ class ToolTests(unittest.TestCase):
Gets the current time.
"""
return str(datetime.now())
+
+ get_current_time.save("output")
assert "datetime" in str(e)
@@ -188,6 +190,9 @@ class ToolTests(unittest.TestCase):
def forward(self):
return str(datetime.now())
+
+ get_current_time = GetCurrentTimeTool()
+ get_current_time.save("output")
assert "datetime" in str(e)
@@ -210,3 +215,63 @@ class ToolTests(unittest.TestCase):
def forward(self):
from datetime import datetime
return str(datetime.now())
+
+ def test_saving_tool_allows_no_arg_in_init(self):
+ # Test one cannot save tool with additional args in init
+ class FailTool(Tool):
+ name = "specific"
+ description = "test description"
+ inputs = {"input_str": {"type": "string", "description": "input description"}}
+ output_type = "string"
+
+ def __init__(self, url):
+ super().__init__(self)
+ self.url = "none"
+
+ def forward(self, string_input):
+ return self.url + string_input
+
+ fail_tool = FailTool("dummy_url")
+ with pytest.raises(Exception) as e:
+ fail_tool.save('output')
+ assert '__init__' in str(e)
+
+ def test_saving_tool_allows_no_imports_from_outside_methods(self):
+ # Test that using imports from outside functions fails
+ from numpy import random
+ class FailTool2(Tool):
+ name = "specific"
+ description = "test description"
+ inputs = {"input_str": {"type": "string", "description": "input description"}}
+ output_type = "string"
+
+ def useless_method(self):
+ self.client = random.random()
+ return ""
+
+ def forward(self, string_input):
+ return self.useless_method() + string_input
+
+ fail_tool_2 = FailTool2()
+ with pytest.raises(Exception) as e:
+ fail_tool_2.save('output')
+ assert 'random' in str(e)
+
+ # Test that putting these imports inside functions works
+
+ class FailTool3(Tool):
+ name = "specific"
+ description = "test description"
+ inputs = {"input_str": {"type": "string", "description": "input description"}}
+ output_type = "string"
+
+ def useless_method(self):
+ from numpy import random
+ self.client = random.random()
+ return ""
+
+ def forward(self, string_input):
+ return self.useless_method() + string_input
+
+ fail_tool_3 = FailTool3()
+ fail_tool_3.save('output')