Force imports inside tool

This commit is contained in:
Aymeric 2024-12-15 14:30:00 +01:00
parent b6fc583d96
commit aef0510e68
2 changed files with 124 additions and 173 deletions

View File

@ -22,6 +22,7 @@ import io
import json import json
import os import os
import tempfile import tempfile
import textwrap
from functools import lru_cache, wraps from functools import lru_cache, wraps
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
@ -115,12 +116,77 @@ def validate_after_init(cls, do_validate_forward: bool = True):
@wraps(original_init) @wraps(original_init)
def new_init(self, *args, **kwargs): def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs) original_init(self, *args, **kwargs)
if not isinstance(self, PipelineTool): self.validate_arguments(do_validate_forward=do_validate_forward)
self.validate_arguments(do_validate_forward=do_validate_forward)
cls.__init__ = new_init cls.__init__ = new_init
return cls return cls
def validate_forward_method_args(cls):
"""Validates that all names in forward method are properly defined.
In particular it will check that all imports are done within the function."""
if 'forward' not in cls.__dict__:
return
forward = cls.__dict__['forward']
source_code = textwrap.dedent(inspect.getsource(forward))
tree = ast.parse(source_code)
# Get function arguments
func_node = tree.body[0]
arg_names = {arg.arg for arg in func_node.args.args}
import builtins
builtin_names = set(vars(builtins))
# Find all used names that aren't arguments or self attributes
class NameChecker(ast.NodeVisitor):
def __init__(self):
self.undefined_names = set()
self.imports = {}
self.from_imports = {}
def visit_Import(self, node):
"""Handle simple imports like 'import datetime'."""
for name in node.names:
actual_name = name.asname or name.name
self.imports[actual_name] = (name.name, actual_name)
def visit_ImportFrom(self, node):
"""Handle from imports like 'from datetime import datetime'."""
module = node.module or ''
for name in node.names:
actual_name = name.asname or name.name
self.from_imports[actual_name] = (module, name.name, actual_name)
def visit_Name(self, node):
if (isinstance(node.ctx, ast.Load) and not (
node.id == "tool" or
node.id in builtin_names or
node.id in arg_names or
node.id == 'self'
)):
if node.id not in self.from_imports and node.id not in self.imports:
self.undefined_names.add(node.id)
def visit_Attribute(self, node):
# Skip self.something
if not (isinstance(node.value, ast.Name) and node.value.id == 'self'):
self.generic_visit(node)
checker = NameChecker()
checker.visit(tree)
if checker.undefined_names:
raise ValueError(
f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}.
Make sure all imports and variables are defined within the method.
For instance:
"""
)
AUTHORIZED_TYPES = [ AUTHORIZED_TYPES = [
"string", "string",
"boolean", "boolean",
@ -136,7 +202,7 @@ CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
class Tool: class Tool:
""" """
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the
following class attributes: following class attributes:
- **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
@ -151,7 +217,7 @@ class Tool:
- **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo` - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
or to make a nice space from your tool, and also can be used in the generated description for your tool. or to make a nice space from your tool, and also can be used in the generated description for your tool.
You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being You can also override the method [`~Tool.setup`] if your tool has an expensive operation to perform before being
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
instantiation. instantiation.
""" """
@ -166,8 +232,10 @@ class Tool:
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
validate_forward_method_args(cls)
validate_after_init(cls, do_validate_forward=False) validate_after_init(cls, do_validate_forward=False)
def validate_arguments(self, do_validate_forward: bool = True): def validate_arguments(self, do_validate_forward: bool = True):
required_attributes = { required_attributes = {
"description": str, "description": str,
@ -198,17 +266,18 @@ class Tool:
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
if do_validate_forward: if do_validate_forward:
if not isinstance(self, PipelineTool): signature = inspect.signature(self.forward)
signature = inspect.signature(self.forward) if not set(signature.parameters.keys()) == set(self.inputs.keys()):
if not set(signature.parameters.keys()) == set(self.inputs.keys()): raise Exception(
raise Exception( "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." )
)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return NotImplemented("Write this method in your subclass of `Tool`.") return NotImplementedError("Write this method in your subclass of `Tool`.")
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if not self.is_initialized:
self.setup()
args, kwargs = handle_agent_inputs(*args, **kwargs) args, kwargs = handle_agent_inputs(*args, **kwargs)
outputs = self.forward(*args, **kwargs) outputs = self.forward(*args, **kwargs)
return handle_agent_outputs(outputs, self.output_type) return handle_agent_outputs(outputs, self.output_type)
@ -225,7 +294,6 @@ class Tool:
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
tool in `output_dir` as well as autogenerate: tool in `output_dir` as well as autogenerate:
- a config file named `tool_config.json`
- an `app.py` file so that your tool can be converted to a space - an `app.py` file so that your tool can be converted to a space
- a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
code) code)
@ -677,166 +745,6 @@ def compile_jinja_template(template):
return jinja_env.from_string(template) return jinja_env.from_string(template)
class PipelineTool(Tool):
"""
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
need to specify:
- **model_class** (`type`) -- The class to use to load the model in this tool.
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
pre-processor
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
post-processor (when different from the pre-processor).
Args:
model (`str` or [`PreTrainedModel`], *optional*):
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
value of the class attribute `default_checkpoint`.
pre_processor (`str` or `Any`, *optional*):
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
unset.
post_processor (`str` or `Any`, *optional*):
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
unset.
device (`int`, `str` or `torch.device`, *optional*):
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
CPU otherwise.
device_map (`str` or `dict`, *optional*):
If passed along, will be used to instantiate the model.
model_kwargs (`dict`, *optional*):
Any keyword argument to send to the model instantiation.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
running `huggingface-cli login` (stored in `~/.huggingface`).
hub_kwargs (additional keyword arguments, *optional*):
Any additional keyword argument to send to the methods that will load the data from the Hub.
"""
pre_processor_class = AutoProcessor
model_class = None
post_processor_class = AutoProcessor
default_checkpoint = None
description = "This is a pipeline tool"
name = "pipeline"
inputs = {"prompt": str}
output_type = str
def __init__(
self,
model=None,
pre_processor=None,
post_processor=None,
device=None,
device_map=None,
model_kwargs=None,
token=None,
**hub_kwargs,
):
if not is_torch_available():
raise ImportError("Please install torch in order to use this tool.")
if not is_accelerate_available():
raise ImportError("Please install accelerate in order to use this tool.")
if model is None:
if self.default_checkpoint is None:
raise ValueError(
"This tool does not implement a default checkpoint, you need to pass one."
)
model = self.default_checkpoint
if pre_processor is None:
pre_processor = model
self.model = model
self.pre_processor = pre_processor
self.post_processor = post_processor
self.device = device
self.device_map = device_map
self.model_kwargs = {} if model_kwargs is None else model_kwargs
if device_map is not None:
self.model_kwargs["device_map"] = device_map
self.hub_kwargs = hub_kwargs
self.hub_kwargs["token"] = token
super().__init__()
def setup(self):
"""
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
"""
if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(
self.pre_processor, **self.hub_kwargs
)
if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(
self.model, **self.model_kwargs, **self.hub_kwargs
)
if self.post_processor is None:
self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(
self.post_processor, **self.hub_kwargs
)
if self.device is None:
if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0]
else:
self.device = PartialState().default_device
if self.device_map is None:
self.model.to(self.device)
super().setup()
def encode(self, raw_inputs):
"""
Uses the `pre_processor` to prepare the inputs for the `model`.
"""
return self.pre_processor(raw_inputs)
def forward(self, inputs):
"""
Sends the inputs through the `model`.
"""
with torch.no_grad():
return self.model(**inputs)
def decode(self, outputs):
"""
Uses the `post_processor` to decode the model output.
"""
return self.post_processor(outputs)
def __call__(self, *args, **kwargs):
args, kwargs = handle_agent_inputs(*args, **kwargs)
if not self.is_initialized:
self.setup()
encoded_inputs = self.encode(*args, **kwargs)
tensor_inputs = {
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
}
non_tensor_inputs = {
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
}
encoded_inputs = send_to_device(tensor_inputs, self.device)
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
outputs = send_to_device(outputs, "cpu")
decoded_outputs = self.decode(outputs)
return handle_agent_outputs(decoded_outputs, self.output_type)
def launch_gradio_demo(tool_class: Tool): def launch_gradio_demo(tool_class: Tool):
""" """
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
@ -1060,6 +968,8 @@ def tool(tool_function: Callable) -> Tool:
"Tool return type not found: make sure your function has a return type hint!" "Tool return type not found: make sure your function has a return type hint!"
) )
class_name = f"{parameters['name'].capitalize()}Tool" class_name = f"{parameters['name'].capitalize()}Tool"
if parameters["return"]["type"] == "object":
parameters["return"]["type"] = "any"
class SpecificTool(Tool): class SpecificTool(Tool):
name = parameters["name"] name = parameters["name"]
@ -1185,4 +1095,4 @@ class Toolbox:
toolbox_description += f"\t{tool.name}: {tool.description}\n" toolbox_description += f"\t{tool.name}: {tool.description}\n"
return toolbox_description return toolbox_description
__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"] __all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]

View File

@ -162,3 +162,44 @@ class ToolTests(unittest.TestCase):
assert coolfunc.output_type == "number" assert coolfunc.output_type == "number"
assert "docstring has no description for the argument" in str(e) assert "docstring has no description for the argument" in str(e)
def test_tool_definition_needs_imports_in_function(self):
with pytest.raises(Exception) as e:
from datetime import datetime
@tool
def get_current_time() -> str:
"""
Gets the current time.
"""
return str(datetime.now())
assert "datetime" in str(e)
# Also test with classic definition
with pytest.raises(Exception) as e:
class GetCurrentTimeTool(Tool):
name="get_current_time_tool"
description="Gets the current time"
inputs = {}
output_type = "string"
def forward(self):
return str(datetime.now())
assert "datetime" in str(e)
@tool
def get_current_time() -> str:
"""
Gets the current time.
"""
from datetime import datetime
return str(datetime.now())
class GetCurrentTimeTool(Tool):
name="get_current_time_tool"
description="Gets the current time"
inputs = {}
output_type = "string"
def forward(self):
from datetime import datetime
return str(datetime.now())