Support any and none tool types (#280)
* Support any and none tool types
This commit is contained in:
parent
83ecd572fc
commit
43904f32c7
|
@ -87,3 +87,8 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
uv run pytest ./tests/test_utils.py
|
uv run pytest ./tests/test_utils.py
|
||||||
if: ${{ success() || failure() }}
|
if: ${{ success() || failure() }}
|
||||||
|
|
||||||
|
- name: Function type hints utils tests
|
||||||
|
run: |
|
||||||
|
uv run pytest ./tests/test_function_type_hints_utils.py
|
||||||
|
if: ${{ success() || failure() }}
|
||||||
|
|
|
@ -276,20 +276,26 @@ def _parse_google_format_docstring(
|
||||||
return description, args_dict, returns
|
return description, args_dict, returns
|
||||||
|
|
||||||
|
|
||||||
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> Dict:
|
||||||
type_hints = get_type_hints(func)
|
type_hints = get_type_hints(func)
|
||||||
signature = inspect.signature(func)
|
signature = inspect.signature(func)
|
||||||
required = []
|
|
||||||
for param_name, param in signature.parameters.items():
|
|
||||||
if param.annotation == inspect.Parameter.empty:
|
|
||||||
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
|
||||||
if param.default == inspect.Parameter.empty:
|
|
||||||
required.append(param_name)
|
|
||||||
|
|
||||||
properties = {}
|
properties = {}
|
||||||
for param_name, param_type in type_hints.items():
|
for param_name, param_type in type_hints.items():
|
||||||
properties[param_name] = _parse_type_hint(param_type)
|
properties[param_name] = _parse_type_hint(param_type)
|
||||||
|
|
||||||
|
required = []
|
||||||
|
for param_name, param in signature.parameters.items():
|
||||||
|
if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints:
|
||||||
|
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
||||||
|
if param_name not in properties:
|
||||||
|
properties[param_name] = {}
|
||||||
|
|
||||||
|
if param.default == inspect.Parameter.empty:
|
||||||
|
required.append(param_name)
|
||||||
|
else:
|
||||||
|
properties[param_name]["nullable"] = True
|
||||||
|
|
||||||
schema = {"type": "object", "properties": properties}
|
schema = {"type": "object", "properties": properties}
|
||||||
if required:
|
if required:
|
||||||
schema["required"] = required
|
schema["required"] = required
|
||||||
|
@ -368,7 +374,8 @@ _BASE_TYPE_MAPPING = {
|
||||||
float: {"type": "number"},
|
float: {"type": "number"},
|
||||||
str: {"type": "string"},
|
str: {"type": "string"},
|
||||||
bool: {"type": "boolean"},
|
bool: {"type": "boolean"},
|
||||||
Any: {},
|
Any: {"type": "any"},
|
||||||
|
types.NoneType: {"type": "null"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from .local_python_executor import (
|
from .local_python_executor import (
|
||||||
BASE_BUILTIN_MODULES,
|
BASE_BUILTIN_MODULES,
|
||||||
|
@ -85,7 +85,7 @@ class FinalAnswerTool(Tool):
|
||||||
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
||||||
output_type = "any"
|
output_type = "any"
|
||||||
|
|
||||||
def forward(self, answer):
|
def forward(self, answer: Any) -> Any:
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -136,7 +136,7 @@ tool_role_conversions = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_json_schema(tool: Tool) -> Dict:
|
def get_tool_json_schema(tool: Tool) -> Dict:
|
||||||
properties = deepcopy(tool.inputs)
|
properties = deepcopy(tool.inputs)
|
||||||
required = []
|
required = []
|
||||||
for key, value in properties.items():
|
for key, value in properties.items():
|
||||||
|
@ -240,7 +240,7 @@ class Model:
|
||||||
if tools_to_call_from:
|
if tools_to_call_from:
|
||||||
completion_kwargs.update(
|
completion_kwargs.update(
|
||||||
{
|
{
|
||||||
"tools": [get_json_schema(tool) for tool in tools_to_call_from],
|
"tools": [get_tool_json_schema(tool) for tool in tools_to_call_from],
|
||||||
"tool_choice": "required",
|
"tool_choice": "required",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -490,7 +490,7 @@ class TransformersModel(Model):
|
||||||
if tools_to_call_from is not None:
|
if tools_to_call_from is not None:
|
||||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tools=completion_kwargs.pop("tools", []),
|
tools=[get_tool_json_schema(tool) for tool in tools_to_call_from],
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
|
|
|
@ -26,7 +26,7 @@ import textwrap
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Union, get_type_hints
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
create_repo,
|
create_repo,
|
||||||
|
@ -38,9 +38,9 @@ from huggingface_hub import (
|
||||||
from huggingface_hub.utils import is_torch_available
|
from huggingface_hub.utils import is_torch_available
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from ._transformers_utils import (
|
from ._function_type_hints_utils import (
|
||||||
TypeHintParsingException,
|
TypeHintParsingException,
|
||||||
_parse_type_hint,
|
_convert_type_hints_to_json_schema,
|
||||||
get_imports,
|
get_imports,
|
||||||
get_json_schema,
|
get_json_schema,
|
||||||
)
|
)
|
||||||
|
@ -64,22 +64,6 @@ def validate_after_init(cls):
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
|
||||||
type_hints = get_type_hints(func)
|
|
||||||
signature = inspect.signature(func)
|
|
||||||
properties = {}
|
|
||||||
for param_name, param_type in type_hints.items():
|
|
||||||
if param_name != "return":
|
|
||||||
properties[param_name] = _parse_type_hint(param_type)
|
|
||||||
if signature.parameters[param_name].default != inspect.Parameter.empty:
|
|
||||||
properties[param_name]["nullable"] = True
|
|
||||||
for param_name in signature.parameters.keys():
|
|
||||||
if signature.parameters[param_name].default != inspect.Parameter.empty:
|
|
||||||
if param_name not in properties: # this can happen if the param has no type hint but a default value
|
|
||||||
properties[param_name] = {"nullable": True}
|
|
||||||
return properties
|
|
||||||
|
|
||||||
|
|
||||||
AUTHORIZED_TYPES = [
|
AUTHORIZED_TYPES = [
|
||||||
"string",
|
"string",
|
||||||
"boolean",
|
"boolean",
|
||||||
|
@ -87,8 +71,10 @@ AUTHORIZED_TYPES = [
|
||||||
"number",
|
"number",
|
||||||
"image",
|
"image",
|
||||||
"audio",
|
"audio",
|
||||||
"any",
|
"array",
|
||||||
"object",
|
"object",
|
||||||
|
"any",
|
||||||
|
"null",
|
||||||
]
|
]
|
||||||
|
|
||||||
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
||||||
|
@ -168,12 +154,15 @@ class Tool:
|
||||||
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
||||||
)
|
)
|
||||||
|
|
||||||
json_schema = _convert_type_hints_to_json_schema(
|
json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[
|
||||||
self.forward
|
"properties"
|
||||||
) # This function will raise an error on missing docstrings, contrary to get_json_schema
|
] # This function will not raise an error on missing docstrings, contrary to get_json_schema
|
||||||
for key, value in self.inputs.items():
|
for key, value in self.inputs.items():
|
||||||
|
assert key in json_schema, (
|
||||||
|
f"Input '{key}' should be present in function signature, found only {json_schema.keys()}"
|
||||||
|
)
|
||||||
if "nullable" in value:
|
if "nullable" in value:
|
||||||
assert key in json_schema and "nullable" in json_schema[key], (
|
assert "nullable" in json_schema[key], (
|
||||||
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
||||||
)
|
)
|
||||||
if key in json_schema and "nullable" in json_schema[key]:
|
if key in json_schema and "nullable" in json_schema[key]:
|
||||||
|
@ -887,16 +876,6 @@ class ToolCollection:
|
||||||
yield cls(tools)
|
yield cls(tools)
|
||||||
|
|
||||||
|
|
||||||
def get_tool_json_schema(tool_function):
|
|
||||||
tool_json_schema = get_json_schema(tool_function)["function"]
|
|
||||||
tool_parameters = tool_json_schema["parameters"]
|
|
||||||
inputs_schema = tool_parameters["properties"]
|
|
||||||
for input_name in inputs_schema:
|
|
||||||
if "required" not in tool_parameters or input_name not in tool_parameters["required"]:
|
|
||||||
inputs_schema[input_name]["nullable"] = True
|
|
||||||
return tool_json_schema
|
|
||||||
|
|
||||||
|
|
||||||
def tool(tool_function: Callable) -> Tool:
|
def tool(tool_function: Callable) -> Tool:
|
||||||
"""
|
"""
|
||||||
Converts a function into an instance of a Tool subclass.
|
Converts a function into an instance of a Tool subclass.
|
||||||
|
@ -905,7 +884,7 @@ def tool(tool_function: Callable) -> Tool:
|
||||||
tool_function: Your function. Should have type hints for each input and a type hint for the output.
|
tool_function: Your function. Should have type hints for each input and a type hint for the output.
|
||||||
Should also have a docstring description including an 'Args:' part where each argument is described.
|
Should also have a docstring description including an 'Args:' part where each argument is described.
|
||||||
"""
|
"""
|
||||||
tool_json_schema = get_tool_json_schema(tool_function)
|
tool_json_schema = get_json_schema(tool_function)["function"]
|
||||||
if "return" not in tool_json_schema:
|
if "return" not in tool_json_schema:
|
||||||
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import unittest
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from smolagents._function_type_hints_utils import get_json_schema
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTextTests(unittest.TestCase):
|
||||||
|
def test_return_none(self):
|
||||||
|
def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
Args:
|
||||||
|
x: The first input
|
||||||
|
y: The second input
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {"type": "integer", "description": "The first input"},
|
||||||
|
"y": {
|
||||||
|
"type": "array",
|
||||||
|
"description": "The second input",
|
||||||
|
"nullable": True,
|
||||||
|
"prefixItems": [{"type": "string"}, {"type": "string"}, {"type": "number"}],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
"return": {"type": "null"},
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"]
|
||||||
|
)
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
|
@ -34,7 +34,9 @@ class ModelTests(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
|
assert (
|
||||||
|
"nullable" in models.get_tool_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
|
||||||
|
)
|
||||||
|
|
||||||
def test_chatmessage_has_model_dumps_json(self):
|
def test_chatmessage_has_model_dumps_json(self):
|
||||||
message = ChatMessage("user", "Hello!")
|
message = ChatMessage("user", "Hello!")
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Dict, Optional, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import mcp
|
import mcp
|
||||||
|
@ -381,13 +381,41 @@ class ToolTests(unittest.TestCase):
|
||||||
Get weather in the next days at given location.
|
Get weather in the next days at given location.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
location: the location
|
location: The location to get the weather for.
|
||||||
celsius: is the temperature given in celsius
|
celsius: is the temperature given in celsius?
|
||||||
"""
|
"""
|
||||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||||
|
|
||||||
assert get_weather.inputs["celsius"]["nullable"]
|
assert get_weather.inputs["celsius"]["nullable"]
|
||||||
|
|
||||||
|
def test_tool_supports_any_none(self):
|
||||||
|
@tool
|
||||||
|
def get_weather(location: Any) -> None:
|
||||||
|
"""
|
||||||
|
Get weather in the next days at given location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The location to get the weather for.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
assert get_weather.inputs["location"]["type"] == "any"
|
||||||
|
|
||||||
|
def test_tool_supports_array(self):
|
||||||
|
@tool
|
||||||
|
def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Get weather in the next days at given locations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locations: The locations to get the weather for.
|
||||||
|
months: The months to get the weather for
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
assert get_weather.inputs["locations"]["type"] == "array"
|
||||||
|
assert get_weather.inputs["months"]["type"] == "array"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_server_parameters():
|
def mock_server_parameters():
|
||||||
|
|
Loading…
Reference in New Issue