diff --git a/src/agents/tools.py b/src/agents/tools.py index b7a11d4..fe10b80 100644 --- a/src/agents/tools.py +++ b/src/agents/tools.py @@ -22,6 +22,7 @@ import io import json import os import tempfile +import textwrap from functools import lru_cache, wraps from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union @@ -115,12 +116,77 @@ def validate_after_init(cls, do_validate_forward: bool = True): @wraps(original_init) def new_init(self, *args, **kwargs): original_init(self, *args, **kwargs) - if not isinstance(self, PipelineTool): - self.validate_arguments(do_validate_forward=do_validate_forward) + self.validate_arguments(do_validate_forward=do_validate_forward) cls.__init__ = new_init return cls +def validate_forward_method_args(cls): + """Validates that all names in forward method are properly defined. + In particular it will check that all imports are done within the function.""" + if 'forward' not in cls.__dict__: + return + + forward = cls.__dict__['forward'] + source_code = textwrap.dedent(inspect.getsource(forward)) + tree = ast.parse(source_code) + + # Get function arguments + func_node = tree.body[0] + arg_names = {arg.arg for arg in func_node.args.args} + + + import builtins + builtin_names = set(vars(builtins)) + + + # Find all used names that aren't arguments or self attributes + class NameChecker(ast.NodeVisitor): + def __init__(self): + self.undefined_names = set() + self.imports = {} + self.from_imports = {} + + def visit_Import(self, node): + """Handle simple imports like 'import datetime'.""" + for name in node.names: + actual_name = name.asname or name.name + self.imports[actual_name] = (name.name, actual_name) + + def visit_ImportFrom(self, node): + """Handle from imports like 'from datetime import datetime'.""" + module = node.module or '' + for name in node.names: + actual_name = name.asname or name.name + self.from_imports[actual_name] = (module, name.name, actual_name) + + def visit_Name(self, node): + if (isinstance(node.ctx, ast.Load) and not ( + node.id == "tool" or + node.id in builtin_names or + node.id in arg_names or + node.id == 'self' + )): + if node.id not in self.from_imports and node.id not in self.imports: + self.undefined_names.add(node.id) + + def visit_Attribute(self, node): + # Skip self.something + if not (isinstance(node.value, ast.Name) and node.value.id == 'self'): + self.generic_visit(node) + + checker = NameChecker() + checker.visit(tree) + + if checker.undefined_names: + raise ValueError( + f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}. + Make sure all imports and variables are defined within the method. + For instance: + + """ + ) + AUTHORIZED_TYPES = [ "string", "boolean", @@ -136,7 +202,7 @@ CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"} class Tool: """ - A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the + A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the following class attributes: - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it @@ -151,7 +217,7 @@ class Tool: - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated description for your tool. - You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being + You can also override the method [`~Tool.setup`] if your tool has an expensive operation to perform before being usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at instantiation. """ @@ -166,8 +232,10 @@ class Tool: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) + validate_forward_method_args(cls) validate_after_init(cls, do_validate_forward=False) + def validate_arguments(self, do_validate_forward: bool = True): required_attributes = { "description": str, @@ -198,17 +266,18 @@ class Tool: assert getattr(self, "output_type", None) in AUTHORIZED_TYPES if do_validate_forward: - if not isinstance(self, PipelineTool): - signature = inspect.signature(self.forward) - if not set(signature.parameters.keys()) == set(self.inputs.keys()): - raise Exception( - "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." - ) + signature = inspect.signature(self.forward) + if not set(signature.parameters.keys()) == set(self.inputs.keys()): + raise Exception( + "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." + ) def forward(self, *args, **kwargs): - return NotImplemented("Write this method in your subclass of `Tool`.") + return NotImplementedError("Write this method in your subclass of `Tool`.") def __call__(self, *args, **kwargs): + if not self.is_initialized: + self.setup() args, kwargs = handle_agent_inputs(*args, **kwargs) outputs = self.forward(*args, **kwargs) return handle_agent_outputs(outputs, self.output_type) @@ -225,7 +294,6 @@ class Tool: Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your tool in `output_dir` as well as autogenerate: - - a config file named `tool_config.json` - an `app.py` file so that your tool can be converted to a space - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its code) @@ -677,166 +745,6 @@ def compile_jinja_template(template): return jinja_env.from_string(template) -class PipelineTool(Tool): - """ - A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will - need to specify: - - - **model_class** (`type`) -- The class to use to load the model in this tool. - - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one. - - **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the - pre-processor - - **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the - post-processor (when different from the pre-processor). - - Args: - model (`str` or [`PreTrainedModel`], *optional*): - The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the - value of the class attribute `default_checkpoint`. - pre_processor (`str` or `Any`, *optional*): - The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a - tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if - unset. - post_processor (`str` or `Any`, *optional*): - The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a - tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if - unset. - device (`int`, `str` or `torch.device`, *optional*): - The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the - CPU otherwise. - device_map (`str` or `dict`, *optional*): - If passed along, will be used to instantiate the model. - model_kwargs (`dict`, *optional*): - Any keyword argument to send to the model instantiation. - token (`str`, *optional*): - The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when - running `huggingface-cli login` (stored in `~/.huggingface`). - hub_kwargs (additional keyword arguments, *optional*): - Any additional keyword argument to send to the methods that will load the data from the Hub. - """ - - pre_processor_class = AutoProcessor - model_class = None - post_processor_class = AutoProcessor - default_checkpoint = None - description = "This is a pipeline tool" - name = "pipeline" - inputs = {"prompt": str} - output_type = str - - def __init__( - self, - model=None, - pre_processor=None, - post_processor=None, - device=None, - device_map=None, - model_kwargs=None, - token=None, - **hub_kwargs, - ): - if not is_torch_available(): - raise ImportError("Please install torch in order to use this tool.") - - if not is_accelerate_available(): - raise ImportError("Please install accelerate in order to use this tool.") - - if model is None: - if self.default_checkpoint is None: - raise ValueError( - "This tool does not implement a default checkpoint, you need to pass one." - ) - model = self.default_checkpoint - if pre_processor is None: - pre_processor = model - - self.model = model - self.pre_processor = pre_processor - self.post_processor = post_processor - self.device = device - self.device_map = device_map - self.model_kwargs = {} if model_kwargs is None else model_kwargs - if device_map is not None: - self.model_kwargs["device_map"] = device_map - self.hub_kwargs = hub_kwargs - self.hub_kwargs["token"] = token - - super().__init__() - - def setup(self): - """ - Instantiates the `pre_processor`, `model` and `post_processor` if necessary. - """ - if isinstance(self.pre_processor, str): - self.pre_processor = self.pre_processor_class.from_pretrained( - self.pre_processor, **self.hub_kwargs - ) - - if isinstance(self.model, str): - self.model = self.model_class.from_pretrained( - self.model, **self.model_kwargs, **self.hub_kwargs - ) - - if self.post_processor is None: - self.post_processor = self.pre_processor - elif isinstance(self.post_processor, str): - self.post_processor = self.post_processor_class.from_pretrained( - self.post_processor, **self.hub_kwargs - ) - - if self.device is None: - if self.device_map is not None: - self.device = list(self.model.hf_device_map.values())[0] - else: - self.device = PartialState().default_device - - if self.device_map is None: - self.model.to(self.device) - - super().setup() - - def encode(self, raw_inputs): - """ - Uses the `pre_processor` to prepare the inputs for the `model`. - """ - return self.pre_processor(raw_inputs) - - def forward(self, inputs): - """ - Sends the inputs through the `model`. - """ - with torch.no_grad(): - return self.model(**inputs) - - def decode(self, outputs): - """ - Uses the `post_processor` to decode the model output. - """ - return self.post_processor(outputs) - - def __call__(self, *args, **kwargs): - args, kwargs = handle_agent_inputs(*args, **kwargs) - - if not self.is_initialized: - self.setup() - - encoded_inputs = self.encode(*args, **kwargs) - - tensor_inputs = { - k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor) - } - non_tensor_inputs = { - k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor) - } - - encoded_inputs = send_to_device(tensor_inputs, self.device) - outputs = self.forward({**encoded_inputs, **non_tensor_inputs}) - outputs = send_to_device(outputs, "cpu") - decoded_outputs = self.decode(outputs) - - return handle_agent_outputs(decoded_outputs, self.output_type) - - def launch_gradio_demo(tool_class: Tool): """ Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes @@ -1060,6 +968,8 @@ def tool(tool_function: Callable) -> Tool: "Tool return type not found: make sure your function has a return type hint!" ) class_name = f"{parameters['name'].capitalize()}Tool" + if parameters["return"]["type"] == "object": + parameters["return"]["type"] = "any" class SpecificTool(Tool): name = parameters["name"] @@ -1185,4 +1095,4 @@ class Toolbox: toolbox_description += f"\t{tool.name}: {tool.description}\n" return toolbox_description -__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"] \ No newline at end of file +__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"] diff --git a/tests/test_tools_common.py b/tests/test_tools_common.py index 260aae2..36c70be 100644 --- a/tests/test_tools_common.py +++ b/tests/test_tools_common.py @@ -162,3 +162,44 @@ class ToolTests(unittest.TestCase): assert coolfunc.output_type == "number" assert "docstring has no description for the argument" in str(e) + + def test_tool_definition_needs_imports_in_function(self): + with pytest.raises(Exception) as e: + from datetime import datetime + @tool + def get_current_time() -> str: + """ + Gets the current time. + """ + return str(datetime.now()) + assert "datetime" in str(e) + + # Also test with classic definition + with pytest.raises(Exception) as e: + class GetCurrentTimeTool(Tool): + name="get_current_time_tool" + description="Gets the current time" + inputs = {} + output_type = "string" + + def forward(self): + return str(datetime.now()) + assert "datetime" in str(e) + + @tool + def get_current_time() -> str: + """ + Gets the current time. + """ + from datetime import datetime + return str(datetime.now()) + + class GetCurrentTimeTool(Tool): + name="get_current_time_tool" + description="Gets the current time" + inputs = {} + output_type = "string" + + def forward(self): + from datetime import datetime + return str(datetime.now())