diff --git a/README.md b/README.md index 0a8082e..6f7c848 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ This library offers: First install the package. ```bash -pip install agents +pip install smolagents ``` Then define your agent, give it the tools it needs and run it! ```py diff --git a/docs/source/_config.py b/docs/source/_config.py index 532a4ed..81f6de0 100644 --- a/docs/source/_config.py +++ b/docs/source/_config.py @@ -1,9 +1,9 @@ # docstyle-ignore INSTALL_CONTENT = """ -# Transformers installation -! pip install agents +# Installation +! pip install smolagents # To install from source instead of the last release, comment the command above and uncomment the following one. -# ! pip install git+https://github.com/huggingface/agents.git +# ! pip install git+https://github.com/huggingface/smolagents.git """ notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}] diff --git a/examples/tool_calling_agent_from_any_llm.py b/examples/tool_calling_agent_from_any_llm.py index 8ef272e..8fee81d 100644 --- a/examples/tool_calling_agent_from_any_llm.py +++ b/examples/tool_calling_agent_from_any_llm.py @@ -1,19 +1,21 @@ from smolagents.agents import ToolCallingAgent from smolagents import tool, HfApiModel, TransformersModel, LiteLLMModel +from typing import Optional # Choose which LLM engine to use! -model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct") -model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct") +# model = HfApiModel("meta-llama/Llama-3.3-70B-Instruct") +# model = TransformersModel("meta-llama/Llama-3.2-2B-Instruct") model = LiteLLMModel("gpt-4o") @tool -def get_weather(location: str) -> str: +def get_weather(location: str, celsius: Optional[bool] = False) -> str: """ Get weather in the next days at given location. Secretly this tool does not care about the location, it hates the weather everywhere. Args: location: the location + celsius: the temperature """ return "The weather is UNGODLY with torrential rains and temperatures below -10°C" diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 0f5a763..6339836 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -172,6 +172,7 @@ class GoogleSearchTool(Tool): "filter_year": { "type": "integer", "description": "Optionally restrict results to a certain year", + "nullable": True, }, } output_type = "string" @@ -209,9 +210,14 @@ class GoogleSearchTool(Tool): raise ValueError(response.json()) if "organic_results" not in results.keys(): - raise Exception( - f"'organic_results' key not found for query: '{query}'. Use a less restrictive query." - ) + if filter_year is not None: + raise Exception( + f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year." + ) + else: + raise Exception( + f"'organic_results' key not found for query: '{query}'. Use a less restrictive query." + ) if len(results["organic_results"]) == 0: year_filter_message = ( f" with filter year={filter_year}" if filter_year is not None else "" diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 9ae010d..6107ccb 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -67,9 +67,12 @@ tool_role_conversions = { def get_json_schema(tool: Tool) -> Dict: properties = deepcopy(tool.inputs) - for value in properties.values(): + required = [] + for key, value in properties.items(): if value["type"] == "any": value["type"] = "string" + if not ("nullable" in value and value["nullable"]): + required.append(key) return { "type": "function", "function": { @@ -78,7 +81,7 @@ def get_json_schema(tool: Tool) -> Dict: "parameters": { "type": "object", "properties": properties, - "required": list(tool.inputs.keys()), + "required": required, }, }, } diff --git a/src/smolagents/prompts.py b/src/smolagents/prompts.py index d4404ea..969a5a2 100644 --- a/src/smolagents/prompts.py +++ b/src/smolagents/prompts.py @@ -260,10 +260,8 @@ Task: "Which city has the highest population: Guangzhou or Shanghai?" Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. Code: ```py -population_guangzhou = search("Guangzhou population") -print("Population Guangzhou:", population_guangzhou) -population_shanghai = search("Shanghai population") -print("Population Shanghai:", population_shanghai) +for city in ["Guangzhou", "Shanghai"]: + print(f"Population {city}:", search(f"{city} population") ``` Observation: Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] @@ -278,11 +276,13 @@ final_answer("Shanghai") --- Task: "What is the current age of the pope, raised to the power 0.36?" -Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36. +Thought: I will use the tool `wiki` to get the age of the pope, and confirm that with a web search. Code: ```py -pope_age = wiki(query="current pope age") -print("Pope age:", pope_age) +pope_age_wiki = wiki(query="current pope age") +print("Pope age as per wikipedia:", pope_age_wiki) +pope_age_search = web_search(query="current pope age") +print("Pope age as per google search:", pope_age_search) ``` Observation: Pope age: "The pope Francis is currently 85 years old." diff --git a/src/smolagents/tool_validation.py b/src/smolagents/tool_validation.py index 9a0a3b4..2dcdc45 100644 --- a/src/smolagents/tool_validation.py +++ b/src/smolagents/tool_validation.py @@ -85,7 +85,6 @@ class MethodChecker(ast.NodeVisitor): self.generic_visit(node) def visit_Attribute(self, node): - # Skip self.something if not (isinstance(node.value, ast.Name) and node.value.id == "self"): self.generic_visit(node) diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 45b1201..ad633ef 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -24,7 +24,7 @@ import torch import textwrap from functools import lru_cache, wraps from pathlib import Path -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union, get_type_hints from huggingface_hub import ( create_repo, get_collection, @@ -42,6 +42,8 @@ from transformers.utils import ( is_accelerate_available, is_torch_available, ) +from transformers.utils.chat_template_utils import _parse_type_hint + from transformers.dynamic_module_utils import get_imports from transformers import AutoProcessor @@ -95,17 +97,27 @@ def setup_default_tools(): return default_tools -def validate_after_init(cls, do_validate_forward: bool = True): +def validate_after_init(cls): original_init = cls.__init__ @wraps(original_init) def new_init(self, *args, **kwargs): original_init(self, *args, **kwargs) - self.validate_arguments(do_validate_forward=do_validate_forward) + self.validate_arguments() cls.__init__ = new_init return cls +def _convert_type_hints_to_json_schema(func: Callable) -> Dict: + type_hints = get_type_hints(func) + signature = inspect.signature(func) + properties = {} + for param_name, param_type in type_hints.items(): + if param_name != "return": + properties[param_name] = _parse_type_hint(param_type) + if signature.parameters[param_name].default != inspect.Parameter.empty: + properties[param_name]["nullable"] = True + return properties AUTHORIZED_TYPES = [ "string", @@ -145,7 +157,7 @@ class Tool: name: str description: str - inputs: Dict[str, Dict[str, Union[str, type]]] + inputs: Dict[str, Dict[str, Union[str, type, bool]]] output_type: str def __init__(self, *args, **kwargs): @@ -153,9 +165,9 @@ class Tool: def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - validate_after_init(cls, do_validate_forward=False) + validate_after_init(cls) - def validate_arguments(self, do_validate_forward: bool = True): + def validate_arguments(self): required_attributes = { "description": str, "name": str, @@ -184,12 +196,21 @@ class Tool: ) assert getattr(self, "output_type", None) in AUTHORIZED_TYPES - if do_validate_forward: - signature = inspect.signature(self.forward) - if not set(signature.parameters.keys()) == set(self.inputs.keys()): - raise Exception( - "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." - ) + + # Validate forward function signature + signature = inspect.signature(self.forward) + + if not set(signature.parameters.keys()) == set(self.inputs.keys()): + raise Exception( + "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." + ) + + json_schema = _convert_type_hints_to_json_schema(self.forward) + for key, value in self.inputs.items(): + if "nullable" in value: + assert (key in json_schema and "nullable" in json_schema[key]), f"Nullable argument '{key}' in inputs should have key 'nullable' set to True in function signature." + if key in json_schema and "nullable" in json_schema[key]: + assert "nullable" in value, f"Nullable argument '{key}' in function signature should have key 'nullable' set to True in inputs." def forward(self, *args, **kwargs): return NotImplementedError("Write this method in your subclass of `Tool`.") @@ -877,9 +898,6 @@ def tool(tool_function: Callable) -> Tool: "Tool return type not found: make sure your function has a return type hint!" ) - if parameters["return"]["type"] == "object": - parameters["return"]["type"] = "any" - class SimpleTool(Tool): def __init__(self, name, description, inputs, output_type, function): self.name = name @@ -898,11 +916,10 @@ def tool(tool_function: Callable) -> Tool: ) original_signature = inspect.signature(tool_function) new_parameters = [ - inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD) + inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY) ] + list(original_signature.parameters.values()) new_signature = original_signature.replace(parameters=new_parameters) simple_tool.forward.__signature__ = new_signature - # SimpleTool.__name__ = "".join([el.title() for el in parameters["name"].split("_")]) return simple_tool diff --git a/tests/test_final_answer.py b/tests/test_final_answer.py index 381e17d..8e79774 100644 --- a/tests/test_final_answer.py +++ b/tests/test_final_answer.py @@ -25,7 +25,7 @@ from smolagents.types import AGENT_TYPE_MAPPING from smolagents.default_tools import FinalAnswerTool -from .test_tools_common import ToolTesterMixin +from .test_tools import ToolTesterMixin if is_torch_available(): diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 76b57e2..61683dd 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -26,7 +26,7 @@ from smolagents.local_python_executor import ( evaluate_python_code, ) -from .test_tools_common import ToolTesterMixin +from .test_tools import ToolTesterMixin # Fake function we will use as tool diff --git a/tests/test_search.py b/tests/test_search.py index 8d972c7..9660d4f 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -17,7 +17,7 @@ import unittest from smolagents import load_tool -from .test_tools_common import ToolTesterMixin +from .test_tools import ToolTesterMixin class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): diff --git a/tests/test_tools_common.py b/tests/test_tools.py similarity index 63% rename from tests/test_tools_common.py rename to tests/test_tools.py index 96f804f..77ae2b3 100644 --- a/tests/test_tools_common.py +++ b/tests/test_tools.py @@ -14,7 +14,7 @@ # limitations under the License. import unittest from pathlib import Path -from typing import Dict, Union +from typing import Dict, Union, Optional import numpy as np import pytest @@ -126,9 +126,9 @@ class ToolTests(unittest.TestCase): "description": "the task category (such as text-classification, depth-estimation, etc)", } } - output_type = "integer" + output_type = "string" - def forward(self, task): + def forward(self, task: str) -> str: return "best model" tool = HFModelDownloadsTool() @@ -223,7 +223,7 @@ class ToolTests(unittest.TestCase): name = "specific" description = "test description" inputs = { - "input_str": {"type": "string", "description": "input description"} + "string_input": {"type": "string", "description": "input description"} } output_type = "string" @@ -231,7 +231,7 @@ class ToolTests(unittest.TestCase): super().__init__(self) self.url = "none" - def forward(self, string_input): + def forward(self, string_input: str) -> str: return self.url + string_input fail_tool = FailTool("dummy_url") @@ -241,46 +241,127 @@ class ToolTests(unittest.TestCase): def test_saving_tool_allows_no_imports_from_outside_methods(self): # Test that using imports from outside functions fails - from numpy import random + import numpy as np - class FailTool2(Tool): + class FailTool(Tool): name = "specific" description = "test description" inputs = { - "input_str": {"type": "string", "description": "input description"} + "string_input": {"type": "string", "description": "input description"} } output_type = "string" def useless_method(self): - self.client = random.random() + self.client = np.random.random() return "" def forward(self, string_input): return self.useless_method() + string_input - fail_tool_2 = FailTool2() + fail_tool = FailTool() with pytest.raises(Exception) as e: - fail_tool_2.save("output") - assert "random" in str(e) + fail_tool.save("output") + assert "'np' is undefined" in str(e) # Test that putting these imports inside functions works - - class FailTool3(Tool): + class SuccessTool(Tool): name = "specific" description = "test description" inputs = { - "input_str": {"type": "string", "description": "input description"} + "string_input": {"type": "string", "description": "input description"} } output_type = "string" def useless_method(self): - from numpy import random + import numpy as np - self.client = random.random() + self.client = np.random.random() return "" def forward(self, string_input): return self.useless_method() + string_input - fail_tool_3 = FailTool3() - fail_tool_3.save("output") + success_tool = SuccessTool() + success_tool.save("output") + + def test_tool_missing_class_attributes_raises_error(self): + with pytest.raises(Exception) as e: + class GetWeatherTool(Tool): + name = "get_weather" + description = "Get weather in the next days at given location." + inputs = { + "location": {"type": "string", "description": "the location"}, + "celsius": {"type": "string", "description": "the temperature type"} + } + + def forward(self, location: str, celsius: Optional[bool] = False) -> str: + return "The weather is UNGODLY with torrential rains and temperatures below -10°C" + + tool = GetWeatherTool() + assert "You must set an attribute output_type" in str(e) + + def test_tool_from_decorator_optional_args(self): + @tool + def get_weather(location: str, celsius: Optional[bool] = False) -> str: + """ + Get weather in the next days at given location. + Secretly this tool does not care about the location, it hates the weather everywhere. + + Args: + location: the location + celsius: the temperature type + """ + return "The weather is UNGODLY with torrential rains and temperatures below -10°C" + + assert "nullable" in get_weather.inputs["celsius"] + assert get_weather.inputs["celsius"]["nullable"] == True + assert "nullable" not in get_weather.inputs["location"] + + def test_tool_mismatching_nullable_args_raises_error(self): + with pytest.raises(Exception) as e: + class GetWeatherTool(Tool): + name = "get_weather" + description = "Get weather in the next days at given location." + inputs = { + "location": {"type": "string", "description": "the location"}, + "celsius": {"type": "string", "description": "the temperature type"} + } + output_type = "string" + + def forward(self, location: str, celsius: Optional[bool] = False) -> str: + return "The weather is UNGODLY with torrential rains and temperatures below -10°C" + + tool = GetWeatherTool() + assert "Nullable" in str(e) + + with pytest.raises(Exception) as e: + class GetWeatherTool2(Tool): + name = "get_weather" + description = "Get weather in the next days at given location." + inputs = { + "location": {"type": "string", "description": "the location"}, + "celsius": {"type": "string", "description": "the temperature type"} + } + output_type = "string" + + def forward(self, location: str, celsius: bool = False) -> str: + return "The weather is UNGODLY with torrential rains and temperatures below -10°C" + + tool = GetWeatherTool2() + assert "Nullable" in str(e) + + with pytest.raises(Exception) as e: + class GetWeatherTool3(Tool): + name = "get_weather" + description = "Get weather in the next days at given location." + inputs = { + "location": {"type": "string", "description": "the location"}, + "celsius": {"type": "string", "description": "the temperature type", "nullable": True} + } + output_type = "string" + + def forward(self, location, celsius: str) -> str: + return "The weather is UNGODLY with torrential rains and temperatures below -10°C" + + tool = GetWeatherTool3() + assert "Nullable" in str(e)