diff --git a/docs/source/en/reference/tools.md b/docs/source/en/reference/tools.md index 41064c4..022ad35 100644 --- a/docs/source/en/reference/tools.md +++ b/docs/source/en/reference/tools.md @@ -39,10 +39,6 @@ contains the API docs for the underlying classes. [[autodoc]] Tool -### Toolbox - -[[autodoc]] Toolbox - ### launch_gradio_demo [[autodoc]] launch_gradio_demo diff --git a/docs/source/en/tutorials/tools.md b/docs/source/en/tutorials/tools.md index c86da57..014cd3b 100644 --- a/docs/source/en/tutorials/tools.md +++ b/docs/source/en/tutorials/tools.md @@ -187,7 +187,7 @@ from smolagents import HfApiModel model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") agent = CodeAgent(tools=[], model=model, add_base_tools=True) -agent.toolbox.add_tool(model_download_tool) +agent.tools.append(model_download_tool) ``` Now we can leverage the new tool: @@ -202,11 +202,6 @@ agent.run( > Beware of not adding too many tools to an agent: this can overwhelm weaker LLM engines. -Use the `agent.toolbox.update_tool()` method to replace an existing tool in the agent's toolbox. -This is useful if your new tool is a one-to-one replacement of the existing tool because the agent already knows how to perform that specific task. -Just make sure the new tool follows the same API as the replaced tool or adapt the system prompt template to ensure all examples using the replaced tool are updated. - - ### Use a collection of tools You can leverage tool collections by using the ToolCollection object, with the slug of the collection you want to use. diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index d8c8c4f..ab320e5 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -18,13 +18,14 @@ import time from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from rich import box from rich.console import Group from rich.panel import Panel from rich.rule import Rule from rich.syntax import Syntax from rich.text import Text -from .default_tools import FinalAnswerTool +from .default_tools import FinalAnswerTool, TOOL_MAPPING from .e2b_executor import E2BExecutor from .local_python_executor import ( BASE_BUILTIN_MODULES, @@ -49,7 +50,6 @@ from .prompts import ( from .tools import ( DEFAULT_TOOL_DESCRIPTION_TEMPLATE, Tool, - Toolbox, get_tool_description_with_args, ) from .types import AgentAudio, AgentImage, handle_agent_output_types @@ -107,18 +107,27 @@ class SystemPromptStep(AgentStep): system_prompt: str -def format_prompt_with_tools( - toolbox: Toolbox, prompt_template: str, tool_description_template: str +def get_tool_descriptions( + tools: Dict[str, Tool], tool_description_template: str ) -> str: - tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) - prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) + return "\n".join( + [ + get_tool_description_with_args(tool, tool_description_template) + for tool in tools.values() + ] + ) + +def format_prompt_with_tools( + tools: Dict[str, Tool], prompt_template: str, tool_description_template: str +) -> str: + tool_descriptions = get_tool_descriptions(tools, tool_description_template) + prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) if "{{tool_names}}" in prompt: prompt = prompt.replace( "{{tool_names}}", - ", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]), + ", ".join([f"'{tool.name}'" for tool in tools.values()]), ) - return prompt @@ -163,7 +172,7 @@ class MultiStepAgent: def __init__( self, - tools: Union[List[Tool], Toolbox], + tools: List[Tool], model: Callable[[List[Dict[str, str]]], str], system_prompt: Optional[str] = None, tool_description_template: Optional[str] = None, @@ -172,7 +181,7 @@ class MultiStepAgent: add_base_tools: bool = False, verbose: bool = False, grammar: Optional[Dict[str, str]] = None, - managed_agents: Optional[Dict] = None, + managed_agents: Optional[List] = None, step_callbacks: Optional[List[Callable]] = None, planning_interval: Optional[int] = None, ): @@ -196,17 +205,18 @@ class MultiStepAgent: self.managed_agents = {} if managed_agents is not None: + print("NOTNONE") self.managed_agents = {agent.name: agent for agent in managed_agents} - if isinstance(tools, Toolbox): - self._toolbox = tools - if add_base_tools: - self._toolbox.add_base_tools( - add_python_interpreter=(self.__class__ == ToolCallingAgent) - ) - else: - self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) - self._toolbox.add_tool(FinalAnswerTool()) + self.tools = {tool.name: tool for tool in tools} + if add_base_tools: + for tool_name, tool_class in TOOL_MAPPING.items(): + if ( + tool_name != "python_interpreter" + or self.__class__.__name__ == "ToolCallingAgent" + ): + self.tools[tool_name] = tool_class() + self.tools["final_answer"] = FinalAnswerTool() self.system_prompt = self.initialize_system_prompt() self.input_messages = None @@ -217,14 +227,9 @@ class MultiStepAgent: self.step_callbacks = step_callbacks if step_callbacks is not None else [] self.step_callbacks.append(self.monitor.update_metrics) - @property - def toolbox(self) -> Toolbox: - """Get the toolbox currently available to the agent""" - return self._toolbox - def initialize_system_prompt(self): self.system_prompt = format_prompt_with_tools( - self._toolbox, + self.tools, self.system_prompt_template, self.tool_description_template, ) @@ -384,10 +389,10 @@ class MultiStepAgent: This method replaces arguments with the actual values from the state if they refer to state variables. Args: - tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox). + tool_name (`str`): Name of the Tool to execute (should be one from self.tools). arguments (Dict[str, str]): Arguments passed to the Tool. """ - available_tools = {**self.toolbox.tools, **self.managed_agents} + available_tools = {**self.tools, **self.managed_agents} if tool_name not in available_tools: error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." raise AgentExecutionError(error_msg) @@ -415,7 +420,7 @@ class MultiStepAgent: raise AgentExecutionError(error_msg) return observation except Exception as e: - if tool_name in self.toolbox.tools: + if tool_name in self.tools: tool_description = get_tool_description_with_args( available_tools[tool_name] ) @@ -512,20 +517,26 @@ You have been provided with these additional arguments, that you can access usin Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method. """ final_answer = None - step_number = 0 - while final_answer is None and step_number < self.max_steps: + self.step_number = 0 + while final_answer is None and self.step_number < self.max_steps: step_start_time = time.time() - step_log = ActionStep(step=step_number, start_time=step_start_time) + step_log = ActionStep(step=self.step_number, start_time=step_start_time) try: if ( self.planning_interval is not None - and step_number % self.planning_interval == 0 + and self.step_number % self.planning_interval == 0 ): self.planning_step( - task, is_first_step=(step_number == 0), step=step_number + task, + is_first_step=(self.step_number == 0), + step=self.step_number, ) console.print( - Rule(f"[bold]Step {step_number}", characters="━", style=YELLOW_HEX) + Rule( + f"[bold]Step {self.step_number}", + characters="━", + style=YELLOW_HEX, + ) ) # Run one step! @@ -538,10 +549,10 @@ You have been provided with these additional arguments, that you can access usin self.logs.append(step_log) for callback in self.step_callbacks: callback(step_log) - step_number += 1 + self.step_number += 1 yield step_log - if final_answer is None and step_number == self.max_steps: + if final_answer is None and self.step_number == self.max_steps: error_message = "Reached max steps." final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) self.logs.append(final_step_log) @@ -561,20 +572,26 @@ You have been provided with these additional arguments, that you can access usin Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method. """ final_answer = None - step_number = 0 - while final_answer is None and step_number < self.max_steps: + self.step_number = 0 + while final_answer is None and self.step_number < self.max_steps: step_start_time = time.time() - step_log = ActionStep(step=step_number, start_time=step_start_time) + step_log = ActionStep(step=self.step_number, start_time=step_start_time) try: if ( self.planning_interval is not None - and step_number % self.planning_interval == 0 + and self.step_number % self.planning_interval == 0 ): self.planning_step( - task, is_first_step=(step_number == 0), step=step_number + task, + is_first_step=(self.step_number == 0), + step=self.step_number, ) console.print( - Rule(f"[bold]Step {step_number}", characters="━", style=YELLOW_HEX) + Rule( + f"[bold]Step {self.step_number}", + characters="━", + style=YELLOW_HEX, + ) ) # Run one step! @@ -589,9 +606,9 @@ You have been provided with these additional arguments, that you can access usin self.logs.append(step_log) for callback in self.step_callbacks: callback(step_log) - step_number += 1 + self.step_number += 1 - if final_answer is None and step_number == self.max_steps: + if final_answer is None and self.step_number == self.max_steps: error_message = "Reached max steps." final_step_log = ActionStep(error=AgentMaxStepsError(error_message)) self.logs.append(final_step_log) @@ -637,8 +654,8 @@ Now begin!""", "role": MessageRole.USER, "content": USER_PROMPT_PLAN.format( task=task, - tool_descriptions=self._toolbox.show_tool_descriptions( - self.tool_description_template + tool_descriptions=get_tool_descriptions( + self.tools, self.tool_description_template ), managed_agents_descriptions=( show_agents_descriptions(self.managed_agents) @@ -692,8 +709,8 @@ Now begin!""", "role": MessageRole.USER, "content": USER_PROMPT_PLAN_UPDATE.format( task=task, - tool_descriptions=self._toolbox.show_tool_descriptions( - self.tool_description_template + tool_descriptions=get_tool_descriptions( + self.tools, self.tool_description_template ), managed_agents_descriptions=( show_agents_descriptions(self.managed_agents) @@ -761,7 +778,7 @@ class ToolCallingAgent(MultiStepAgent): try: tool_name, tool_arguments, tool_call_id = self.model.get_tool_call( self.input_messages, - available_tools=list(self.toolbox._tools.values()), + available_tools=list(self.tools.values()), stop_sequences=["Observation:"], ) except Exception as e: @@ -856,7 +873,7 @@ class CodeAgent(MultiStepAgent): f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution." ) - all_tools = {**self.toolbox.tools, **self.managed_agents} + all_tools = {**self.tools, **self.managed_agents} if use_e2b_executor: self.python_executor = E2BExecutor( self.additional_authorized_imports, list(all_tools.values()) @@ -941,10 +958,10 @@ class CodeAgent(MultiStepAgent): lexer="python", theme="monokai", word_wrap=True, - line_numbers=True, ), title="[bold]Executing this code:", title_align="left", + box=box.HORIZONTALS, ) ) observation = "" @@ -1045,5 +1062,4 @@ __all__ = [ "MultiStepAgent", "CodeAgent", "ToolCallingAgent", - "Toolbox", ] diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 5959cda..79539fd 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -322,6 +322,15 @@ class SpeechToTextTool(PipelineTool): return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0] +TOOL_MAPPING = { + tool_class.name: tool_class + for tool_class in [ + PythonInterpreterTool, + DuckDuckGoSearchTool, + VisitWebpageTool, + ] +} + __all__ = [ "PythonInterpreterTool", "FinalAnswerTool", diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 32d08f4..a7bea46 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -157,6 +157,14 @@ class Model: ): raise NotImplementedError + def get_tool_call( + self, + messages: List[Dict[str, str]], + available_tools: List[Tool], + stop_sequences, + ): + raise NotImplementedError + def __call__( self, messages: List[Dict[str, str]], diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py index 5dae400..12d7d63 100644 --- a/src/smolagents/tools.py +++ b/src/smolagents/tools.py @@ -25,7 +25,7 @@ import tempfile import textwrap from functools import lru_cache, wraps from pathlib import Path -from typing import Callable, Dict, List, Optional, Union, get_type_hints +from typing import Callable, Dict, Optional, Union, get_type_hints import torch from huggingface_hub import ( @@ -85,18 +85,6 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs): return "space" -def setup_default_tools(): - default_tools = {} - main_module = importlib.import_module("smolagents") - - for task_name, tool_class_name in TOOL_MAPPING.items(): - tool_class = getattr(main_module, tool_class_name) - tool_instance = tool_class() - default_tools[tool_class.name] = tool_instance - - return default_tools - - def validate_after_init(cls): original_init = cls.__init__ @@ -727,10 +715,10 @@ def get_tool_description_with_args( if description_template is None: description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE compiled_template = compile_jinja_template(description_template) - rendered = compiled_template.render( + tool_description = compiled_template.render( tool=tool, ) - return rendered + return tool_description @lru_cache @@ -806,13 +794,6 @@ def launch_gradio_demo(tool: Tool): ).launch() -TOOL_MAPPING = { - "python_interpreter": "PythonInterpreterTool", - "web_search": "DuckDuckGoSearchTool", - "transcriber": "SpeechToTextTool", -} - - def load_tool( task_or_repo_id, model_repo_id: Optional[str] = None, @@ -821,7 +802,7 @@ def load_tool( **kwargs, ): """ - Main function to quickly load a tool, be it on the Hub or in the Transformers library. + Main function to quickly load a tool from the Hub. @@ -854,20 +835,13 @@ def load_tool( `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others will be passed along to its init. """ - if task_or_repo_id in TOOL_MAPPING: - tool_class_name = TOOL_MAPPING[task_or_repo_id] - main_module = importlib.import_module("smolagents") - tools_module = main_module - tool_class = getattr(tools_module, tool_class_name) - return tool_class(token=token, **kwargs) - else: - return Tool.from_hub( - task_or_repo_id, - model_repo_id=model_repo_id, - token=token, - trust_remote_code=trust_remote_code, - **kwargs, - ) + return Tool.from_hub( + task_or_repo_id, + model_repo_id=model_repo_id, + token=token, + trust_remote_code=trust_remote_code, + **kwargs, + ) def add_description(description): @@ -961,107 +935,6 @@ def tool(tool_function: Callable) -> Tool: return simple_tool -HUGGINGFACE_DEFAULT_TOOLS = {} - - -class Toolbox: - """ - The toolbox contains all tools that the agent can perform operations with, as well as a few methods to - manage them. - - Args: - tools (`List[Tool]`): - The list of tools to instantiate the toolbox with - add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to add the tools available within `transformers` to the toolbox. - """ - - def __init__(self, tools: List[Tool], add_base_tools: bool = False): - self._tools = {tool.name: tool for tool in tools} - if add_base_tools: - self.add_base_tools() - - def add_base_tools(self, add_python_interpreter: bool = False): - global HUGGINGFACE_DEFAULT_TOOLS - if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0: - HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools() - for tool in HUGGINGFACE_DEFAULT_TOOLS.values(): - if tool.name != "python_interpreter" or add_python_interpreter: - self.add_tool(tool) - - @property - def tools(self) -> Dict[str, Tool]: - """Get all tools currently in the toolbox""" - return self._tools - - def show_tool_descriptions( - self, tool_description_template: Optional[str] = None - ) -> str: - """ - Returns the description of all tools in the toolbox - - Args: - tool_description_template (`str`, *optional*): - The template to use to describe the tools. If not provided, the default template will be used. - """ - return "\n".join( - [ - get_tool_description_with_args(tool, tool_description_template) - for tool in self._tools.values() - ] - ) - - def add_tool(self, tool: Tool): - """ - Adds a tool to the toolbox - - Args: - tool (`Tool`): - The tool to add to the toolbox. - """ - if tool.name in self._tools: - raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.") - self._tools[tool.name] = tool - - def remove_tool(self, tool_name: str): - """ - Removes a tool from the toolbox - - Args: - tool_name (`str`): - The tool to remove from the toolbox. - """ - if tool_name not in self._tools: - raise KeyError( - f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}." - ) - del self._tools[tool_name] - - def update_tool(self, tool: Tool): - """ - Updates a tool in the toolbox according to its name. - - Args: - tool (`Tool`): - The tool to update to the toolbox. - """ - if tool.name not in self._tools: - raise KeyError( - f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}." - ) - self._tools[tool.name] = tool - - def clear_toolbox(self): - """Clears the toolbox""" - self._tools = {} - - def __repr__(self): - toolbox_description = "Toolbox contents:\n" - for tool in self._tools.values(): - toolbox_description += f"\t{tool.name}: {tool.description}\n" - return toolbox_description - - class PipelineTool(Tool): """ A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will @@ -1234,6 +1107,5 @@ __all__ = [ "tool", "load_tool", "launch_gradio_demo", - "Toolbox", "ToolCollection", ] diff --git a/tests/test_agents.py b/tests/test_agents.py index 2d666e6..9327285 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -18,14 +18,12 @@ import unittest import uuid from pathlib import Path -import pytest from transformers.testing_utils import get_tests_dir from smolagents.agents import ( AgentMaxStepsError, CodeAgent, ManagedAgent, - Toolbox, ToolCall, ToolCallingAgent, ) @@ -289,37 +287,35 @@ class AgentTests(unittest.TestCase): assert len(agent.logs) == 8 assert type(agent.logs[-1].error) is AgentMaxStepsError + def test_tool_descriptions_get_baked_in_system_prompt(self): + tool = PythonInterpreterTool() + tool.name = "fake_tool_name" + tool.description = "fake_tool_description" + agent = CodeAgent(tools=[tool], model=fake_code_model) + agent.run("Empty task") + assert tool.name in agent.system_prompt + assert tool.description in agent.system_prompt + def test_init_agent_with_different_toolsets(self): toolset_1 = [] agent = CodeAgent(tools=toolset_1, model=fake_code_model) assert ( - len(agent.toolbox.tools) == 1 + len(agent.tools) == 1 ) # when no tools are provided, only the final_answer tool is added by default toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] agent = CodeAgent(tools=toolset_2, model=fake_code_model) assert ( - len(agent.toolbox.tools) == 2 + len(agent.tools) == 2 ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer - toolset_3 = Toolbox(toolset_2) - agent = CodeAgent(tools=toolset_3, model=fake_code_model) - assert ( - len(agent.toolbox.tools) == 2 - ) # same as previous one, where toolset_3 is an instantiation of previous one - - # check that add_base_tools will not interfere with existing tools - with pytest.raises(KeyError) as e: - agent = ToolCallingAgent( - tools=toolset_3, model=FakeToolCallModel(), add_base_tools=True - ) - assert "already exists in the toolbox" in str(e) - - # check that python_interpreter base tool does not get added to code agents + # check that python_interpreter base tool does not get added to CodeAgent agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True) - assert ( - len(agent.toolbox.tools) == 3 - ) # added final_answer tool + search + transcribe + assert len(agent.tools) == 3 # added final_answer tool + search + visit_webpage + + # check that python_interpreter base tool gets added to ToolCallingAgent + agent = ToolCallingAgent(tools=[], model=fake_code_model, add_base_tools=True) + assert len(agent.tools) == 4 # added final_answer tool + search + visit_webpage def test_function_persistence_across_steps(self): agent = CodeAgent( diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py index 5f4ffc4..5944066 100644 --- a/tests/test_python_interpreter.py +++ b/tests/test_python_interpreter.py @@ -18,8 +18,7 @@ import unittest import numpy as np import pytest -from smolagents import load_tool -from smolagents.default_tools import BASE_PYTHON_TOOLS +from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool from smolagents.local_python_executor import ( InterpreterError, evaluate_python_code, @@ -37,7 +36,7 @@ def add_two(x): class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): - self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"]) + self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"]) self.tool.setup() def test_exact_match_arg(self): diff --git a/tests/test_search.py b/tests/test_search.py index 488b97b..7fc6c26 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -15,14 +15,14 @@ import unittest -from smolagents import load_tool +from smolagents import DuckDuckGoSearchTool from .test_tools import ToolTesterMixin class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): - self.tool = load_tool("web_search") + self.tool = DuckDuckGoSearchTool() self.tool.setup() def test_exact_match_arg(self):