From 584ce8f36350de4fe7c0925df76e1e02203b60bf Mon Sep 17 00:00:00 2001 From: Aymeric Date: Thu, 19 Dec 2024 16:05:17 +0100 Subject: [PATCH] Consolidate pushing Tools to Hub --- Dockerfile | 5 +- README.md | 17 ++ docs/source/tutorials/tools.md | 27 +-- examples/docker_example.py | 10 +- examples/dummytool.py | 2 +- src/agents/__init__.py | 5 +- src/agents/agents.py | 2 +- src/agents/default_tools.py | 2 +- src/agents/docker_alternative.py | 2 +- src/agents/{tools.py => tool.py} | 271 ++++++++++--------------------- src/agents/tool_validation.py | 191 ++++++++++++++++++++++ src/agents/tools/__init__.py | 0 src/agents/{ => tools}/search.py | 2 +- src/agents/utils.py | 112 +++++++++++++ tests/test_agents.py | 2 +- tests/test_all_docs.py | 2 +- tests/test_tools_common.py | 67 +++++++- 17 files changed, 506 insertions(+), 213 deletions(-) rename src/agents/{tools.py => tool.py} (82%) create mode 100644 src/agents/tool_validation.py create mode 100644 src/agents/tools/__init__.py rename src/agents/{ => tools}/search.py (99%) diff --git a/Dockerfile b/Dockerfile index 321f455..67d2b42 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Base Python image -FROM python:3.12-slim +FROM python:3.9-slim # Set working directory WORKDIR /app @@ -25,4 +25,7 @@ RUN pip install -e . COPY server.py /app/server.py +# Expose the port your server will run on +EXPOSE 65432 + CMD ["python", "/app/server.py"] diff --git a/README.md b/README.md index 86b7adf..f7abcbc 100644 --- a/README.md +++ b/README.md @@ -27,3 +27,20 @@ limitations under the License.

Run agents!

