Support optional arguments in tool calls

This commit is contained in:
Aymeric 2024-12-26 11:55:20 +01:00
parent 93569bd7c1
commit e5ca0f0cb8
12 changed files with 167 additions and 59 deletions

View File

@ -44,7 +44,7 @@ This library offers:
First install the package. First install the package.
```bash ```bash
pip install agents pip install smolagents
``` ```
Then define your agent, give it the tools it needs and run it! Then define your agent, give it the tools it needs and run it!
```py ```py

View File

@ -1,9 +1,9 @@
# docstyle-ignore # docstyle-ignore
INSTALL_CONTENT = """ INSTALL_CONTENT = """
# Transformers installation # Installation
! pip install agents ! pip install smolagents
# To install from source instead of the last release, comment the command above and uncomment the following one. # To install from source instead of the last release, comment the command above and uncomment the following one.
# ! pip install git+https://github.com/huggingface/agents.git # ! pip install git+https://github.com/huggingface/smolagents.git
""" """
notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}] notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]

View File

@ -1,19 +1,21 @@
from smolagents.agents import ToolCallingAgent from smolagents.agents import ToolCallingAgent
from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel
from typing import Optional
# Choose which LLM engine to use! # Choose which LLM engine to use!
model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct") # model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct")
model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct") # model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct")
model = LiteLLMModel("gpt-4o") model = LiteLLMModel("gpt-4o")
@tool @tool
def get_weather(location: str) -> str: def get_weather(location: str, celsius: Optional[bool] = False) -> str:
""" """
Get weather in the next days at given location. Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere. Secretly this tool does not care about the location, it hates the weather everywhere.
Args: Args:
location: the location location: the location
celsius: the temperature
""" """
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"

View File

