Support any and none tool types (#280)
* Support any and none tool types
This commit is contained in:
parent
83ecd572fc
commit
43904f32c7
|
@ -87,3 +87,8 @@ jobs:
|
|||
run: |
|
||||
uv run pytest ./tests/test_utils.py
|
||||
if: ${{ success() || failure() }}
|
||||
|
||||
- name: Function type hints utils tests
|
||||
run: |
|
||||
uv run pytest ./tests/test_function_type_hints_utils.py
|
||||
if: ${{ success() || failure() }}
|
||||
|
|
|
@ -276,20 +276,26 @@ def _parse_google_format_docstring(
|
|||
return description, args_dict, returns
|
||||
|
||||
|
||||
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||
def _convert_type_hints_to_json_schema(func: Callable, error_on_missing_type_hints: bool = True) -> Dict:
|
||||
type_hints = get_type_hints(func)
|
||||
signature = inspect.signature(func)
|
||||
required = []
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param.annotation == inspect.Parameter.empty:
|
||||
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
||||
if param.default == inspect.Parameter.empty:
|
||||
required.append(param_name)
|
||||
|
||||
properties = {}
|
||||
for param_name, param_type in type_hints.items():
|
||||
properties[param_name] = _parse_type_hint(param_type)
|
||||
|
||||
required = []
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param.annotation == inspect.Parameter.empty and error_on_missing_type_hints:
|
||||
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
||||
if param_name not in properties:
|
||||
properties[param_name] = {}
|
||||
|
||||
if param.default == inspect.Parameter.empty:
|
||||
required.append(param_name)
|
||||
else:
|
||||
properties[param_name]["nullable"] = True
|
||||
|
||||
schema = {"type": "object", "properties": properties}
|
||||
if required:
|
||||
schema["required"] = required
|
||||
|
@ -368,7 +374,8 @@ _BASE_TYPE_MAPPING = {
|
|||
float: {"type": "number"},
|
||||
str: {"type": "string"},
|
||||
bool: {"type": "boolean"},
|
||||
Any: {},
|
||||
Any: {"type": "any"},
|
||||
types.NoneType: {"type": "null"},
|
||||
}
|
||||
|
||||
|
|
@ -16,7 +16,7 @@
|
|||
# limitations under the License.
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .local_python_executor import (
|
||||
BASE_BUILTIN_MODULES,
|
||||
|
@ -85,7 +85,7 @@ class FinalAnswerTool(Tool):
|
|||
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
||||
output_type = "any"
|
||||
|
||||
def forward(self, answer):
|
||||
def forward(self, answer: Any) -> Any:
|
||||
return answer
|
||||
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ tool_role_conversions = {
|
|||
}
|
||||
|
||||
|
||||
def get_json_schema(tool: Tool) -> Dict:
|
||||
def get_tool_json_schema(tool: Tool) -> Dict:
|
||||
properties = deepcopy(tool.inputs)
|
||||
required = []
|
||||
for key, value in properties.items():
|
||||
|
@ -240,7 +240,7 @@ class Model:
|
|||
if tools_to_call_from:
|
||||
completion_kwargs.update(
|
||||
{
|
||||
"tools": [get_json_schema(tool) for tool in tools_to_call_from],
|
||||
"tools": [get_tool_json_schema(tool) for tool in tools_to_call_from],
|
||||
"tool_choice": "required",
|
||||
}
|
||||
)
|
||||
|
@ -490,7 +490,7 @@ class TransformersModel(Model):
|
|||
if tools_to_call_from is not None:
|
||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=completion_kwargs.pop("tools", []),
|
||||
tools=[get_tool_json_schema(tool) for tool in tools_to_call_from],
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
add_generation_prompt=True,
|
||||
|
|
|
@ -26,7 +26,7 @@ import textwrap
|
|||
from contextlib import contextmanager
|
||||
from functools import lru_cache, wraps
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union, get_type_hints
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import (
|
||||
create_repo,
|
||||
|
@ -38,9 +38,9 @@ from huggingface_hub import (
|
|||
from huggingface_hub.utils import is_torch_available
|
||||
from packaging import version
|
||||
|
||||
from ._transformers_utils import (
|
||||
from ._function_type_hints_utils import (
|
||||
TypeHintParsingException,
|
||||
_parse_type_hint,
|
||||
_convert_type_hints_to_json_schema,
|
||||
get_imports,
|
||||
get_json_schema,
|
||||
)
|
||||
|
@ -64,22 +64,6 @@ def validate_after_init(cls):
|
|||
return cls
|
||||
|
||||
|
||||
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||
type_hints = get_type_hints(func)
|
||||
signature = inspect.signature(func)
|
||||
properties = {}
|
||||
for param_name, param_type in type_hints.items():
|
||||
if param_name != "return":
|
||||
properties[param_name] = _parse_type_hint(param_type)
|
||||
if signature.parameters[param_name].default != inspect.Parameter.empty:
|
||||
properties[param_name]["nullable"] = True
|
||||
for param_name in signature.parameters.keys():
|
||||
if signature.parameters[param_name].default != inspect.Parameter.empty:
|
||||
if param_name not in properties: # this can happen if the param has no type hint but a default value
|
||||
properties[param_name] = {"nullable": True}
|
||||
return properties
|
||||
|
||||
|
||||
AUTHORIZED_TYPES = [
|
||||
"string",
|
||||
"boolean",
|
||||
|
@ -87,8 +71,10 @@ AUTHORIZED_TYPES = [
|
|||
"number",
|
||||
"image",
|
||||
"audio",
|
||||
"any",
|
||||
"array",
|
||||
"object",
|
||||
"any",
|
||||
"null",
|
||||
]
|
||||
|
||||
CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
|
||||
|
@ -168,12 +154,15 @@ class Tool:
|
|||
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
||||
)
|
||||
|
||||
json_schema = _convert_type_hints_to_json_schema(
|
||||
self.forward
|
||||
) # This function will raise an error on missing docstrings, contrary to get_json_schema
|
||||
json_schema = _convert_type_hints_to_json_schema(self.forward, error_on_missing_type_hints=False)[
|
||||
"properties"
|
||||
] # This function will not raise an error on missing docstrings, contrary to get_json_schema
|
||||
for key, value in self.inputs.items():
|
||||
assert key in json_schema, (
|
||||
f"Input '{key}' should be present in function signature, found only {json_schema.keys()}"
|
||||
)
|
||||
if "nullable" in value:
|
||||
assert key in json_schema and "nullable" in json_schema[key], (
|
||||
assert "nullable" in json_schema[key], (
|
||||
f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature."
|
||||
)
|
||||
if key in json_schema and "nullable" in json_schema[key]:
|
||||
|
@ -887,16 +876,6 @@ class ToolCollection:
|
|||
yield cls(tools)
|
||||
|
||||
|
||||
def get_tool_json_schema(tool_function):
|
||||
tool_json_schema = get_json_schema(tool_function)["function"]
|
||||
tool_parameters = tool_json_schema["parameters"]
|
||||
inputs_schema = tool_parameters["properties"]
|
||||
for input_name in inputs_schema:
|
||||
if "required" not in tool_parameters or input_name not in tool_parameters["required"]:
|
||||
inputs_schema[input_name]["nullable"] = True
|
||||
return tool_json_schema
|
||||
|
||||
|
||||
def tool(tool_function: Callable) -> Tool:
|
||||
"""
|
||||
Converts a function into an instance of a Tool subclass.
|
||||
|
@ -905,7 +884,7 @@ def tool(tool_function: Callable) -> Tool:
|
|||
tool_function: Your function. Should have type hints for each input and a type hint for the output.
|
||||
Should also have a docstring description including an 'Args:' part where each argument is described.
|
||||
"""
|
||||
tool_json_schema = get_tool_json_schema(tool_function)
|
||||
tool_json_schema = get_json_schema(tool_function)["function"]
|
||||
if "return" not in tool_json_schema:
|
||||
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from smolagents._function_type_hints_utils import get_json_schema
|
||||
|
||||
|
||||
class AgentTextTests(unittest.TestCase):
|
||||
def test_return_none(self):
|
||||
def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None:
|
||||
"""
|
||||
Test function
|
||||
Args:
|
||||
x: The first input
|
||||
y: The second input
|
||||
"""
|
||||
pass
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "integer", "description": "The first input"},
|
||||
"y": {
|
||||
"type": "array",
|
||||
"description": "The second input",
|
||||
"nullable": True,
|
||||
"prefixItems": [{"type": "string"}, {"type": "string"}, {"type": "number"}],
|
||||
},
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
"return": {"type": "null"},
|
||||
}
|
||||
self.assertEqual(
|
||||
schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"]
|
||||
)
|
||||
self.assertEqual(schema["function"], expected_schema)
|
|
@ -34,7 +34,9 @@ class ModelTests(unittest.TestCase):
|
|||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
|
||||
assert (
|
||||
"nullable" in models.get_tool_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
|
||||
)
|
||||
|
||||
def test_chatmessage_has_model_dumps_json(self):
|
||||
message = ChatMessage("user", "Hello!")
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import unittest
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mcp
|
||||
|
@ -381,13 +381,41 @@ class ToolTests(unittest.TestCase):
|
|||
Get weather in the next days at given location.
|
||||
|
||||
Args:
|
||||
location: the location
|
||||
celsius: is the temperature given in celsius
|
||||
location: The location to get the weather for.
|
||||
celsius: is the temperature given in celsius?
|
||||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
assert get_weather.inputs["celsius"]["nullable"]
|
||||
|
||||
def test_tool_supports_any_none(self):
|
||||
@tool
|
||||
def get_weather(location: Any) -> None:
|
||||
"""
|
||||
Get weather in the next days at given location.
|
||||
|
||||
Args:
|
||||
location: The location to get the weather for.
|
||||
"""
|
||||
return
|
||||
|
||||
assert get_weather.inputs["location"]["type"] == "any"
|
||||
|
||||
def test_tool_supports_array(self):
|
||||
@tool
|
||||
def get_weather(locations: List[str], months: Optional[Tuple[str, str]] = None) -> Dict[str, float]:
|
||||
"""
|
||||
Get weather in the next days at given locations.
|
||||
|
||||
Args:
|
||||
locations: The locations to get the weather for.
|
||||
months: The months to get the weather for
|
||||
"""
|
||||
return
|
||||
|
||||
assert get_weather.inputs["locations"]["type"] == "array"
|
||||
assert get_weather.inputs["months"]["type"] == "array"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_parameters():
|
||||
|
|
Loading…
Reference in New Issue