Support optional arguments in tool calls
This commit is contained in:
parent
93569bd7c1
commit
e5ca0f0cb8
|
@ -44,7 +44,7 @@ This library offers:
|
|||
|
||||
First install the package.
|
||||
```bash
|
||||
pip install agents
|
||||
pip install smolagents
|
||||
```
|
||||
Then define your agent, give it the tools it needs and run it!
|
||||
```py
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# docstyle-ignore
|
||||
INSTALL_CONTENT = """
|
||||
# Transformers installation
|
||||
! pip install agents
|
||||
# Installation
|
||||
! pip install smolagents
|
||||
# To install from source instead of the last release, comment the command above and uncomment the following one.
|
||||
# ! pip install git+https://github.com/huggingface/agents.git
|
||||
# ! pip install git+https://github.com/huggingface/smolagents.git
|
||||
"""
|
||||
|
||||
notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]
|
||||
|
|
|
@ -1,19 +1,21 @@
|
|||
from smolagents.agents import ToolCallingAgent
|
||||
from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel
|
||||
from typing import Optional
|
||||
|
||||
# Choose which LLM engine to use!
|
||||
model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct")
|
||||
model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct")
|
||||
# model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct")
|
||||
# model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct")
|
||||
model = LiteLLMModel("gpt-4o")
|
||||
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||
"""
|
||||
Get weather in the next days at given location.
|
||||
Secretly this tool does not care about the location, it hates the weather everywhere.
|
||||
|
||||
Args:
|
||||
location: the location
|
||||
celsius: the temperature
|
||||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
|
|
|
@ -172,6 +172,7 @@ class GoogleSearchTool(Tool):
|
|||
"filter_year": {
|
||||
"type": "integer",
|
||||
"description": "Optionally restrict results to a certain year",
|
||||
"nullable": True,
|
||||
},
|
||||
}
|
||||
output_type = "string"
|
||||
|
@ -209,9 +210,14 @@ class GoogleSearchTool(Tool):
|
|||
raise ValueError(response.json())
|
||||
|
||||
if "organic_results" not in results.keys():
|
||||
raise Exception(
|
||||
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
|
||||
)
|
||||
if filter_year is not None:
|
||||
raise Exception(
|
||||
f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
|
||||
)
|
||||
if len(results["organic_results"]) == 0:
|
||||
year_filter_message = (
|
||||
f" with filter year={filter_year}" if filter_year is not None else ""
|
||||
|
|
|
@ -67,9 +67,12 @@ tool_role_conversions = {
|
|||
|
||||
def get_json_schema(tool: Tool) -> Dict:
|
||||
properties = deepcopy(tool.inputs)
|
||||
for value in properties.values():
|
||||
required = []
|
||||
for key, value in properties.items():
|
||||
if value["type"] == "any":
|
||||
value["type"] = "string"
|
||||
if not ("nullable" in value and value["nullable"]):
|
||||
required.append(key)
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
@ -78,7 +81,7 @@ def get_json_schema(tool: Tool) -> Dict:
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": list(tool.inputs.keys()),
|
||||
"required": required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -260,10 +260,8 @@ Task: "Which city has the highest population: Guangzhou or Shanghai?"
|
|||
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
||||
Code:
|
||||
```py
|
||||
population_guangzhou = search("Guangzhou population")
|
||||
print("Population Guangzhou:", population_guangzhou)
|
||||
population_shanghai = search("Shanghai population")
|
||||
print("Population Shanghai:", population_shanghai)
|
||||
for city in ["Guangzhou", "Shanghai"]:
|
||||
print(f"Population {city}:", search(f"{city} population")
|
||||
```<end_code>
|
||||
Observation:
|
||||
Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
||||
|
@ -278,11 +276,13 @@ final_answer("Shanghai")
|
|||
---
|
||||
Task: "What is the current age of the pope, raised to the power 0.36?"
|
||||
|
||||
Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36.
|
||||
Thought: I will use the tool `wiki` to get the age of the pope, and confirm that with a web search.
|
||||
Code:
|
||||
```py
|
||||
pope_age = wiki(query="current pope age")
|
||||
print("Pope age:", pope_age)
|
||||
pope_age_wiki = wiki(query="current pope age")
|
||||
print("Pope age as per wikipedia:", pope_age_wiki)
|
||||
pope_age_search = web_search(query="current pope age")
|
||||
print("Pope age as per google search:", pope_age_search)
|
||||
```<end_code>
|
||||
Observation:
|
||||
Pope age: "The pope Francis is currently 85 years old."
|
||||
|
|
|
@ -85,7 +85,6 @@ class MethodChecker(ast.NodeVisitor):
|
|||
self.generic_visit(node)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
# Skip self.something
|
||||
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
|
||||
self.generic_visit(node)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch
|
|||
import textwrap
|
||||
from functools import lru_cache, wraps
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Union, get_type_hints
|
||||
from huggingface_hub import (
|
||||
create_repo,
|
||||
get_collection,
|
||||
|
@ -42,6 +42,8 @@ from transformers.utils import (
|
|||
is_accelerate_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.utils.chat_template_utils import _parse_type_hint
|
||||
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
@ -95,17 +97,27 @@ def setup_default_tools():
|
|||
return default_tools
|
||||
|
||||
|
||||
def validate_after_init(cls, do_validate_forward: bool = True):
|
||||
def validate_after_init(cls):
|
||||
original_init = cls.__init__
|
||||
|
||||
@wraps(original_init)
|
||||
def new_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.validate_arguments(do_validate_forward=do_validate_forward)
|
||||
self.validate_arguments()
|
||||
|
||||
cls.__init__ = new_init
|
||||
return cls
|
||||
|
||||
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||
type_hints = get_type_hints(func)
|
||||
signature = inspect.signature(func)
|
||||
properties = {}
|
||||
for param_name, param_type in type_hints.items():
|
||||
if param_name != "return":
|
||||
properties[param_name] = _parse_type_hint(param_type)
|
||||
if signature.parameters[param_name].default != inspect.Parameter.empty:
|
||||
properties[param_name]["nullable"] = True
|
||||
return properties
|
||||
|
||||
AUTHORIZED_TYPES = [
|
||||
"string",
|
||||
|
@ -145,7 +157,7 @@ class Tool:
|
|||
|
||||
name: str
|
||||
description: str
|
||||
inputs: Dict[str, Dict[str, Union[str, type]]]
|
||||
inputs: Dict[str, Dict[str, Union[str, type, bool]]]
|
||||
output_type: str
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -153,9 +165,9 @@ class Tool:
|
|||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
validate_after_init(cls, do_validate_forward=False)
|
||||
validate_after_init(cls)
|
||||
|
||||
def validate_arguments(self, do_validate_forward: bool = True):
|
||||
def validate_arguments(self):
|
||||
required_attributes = {
|
||||
"description": str,
|
||||
"name": str,
|
||||
|
@ -184,12 +196,21 @@ class Tool:
|
|||
)
|
||||
|
||||
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
|
||||
if do_validate_forward:
|
||||
signature = inspect.signature(self.forward)
|
||||
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
||||
raise Exception(
|
||||
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
||||
)
|
||||
|
||||
# Validate forward function signature
|
||||
signature = inspect.signature(self.forward)
|
||||
|
||||
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
||||
raise Exception(
|
||||
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
||||
)
|
||||
|
||||
json_schema = _convert_type_hints_to_json_schema(self.forward)
|
||||
for key, value in self.inputs.items():
|
||||
if "nullable" in value:
|
||||
assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
||||
if key in json_schema and "nullable" in json_schema[key]:
|
||||
assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return NotImplementedError("Write this method in your subclass of `Tool`.")
|
||||
|
@ -877,9 +898,6 @@ def tool(tool_function: Callable) -> Tool:
|
|||
"Tool return type not found: make sure your function has a return type hint!"
|
||||
)
|
||||
|
||||
if parameters["return"]["type"] == "object":
|
||||
parameters["return"]["type"] = "any"
|
||||
|
||||
class SimpleTool(Tool):
|
||||
def __init__(self, name, description, inputs, output_type, function):
|
||||
self.name = name
|
||||
|
@ -898,11 +916,10 @@ def tool(tool_function: Callable) -> Tool:
|
|||
)
|
||||
original_signature = inspect.signature(tool_function)
|
||||
new_parameters = [
|
||||
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||
inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)
|
||||
] + list(original_signature.parameters.values())
|
||||
new_signature = original_signature.replace(parameters=new_parameters)
|
||||
simple_tool.forward.__signature__ = new_signature
|
||||
# SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")])
|
||||
return simple_tool
|
||||
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from smolagents.types import AGENT_TYPE_MAPPING
|
|||
|
||||
from smolagents.default_tools import FinalAnswerTool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
from .test_tools import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
|
|
@ -26,7 +26,7 @@ from smolagents.local_python_executor import (
|
|||
evaluate_python_code,
|
||||
)
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
from .test_tools import ToolTesterMixin
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
|
|
|
@ -17,7 +17,7 @@ import unittest
|
|||
|
||||
from smolagents import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
from .test_tools import ToolTesterMixin
|
||||
|
||||
|
||||
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -126,9 +126,9 @@ class ToolTests(unittest.TestCase):
|
|||
"description": "the task category (such as text-classification, depth-estimation, etc)",
|
||||
}
|
||||
}
|
||||
output_type = "integer"
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, task):
|
||||
def forward(self, task: str) -> str:
|
||||
return "best model"
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
|
@ -223,7 +223,7 @@ class ToolTests(unittest.TestCase):
|
|||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {
|
||||
"input_str": {"type": "string", "description": "input description"}
|
||||
"string_input": {"type": "string", "description": "input description"}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
|
@ -231,7 +231,7 @@ class ToolTests(unittest.TestCase):
|
|||
super().__init__(self)
|
||||
self.url = "none"
|
||||
|
||||
def forward(self, string_input):
|
||||
def forward(self, string_input: str) -> str:
|
||||
return self.url + string_input
|
||||
|
||||
fail_tool = FailTool("dummy_url")
|
||||
|
@ -241,46 +241,127 @@ class ToolTests(unittest.TestCase):
|
|||
|
||||
def test_saving_tool_allows_no_imports_from_outside_methods(self):
|
||||
# Test that using imports from outside functions fails
|
||||
from numpy import random
|
||||
import numpy as np
|
||||
|
||||
class FailTool2(Tool):
|
||||
class FailTool(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {
|
||||
"input_str": {"type": "string", "description": "input description"}
|
||||
"string_input": {"type": "string", "description": "input description"}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def useless_method(self):
|
||||
self.client = random.random()
|
||||
self.client = np.random.random()
|
||||
return ""
|
||||
|
||||
def forward(self, string_input):
|
||||
return self.useless_method() + string_input
|
||||
|
||||
fail_tool_2 = FailTool2()
|
||||
fail_tool = FailTool()
|
||||
with pytest.raises(Exception) as e:
|
||||
fail_tool_2.save("output")
|
||||
assert "random" in str(e)
|
||||
fail_tool.save("output")
|
||||
assert "'np' is undefined" in str(e)
|
||||
|
||||
# Test that putting these imports inside functions works
|
||||
|
||||
class FailTool3(Tool):
|
||||
class SuccessTool(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {
|
||||
"input_str": {"type": "string", "description": "input description"}
|
||||
"string_input": {"type": "string", "description": "input description"}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def useless_method(self):
|
||||
from numpy import random
|
||||
import numpy as np
|
||||
|
||||
self.client = random.random()
|
||||
self.client = np.random.random()
|
||||
return ""
|
||||
|
||||
def forward(self, string_input):
|
||||
return self.useless_method() + string_input
|
||||
|
||||
fail_tool_3 = FailTool3()
|
||||
fail_tool_3.save("output")
|
||||
success_tool = SuccessTool()
|
||||
success_tool.save("output")
|
||||
|
||||
def test_tool_missing_class_attributes_raises_error(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
class GetWeatherTool(Tool):
|
||||
name = "get_weather"
|
||||
description = "Get weather in the next days at given location."
|
||||
inputs = {
|
||||
"location": {"type": "string", "description": "the location"},
|
||||
"celsius": {"type": "string", "description": "the temperature type"}
|
||||
}
|
||||
|
||||
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
tool = GetWeatherTool()
|
||||
assert "You must set an attribute output_type" in str(e)
|
||||
|
||||
def test_tool_from_decorator_optional_args(self):
|
||||
@tool
|
||||
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
|
||||
"""
|
||||
Get weather in the next days at given location.
|
||||
Secretly this tool does not care about the location, it hates the weather everywhere.
|
||||
|
||||
Args:
|
||||
location: the location
|
||||
celsius: the temperature type
|
||||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
assert "nullable" in get_weather.inputs["celsius"]
|
||||
assert get_weather.inputs["celsius"]["nullable"] == True
|
||||
assert "nullable" not in get_weather.inputs["location"]
|
||||
|
||||
def test_tool_mismatching_nullable_args_raises_error(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
class GetWeatherTool(Tool):
|
||||
name = "get_weather"
|
||||
description = "Get weather in the next days at given location."
|
||||
inputs = {
|
||||
"location": {"type": "string", "description": "the location"},
|
||||
"celsius": {"type": "string", "description": "the temperature type"}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
tool = GetWeatherTool()
|
||||
assert "Nullable" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
class GetWeatherTool2(Tool):
|
||||
name = "get_weather"
|
||||
description = "Get weather in the next days at given location."
|
||||
inputs = {
|
||||
"location": {"type": "string", "description": "the location"},
|
||||
"celsius": {"type": "string", "description": "the temperature type"}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, location: str, celsius: bool = False) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
tool = GetWeatherTool2()
|
||||
assert "Nullable" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
class GetWeatherTool3(Tool):
|
||||
name = "get_weather"
|
||||
description = "Get weather in the next days at given location."
|
||||
inputs = {
|
||||
"location": {"type": "string", "description": "the location"},
|
||||
"celsius": {"type": "string", "description": "the temperature type", "nullable": True}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, location, celsius: str) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
tool = GetWeatherTool3()
|
||||
assert "Nullable" in str(e)
|
Loading…
Reference in New Issue