parent
cf04285cc1
commit
cb9830a554
|
@ -39,10 +39,6 @@ contains the API docs for the underlying classes.
|
||||||
|
|
||||||
[[autodoc]] Tool
|
[[autodoc]] Tool
|
||||||
|
|
||||||
### Toolbox
|
|
||||||
|
|
||||||
[[autodoc]] Toolbox
|
|
||||||
|
|
||||||
### launch_gradio_demo
|
### launch_gradio_demo
|
||||||
|
|
||||||
[[autodoc]] launch_gradio_demo
|
[[autodoc]] launch_gradio_demo
|
||||||
|
|
|
@ -187,7 +187,7 @@ from smolagents import HfApiModel
|
||||||
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
|
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
|
||||||
|
|
||||||
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
|
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
|
||||||
agent.toolbox.add_tool(model_download_tool)
|
agent.tools.append(model_download_tool)
|
||||||
```
|
```
|
||||||
Now we can leverage the new tool:
|
Now we can leverage the new tool:
|
||||||
|
|
||||||
|
@ -202,11 +202,6 @@ agent.run(
|
||||||
> Beware of not adding too many tools to an agent: this can overwhelm weaker LLM engines.
|
> Beware of not adding too many tools to an agent: this can overwhelm weaker LLM engines.
|
||||||
|
|
||||||
|
|
||||||
Use the `agent.toolbox.update_tool()` method to replace an existing tool in the agent's toolbox.
|
|
||||||
This is useful if your new tool is a one-to-one replacement of the existing tool because the agent already knows how to perform that specific task.
|
|
||||||
Just make sure the new tool follows the same API as the replaced tool or adapt the system prompt template to ensure all examples using the replaced tool are updated.
|
|
||||||
|
|
||||||
|
|
||||||
### Use a collection of tools
|
### Use a collection of tools
|
||||||
|
|
||||||
You can leverage tool collections by using the ToolCollection object, with the slug of the collection you want to use.
|
You can leverage tool collections by using the ToolCollection object, with the slug of the collection you want to use.
|
||||||
|
|
|
@ -18,13 +18,14 @@ import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from rich import box
|
||||||
from rich.console import Group
|
from rich.console import Group
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.rule import Rule
|
from rich.rule import Rule
|
||||||
from rich.syntax import Syntax
|
from rich.syntax import Syntax
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from .default_tools import FinalAnswerTool
|
from .default_tools import FinalAnswerTool, TOOL_MAPPING
|
||||||
from .e2b_executor import E2BExecutor
|
from .e2b_executor import E2BExecutor
|
||||||
from .local_python_executor import (
|
from .local_python_executor import (
|
||||||
BASE_BUILTIN_MODULES,
|
BASE_BUILTIN_MODULES,
|
||||||
|
@ -49,7 +50,6 @@ from .prompts import (
|
||||||
from .tools import (
|
from .tools import (
|
||||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
Tool,
|
Tool,
|
||||||
Toolbox,
|
|
||||||
get_tool_description_with_args,
|
get_tool_description_with_args,
|
||||||
)
|
)
|
||||||
from .types import AgentAudio, AgentImage, handle_agent_output_types
|
from .types import AgentAudio, AgentImage, handle_agent_output_types
|
||||||
|
@ -107,18 +107,27 @@ class SystemPromptStep(AgentStep):
|
||||||
system_prompt: str
|
system_prompt: str
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_tools(
|
def get_tool_descriptions(
|
||||||
toolbox: Toolbox, prompt_template: str, tool_description_template: str
|
tools: Dict[str, Tool], tool_description_template: str
|
||||||
) -> str:
|
) -> str:
|
||||||
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
return "\n".join(
|
||||||
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
[
|
||||||
|
get_tool_description_with_args(tool, tool_description_template)
|
||||||
|
for tool in tools.values()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt_with_tools(
|
||||||
|
tools: Dict[str, Tool], prompt_template: str, tool_description_template: str
|
||||||
|
) -> str:
|
||||||
|
tool_descriptions = get_tool_descriptions(tools, tool_description_template)
|
||||||
|
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
||||||
if "{{tool_names}}" in prompt:
|
if "{{tool_names}}" in prompt:
|
||||||
prompt = prompt.replace(
|
prompt = prompt.replace(
|
||||||
"{{tool_names}}",
|
"{{tool_names}}",
|
||||||
", ".join([f"'{tool_name}'" for tool_name in toolbox.tools.keys()]),
|
", ".join([f"'{tool.name}'" for tool in tools.values()]),
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,7 +172,7 @@ class MultiStepAgent:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tools: Union[List[Tool], Toolbox],
|
tools: List[Tool],
|
||||||
model: Callable[[List[Dict[str, str]]], str],
|
model: Callable[[List[Dict[str, str]]], str],
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
tool_description_template: Optional[str] = None,
|
tool_description_template: Optional[str] = None,
|
||||||
|
@ -172,7 +181,7 @@ class MultiStepAgent:
|
||||||
add_base_tools: bool = False,
|
add_base_tools: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
grammar: Optional[Dict[str, str]] = None,
|
grammar: Optional[Dict[str, str]] = None,
|
||||||
managed_agents: Optional[Dict] = None,
|
managed_agents: Optional[List] = None,
|
||||||
step_callbacks: Optional[List[Callable]] = None,
|
step_callbacks: Optional[List[Callable]] = None,
|
||||||
planning_interval: Optional[int] = None,
|
planning_interval: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
@ -196,17 +205,18 @@ class MultiStepAgent:
|
||||||
|
|
||||||
self.managed_agents = {}
|
self.managed_agents = {}
|
||||||
if managed_agents is not None:
|
if managed_agents is not None:
|
||||||
|
print("NOTNONE")
|
||||||
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
||||||
|
|
||||||
if isinstance(tools, Toolbox):
|
self.tools = {tool.name: tool for tool in tools}
|
||||||
self._toolbox = tools
|
|
||||||
if add_base_tools:
|
if add_base_tools:
|
||||||
self._toolbox.add_base_tools(
|
for tool_name, tool_class in TOOL_MAPPING.items():
|
||||||
add_python_interpreter=(self.__class__ == ToolCallingAgent)
|
if (
|
||||||
)
|
tool_name != "python_interpreter"
|
||||||
else:
|
or self.__class__.__name__ == "ToolCallingAgent"
|
||||||
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
):
|
||||||
self._toolbox.add_tool(FinalAnswerTool())
|
self.tools[tool_name] = tool_class()
|
||||||
|
self.tools["final_answer"] = FinalAnswerTool()
|
||||||
|
|
||||||
self.system_prompt = self.initialize_system_prompt()
|
self.system_prompt = self.initialize_system_prompt()
|
||||||
self.input_messages = None
|
self.input_messages = None
|
||||||
|
@ -217,14 +227,9 @@ class MultiStepAgent:
|
||||||
self.step_callbacks = step_callbacks if step_callbacks is not None else []
|
self.step_callbacks = step_callbacks if step_callbacks is not None else []
|
||||||
self.step_callbacks.append(self.monitor.update_metrics)
|
self.step_callbacks.append(self.monitor.update_metrics)
|
||||||
|
|
||||||
@property
|
|
||||||
def toolbox(self) -> Toolbox:
|
|
||||||
"""Get the toolbox currently available to the agent"""
|
|
||||||
return self._toolbox
|
|
||||||
|
|
||||||
def initialize_system_prompt(self):
|
def initialize_system_prompt(self):
|
||||||
self.system_prompt = format_prompt_with_tools(
|
self.system_prompt = format_prompt_with_tools(
|
||||||
self._toolbox,
|
self.tools,
|
||||||
self.system_prompt_template,
|
self.system_prompt_template,
|
||||||
self.tool_description_template,
|
self.tool_description_template,
|
||||||
)
|
)
|
||||||
|
@ -384,10 +389,10 @@ class MultiStepAgent:
|
||||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
tool_name (`str`): Name of the Tool to execute (should be one from self.tools).
|
||||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||||
"""
|
"""
|
||||||
available_tools = {**self.toolbox.tools, **self.managed_agents}
|
available_tools = {**self.tools, **self.managed_agents}
|
||||||
if tool_name not in available_tools:
|
if tool_name not in available_tools:
|
||||||
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
error_msg = f"Unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}."
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
|
@ -415,7 +420,7 @@ class MultiStepAgent:
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
return observation
|
return observation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if tool_name in self.toolbox.tools:
|
if tool_name in self.tools:
|
||||||
tool_description = get_tool_description_with_args(
|
tool_description = get_tool_description_with_args(
|
||||||
available_tools[tool_name]
|
available_tools[tool_name]
|
||||||
)
|
)
|
||||||
|
@ -512,20 +517,26 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
||||||
"""
|
"""
|
||||||
final_answer = None
|
final_answer = None
|
||||||
step_number = 0
|
self.step_number = 0
|
||||||
while final_answer is None and step_number < self.max_steps:
|
while final_answer is None and self.step_number < self.max_steps:
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(step=step_number, start_time=step_start_time)
|
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
||||||
try:
|
try:
|
||||||
if (
|
if (
|
||||||
self.planning_interval is not None
|
self.planning_interval is not None
|
||||||
and step_number % self.planning_interval == 0
|
and self.step_number % self.planning_interval == 0
|
||||||
):
|
):
|
||||||
self.planning_step(
|
self.planning_step(
|
||||||
task, is_first_step=(step_number == 0), step=step_number
|
task,
|
||||||
|
is_first_step=(self.step_number == 0),
|
||||||
|
step=self.step_number,
|
||||||
)
|
)
|
||||||
console.print(
|
console.print(
|
||||||
Rule(f"[bold]Step {step_number}", characters="━", style=YELLOW_HEX)
|
Rule(
|
||||||
|
f"[bold]Step {self.step_number}",
|
||||||
|
characters="━",
|
||||||
|
style=YELLOW_HEX,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run one step!
|
# Run one step!
|
||||||
|
@ -538,10 +549,10 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
self.logs.append(step_log)
|
self.logs.append(step_log)
|
||||||
for callback in self.step_callbacks:
|
for callback in self.step_callbacks:
|
||||||
callback(step_log)
|
callback(step_log)
|
||||||
step_number += 1
|
self.step_number += 1
|
||||||
yield step_log
|
yield step_log
|
||||||
|
|
||||||
if final_answer is None and step_number == self.max_steps:
|
if final_answer is None and self.step_number == self.max_steps:
|
||||||
error_message = "Reached max steps."
|
error_message = "Reached max steps."
|
||||||
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
|
@ -561,20 +572,26 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
||||||
"""
|
"""
|
||||||
final_answer = None
|
final_answer = None
|
||||||
step_number = 0
|
self.step_number = 0
|
||||||
while final_answer is None and step_number < self.max_steps:
|
while final_answer is None and self.step_number < self.max_steps:
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
step_log = ActionStep(step=step_number, start_time=step_start_time)
|
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
||||||
try:
|
try:
|
||||||
if (
|
if (
|
||||||
self.planning_interval is not None
|
self.planning_interval is not None
|
||||||
and step_number % self.planning_interval == 0
|
and self.step_number % self.planning_interval == 0
|
||||||
):
|
):
|
||||||
self.planning_step(
|
self.planning_step(
|
||||||
task, is_first_step=(step_number == 0), step=step_number
|
task,
|
||||||
|
is_first_step=(self.step_number == 0),
|
||||||
|
step=self.step_number,
|
||||||
)
|
)
|
||||||
console.print(
|
console.print(
|
||||||
Rule(f"[bold]Step {step_number}", characters="━", style=YELLOW_HEX)
|
Rule(
|
||||||
|
f"[bold]Step {self.step_number}",
|
||||||
|
characters="━",
|
||||||
|
style=YELLOW_HEX,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run one step!
|
# Run one step!
|
||||||
|
@ -589,9 +606,9 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
self.logs.append(step_log)
|
self.logs.append(step_log)
|
||||||
for callback in self.step_callbacks:
|
for callback in self.step_callbacks:
|
||||||
callback(step_log)
|
callback(step_log)
|
||||||
step_number += 1
|
self.step_number += 1
|
||||||
|
|
||||||
if final_answer is None and step_number == self.max_steps:
|
if final_answer is None and self.step_number == self.max_steps:
|
||||||
error_message = "Reached max steps."
|
error_message = "Reached max steps."
|
||||||
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
final_step_log = ActionStep(error=AgentMaxStepsError(error_message))
|
||||||
self.logs.append(final_step_log)
|
self.logs.append(final_step_log)
|
||||||
|
@ -637,8 +654,8 @@ Now begin!""",
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": USER_PROMPT_PLAN.format(
|
"content": USER_PROMPT_PLAN.format(
|
||||||
task=task,
|
task=task,
|
||||||
tool_descriptions=self._toolbox.show_tool_descriptions(
|
tool_descriptions=get_tool_descriptions(
|
||||||
self.tool_description_template
|
self.tools, self.tool_description_template
|
||||||
),
|
),
|
||||||
managed_agents_descriptions=(
|
managed_agents_descriptions=(
|
||||||
show_agents_descriptions(self.managed_agents)
|
show_agents_descriptions(self.managed_agents)
|
||||||
|
@ -692,8 +709,8 @@ Now begin!""",
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": USER_PROMPT_PLAN_UPDATE.format(
|
"content": USER_PROMPT_PLAN_UPDATE.format(
|
||||||
task=task,
|
task=task,
|
||||||
tool_descriptions=self._toolbox.show_tool_descriptions(
|
tool_descriptions=get_tool_descriptions(
|
||||||
self.tool_description_template
|
self.tools, self.tool_description_template
|
||||||
),
|
),
|
||||||
managed_agents_descriptions=(
|
managed_agents_descriptions=(
|
||||||
show_agents_descriptions(self.managed_agents)
|
show_agents_descriptions(self.managed_agents)
|
||||||
|
@ -761,7 +778,7 @@ class ToolCallingAgent(MultiStepAgent):
|
||||||
try:
|
try:
|
||||||
tool_name, tool_arguments, tool_call_id = self.model.get_tool_call(
|
tool_name, tool_arguments, tool_call_id = self.model.get_tool_call(
|
||||||
self.input_messages,
|
self.input_messages,
|
||||||
available_tools=list(self.toolbox._tools.values()),
|
available_tools=list(self.tools.values()),
|
||||||
stop_sequences=["Observation:"],
|
stop_sequences=["Observation:"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -856,7 +873,7 @@ class CodeAgent(MultiStepAgent):
|
||||||
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
|
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
|
||||||
)
|
)
|
||||||
|
|
||||||
all_tools = {**self.toolbox.tools, **self.managed_agents}
|
all_tools = {**self.tools, **self.managed_agents}
|
||||||
if use_e2b_executor:
|
if use_e2b_executor:
|
||||||
self.python_executor = E2BExecutor(
|
self.python_executor = E2BExecutor(
|
||||||
self.additional_authorized_imports, list(all_tools.values())
|
self.additional_authorized_imports, list(all_tools.values())
|
||||||
|
@ -941,10 +958,10 @@ class CodeAgent(MultiStepAgent):
|
||||||
lexer="python",
|
lexer="python",
|
||||||
theme="monokai",
|
theme="monokai",
|
||||||
word_wrap=True,
|
word_wrap=True,
|
||||||
line_numbers=True,
|
|
||||||
),
|
),
|
||||||
title="[bold]Executing this code:",
|
title="[bold]Executing this code:",
|
||||||
title_align="left",
|
title_align="left",
|
||||||
|
box=box.HORIZONTALS,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
observation = ""
|
observation = ""
|
||||||
|
@ -1045,5 +1062,4 @@ __all__ = [
|
||||||
"MultiStepAgent",
|
"MultiStepAgent",
|
||||||
"CodeAgent",
|
"CodeAgent",
|
||||||
"ToolCallingAgent",
|
"ToolCallingAgent",
|
||||||
"Toolbox",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -322,6 +322,15 @@ class SpeechToTextTool(PipelineTool):
|
||||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_MAPPING = {
|
||||||
|
tool_class.name: tool_class
|
||||||
|
for tool_class in [
|
||||||
|
PythonInterpreterTool,
|
||||||
|
DuckDuckGoSearchTool,
|
||||||
|
VisitWebpageTool,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PythonInterpreterTool",
|
"PythonInterpreterTool",
|
||||||
"FinalAnswerTool",
|
"FinalAnswerTool",
|
||||||
|
|
|
@ -157,6 +157,14 @@ class Model:
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_tool_call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
available_tools: List[Tool],
|
||||||
|
stop_sequences,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
|
|
|
@ -25,7 +25,7 @@ import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Union, get_type_hints
|
from typing import Callable, Dict, Optional, Union, get_type_hints
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
|
@ -85,18 +85,6 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
||||||
return "space"
|
return "space"
|
||||||
|
|
||||||
|
|
||||||
def setup_default_tools():
|
|
||||||
default_tools = {}
|
|
||||||
main_module = importlib.import_module("smolagents")
|
|
||||||
|
|
||||||
for task_name, tool_class_name in TOOL_MAPPING.items():
|
|
||||||
tool_class = getattr(main_module, tool_class_name)
|
|
||||||
tool_instance = tool_class()
|
|
||||||
default_tools[tool_class.name] = tool_instance
|
|
||||||
|
|
||||||
return default_tools
|
|
||||||
|
|
||||||
|
|
||||||
def validate_after_init(cls):
|
def validate_after_init(cls):
|
||||||
original_init = cls.__init__
|
original_init = cls.__init__
|
||||||
|
|
||||||
|
@ -727,10 +715,10 @@ def get_tool_description_with_args(
|
||||||
if description_template is None:
|
if description_template is None:
|
||||||
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||||
compiled_template = compile_jinja_template(description_template)
|
compiled_template = compile_jinja_template(description_template)
|
||||||
rendered = compiled_template.render(
|
tool_description = compiled_template.render(
|
||||||
tool=tool,
|
tool=tool,
|
||||||
)
|
)
|
||||||
return rendered
|
return tool_description
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
|
@ -806,13 +794,6 @@ def launch_gradio_demo(tool: Tool):
|
||||||
).launch()
|
).launch()
|
||||||
|
|
||||||
|
|
||||||
TOOL_MAPPING = {
|
|
||||||
"python_interpreter": "PythonInterpreterTool",
|
|
||||||
"web_search": "DuckDuckGoSearchTool",
|
|
||||||
"transcriber": "SpeechToTextTool",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_tool(
|
def load_tool(
|
||||||
task_or_repo_id,
|
task_or_repo_id,
|
||||||
model_repo_id: Optional[str] = None,
|
model_repo_id: Optional[str] = None,
|
||||||
|
@ -821,7 +802,7 @@ def load_tool(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
|
Main function to quickly load a tool from the Hub.
|
||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
@ -854,13 +835,6 @@ def load_tool(
|
||||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
||||||
will be passed along to its init.
|
will be passed along to its init.
|
||||||
"""
|
"""
|
||||||
if task_or_repo_id in TOOL_MAPPING:
|
|
||||||
tool_class_name = TOOL_MAPPING[task_or_repo_id]
|
|
||||||
main_module = importlib.import_module("smolagents")
|
|
||||||
tools_module = main_module
|
|
||||||
tool_class = getattr(tools_module, tool_class_name)
|
|
||||||
return tool_class(token=token, **kwargs)
|
|
||||||
else:
|
|
||||||
return Tool.from_hub(
|
return Tool.from_hub(
|
||||||
task_or_repo_id,
|
task_or_repo_id,
|
||||||
model_repo_id=model_repo_id,
|
model_repo_id=model_repo_id,
|
||||||
|
@ -961,107 +935,6 @@ def tool(tool_function: Callable) -> Tool:
|
||||||
return simple_tool
|
return simple_tool
|
||||||
|
|
||||||
|
|
||||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
|
||||||
|
|
||||||
|
|
||||||
class Toolbox:
|
|
||||||
"""
|
|
||||||
The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
|
|
||||||
manage them.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tools (`List[Tool]`):
|
|
||||||
The list of tools to instantiate the toolbox with
|
|
||||||
add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
||||||
Whether to add the tools available within `transformers` to the toolbox.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, tools: List[Tool], add_base_tools: bool = False):
|
|
||||||
self._tools = {tool.name: tool for tool in tools}
|
|
||||||
if add_base_tools:
|
|
||||||
self.add_base_tools()
|
|
||||||
|
|
||||||
def add_base_tools(self, add_python_interpreter: bool = False):
|
|
||||||
global HUGGINGFACE_DEFAULT_TOOLS
|
|
||||||
if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0:
|
|
||||||
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
|
|
||||||
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
|
||||||
if tool.name != "python_interpreter" or add_python_interpreter:
|
|
||||||
self.add_tool(tool)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tools(self) -> Dict[str, Tool]:
|
|
||||||
"""Get all tools currently in the toolbox"""
|
|
||||||
return self._tools
|
|
||||||
|
|
||||||
def show_tool_descriptions(
|
|
||||||
self, tool_description_template: Optional[str] = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Returns the description of all tools in the toolbox
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_description_template (`str`, *optional*):
|
|
||||||
The template to use to describe the tools. If not provided, the default template will be used.
|
|
||||||
"""
|
|
||||||
return "\n".join(
|
|
||||||
[
|
|
||||||
get_tool_description_with_args(tool, tool_description_template)
|
|
||||||
for tool in self._tools.values()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_tool(self, tool: Tool):
|
|
||||||
"""
|
|
||||||
Adds a tool to the toolbox
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool (`Tool`):
|
|
||||||
The tool to add to the toolbox.
|
|
||||||
"""
|
|
||||||
if tool.name in self._tools:
|
|
||||||
raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
|
|
||||||
self._tools[tool.name] = tool
|
|
||||||
|
|
||||||
def remove_tool(self, tool_name: str):
|
|
||||||
"""
|
|
||||||
Removes a tool from the toolbox
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name (`str`):
|
|
||||||
The tool to remove from the toolbox.
|
|
||||||
"""
|
|
||||||
if tool_name not in self._tools:
|
|
||||||
raise KeyError(
|
|
||||||
f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}."
|
|
||||||
)
|
|
||||||
del self._tools[tool_name]
|
|
||||||
|
|
||||||
def update_tool(self, tool: Tool):
|
|
||||||
"""
|
|
||||||
Updates a tool in the toolbox according to its name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool (`Tool`):
|
|
||||||
The tool to update to the toolbox.
|
|
||||||
"""
|
|
||||||
if tool.name not in self._tools:
|
|
||||||
raise KeyError(
|
|
||||||
f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}."
|
|
||||||
)
|
|
||||||
self._tools[tool.name] = tool
|
|
||||||
|
|
||||||
def clear_toolbox(self):
|
|
||||||
"""Clears the toolbox"""
|
|
||||||
self._tools = {}
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
toolbox_description = "Toolbox contents:\n"
|
|
||||||
for tool in self._tools.values():
|
|
||||||
toolbox_description += f"\t{tool.name}: {tool.description}\n"
|
|
||||||
return toolbox_description
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineTool(Tool):
|
class PipelineTool(Tool):
|
||||||
"""
|
"""
|
||||||
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
|
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
|
||||||
|
@ -1234,6 +1107,5 @@ __all__ = [
|
||||||
"tool",
|
"tool",
|
||||||
"load_tool",
|
"load_tool",
|
||||||
"launch_gradio_demo",
|
"launch_gradio_demo",
|
||||||
"Toolbox",
|
|
||||||
"ToolCollection",
|
"ToolCollection",
|
||||||
]
|
]
|
||||||
|
|
|
@ -18,14 +18,12 @@ import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
from smolagents.agents import (
|
from smolagents.agents import (
|
||||||
AgentMaxStepsError,
|
AgentMaxStepsError,
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
ManagedAgent,
|
ManagedAgent,
|
||||||
Toolbox,
|
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallingAgent,
|
ToolCallingAgent,
|
||||||
)
|
)
|
||||||
|
@ -289,37 +287,35 @@ class AgentTests(unittest.TestCase):
|
||||||
assert len(agent.logs) == 8
|
assert len(agent.logs) == 8
|
||||||
assert type(agent.logs[-1].error) is AgentMaxStepsError
|
assert type(agent.logs[-1].error) is AgentMaxStepsError
|
||||||
|
|
||||||
|
def test_tool_descriptions_get_baked_in_system_prompt(self):
|
||||||
|
tool = PythonInterpreterTool()
|
||||||
|
tool.name = "fake_tool_name"
|
||||||
|
tool.description = "fake_tool_description"
|
||||||
|
agent = CodeAgent(tools=[tool], model=fake_code_model)
|
||||||
|
agent.run("Empty task")
|
||||||
|
assert tool.name in agent.system_prompt
|
||||||
|
assert tool.description in agent.system_prompt
|
||||||
|
|
||||||
def test_init_agent_with_different_toolsets(self):
|
def test_init_agent_with_different_toolsets(self):
|
||||||
toolset_1 = []
|
toolset_1 = []
|
||||||
agent = CodeAgent(tools=toolset_1, model=fake_code_model)
|
agent = CodeAgent(tools=toolset_1, model=fake_code_model)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 1
|
len(agent.tools) == 1
|
||||||
) # when no tools are provided, only the final_answer tool is added by default
|
) # when no tools are provided, only the final_answer tool is added by default
|
||||||
|
|
||||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||||
agent = CodeAgent(tools=toolset_2, model=fake_code_model)
|
agent = CodeAgent(tools=toolset_2, model=fake_code_model)
|
||||||
assert (
|
assert (
|
||||||
len(agent.toolbox.tools) == 2
|
len(agent.tools) == 2
|
||||||
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
||||||
|
|
||||||
toolset_3 = Toolbox(toolset_2)
|
# check that python_interpreter base tool does not get added to CodeAgent
|
||||||
agent = CodeAgent(tools=toolset_3, model=fake_code_model)
|
|
||||||
assert (
|
|
||||||
len(agent.toolbox.tools) == 2
|
|
||||||
) # same as previous one, where toolset_3 is an instantiation of previous one
|
|
||||||
|
|
||||||
# check that add_base_tools will not interfere with existing tools
|
|
||||||
with pytest.raises(KeyError) as e:
|
|
||||||
agent = ToolCallingAgent(
|
|
||||||
tools=toolset_3, model=FakeToolCallModel(), add_base_tools=True
|
|
||||||
)
|
|
||||||
assert "already exists in the toolbox" in str(e)
|
|
||||||
|
|
||||||
# check that python_interpreter base tool does not get added to code agents
|
|
||||||
agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True)
|
agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True)
|
||||||
assert (
|
assert len(agent.tools) == 3 # added final_answer tool + search + visit_webpage
|
||||||
len(agent.toolbox.tools) == 3
|
|
||||||
) # added final_answer tool + search + transcribe
|
# check that python_interpreter base tool gets added to ToolCallingAgent
|
||||||
|
agent = ToolCallingAgent(tools=[], model=fake_code_model, add_base_tools=True)
|
||||||
|
assert len(agent.tools) == 4 # added final_answer tool + search + visit_webpage
|
||||||
|
|
||||||
def test_function_persistence_across_steps(self):
|
def test_function_persistence_across_steps(self):
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
|
|
|
@ -18,8 +18,7 @@ import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from smolagents import load_tool
|
from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool
|
||||||
from smolagents.default_tools import BASE_PYTHON_TOOLS
|
|
||||||
from smolagents.local_python_executor import (
|
from smolagents.local_python_executor import (
|
||||||
InterpreterError,
|
InterpreterError,
|
||||||
evaluate_python_code,
|
evaluate_python_code,
|
||||||
|
@ -37,7 +36,7 @@ def add_two(x):
|
||||||
|
|
||||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
|
self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"])
|
||||||
self.tool.setup()
|
self.tool.setup()
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
|
|
|
@ -15,14 +15,14 @@
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from smolagents import load_tool
|
from smolagents import DuckDuckGoSearchTool
|
||||||
|
|
||||||
from .test_tools import ToolTesterMixin
|
from .test_tools import ToolTesterMixin
|
||||||
|
|
||||||
|
|
||||||
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
|
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tool = load_tool("web_search")
|
self.tool = DuckDuckGoSearchTool()
|
||||||
self.tool.setup()
|
self.tool.setup()
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
|
|
Loading…
Reference in New Issue