Support optional arguments in tool calls

This commit is contained in:
Aymeric 2024-12-26 11:55:20 +01:00
parent 93569bd7c1
commit e5ca0f0cb8
12 changed files with 167 additions and 59 deletions

View File

@ -44,7 +44,7 @@ This library offers:
First install the package.
```bash
pip install agents
pip install smolagents
```
Then define your agent, give it the tools it needs and run it!
```py

View File

@ -1,9 +1,9 @@
# docstyle-ignore
INSTALL_CONTENT = """
# Transformers installation
! pip install agents
# Installation
! pip install smolagents
# To install from source instead of the last release, comment the command above and uncomment the following one.
# ! pip install git+https://github.com/huggingface/agents.git
# ! pip install git+https://github.com/huggingface/smolagents.git
"""
notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]

View File

@ -1,19 +1,21 @@
from smolagents.agents import ToolCallingAgent
from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel
from typing import Optional
# Choose which LLM engine to use!
model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct")
model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct")
# model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct")
# model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct")
model = LiteLLMModel("gpt-4o")
@tool
def get_weather(location: str) -> str:
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"

View File

@ -172,6 +172,7 @@ class GoogleSearchTool(Tool):
"filter_year": {
"type": "integer",
"description": "Optionally restrict results to a certain year",
"nullable": True,
},
}
output_type = "string"
@ -209,6 +210,11 @@ class GoogleSearchTool(Tool):
raise ValueError(response.json())
if "organic_results" not in results.keys():
if filter_year is not None:
raise Exception(
f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
)
else:
raise Exception(
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
)

View File

@ -67,9 +67,12 @@ tool_role_conversions = {
def get_json_schema(tool: Tool) -> Dict:
properties = deepcopy(tool.inputs)
for value in properties.values():
required = []
for key, value in properties.items():
if value["type"] == "any":
value["type"] = "string"
if not ("nullable" in value and value["nullable"]):
required.append(key)
return {
"type": "function",
"function": {
@ -78,7 +81,7 @@ def get_json_schema(tool: Tool) -> Dict:
"parameters": {
"type": "object",
"properties": properties,
"required": list(tool.inputs.keys()),
"required": required,
},
},
}

View File

@ -260,10 +260,8 @@ Task: "Which city has the highest population: Guangzhou or Shanghai?"
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
Code:
```py
population_guangzhou = search("Guangzhou population")
print("Population Guangzhou:", population_guangzhou)
population_shanghai = search("Shanghai population")
print("Population Shanghai:", population_shanghai)
for city in ["Guangzhou", "Shanghai"]:
print(f"Population {city}:", search(f"{city} population")
```<end_code>
Observation:
Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
@ -278,11 +276,13 @@ final_answer("Shanghai")
---
Task: "What is the current age of the pope, raised to the power 0.36?"
Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36.
Thought: I will use the tool `wiki` to get the age of the pope, and confirm that with a web search.
Code:
```py
pope_age = wiki(query="current pope age")
print("Pope age:", pope_age)
pope_age_wiki = wiki(query="current pope age")
print("Pope age as per wikipedia:", pope_age_wiki)
pope_age_search = web_search(query="current pope age")
print("Pope age as per google search:", pope_age_search)
```<end_code>
Observation:
Pope age: "The pope Francis is currently 85 years old."

View File

@ -85,7 +85,6 @@ class MethodChecker(ast.NodeVisitor):
self.generic_visit(node)
def visit_Attribute(self, node):
# Skip self.something
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node)

View File

@ -24,7 +24,7 @@ import torch
import textwrap
from functools import lru_cache, wraps
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union, get_type_hints
from huggingface_hub import (
create_repo,
get_collection,
@ -42,6 +42,8 @@ from transformers.utils import (
is_accelerate_available,
is_torch_available,
)
from transformers.utils.chat_template_utils import _parse_type_hint
from transformers.dynamic_module_utils import get_imports
from transformers import AutoProcessor
@ -95,17 +97,27 @@ def setup_default_tools():
return default_tools
def validate_after_init(cls, do_validate_forward: bool = True):
def validate_after_init(cls):
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.validate_arguments(do_validate_forward=do_validate_forward)
self.validate_arguments()
cls.__init__ = new_init
return cls
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
type_hints = get_type_hints(func)
signature = inspect.signature(func)
properties = {}
for param_name, param_type in type_hints.items():
if param_name != "return":
properties[param_name] = _parse_type_hint(param_type)
if signature.parameters[param_name].default != inspect.Parameter.empty:
properties[param_name]["nullable"] = True
return properties
AUTHORIZED_TYPES = [
"string",
@ -145,7 +157,7 @@ class Tool:
name: str
description: str
inputs: Dict[str, Dict[str, Union[str, type]]]
inputs: Dict[str, Dict[str, Union[str, type, bool]]]
output_type: str
def __init__(self, *args, **kwargs):
@ -153,9 +165,9 @@ class Tool:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
validate_after_init(cls, do_validate_forward=False)
validate_after_init(cls)
def validate_arguments(self, do_validate_forward: bool = True):
def validate_arguments(self):
required_attributes = {
"description": str,
"name": str,
@ -184,13 +196,22 @@ class Tool:
)
assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
if do_validate_forward:
# Validate forward function signature
signature = inspect.signature(self.forward)
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)
json_schema = _convert_type_hints_to_json_schema(self.forward)
for key, value in self.inputs.items():
if "nullable" in value:
assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
if key in json_schema and "nullable" in json_schema[key]:
assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs."
def forward(self, *args, **kwargs):
return NotImplementedError("Write this method in your subclass of `Tool`.")
@ -877,9 +898,6 @@ def tool(tool_function: Callable) -> Tool:
"Tool return type not found: make sure your function has a return type hint!"
)
if parameters["return"]["type"] == "object":
parameters["return"]["type"] = "any"
class SimpleTool(Tool):
def __init__(self, name, description, inputs, output_type, function):
self.name = name
@ -898,11 +916,10 @@ def tool(tool_function: Callable) -> Tool:
)
original_signature = inspect.signature(tool_function)
new_parameters = [
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)
] + list(original_signature.parameters.values())
new_signature = original_signature.replace(parameters=new_parameters)
simple_tool.forward.__signature__ = new_signature
# SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")])
return simple_tool

View File

@ -25,7 +25,7 @@ from smolagents.types import AGENT_TYPE_MAPPING
from smolagents.default_tools import FinalAnswerTool
from .test_tools_common import ToolTesterMixin
from .test_tools import ToolTesterMixin
if is_torch_available():

View File

@ -26,7 +26,7 @@ from smolagents.local_python_executor import (
evaluate_python_code,
)
from .test_tools_common import ToolTesterMixin
from .test_tools import ToolTesterMixin
# Fake function we will use as tool

View File

@ -17,7 +17,7 @@ import unittest
from smolagents import load_tool
from .test_tools_common import ToolTesterMixin
from .test_tools import ToolTesterMixin
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):

View File

@ -14,7 +14,7 @@
# limitations under the License.
import unittest
from pathlib import Path
from typing import Dict, Union
from typing import Dict, Union, Optional
import numpy as np
import pytest
@ -126,9 +126,9 @@ class ToolTests(unittest.TestCase):
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "integer"
output_type = "string"
def forward(self, task):
def forward(self, task: str) -> str:
return "best model"
tool = HFModelDownloadsTool()
@ -223,7 +223,7 @@ class ToolTests(unittest.TestCase):
name = "specific"
description = "test description"
inputs = {
"input_str": {"type": "string", "description": "input description"}
"string_input": {"type": "string", "description": "input description"}
}
output_type = "string"
@ -231,7 +231,7 @@ class ToolTests(unittest.TestCase):
super().__init__(self)
self.url = "none"
def forward(self, string_input):
def forward(self, string_input: str) -> str:
return self.url + string_input
fail_tool = FailTool("dummy_url")
@ -241,46 +241,127 @@ class ToolTests(unittest.TestCase):
def test_saving_tool_allows_no_imports_from_outside_methods(self):
# Test that using imports from outside functions fails
from numpy import random
import numpy as np
class FailTool2(Tool):
class FailTool(Tool):
name = "specific"
description = "test description"
inputs = {
"input_str": {"type": "string", "description": "input description"}
"string_input": {"type": "string", "description": "input description"}
}
output_type = "string"
def useless_method(self):
self.client = random.random()
self.client = np.random.random()
return ""
def forward(self, string_input):
return self.useless_method() + string_input
fail_tool_2 = FailTool2()
fail_tool = FailTool()
with pytest.raises(Exception) as e:
fail_tool_2.save("output")
assert "random" in str(e)
fail_tool.save("output")
assert "'np' is undefined" in str(e)
# Test that putting these imports inside functions works
class FailTool3(Tool):
class SuccessTool(Tool):
name = "specific"
description = "test description"
inputs = {
"input_str": {"type": "string", "description": "input description"}
"string_input": {"type": "string", "description": "input description"}
}
output_type = "string"
def useless_method(self):
from numpy import random
import numpy as np
self.client = random.random()
self.client = np.random.random()
return ""
def forward(self, string_input):
return self.useless_method() + string_input
fail_tool_3 = FailTool3()
fail_tool_3.save("output")
success_tool = SuccessTool()
success_tool.save("output")
def test_tool_missing_class_attributes_raises_error(self):
with pytest.raises(Exception) as e:
class GetWeatherTool(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"}
}
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool()
assert "You must set an attribute output_type" in str(e)
def test_tool_from_decorator_optional_args(self):
@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
Get weather in the next days at given location.
Secretly this tool does not care about the location, it hates the weather everywhere.
Args:
location: the location
celsius: the temperature type
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert "nullable" in get_weather.inputs["celsius"]
assert get_weather.inputs["celsius"]["nullable"] == True
assert "nullable" not in get_weather.inputs["location"]
def test_tool_mismatching_nullable_args_raises_error(self):
with pytest.raises(Exception) as e:
class GetWeatherTool(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"}
}
output_type = "string"
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool()
assert "Nullable" in str(e)
with pytest.raises(Exception) as e:
class GetWeatherTool2(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type"}
}
output_type = "string"
def forward(self, location: str, celsius: bool = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool2()
assert "Nullable" in str(e)
with pytest.raises(Exception) as e:
class GetWeatherTool3(Tool):
name = "get_weather"
description = "Get weather in the next days at given location."
inputs = {
"location": {"type": "string", "description": "the location"},
"celsius": {"type": "string", "description": "the temperature type", "nullable": True}
}
output_type = "string"
def forward(self, location, celsius: str) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
tool = GetWeatherTool3()
assert "Nullable" in str(e)