@ -172,6 +172,7 @@ class GoogleSearchTool(Tool):
"filter_year": { "filter_year": {
"type": "integer", "type": "integer",
"description": "Optionally restrict results to a certain year", "description": "Optionally restrict results to a certain year",
"nullable": True,
}, },
} }
output_type = "string" output_type = "string"
@ -209,9 +210,14 @@ class GoogleSearchTool(Tool):
raise ValueError(response.json()) raise ValueError(response.json())
if "organic_results" not in results.keys(): if "organic_results" not in results.keys():
raise Exception( if filter_year is not None:
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query." raise Exception(
) f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
)
else:
raise Exception(
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
)
if len(results["organic_results"]) == 0: if len(results["organic_results"]) == 0:
year_filter_message = ( year_filter_message = (
f" with filter year={filter_year}" if filter_year is not None else "" f" with filter year={filter_year}" if filter_year is not None else ""

View File

@ -67,9 +67,12 @@ tool_role_conversions = {
def get_json_schema(tool: Tool) -> Dict: def get_json_schema(tool: Tool) -> Dict:
properties = deepcopy(tool.inputs) properties = deepcopy(tool.inputs)
for value in properties.values(): required = []
for key, value in properties.items():
if value["type"] == "any": if value["type"] == "any":
value["type"] = "string" value["type"] = "string"
if not ("nullable" in value and value["nullable"]):
required.append(key)
return { return {
"type": "function", "type": "function",
"function": { "function": {
@ -78,7 +81,7 @@ def get_json_schema(tool: Tool) -> Dict:
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": properties, "properties": properties,
"required": list(tool.inputs.keys()), "required": required,
}, },
}, },
} }

View File

@ -260,10 +260,8 @@ Task: "Which city has the highest population: Guangzhou or Shanghai?"
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
Code: Code:
```py ```py
population_guangzhou = search("Guangzhou population") for city in ["Guangzhou", "Shanghai"]:
print("Population Guangzhou:", population_guangzhou) print(f"Population {city}:", search(f"{city} population")
population_shanghai = search("Shanghai population")
print("Population Shanghai:", population_shanghai)
```<end_code> ```<end_code>
Observation: Observation:
Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
@ -278,11 +276,13 @@ final_answer("Shanghai")
--- ---
Task: "What is the current age of the pope, raised to the power 0.36?" Task: "What is the current age of the pope, raised to the power 0.36?"
Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36. Thought: I will use the tool `wiki` to get the age of the pope, and confirm that with a web search.
Code: Code:
```py ```py
pope_age = wiki(query="current pope age") pope_age_wiki = wiki(query="current pope age")
print("Pope age:", pope_age) print("Pope age as per wikipedia:", pope_age_wiki)
pope_age_search = web_search(query="current pope age")
print("Pope age as per google search:", pope_age_search)
```<end_code> ```<end_code>
Observation: Observation:
Pope age: "The pope Francis is currently 85 years old." Pope age: "The pope Francis is currently 85 years old."

View File

@ -85,7 +85,6 @@ class MethodChecker(ast.NodeVisitor):
self.generic_visit(node) self.generic_visit(node)
def visit_Attribute(self, node): def visit_Attribute(self, node):
# Skip self.something
if not (isinstance(node.value, ast.Name) and node.value.id == "self"): if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node) self.generic_visit(node)

View File

@ -24,7 +24,7 @@ import torch
import textwrap import textwrap
from functools import lru_cache, wraps from functools import lru_cache, wraps
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union, get_type_hints
from huggingface_hub import ( from huggingface_hub import (
create_repo, create_repo,
get_collection, get_collection,
@ -42,6 +42,8 @@ from transformers.utils import (
is_accelerate_available, is_accelerate_available,
is_torch_available, is_torch_available,
) )
from transformers.utils.chat_template_utils import _parse_type_hint
from transformers.dynamic_module_utils import get_imports from transformers.dynamic_module_utils import get_imports
from transformers import AutoProcessor from transformers import AutoProcessor
@ -95,17 +97,27 @@ def setup_default_tools():
return default_tools return default_tools
def validate_after_init(cls, do_validate_forward: bool = True): def validate_after_init(cls):
original_init = cls.__init__ original_init = cls.__init__
@wraps(original_init) @wraps(original_init)
def new_init(self, *args, **kwargs): def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs) original_init(self, *args, **kwargs)
self.validate_arguments(do_validate_forward=do_validate_forward) self.validate_arguments()
cls.__init__ = new_init cls.__init__ = new_init
return cls return cls
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
type_hints = get_type_hints(func)
signature = inspect.signature(func)
properties = {}
for param_name, param_type in type_hints.items():
if param_name != "return":
properties[param_name] = _parse_type_hint(param_type)
if signature.parameters[param_name].default != inspect.Parameter.empty:
properties[param_name]["nullable"] = True
return properties
AUTHORIZED_TYPES = [ AUTHORIZED_TYPES = [
"string", "string",
@ -145,7 +157,7 @@ class Tool:
name: str name: str
description: str description: str
inputs: Dict[str, Dict[str, Union[str, type]]] inputs: Dict[str, Dict[str, Union[str, type, bool]]]
output_type: str output_type: str
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -153,9 +165,9 @@ class Tool:
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
validate_after_init(cls, do_validate_forward=False) validate_after_init(cls)
def validate_arguments(self, do_validate_forward: bool = True): def validate_arguments(self):
required_attributes = { required_attributes = {
"description": str, "description": str,
"name": str, "name": str,
@ -184,12 +196,21 @@ class Tool:
) )
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
if do_validate_forward:
signature = inspect.signature(self.forward) # Validate forward function signature
if not set(signature.parameters.keys()) == set(self.inputs.keys()): signature = inspect.signature(self.forward)
raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." if not set(signature.parameters.keys()) == set(self.inputs.keys()):
) raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)
json_schema = _convert_type_hints_to_json_schema(self.forward)
for key, value in self.inputs.items():
if "nullable" in value:
assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
if key in json_schema and "nullable" in json_schema[key]:
assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.") return NotImplementedError("Write this method in your subclass of `Tool`.")
@ -877,9 +898,6 @@ def tool(tool_function: Callable) -> Tool:
"Tool return type not found: make sure your function has a return type hint!" "Tool return type not found: make sure your function has a return type hint!"
) )
if parameters["return"]["type"] == "object":
parameters["return"]["type"] = "any"
class SimpleTool(Tool): class SimpleTool(Tool):
def __init__(self, name, description, inputs, output_type, function): def __init__(self, name, description, inputs, output_type, function):
self.name = name self.name = name
@ -898,11 +916,10 @@ def tool(tool_function: Callable) -> Tool:
) )
original_signature = inspect.signature(tool_function) original_signature = inspect.signature(tool_function)
new_parameters = [ new_parameters = [
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)
] + list(original_signature.parameters.values()) ] + list(original_signature.parameters.values())
new_signature = original_signature.replace(parameters=new_parameters) new_signature = original_signature.replace(parameters=new_parameters)
simple_tool.forward.__signature__ = new_signature simple_tool.forward.__signature__ = new_signature
# SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")])
return simple_tool return simple_tool

View File

@ -25,7 +25,7 @@ from smolagents.types import AGENT_TYPE_MAPPING
from smolagents.default_tools import FinalAnswerTool from smolagents.default_tools import FinalAnswerTool
from .test_tools_common import ToolTesterMixin from .test_tools import ToolTesterMixin
if is_torch_available(): if is_torch_available():