+ +W + +
+ + +
+ +To run Docker, run `docker build -t pyrunner:latest .` + +This will use the Local Dockerfile to create your Docker image! \ No newline at end of file diff --git a/docs/source/tutorials/tools.md b/docs/source/tutorials/tools.md index 113fb71..7ea6619 100644 --- a/docs/source/tutorials/tools.md +++ b/docs/source/tutorials/tools.md @@ -23,11 +23,11 @@ Here, we're going to see advanced tool usage. > If you're new to `agents`, make sure to first read the main [agents documentation](./agents). -### Directly define a tool by subclassing Tool, and share it to the Hub +### Directly define a tool by subclassing Tool -Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator. +Let's take again the tool example from the [quicktour](../quicktour), for which we had implemented a `@tool` decorator. The `tool` decorator is the standard format, but sometimes you need more: use several methods in a class for more clarity, or using additional class attributes. -If you need to add variation, like custom attributes for your tool, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass. +In this case, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass. The custom tool needs: - An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name it `model_download_counter`. @@ -67,19 +67,28 @@ tool = HFModelDownloadsTool() Now the custom `HfModelDownloadsTool` class is ready. +### Share your tool to the Hub + You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. ```python -tool.push_to_hub("m-ric/hf-model-downloads", token="") +tool.push_to_hub("{your_username}/hf-model-downloads", token="") ``` -Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. +For the push to Hub to work, your tool will need to respect some rules: +- All method are self-contained, e.g. use variables that come either from their args, +- If you subclass the `__init__` method, you can give it no other argument than `self`. This is because arguments set during a specific tool instance's initialization are hard to track, which prevents from sharing them properly to the hub. And anyway, the idea of making a specific class is that you can already set class attributes for anything you need to hard-code (just set `your_variable=(...)` directly under the `class YourTool(Tool):` line). And of course you can still create a class attribute anywhere in your code by assigning stuff to `self.your_variable`. + +Once your tool is pushed to Hub, you can load it with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. Since running tools means running custom code, you need to make sure you trust the repository, and pass `trust_remote_code=True`. ```python from agents import load_tool, CodeAgent -model_download_tool = load_tool("m-ric/hf-model-downloads", trust_remote_code=True) +model_download_tool = load_tool( + "{your_username}/hf-model-downloads", + trust_remote_code=True +) ``` ### Import a Space as a tool 🚀 @@ -214,8 +223,4 @@ agent = CodeAgent(tools=[*image_tool_collection.tools], llm_engine=llm_engine, a agent.run("Please draw me a picture of rivers and lakes.") ``` -To speed up the start, tools are loaded only if called by the agent. - -This gets you this image: - - +To speed up the start, tools are loaded only if called by the agent. \ No newline at end of file diff --git a/examples/docker_example.py b/examples/docker_example.py index f229bb9..3938409 100644 --- a/examples/docker_example.py +++ b/examples/docker_example.py @@ -1,9 +1,8 @@ -from agents.search import DuckDuckGoSearchTool +from agents.tools.search import DuckDuckGoSearchTool from agents.docker_alternative import DockerPythonInterpreter -test = """ -from agents.tools import Tool +from agents.tool import Tool class DummyTool(Tool): name = "echo" @@ -17,7 +16,6 @@ class DummyTool(Tool): def forward(self, cmd: str) -> str: return cmd -""" container = DockerPythonInterpreter() @@ -30,10 +28,8 @@ breakpoint() print("---------") -output = container.execute(test) -print(output) -output = container.execute("res = DummyTool(cmd='echo this'); print(res)") +output = container.execute("res = DummyTool(cmd='echo this'); print(res())") print(output) container.stop() diff --git a/examples/dummytool.py b/examples/dummytool.py index 75edfd9..f917629 100644 --- a/examples/dummytool.py +++ b/examples/dummytool.py @@ -1,4 +1,4 @@ -from agents.tools import Tool +from agents.tool import Tool class DummyTool(Tool): diff --git a/src/agents/__init__.py b/src/agents/__init__.py index f496ba8..40932d6 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -30,11 +30,12 @@ if TYPE_CHECKING: from .local_python_executor import * from .monitoring import * from .prompts import * - from .search import * - from .tools import * + from .tools.search import * + from .tool import * from .types import * from .utils import * + else: import sys diff --git a/src/agents/agents.py b/src/agents/agents.py index ea11390..acfdc01 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -43,7 +43,7 @@ from .prompts import ( SYSTEM_PROMPT_PLAN, ) from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code -from .tools import ( +from .tool import ( DEFAULT_TOOL_DESCRIPTION_TEMPLATE, Tool, get_tool_description_with_args, diff --git a/src/agents/default_tools.py b/src/agents/default_tools.py index 60704bd..ca41136 100644 --- a/src/agents/default_tools.py +++ b/src/agents/default_tools.py @@ -24,7 +24,7 @@ from huggingface_hub import hf_hub_download, list_spaces from transformers.utils import is_offline_mode from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code -from .tools import TOOL_CONFIG_FILE, Tool +from .tool import TOOL_CONFIG_FILE, Tool def custom_print(*args): diff --git a/src/agents/docker_alternative.py b/src/agents/docker_alternative.py index b035c7e..711d2c8 100644 --- a/src/agents/docker_alternative.py +++ b/src/agents/docker_alternative.py @@ -3,7 +3,7 @@ from typing import List, Optional import warnings import socket -from agents.tools import Tool +from agents.tool import Tool class DockerPythonInterpreter: def __init__(self): diff --git a/src/agents/tools.py b/src/agents/tool.py similarity index 82% rename from src/agents/tools.py rename to src/agents/tool.py index fe589f0..7671394 100644 --- a/src/agents/tools.py +++ b/src/agents/tool.py @@ -47,8 +47,10 @@ from transformers.utils import ( is_torch_available, is_vision_available, ) +from transformers.dynamic_module_utils import get_imports from .types import ImageType, handle_agent_inputs, handle_agent_outputs -from .utils import ImportFinder +from .utils import instance_to_source +from .tool_validation import validate_tool_attributes, MethodChecker import logging @@ -97,14 +99,6 @@ def setup_default_tools(): return default_tools -# docstyle-ignore -APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo -from tool import {class_name} - -launch_gradio_demo({class_name}) -""" - - def validate_after_init(cls, do_validate_forward: bool = True): original_init = cls.__init__ @@ -117,112 +111,6 @@ def validate_after_init(cls, do_validate_forward: bool = True): return cls -def validate_args_are_self_contained(source_code): - """Validates that all names in forward method are properly defined. - In particular it will check that all imports are done within the function.""" - print("CODDDD", source_code) - tree = ast.parse(textwrap.dedent(source_code)) - - # Get function arguments - func_node = tree.body[0] - arg_names = {arg.arg for arg in func_node.args.args} | {"kwargs"} - - builtin_names = set(vars(builtins)) - - class NameChecker(ast.NodeVisitor): - def __init__(self): - self.undefined_names = set() - self.imports = {} - self.from_imports = {} - self.assigned_names = set() - - def visit_Import(self, node): - """Handle simple imports like 'import datetime'.""" - for name in node.names: - actual_name = name.asname or name.name - self.imports[actual_name] = (name.name, actual_name) - - def visit_ImportFrom(self, node): - """Handle from imports like 'from datetime import datetime'.""" - module = node.module or "" - for name in node.names: - actual_name = name.asname or name.name - self.from_imports[actual_name] = (module, name.name, actual_name) - - def visit_Assign(self, node): - """Track variable assignments.""" - for target in node.targets: - if isinstance(target, ast.Name): - self.assigned_names.add(target.id) - self.visit(node.value) - - def visit_AnnAssign(self, node): - """Track annotated assignments.""" - if isinstance(node.target, ast.Name): - self.assigned_names.add(node.target.id) - if node.value: - self.visit(node.value) - - def _handle_for_target(self, target) -> Set[str]: - """Extract all names from a for loop target.""" - names = set() - if isinstance(target, ast.Name): - names.add(target.id) - elif isinstance(target, ast.Tuple): - for elt in target.elts: - if isinstance(elt, ast.Name): - names.add(elt.id) - return names - - def visit_For(self, node): - """Track for-loop target variables and handle enumerate specially.""" - # Add names from the target - target_names = self._handle_for_target(node.target) - self.assigned_names.update(target_names) - - # Special handling for enumerate - if ( - isinstance(node.iter, ast.Call) - and isinstance(node.iter.func, ast.Name) - and node.iter.func.id == "enumerate" - ): - # For enumerate, if we have "for i, x in enumerate(...)", - # both i and x should be marked as assigned - if isinstance(node.target, ast.Tuple): - for elt in node.target.elts: - if isinstance(elt, ast.Name): - self.assigned_names.add(elt.id) - - # Visit the rest of the node - self.generic_visit(node) - - def visit_Name(self, node): - if isinstance(node.ctx, ast.Load) and not ( - node.id == "tool" - or node.id in builtin_names - or node.id in arg_names - or node.id == "self" - or node.id in self.assigned_names - ): - if node.id not in self.from_imports and node.id not in self.imports: - self.undefined_names.add(node.id) - - def visit_Attribute(self, node): - # Skip self.something - if not (isinstance(node.value, ast.Name) and node.value.id == "self"): - self.generic_visit(node) - - checker = NameChecker() - checker.visit(tree) - - if checker.undefined_names: - raise ValueError( - f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}. - Make sure all imports and variables are self-contained within the method. - """ - ) - - AUTHORIZED_TYPES = [ "string", "boolean", @@ -339,64 +227,79 @@ class Tool: """ os.makedirs(output_dir, exist_ok=True) class_name = self.__class__.__name__ + tool_file = os.path.join(output_dir, "tool.py") # Save tool file - forward_source_code = inspect.getsource(self.forward) - validate_args_are_self_contained(forward_source_code) - tool_code = f""" -from agents import Tool + if type(self).__name__ == "SimpleTool": + # Check that imports are self-contained + forward_node = ast.parse(textwrap.dedent(inspect.getsource(self.forward))) + # If tool was created using '@tool' decorator, it has only a forward pass, so it's simpler to just get its code + method_checker = MethodChecker(set()) + method_checker.visit(forward_node) + if len(method_checker.errors) > 0: + raise(ValueError("\n".join(method_checker.errors))) -class {class_name}(Tool): - name = "{self.name}" - description = \"\"\"{self.description}\"\"\" - inputs = {json.dumps(self.inputs, separators=(',', ':'))} - output_type = "{self.output_type}" -""".strip() + forward_source_code = inspect.getsource(self.forward) + tool_code = textwrap.dedent(f""" + from agents import Tool - def add_self_argument(source_code: str) -> str: - """Add 'self' as first argument to a function definition if not present.""" - pattern = r"def forward\(((?!self)[^)]*)\)" + class {class_name}(Tool): + name = "{self.name}" + description = "{self.description}" + inputs = {json.dumps(self.inputs, separators=(',', ':'))} + output_type = "{self.output_type}" + """).strip() + import re + def add_self_argument(source_code: str) -> str: + """Add 'self' as first argument to a function definition if not present.""" + pattern = r'def forward\(((?!self)[^)]*)\)' + + def replacement(match): + args = match.group(1).strip() + if args: # If there are other arguments + return f'def forward(self, {args})' + return 'def forward(self)' + + return re.sub(pattern, replacement, source_code) - def replacement(match): - args = match.group(1).strip() - if args: # If there are other arguments - return f"def forward(self, {args})" - return "def forward(self)" + forward_source_code = forward_source_code.replace(self.name, "forward") + forward_source_code = add_self_argument(forward_source_code) + forward_source_code = forward_source_code.replace("@tool", "").strip() + tool_code += "\n\n" + textwrap.indent(forward_source_code, " ") + + with open(tool_file, "w", encoding="utf-8") as f: + f.write(tool_code) + else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool + if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]: + raise ValueError( + f"Cannot save objects created with from_space, from_langchain or from_gradio, as this would create errors." + ) - return re.sub(pattern, replacement, source_code) + validate_tool_attributes(self.__class__) - forward_source_code = forward_source_code.replace(self.name, "forward") - forward_source_code = add_self_argument(forward_source_code) - forward_source_code = forward_source_code.replace("@tool", "").strip() - tool_code += "\n\n" + textwrap.indent(forward_source_code, " ") - with open(os.path.join(output_dir, "tool.py"), "w", encoding="utf-8") as f: - f.write(tool_code) - - # Save config file - config_file = os.path.join(output_dir, "tool_config.json") - tool_config = { - "tool_class": self.__class__.__name__, - "description": self.description, - "name": self.name, - "inputs": self.inputs, - "output_type": str(self.output_type), - } - with open(config_file, "w", encoding="utf-8") as f: - f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n") + tool_code = instance_to_source(self, base_cls=Tool) + with open(tool_file, "w", encoding="utf-8") as f: + f.write(tool_code) # Save app file app_file = os.path.join(output_dir, "app.py") with open(app_file, "w", encoding="utf-8") as f: - f.write(APP_FILE_TEMPLATE.format(class_name=class_name)) + f.write(textwrap.dedent(f""" + from agents import launch_gradio_demo + from tool import {class_name} + + tool = {class_name}() + + launch_gradio_demo(tool) + """).lstrip()) # Save requirements file requirements_file = os.path.join(output_dir, "requirements.txt") - tree = ast.parse(forward_source_code) - import_finder = ImportFinder() - import_finder.visit(tree) - - imports = list(set(import_finder.packages)) + imports = [] + for module in [tool_file]: + imports.extend(get_imports(module)) + imports = list(set(imports)) with open(requirements_file, "w", encoding="utf-8") as f: f.write("agents_package\n" + "\n".join(imports) + "\n") @@ -612,7 +515,6 @@ class {class_name}(Tool): ``` """ from gradio_client import Client, handle_file - from gradio_client.utils import is_http_url_like class SpaceToolWrapper(Tool): def __init__( @@ -665,6 +567,7 @@ class {class_name}(Tool): self.is_initialized = True def sanitize_argument_for_prediction(self, arg): + from gradio_client.utils import is_http_url_like if isinstance(arg, ImageType): temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) arg.save(temp_file.name) @@ -793,13 +696,13 @@ def compile_jinja_template(template): return jinja_env.from_string(template) -def launch_gradio_demo(tool_class: Tool): +def launch_gradio_demo(tool: Tool): """ Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes `inputs` and `output_type`. Args: - tool_class (`type`): The class of the tool for which to launch the demo. + tool (`type`): The tool for which to launch the demo. """ try: import gradio as gr @@ -808,11 +711,6 @@ def launch_gradio_demo(tool_class: Tool): "Gradio should be installed in order to launch a gradio demo." ) - tool = tool_class() - - def fn(*args, **kwargs): - return tool(*args, **kwargs) - TYPE_TO_COMPONENT_CLASS_MAPPING = { "image": gr.Image, "audio": gr.Audio, @@ -822,7 +720,7 @@ def launch_gradio_demo(tool_class: Tool): } gradio_inputs = [] - for input_name, input_details in tool_class.inputs.items(): + for input_name, input_details in tool.inputs.items(): input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[ input_details["type"] ] @@ -830,15 +728,15 @@ def launch_gradio_demo(tool_class: Tool): gradio_inputs.append(new_component) output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[ - tool_class.output_type + tool.output_type ] - gradio_output = output_gradio_componentclass(label=input_name) + gradio_output = output_gradio_componentclass(label="Output") gr.Interface( - fn=fn, + fn=tool, # This works because `tool` has a __call__ method inputs=gradio_inputs, outputs=gradio_output, - title=tool_class.__name__, + title=tool.name, article=tool.description, ).launch() @@ -1027,29 +925,34 @@ def tool(tool_function: Callable) -> Tool: raise TypeHintParsingException( "Tool return type not found: make sure your function has a return type hint!" ) - class_name = "".join([el.title() for el in parameters["name"].split("_")]) if parameters["return"]["type"] == "object": parameters["return"]["type"] = "any" - class SpecificTool(Tool): - name = parameters["name"] - description = parameters["description"] - inputs = parameters["parameters"]["properties"] - output_type = parameters["return"]["type"] - - @wraps(tool_function) - def forward(self, *args, **kwargs): - return tool_function(*args, **kwargs) + class SimpleTool(Tool): + def __init__(self, name, description, inputs, output_type, function): + self.name = name + self.description = description + self.inputs = inputs + self.output_type = output_type + self.forward = function + self.is_initialized = True + simple_tool = SimpleTool( + parameters["name"], + parameters["description"], + parameters["parameters"]["properties"], + parameters["return"]["type"], + function=tool_function + ) original_signature = inspect.signature(tool_function) new_parameters = [ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) ] + list(original_signature.parameters.values()) new_signature = original_signature.replace(parameters=new_parameters) - SpecificTool.forward.__signature__ = new_signature - SpecificTool.__name__ = class_name - return SpecificTool() + simple_tool.forward.__signature__ = new_signature + # SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")]) + return simple_tool HUGGINGFACE_DEFAULT_TOOLS = {} diff --git a/src/agents/tool_validation.py b/src/agents/tool_validation.py new file mode 100644 index 0000000..7e25afe --- /dev/null +++ b/src/agents/tool_validation.py @@ -0,0 +1,191 @@ +import ast +import inspect +import importlib.util +import builtins +from pathlib import Path +from typing import List, Set, Dict +import textwrap + +_BUILTIN_NAMES = set(vars(builtins)) + +def is_local_import(module_name: str) -> bool: + """ + Check if an import is from a local file or a package. + Returns True if it's a local file import. + """ + try: + spec = importlib.util.find_spec(module_name) + if spec is None: + return True # If we can't find the module, assume it's local + + # If the module is found and has a file path, check if it's in site-packages + if spec.origin and 'site-packages' not in spec.origin: + # Check if it's a .py file in the current directory or subdirectories + return spec.origin.endswith('.py') + + return False + except ImportError: + return True # If there's an import error, assume it's local + +class MethodChecker(ast.NodeVisitor): + """ + Checks that a method + - only uses defined names + - contains no local imports (e.g. numpy is ok but local_script is not) + """ + def __init__(self, class_attributes: Set[str]): + self.undefined_names = set() + self.imports = {} + self.from_imports = {} + self.assigned_names = set() + self.arg_names = set() + self.class_attributes = class_attributes + self.errors = [] + + def visit_arguments(self, node): + """Collect function arguments""" + self.arg_names = {arg.arg for arg in node.args} + if node.kwarg: + self.arg_names.add(node.kwarg.arg) + if node.vararg: + self.arg_names.add(node.vararg.arg) + + def visit_Import(self, node): + for name in node.names: + actual_name = name.asname or name.name + if is_local_import(actual_name): + self.errors.append(f"Local import '{actual_name}'") + self.imports[actual_name] = name.name + + def visit_ImportFrom(self, node): + module = node.module or "" + for name in node.names: + actual_name = name.asname or name.name + if is_local_import(module): + self.errors.append(f"Local import '{module}'") + self.from_imports[actual_name] = (module, name.name) + + def visit_Assign(self, node): + for target in node.targets: + if isinstance(target, ast.Name): + self.assigned_names.add(target.id) + self.visit(node.value) + + def visit_AnnAssign(self, node): + """Track annotated assignments.""" + if isinstance(node.target, ast.Name): + self.assigned_names.add(node.target.id) + if node.value: + self.visit(node.value) + + def visit_For(self, node): + target = node.target + if isinstance(target, ast.Name): + self.assigned_names.add(target.id) + elif isinstance(target, ast.Tuple): + for elt in target.elts: + if isinstance(elt, ast.Name): + self.assigned_names.add(elt.id) + self.generic_visit(node) + + def visit_Attribute(self, node): + # Skip self.something + if not (isinstance(node.value, ast.Name) and node.value.id == "self"): + self.generic_visit(node) + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Load): + if not ( + node.id in _BUILTIN_NAMES + or node.id in self.arg_names + or node.id == "self" + or node.id in self.class_attributes + or node.id in self.imports + or node.id in self.from_imports + or node.id in self.assigned_names + ): + self.errors.append(f"Name '{node.id}' is undefined.") + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + if not ( + node.func.id in _BUILTIN_NAMES + or node.func.id in self.arg_names + or node.func.id == "self" + or node.func.id in self.class_attributes + or node.func.id in self.imports + or node.func.id in self.from_imports + or node.func.id in self.assigned_names + ): + self.errors.append(f"Name '{node.func.id}' is undefined.") + self.generic_visit(node) + +def validate_tool_attributes(cls) -> None: + """ + Validates that a Tool class follows the proper patterns: + 0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!). + 1. About the class: + - Class attributes should only be strings or dicts + - Class attributes cannot be complex attributes + 2. About all class methods: + - Imports must be from packages, not local files + - All methods must be self-contained + + Raises all errors encountered, if no error returns None. + """ + errors = [] + + source = textwrap.dedent(inspect.getsource(cls)) + + tree = ast.parse(source) + + if not isinstance(tree.body[0], ast.ClassDef): + raise ValueError("Source code must define a class") + + # Check that __init__ method takes no arguments + if not cls.__init__.__qualname__ == 'Tool.__init__': + sig = inspect.signature(cls.__init__) + non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]) + if len(non_self_params) > 0: + errors.append(f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!") + + class_node = tree.body[0] + + class ClassLevelChecker(ast.NodeVisitor): + def __init__(self): + self.imported_names = set() + self.complex_attributes = set() + self.class_attributes = set() + + def visit_Assign(self, node): + # Track class attributes + for target in node.targets: + if isinstance(target, ast.Name): + self.class_attributes.add(target.id) + + # Check if the assignment is more complex than simple literals + if not all(isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)) + for val in ast.walk(node.value)): + for target in node.targets: + if isinstance(target, ast.Name): + self.complex_attributes.add(target.id) + + class_level_checker = ClassLevelChecker() + class_level_checker.visit(class_node) + + if class_level_checker.complex_attributes: + errors.append( + f"Complex attributes should be defined in __init__, not as class attributes: " + f"{', '.join(class_level_checker.complex_attributes)}" + ) + + # Run checks on all methods + for node in class_node.body: + if isinstance(node, ast.FunctionDef): + method_checker = MethodChecker(class_level_checker.class_attributes) + method_checker.visit(node) + errors += [f"- {node.name}: {error}" for error in method_checker.errors] + + if errors: + raise ValueError("Tool validation failed:\n" + "\n".join(errors)) + return diff --git a/src/agents/tools/__init__.py b/src/agents/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/agents/search.py b/src/agents/tools/search.py similarity index 99% rename from src/agents/search.py rename to src/agents/tools/search.py index c989bc5..01af0cd 100644 --- a/src/agents/search.py +++ b/src/agents/tools/search.py @@ -19,7 +19,7 @@ import re import requests from requests.exceptions import RequestException -from .tools import Tool +from ..tools import Tool class DuckDuckGoSearchTool(Tool): diff --git a/src/agents/utils.py b/src/agents/utils.py index e7d286a..a5fc8ca 100644 --- a/src/agents/utils.py +++ b/src/agents/utils.py @@ -19,6 +19,9 @@ import re from typing import Tuple, Dict, Union import ast from rich.console import Console +import ast +import inspect +import types from transformers.utils.import_utils import _is_package_available @@ -127,5 +130,114 @@ class ImportFinder(ast.NodeVisitor): base_package = node.module.split(".")[0] self.packages.add(base_package) +import ast +import builtins +from typing import Set, Dict, List + + +def get_method_source(method): + """Get source code for a method, including bound methods.""" + if isinstance(method, types.MethodType): + method = method.__func__ + return inspect.getsource(method).strip() + + +def is_same_method(method1, method2): + """Compare two methods by their source code.""" + try: + source1 = get_method_source(method1) + source2 = get_method_source(method2) + + # Remove method decorators if any + source1 = '\n'.join(line for line in source1.split('\n') + if not line.strip().startswith('@')) + source2 = '\n'.join(line for line in source2.split('\n') + if not line.strip().startswith('@')) + + return source1 == source2 + except (TypeError, OSError): + return False + +def is_same_item(item1, item2): + """Compare two class items (methods or attributes) for equality.""" + if callable(item1) and callable(item2): + return is_same_method(item1, item2) + else: + return item1 == item2 + +def instance_to_source(instance, base_cls=None): + """Convert an instance to its class source code representation.""" + cls = instance.__class__ + class_name = cls.__name__ + + # Start building class lines + class_lines = [] + if base_cls: + class_lines.append(f"class {class_name}({base_cls.__name__}):") + else: + class_lines.append(f"class {class_name}:") + + # Add docstring if it exists and differs from base + if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__): + class_lines.append(f' """{cls.__doc__}"""') + + # Add class-level attributes + class_attrs = { + name: value for name, value in cls.__dict__.items() + if not name.startswith('__') and not callable(value) and + not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value) + } + + for name, value in class_attrs.items(): + if isinstance(value, str): + class_lines.append(f' {name} = "{value}"') + else: + class_lines.append(f' {name} = {repr(value)}') + + if class_attrs: + class_lines.append("") + + # Add methods + methods = { + name: func for name, func in cls.__dict__.items() + if callable(func) and + not (base_cls and hasattr(base_cls, name) and + getattr(base_cls, name).__code__.co_code == func.__code__.co_code) + } + + for name, method in methods.items(): + method_source = inspect.getsource(method) + # Clean up the indentation + method_lines = method_source.split('\n') + first_line = method_lines[0] + indent = len(first_line) - len(first_line.lstrip()) + method_lines = [line[indent:] for line in method_lines] + method_source = '\n'.join([' ' + line if line.strip() else line + for line in method_lines]) + class_lines.append(method_source) + class_lines.append("") + + # Find required imports using ImportFinder + import_finder = ImportFinder() + import_finder.visit(ast.parse('\n'.join(class_lines))) + required_imports = import_finder.packages + + # Build final code with imports + final_lines = [] + + # Add base class import if needed + if base_cls: + final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}") + + # Add discovered imports + final_lines.extend(required_imports) + + if final_lines: # Add empty line after imports + final_lines.append("") + + # Add the class code + final_lines.extend(class_lines) + + return '\n'.join(final_lines) __all__ = [] diff --git a/tests/test_agents.py b/tests/test_agents.py index 539f4cf..93154c0 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -29,7 +29,7 @@ from agents.agents import ( Toolbox, ToolCall, ) -from agents.tools import tool +from agents.tool import tool from agents.default_tools import PythonInterpreterTool from transformers.testing_utils import get_tests_dir diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py index e5b46ff..a3feb98 100644 --- a/tests/test_all_docs.py +++ b/tests/test_all_docs.py @@ -125,7 +125,7 @@ class TestDocs: "from_langchain", ] code_blocks = [ - block.replace("", self.hf_token) + block.replace("", self.hf_token).replace("{your_username}", "m-ric") for block in code_blocks if not any( [snippet in block for snippet in excluded_snippets] diff --git a/tests/test_tools_common.py b/tests/test_tools_common.py index ae3b622..7d3b730 100644 --- a/tests/test_tools_common.py +++ b/tests/test_tools_common.py @@ -26,7 +26,7 @@ from agents.types import ( AgentImage, AgentText, ) -from agents.tools import Tool, tool, AUTHORIZED_TYPES +from agents.tool import Tool, tool, AUTHORIZED_TYPES from transformers.testing_utils import get_tests_dir @@ -174,6 +174,8 @@ class ToolTests(unittest.TestCase): Gets the current time. """ return str(datetime.now()) + + get_current_time.save("output") assert "datetime" in str(e) @@ -188,6 +190,9 @@ class ToolTests(unittest.TestCase): def forward(self): return str(datetime.now()) + + get_current_time = GetCurrentTimeTool() + get_current_time.save("output") assert "datetime" in str(e) @@ -210,3 +215,63 @@ class ToolTests(unittest.TestCase): def forward(self): from datetime import datetime return str(datetime.now()) + + def test_saving_tool_allows_no_arg_in_init(self): + # Test one cannot save tool with additional args in init + class FailTool(Tool): + name = "specific" + description = "test description" + inputs = {"input_str": {"type": "string", "description": "input description"}} + output_type = "string" + + def __init__(self, url): + super().__init__(self) + self.url = "none" + + def forward(self, string_input): + return self.url + string_input + + fail_tool = FailTool("dummy_url") + with pytest.raises(Exception) as e: + fail_tool.save('output') + assert '__init__' in str(e) + + def test_saving_tool_allows_no_imports_from_outside_methods(self): + # Test that using imports from outside functions fails + from numpy import random + class FailTool2(Tool): + name = "specific" + description = "test description" + inputs = {"input_str": {"type": "string", "description": "input description"}} + output_type = "string" + + def useless_method(self): + self.client = random.random() + return "" + + def forward(self, string_input): + return self.useless_method() + string_input + + fail_tool_2 = FailTool2() + with pytest.raises(Exception) as e: + fail_tool_2.save('output') + assert 'random' in str(e) + + # Test that putting these imports inside functions works + + class FailTool3(Tool): + name = "specific" + description = "test description" + inputs = {"input_str": {"type": "string", "description": "input description"}} + output_type = "string" + + def useless_method(self): + from numpy import random + self.client = random.random() + return "" + + def forward(self, string_input): + return self.useless_method() + string_input + + fail_tool_3 = FailTool3() + fail_tool_3.save('output')