Force imports inside tool
This commit is contained in:
parent
b6fc583d96
commit
aef0510e68
|
@ -22,6 +22,7 @@ import io
|
|||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
from functools import lru_cache, wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
@ -115,12 +116,77 @@ def validate_after_init(cls, do_validate_forward: bool = True):
|
|||
@wraps(original_init)
|
||||
def new_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
if not isinstance(self, PipelineTool):
|
||||
self.validate_arguments(do_validate_forward=do_validate_forward)
|
||||
self.validate_arguments(do_validate_forward=do_validate_forward)
|
||||
|
||||
cls.__init__ = new_init
|
||||
return cls
|
||||
|
||||
def validate_forward_method_args(cls):
|
||||
"""Validates that all names in forward method are properly defined.
|
||||
In particular it will check that all imports are done within the function."""
|
||||
if 'forward' not in cls.__dict__:
|
||||
return
|
||||
|
||||
forward = cls.__dict__['forward']
|
||||
source_code = textwrap.dedent(inspect.getsource(forward))
|
||||
tree = ast.parse(source_code)
|
||||
|
||||
# Get function arguments
|
||||
func_node = tree.body[0]
|
||||
arg_names = {arg.arg for arg in func_node.args.args}
|
||||
|
||||
|
||||
import builtins
|
||||
builtin_names = set(vars(builtins))
|
||||
|
||||
|
||||
# Find all used names that aren't arguments or self attributes
|
||||
class NameChecker(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.undefined_names = set()
|
||||
self.imports = {}
|
||||
self.from_imports = {}
|
||||
|
||||
def visit_Import(self, node):
|
||||
"""Handle simple imports like 'import datetime'."""
|
||||
for name in node.names:
|
||||
actual_name = name.asname or name.name
|
||||
self.imports[actual_name] = (name.name, actual_name)
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
"""Handle from imports like 'from datetime import datetime'."""
|
||||
module = node.module or ''
|
||||
for name in node.names:
|
||||
actual_name = name.asname or name.name
|
||||
self.from_imports[actual_name] = (module, name.name, actual_name)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if (isinstance(node.ctx, ast.Load) and not (
|
||||
node.id == "tool" or
|
||||
node.id in builtin_names or
|
||||
node.id in arg_names or
|
||||
node.id == 'self'
|
||||
)):
|
||||
if node.id not in self.from_imports and node.id not in self.imports:
|
||||
self.undefined_names.add(node.id)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
# Skip self.something
|
||||
if not (isinstance(node.value, ast.Name) and node.value.id == 'self'):
|
||||
self.generic_visit(node)
|
||||
|
||||
checker = NameChecker()
|
||||
checker.visit(tree)
|
||||
|
||||
if checker.undefined_names:
|
||||
raise ValueError(
|
||||
f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}.
|
||||
Make sure all imports and variables are defined within the method.
|
||||
For instance:
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
AUTHORIZED_TYPES = [
|
||||
"string",
|
||||
"boolean",
|
||||
|
@ -136,7 +202,7 @@ CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
|||
|
||||
class Tool:
|
||||
"""
|
||||
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
|
||||
A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the
|
||||
following class attributes:
|
||||
|
||||
- **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
|
||||
|
@ -151,7 +217,7 @@ class Tool:
|
|||
- **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
|
||||
or to make a nice space from your tool, and also can be used in the generated description for your tool.
|
||||
|
||||
You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
|
||||
You can also override the method [`~Tool.setup`] if your tool has an expensive operation to perform before being
|
||||
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
|
||||
instantiation.
|
||||
"""
|
||||
|
@ -166,8 +232,10 @@ class Tool:
|
|||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
validate_forward_method_args(cls)
|
||||
validate_after_init(cls, do_validate_forward=False)
|
||||
|
||||
|
||||
def validate_arguments(self, do_validate_forward: bool = True):
|
||||
required_attributes = {
|
||||
"description": str,
|
||||
|
@ -198,17 +266,18 @@ class Tool:
|
|||
|
||||
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
|
||||
if do_validate_forward:
|
||||
if not isinstance(self, PipelineTool):
|
||||
signature = inspect.signature(self.forward)
|
||||
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
||||
raise Exception(
|
||||
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
||||
)
|
||||
signature = inspect.signature(self.forward)
|
||||
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
||||
raise Exception(
|
||||
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
||||
)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return NotImplemented("Write this method in your subclass of `Tool`.")
|
||||
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
||||
outputs = self.forward(*args, **kwargs)
|
||||
return handle_agent_outputs(outputs, self.output_type)
|
||||
|
@ -225,7 +294,6 @@ class Tool:
|
|||
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
|
||||
tool in `output_dir` as well as autogenerate:
|
||||
|
||||
- a config file named `tool_config.json`
|
||||
- an `app.py` file so that your tool can be converted to a space
|
||||
- a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
|
||||
code)
|
||||
|
@ -677,166 +745,6 @@ def compile_jinja_template(template):
|
|||
return jinja_env.from_string(template)
|
||||
|
||||
|
||||
class PipelineTool(Tool):
|
||||
"""
|
||||
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
|
||||
need to specify:
|
||||
|
||||
- **model_class** (`type`) -- The class to use to load the model in this tool.
|
||||
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
|
||||
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
||||
pre-processor
|
||||
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
||||
post-processor (when different from the pre-processor).
|
||||
|
||||
Args:
|
||||
model (`str` or [`PreTrainedModel`], *optional*):
|
||||
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
|
||||
value of the class attribute `default_checkpoint`.
|
||||
pre_processor (`str` or `Any`, *optional*):
|
||||
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
|
||||
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
|
||||
unset.
|
||||
post_processor (`str` or `Any`, *optional*):
|
||||
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
|
||||
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
|
||||
unset.
|
||||
device (`int`, `str` or `torch.device`, *optional*):
|
||||
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
|
||||
CPU otherwise.
|
||||
device_map (`str` or `dict`, *optional*):
|
||||
If passed along, will be used to instantiate the model.
|
||||
model_kwargs (`dict`, *optional*):
|
||||
Any keyword argument to send to the model instantiation.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
||||
running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
hub_kwargs (additional keyword arguments, *optional*):
|
||||
Any additional keyword argument to send to the methods that will load the data from the Hub.
|
||||
"""
|
||||
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = None
|
||||
post_processor_class = AutoProcessor
|
||||
default_checkpoint = None
|
||||
description = "This is a pipeline tool"
|
||||
name = "pipeline"
|
||||
inputs = {"prompt": str}
|
||||
output_type = str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
pre_processor=None,
|
||||
post_processor=None,
|
||||
device=None,
|
||||
device_map=None,
|
||||
model_kwargs=None,
|
||||
token=None,
|
||||
**hub_kwargs,
|
||||
):
|
||||
if not is_torch_available():
|
||||
raise ImportError("Please install torch in order to use this tool.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Please install accelerate in order to use this tool.")
|
||||
|
||||
if model is None:
|
||||
if self.default_checkpoint is None:
|
||||
raise ValueError(
|
||||
"This tool does not implement a default checkpoint, you need to pass one."
|
||||
)
|
||||
model = self.default_checkpoint
|
||||
if pre_processor is None:
|
||||
pre_processor = model
|
||||
|
||||
self.model = model
|
||||
self.pre_processor = pre_processor
|
||||
self.post_processor = post_processor
|
||||
self.device = device
|
||||
self.device_map = device_map
|
||||
self.model_kwargs = {} if model_kwargs is None else model_kwargs
|
||||
if device_map is not None:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.hub_kwargs = hub_kwargs
|
||||
self.hub_kwargs["token"] = token
|
||||
|
||||
super().__init__()
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
||||
"""
|
||||
if isinstance(self.pre_processor, str):
|
||||
self.pre_processor = self.pre_processor_class.from_pretrained(
|
||||
self.pre_processor, **self.hub_kwargs
|
||||
)
|
||||
|
||||
if isinstance(self.model, str):
|
||||
self.model = self.model_class.from_pretrained(
|
||||
self.model, **self.model_kwargs, **self.hub_kwargs
|
||||
)
|
||||
|
||||
if self.post_processor is None:
|
||||
self.post_processor = self.pre_processor
|
||||
elif isinstance(self.post_processor, str):
|
||||
self.post_processor = self.post_processor_class.from_pretrained(
|
||||
self.post_processor, **self.hub_kwargs
|
||||
)
|
||||
|
||||
if self.device is None:
|
||||
if self.device_map is not None:
|
||||
self.device = list(self.model.hf_device_map.values())[0]
|
||||
else:
|
||||
self.device = PartialState().default_device
|
||||
|
||||
if self.device_map is None:
|
||||
self.model.to(self.device)
|
||||
|
||||
super().setup()
|
||||
|
||||
def encode(self, raw_inputs):
|
||||
"""
|
||||
Uses the `pre_processor` to prepare the inputs for the `model`.
|
||||
"""
|
||||
return self.pre_processor(raw_inputs)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Sends the inputs through the `model`.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self.model(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
"""
|
||||
Uses the `post_processor` to decode the model output.
|
||||
"""
|
||||
return self.post_processor(outputs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
||||
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
encoded_inputs = self.encode(*args, **kwargs)
|
||||
|
||||
tensor_inputs = {
|
||||
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
|
||||
}
|
||||
non_tensor_inputs = {
|
||||
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
|
||||
}
|
||||
|
||||
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
||||
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
||||
outputs = send_to_device(outputs, "cpu")
|
||||
decoded_outputs = self.decode(outputs)
|
||||
|
||||
return handle_agent_outputs(decoded_outputs, self.output_type)
|
||||
|
||||
|
||||
def launch_gradio_demo(tool_class: Tool):
|
||||
"""
|
||||
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
|
||||
|
@ -1060,6 +968,8 @@ def tool(tool_function: Callable) -> Tool:
|
|||
"Tool return type not found: make sure your function has a return type hint!"
|
||||
)
|
||||
class_name = f"{parameters['name'].capitalize()}Tool"
|
||||
if parameters["return"]["type"] == "object":
|
||||
parameters["return"]["type"] = "any"
|
||||
|
||||
class SpecificTool(Tool):
|
||||
name = parameters["name"]
|
||||
|
@ -1185,4 +1095,4 @@ class Toolbox:
|
|||
toolbox_description += f"\t{tool.name}: {tool.description}\n"
|
||||
return toolbox_description
|
||||
|
||||
__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]
|
||||
__all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]
|
||||
|
|
|
@ -162,3 +162,44 @@ class ToolTests(unittest.TestCase):
|
|||
|
||||
assert coolfunc.output_type == "number"
|
||||
assert "docstring has no description for the argument" in str(e)
|
||||
|
||||
def test_tool_definition_needs_imports_in_function(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
from datetime import datetime
|
||||
@tool
|
||||
def get_current_time() -> str:
|
||||
"""
|
||||
Gets the current time.
|
||||
"""
|
||||
return str(datetime.now())
|
||||
assert "datetime" in str(e)
|
||||
|
||||
# Also test with classic definition
|
||||
with pytest.raises(Exception) as e:
|
||||
class GetCurrentTimeTool(Tool):
|
||||
name="get_current_time_tool"
|
||||
description="Gets the current time"
|
||||
inputs = {}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self):
|
||||
return str(datetime.now())
|
||||
assert "datetime" in str(e)
|
||||
|
||||
@tool
|
||||
def get_current_time() -> str:
|
||||
"""
|
||||
Gets the current time.
|
||||
"""
|
||||
from datetime import datetime
|
||||
return str(datetime.now())
|
||||
|
||||
class GetCurrentTimeTool(Tool):
|
||||
name="get_current_time_tool"
|
||||
description="Gets the current time"
|
||||
inputs = {}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self):
|
||||
from datetime import datetime
|
||||
return str(datetime.now())
|
||||
|
|
Loading…
Reference in New Issue