View File

@ -26,7 +26,7 @@ from smolagents.local_python_executor import (
evaluate_python_code, evaluate_python_code,
) )
from .test_tools_common import ToolTesterMixin from .test_tools import ToolTesterMixin
# Fake function we will use as tool # Fake function we will use as tool

View File

@ -17,7 +17,7 @@ import unittest
from smolagents import load_tool from smolagents import load_tool
from .test_tools_common import ToolTesterMixin from .test_tools import ToolTesterMixin
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union, Optional
import numpy as np import numpy as np
import pytest import pytest
@ -126,9 +126,9 @@ class ToolTests(unittest.TestCase):
"description": "the task category (such as text-classification, depth-estimation, etc)", "description": "the task category (such as text-classification, depth-estimation, etc)",
} }
} }
output_type = "integer" output_type = "string"
def forward(self, task): def forward(self, task: str) -> str:
return "best model" return "best model"
tool = HFModelDownloadsTool() tool = HFModelDownloadsTool()
@ -223,7 +223,7 @@ class ToolTests(unittest.TestCase):
name = "specific" name = "specific"
description = "test description" description = "test description"
inputs = { inputs = {
"input_str": {"type": "string", "description": "input description"} "string_input": {"type": "string", "description": "input description"}
} }
output_type = "string" output_type = "string"
@ -231,7 +231,7 @@ class ToolTests(unittest.TestCase):
super().__init__(self) super().__init__(self)
self.url = "none" self.url = "none"
def forward(self, string_input): def forward(self, string_input: str) -> str:
return self.url + string_input return self.url + string_input
fail_tool = FailTool("dummy_url") fail_tool = FailTool("dummy_url")
@ -241,46 +241,127 @@ class ToolTests(unittest.TestCase):
def test_saving_tool_allows_no_imports_from_outside_methods(self): def test_saving_tool_allows_no_imports_from_outside_methods(self):
# Test that using imports from outside functions fails # Test that using imports from outside functions fails
from numpy import random import numpy as np
class FailTool2(Tool): class FailTool(Tool):
name = "specific" name = "specific"
description = "test description" description = "test description"
inputs = { inputs = {
"input_str": {"type": "string", "description": "input description"} "string_input": {"type": "string", "description": "input description"}
} }
output_type = "string" output_type = "string"
def useless_method(self): def useless_method(self):
self.client = random.random() self.client = np.random.random()
return "" return ""
def forward(self, string_input): def forward(self, string_input):
return self.useless_method() + string_input return self.useless_method() + string_input
fail_tool_2 = FailTool2() fail_tool = FailTool()
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
fail_tool_2.save("output") fail_tool.save("output")
assert "random" in str(e) assert "'np' is undefined" in str(e)
# Test that putting these imports inside functions works # Test that putting these imports inside functions works
class SuccessTool(Tool):
class FailTool3(Tool):
name = "specific" name = "specific"
description = "test description" description = "test description"
inputs = { inputs = {
"input_str": {"type": "string", "description": "input description"} "string_input": {"type": "string", "description": "input description"}
} }
output_type = "string" output_type = "string"
def useless_method(self): def useless_method(self):
from numpy import random import numpy as np
self.client = random.random() self.client = np.random.random()
return "" return ""
def forward(self, string_input): def forward(self, string_input):
return self.useless_method() + string_input return self.useless_method() + string_input
fail_tool_3 = FailTool3() success_tool = SuccessTool()
fail_tool_3.save("output") success_tool.save("output")
def test_tool_missing_class_attributes_raises_error(self):
with pytest.raises(Exception) as e:
class GetWeatherTool(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"}
}
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool()
assert "You must set an attribute output_type" in str(e)
def test_tool_from_decorator_optional_args(self):
@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature type
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert "nullable" in get_weather.inputs["celsius"]
assert get_weather.inputs["celsius"]["nullable"] == True
assert "nullable" not in get_weather.inputs["location"]
def test_tool_mismatching_nullable_args_raises_error(self):
with pytest.raises(Exception) as e:
class GetWeatherTool(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"}
}
output_type = "string"
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool()
assert "Nullable" in str(e)
with pytest.raises(Exception) as e:
class GetWeatherTool2(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"}
}
output_type = "string"
def forward(self, location: str, celsius: bool = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool2()
assert "Nullable" in str(e)
with pytest.raises(Exception) as e:
class GetWeatherTool3(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type", "nullable": True}
}
output_type = "string"
def forward(self, location, celsius: str) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool3()
assert "Nullable" in str(e)