Share full agents (#533)

Co-authored-by: Alex <sysradium@users.noreply.github.com>
Co-authored-by: Parteek <parteekkamboj112@gmail.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: David Berenstein <david.m.berenstein@gmail.com>
This commit is contained in:
Aymeric Roucher 2025-02-13 16:24:44 +01:00 committed by GitHub
parent 360e1a8781
commit 1c1418dfa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 820 additions and 273 deletions

View File

@ -567,13 +567,13 @@ class WavConverter(MediaConverter):
class Mp3Converter(WavConverter): class Mp3Converter(WavConverter):
""" """
Converts MP3 files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` AND `pydub` are installed). Converts MP3 and M4A files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` AND `pydub` are installed).
""" """
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]: def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
# Bail if not a MP3 # Bail if not a MP3
extension = kwargs.get("file_extension", "") extension = kwargs.get("file_extension", "")
if extension.lower() != ".mp3": if extension.lower() not in [".mp3", ".m4a"]:
return None return None
md_content = "" md_content = ""
@ -600,7 +600,10 @@ class Mp3Converter(WavConverter):
handle, temp_path = tempfile.mkstemp(suffix=".wav") handle, temp_path = tempfile.mkstemp(suffix=".wav")
os.close(handle) os.close(handle)
try: try:
if extension.lower() == ".mp3":
sound = pydub.AudioSegment.from_mp3(local_path) sound = pydub.AudioSegment.from_mp3(local_path)
else:
sound = pydub.AudioSegment.from_file(local_path, format="m4a")
sound.export(temp_path, format="wav") sound.export(temp_path, format="wav")
_args = dict() _args = dict()

View File

@ -10,7 +10,7 @@ class TextInspectorTool(Tool):
name = "inspect_file_as_text" name = "inspect_file_as_text"
description = """ description = """
You cannot load files yourself: instead call this tool to read a file as markdown text and ask questions about it. You cannot load files yourself: instead call this tool to read a file as markdown text and ask questions about it.
This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES.""" This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".m4a", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES."""
inputs = { inputs = {
"file_path": { "file_path": {

View File

@ -410,7 +410,7 @@ class VisitTool(Tool):
class DownloadTool(Tool): class DownloadTool(Tool):
name = "download_file" name = "download_file"
description = """ description = """
Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"] Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".m4a", ".png", ".docx"]
After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it. After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it.
DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead.""" DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead."""
inputs = {"url": {"type": "string", "description": "The relative or absolute url of the file to be downloaded."}} inputs = {"url": {"type": "string", "description": "The relative or absolute url of the file to be downloaded."}}

View File

@ -24,7 +24,6 @@ TODO: move them to `huggingface_hub` to avoid code duplication.
import inspect import inspect
import json import json
import os
import re import re
import types import types
from copy import copy from copy import copy
@ -46,34 +45,31 @@ from huggingface_hub.utils import is_torch_available
from .utils import _is_pillow_available from .utils import _is_pillow_available
def get_imports(filename: Union[str, os.PathLike]) -> List[str]: def get_imports(code: str) -> List[str]:
""" """
Extracts all the libraries (not relative imports this time) that are imported in a file. Extracts all the libraries (not relative imports) that are imported in a code.
Args: Args:
filename (`str` or `os.PathLike`): The module file to inspect. code (`str`): Code text to inspect.
Returns: Returns:
`List[str]`: The list of all packages required to use the input module. `list[str]`: List of all packages required to use the input code.
""" """
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
# filter out try/except block so in custom code we can have try/except imports # filter out try/except block so in custom code we can have try/except imports
content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL) code = re.sub(r"\s*try\s*:.*?except.*?:", "", code, flags=re.DOTALL)
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
content = re.sub( code = re.sub(
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+",
"", "",
content, code,
flags=re.MULTILINE, flags=re.MULTILINE,
) )
# Imports of the form `import xxx` # Imports of the form `import xxx` or `import xxx as yyy`
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) imports = re.findall(r"^\s*import\s+(\S+?)(?:\s+as\s+\S+)?\s*$", code, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy` # Imports of the form `from xxx import yyy`
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) imports += re.findall(r"^\s*from\s+(\S+)\s+import", code, flags=re.MULTILINE)
# Only keep the top-level module # Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
return list(set(imports)) return list(set(imports))

View File

@ -14,44 +14,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"]
import importlib.resources
import inspect import inspect
import json
import os
import re import re
import tempfile
import textwrap import textwrap
import time import time
from collections import deque from collections import deque
from logging import getLogger from logging import getLogger
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union
import jinja2
import yaml import yaml
from huggingface_hub import create_repo, metadata_update, snapshot_download, upload_folder
from jinja2 import StrictUndefined, Template from jinja2 import StrictUndefined, Template
from rich.console import Group from rich.console import Group
from rich.panel import Panel from rich.panel import Panel
from rich.rule import Rule from rich.rule import Rule
from rich.text import Text from rich.text import Text
from smolagents.agent_types import AgentAudio, AgentImage, handle_agent_output_types from .agent_types import AgentAudio, AgentImage, AgentType, handle_agent_output_types
from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from smolagents.monitoring import (
YELLOW_HEX,
AgentLogger,
LogLevel,
)
from smolagents.utils import (
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
AgentParsingError,
parse_code_blobs,
parse_json_tool_call,
truncate_content,
)
from .agent_types import AgentType
from .default_tools import TOOL_MAPPING, FinalAnswerTool from .default_tools import TOOL_MAPPING, FinalAnswerTool
from .e2b_executor import E2BExecutor from .e2b_executor import E2BExecutor
from .local_python_executor import ( from .local_python_executor import (
@ -59,12 +44,30 @@ from .local_python_executor import (
LocalPythonInterpreter, LocalPythonInterpreter,
fix_final_answer_code, fix_final_answer_code,
) )
from .memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
from .models import ( from .models import (
ChatMessage, ChatMessage,
MessageRole, MessageRole,
Model,
)
from .monitoring import (
YELLOW_HEX,
AgentLogger,
LogLevel,
Monitor,
) )
from .monitoring import Monitor
from .tools import Tool from .tools import Tool
from .utils import (
AgentError,
AgentExecutionError,
AgentGenerationError,
AgentMaxStepsError,
AgentParsingError,
make_init_file,
parse_code_blobs,
parse_json_tool_call,
truncate_content,
)
logger = getLogger(__name__) logger = getLogger(__name__)
@ -228,9 +231,19 @@ class MultiStepAgent:
) )
self.managed_agents = {agent.name: agent for agent in managed_agents} self.managed_agents = {agent.name: agent for agent in managed_agents}
tool_and_managed_agent_names = [tool.name for tool in tools]
if managed_agents is not None:
tool_and_managed_agent_names += [agent.name for agent in managed_agents]
if len(tool_and_managed_agent_names) != len(set(tool_and_managed_agent_names)):
raise ValueError(
"Each tool or managed_agent should have a unique name! You passed these duplicate names: "
f"{[name for name in tool_and_managed_agent_names if tool_and_managed_agent_names.count(name) > 1]}"
)
for tool in tools: for tool in tools:
assert isinstance(tool, Tool), f"This element is not of class Tool: {str(tool)}" assert isinstance(tool, Tool), f"This element is not of class Tool: {str(tool)}"
self.tools = {tool.name: tool for tool in tools} self.tools = {tool.name: tool for tool in tools}
if add_base_tools: if add_base_tools:
for tool_name, tool_class in TOOL_MAPPING.items(): for tool_name, tool_class in TOOL_MAPPING.items():
if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent": if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent":
@ -709,6 +722,310 @@ You have been provided with these additional arguments, that you can access usin
answer += "\n</summary_of_work>" answer += "\n</summary_of_work>"
return answer return answer
def save(self, output_dir: str, relative_path: Optional[str] = None):
"""
Saves the relevant code files for your agent. This will copy the code of your agent in `output_dir` as well as autogenerate:
- a `tools` folder containing the logic for each of the tools under `tools/{tool_name}.py`.
- a `managed_agents` folder containing the logic for each of the managed agents.
- an `agent.json` file containing a dictionary representing your agent.
- a `prompt.yaml` file containing the prompt templates used by your agent.
- an `app.py` file providing a UI for your agent when it is exported to a Space with `agent.push_to_hub()`
- a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
code)
Args:
output_dir (`str`): The folder in which you want to save your tool.
"""
make_init_file(output_dir)
# Recursively save managed agents
if self.managed_agents:
make_init_file(os.path.join(output_dir, "managed_agents"))
for agent_name, agent in self.managed_agents.items():
agent_suffix = f"managed_agents.{agent_name}"
if relative_path:
agent_suffix = relative_path + "." + agent_suffix
agent.save(os.path.join(output_dir, "managed_agents", agent_name), relative_path=agent_suffix)
class_name = self.__class__.__name__
# Save tools to different .py files
for tool in self.tools.values():
make_init_file(os.path.join(output_dir, "tools"))
tool.save(os.path.join(output_dir, "tools"), tool_file_name=tool.name, make_gradio_app=False)
# Save prompts to yaml
yaml_prompts = yaml.safe_dump(
self.prompt_templates,
default_style="|", # This forces block literals for all strings
default_flow_style=False,
width=float("inf"),
sort_keys=False,
allow_unicode=True,
indent=2,
)
with open(os.path.join(output_dir, "prompts.yaml"), "w", encoding="utf-8") as f:
f.write(yaml_prompts)
# Save agent dictionary to json
agent_dict = self.to_dict()
agent_dict["tools"] = [tool.name for tool in self.tools.values()]
with open(os.path.join(output_dir, "agent.json"), "w", encoding="utf-8") as f:
json.dump(agent_dict, f, indent=4)
# Save requirements
with open(os.path.join(output_dir, "requirements.txt"), "w", encoding="utf-8") as f:
f.writelines(f"{r}\n" for r in agent_dict["requirements"])
# Make agent.py file with Gradio UI
agent_name = f"agent_{self.name}" if getattr(self, "name", None) else "agent"
managed_agent_relative_path = relative_path + "." if relative_path is not None else ""
app_template = textwrap.dedent("""
import yaml
import os
from smolagents import GradioUI, {{ class_name }}, {{ agent_dict['model']['class'] }}
# Get current directory path
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
{% for tool in tools.values() -%}
from {{managed_agent_relative_path}}tools.{{ tool.name }} import {{ tool.__class__.__name__ }} as {{ tool.name | camelcase }}
{% endfor %}
{% for managed_agent in managed_agents.values() -%}
from {{managed_agent_relative_path}}managed_agents.{{ managed_agent.name }}.app import agent_{{ managed_agent.name }}
{% endfor %}
model = {{ agent_dict['model']['class'] }}(
{% for key in agent_dict['model']['data'] if key not in ['class', 'last_input_token_count', 'last_output_token_count'] -%}
{{ key }}={{ agent_dict['model']['data'][key]|repr }},
{% endfor %})
{% for tool in tools.values() -%}
{{ tool.name }} = {{ tool.name | camelcase }}()
{% endfor %}
with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream:
prompt_templates = yaml.safe_load(stream)
{{ agent_name }} = {{ class_name }}(
model=model,
tools=[{% for tool_name in tools.keys() if tool_name != "final_answer" %}{{ tool_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
managed_agents=[{% for subagent_name in managed_agents.keys() %}agent_{{ subagent_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
{% for attribute_name, value in agent_dict.items() if attribute_name not in ["model", "tools", "prompt_templates", "authorized_imports", "managed_agents", "requirements"] -%}
{{ attribute_name }}={{ value|repr }},
{% endfor %}prompt_templates=prompt_templates
)
if __name__ == "__main__":
GradioUI({{ agent_name }}).launch()
""").strip()
template_env = jinja2.Environment(loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined)
template_env.filters["repr"] = repr
template_env.filters["camelcase"] = lambda value: "".join(word.capitalize() for word in value.split("_"))
template = template_env.from_string(app_template)
# Render the app.py file from Jinja2 template
app_text = template.render(
{
"agent_name": agent_name,
"class_name": class_name,
"agent_dict": agent_dict,
"tools": self.tools,
"managed_agents": self.managed_agents,
"managed_agent_relative_path": managed_agent_relative_path,
}
)
with open(os.path.join(output_dir, "app.py"), "w", encoding="utf-8") as f:
f.write(app_text + "\n") # Append newline at the end
def to_dict(self) -> Dict[str, Any]:
"""Converts agent into a dictionary."""
# TODO: handle serializing step_callbacks and final_answer_checks
for attr in ["final_answer_checks", "step_callbacks"]:
if getattr(self, attr, None):
self.logger.log(f"This agent has {attr}: they will be ignored by this method.", LogLevel.INFO)
tool_dicts = [tool.to_dict() for tool in self.tools.values()]
tool_requirements = {req for tool in self.tools.values() for req in tool.to_dict()["requirements"]}
managed_agents_requirements = {
req for managed_agent in self.managed_agents.values() for req in managed_agent.to_dict()["requirements"]
}
requirements = tool_requirements | managed_agents_requirements
if hasattr(self, "authorized_imports"):
requirements.update(
{package.split(".")[0] for package in self.authorized_imports if package not in BASE_BUILTIN_MODULES}
)
agent_dict = {
"tools": tool_dicts,
"model": {
"class": self.model.__class__.__name__,
"data": self.model.to_dict(),
},
"managed_agents": {
managed_agent.name: managed_agent.__class__.__name__ for managed_agent in self.managed_agents.values()
},
"prompt_templates": self.prompt_templates,
"max_steps": self.max_steps,
"verbosity_level": int(self.logger.level),
"grammar": self.grammar,
"planning_interval": self.planning_interval,
"name": self.name,
"description": self.description,
"requirements": list(requirements),
}
if hasattr(self, "authorized_imports"):
agent_dict["authorized_imports"] = self.authorized_imports
return agent_dict
@classmethod
def from_hub(
cls,
repo_id: str,
token: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs,
):
"""
Loads an agent defined on the Hub.
<Tip warning={true}>
Loading a tool from the Hub means that you'll download the tool and execute it locally.
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
installing a package using pip/npm/apt.
</Tip>
Args:
repo_id (`str`):
The name of the repo on the Hub where your tool is defined.
token (`str`, *optional*):
The token to identify you on hf.co. If unset, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
trust_remote_code(`bool`, *optional*, defaults to False):
This flags marks that you understand the risk of running remote code and that you trust this tool.
If not setting this to True, loading the tool from Hub will fail.
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your agent, and the
others will be passed along to its init.
"""
if not trust_remote_code:
raise ValueError(
"Loading an agent from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
)
# Get the agent's Hub folder.
download_kwargs = {"token": token, "repo_type": "space"} | {
key: kwargs.pop(key)
for key in [
"cache_dir",
"force_download",
"resume_download",
"proxies",
"revision",
"subfolder",
"local_files_only",
]
if key in kwargs
}
download_folder = Path(snapshot_download(repo_id=repo_id, **download_kwargs))
return cls.from_folder(download_folder, **kwargs)
@classmethod
def from_folder(cls, folder: Union[str, Path], **kwargs):
"""Loads an agent from a local folder"""
folder = Path(folder)
agent_dict = json.loads((folder / "agent.json").read_text())
# Recursively get managed agents
managed_agents = []
for managed_agent_name, managed_agent_class in agent_dict["managed_agents"].items():
agent_cls = getattr(importlib.import_module("smolagents.agents"), managed_agent_class)
managed_agents.append(agent_cls.from_folder(folder / "managed_agents" / managed_agent_name))
tools = []
for tool_name in agent_dict["tools"]:
tool_code = (folder / "tools" / f"{tool_name}.py").read_text()
tools.append(Tool.from_code(tool_code))
model_class: Model = getattr(importlib.import_module("smolagents.models"), agent_dict["model"]["class"])
model = model_class.from_dict(agent_dict["model"]["data"])
args = dict(
model=model,
tools=tools,
managed_agents=managed_agents,
name=agent_dict["name"],
description=agent_dict["description"],
max_steps=agent_dict["max_steps"],
planning_interval=agent_dict["planning_interval"],
grammar=agent_dict["grammar"],
verbosity_level=agent_dict["verbosity_level"],
)
if cls.__name__ == "CodeAgent":
args["additional_authorized_imports"] = agent_dict["authorized_imports"]
args.update(kwargs)
return cls(**args)
def push_to_hub(
self,
repo_id: str,
commit_message: str = "Upload agent",
private: Optional[bool] = None,
token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
) -> str:
"""
Upload the agent to the Hub.
Parameters:
repo_id (`str`):
The name of the repository you want to push to. It should contain your organization name when
pushing to a given organization.
commit_message (`str`, *optional*, defaults to `"Upload agent"`):
Message to commit while pushing.
private (`bool`, *optional*, defaults to `None`):
Whether to make the repo private. If `None`, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
token (`bool` or `str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
create_pr (`bool`, *optional*, defaults to `False`):
Whether to create a PR with the uploaded files or directly commit.
"""
repo_url = create_repo(
repo_id=repo_id,
token=token,
private=private,
exist_ok=True,
repo_type="space",
space_sdk="gradio",
)
repo_id = repo_url.repo_id
metadata_update(
repo_id,
{"tags": ["smolagents", "agent"]},
repo_type="space",
token=token,
overwrite=True,
)
with tempfile.TemporaryDirectory() as work_dir:
self.save(work_dir)
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
return upload_folder(
repo_id=repo_id,
commit_message=commit_message,
folder_path=work_dir,
token=token,
create_pr=create_pr,
repo_type="space",
)
class ToolCallingAgent(MultiStepAgent): class ToolCallingAgent(MultiStepAgent):
""" """
@ -863,6 +1180,8 @@ class CodeAgent(MultiStepAgent):
): ):
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
self.use_e2b_executor = use_e2b_executor
self.max_print_outputs_length = max_print_outputs_length
prompt_templates = prompt_templates or yaml.safe_load( prompt_templates = prompt_templates or yaml.safe_load(
importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text() importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
) )

View File

@ -141,7 +141,7 @@ def stream_to_gradio(
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
# Track tokens if model provides them # Track tokens if model provides them
if hasattr(agent.model, "last_input_token_count"): if getattr(agent.model, "last_input_token_count", None):
total_input_tokens += agent.model.last_input_token_count total_input_tokens += agent.model.last_input_token_count
total_output_tokens += agent.model.last_output_token_count total_output_tokens += agent.model.last_output_token_count
if isinstance(step_log, ActionStep): if isinstance(step_log, ActionStep):

View File

@ -331,6 +331,53 @@ class Model:
""" """
pass # To be implemented in child classes! pass # To be implemented in child classes!
def to_dict(self) -> Dict:
"""
Converts the model into a JSON-compatible dictionary.
"""
model_dictionary = {
**self.kwargs,
"last_input_token_count": self.last_input_token_count,
"last_output_token_count": self.last_output_token_count,
"model_id": self.model_id,
}
for attribute in [
"custom_role_conversion",
"temperature",
"max_tokens",
"provider",
"timeout",
"api_base",
"torch_dtype",
"device_map",
"organization",
"project",
"azure_endpoint",
]:
if hasattr(self, attribute):
model_dictionary[attribute] = getattr(self, attribute)
dangerous_attributes = ["token", "api_key"]
for attribute_name in dangerous_attributes:
if hasattr(self, attribute_name):
print(
f"For security reasons, we do not export the `{attribute_name}` attribute of your model. Please export it manually."
)
return model_dictionary
@classmethod
def from_dict(cls, model_dictionary: Dict[str, Any]) -> "Model":
model_instance = cls(
**{
k: v
for k, v in model_dictionary.items()
if k not in ["last_input_token_count", "last_output_token_count"]
}
)
model_instance.last_input_token_count = model_dictionary.pop("last_input_token_count", None)
model_instance.last_output_token_count = model_dictionary.pop("last_output_token_count", None)
return model_instance
class HfApiModel(Model): class HfApiModel(Model):
"""A class to interact with Hugging Face's Inference API for language model interaction. """A class to interact with Hugging Face's Inference API for language model interaction.

View File

@ -83,6 +83,31 @@ class MethodChecker(ast.NodeVisitor):
self.assigned_names.add(elt.id) self.assigned_names.add(elt.id)
self.generic_visit(node) self.generic_visit(node)
def _handle_comprehension_generators(self, generators):
"""Helper method to handle generators in all types of comprehensions"""
for generator in generators:
if isinstance(generator.target, ast.Name):
self.assigned_names.add(generator.target.id)
elif isinstance(generator.target, ast.Tuple):
for elt in generator.target.elts:
if isinstance(elt, ast.Name):
self.assigned_names.add(elt.id)
def visit_ListComp(self, node):
"""Track variables in list comprehensions"""
self._handle_comprehension_generators(node.generators)
self.generic_visit(node)
def visit_DictComp(self, node):
"""Track variables in dictionary comprehensions"""
self._handle_comprehension_generators(node.generators)
self.generic_visit(node)
def visit_SetComp(self, node):
"""Track variables in set comprehensions"""
self._handle_comprehension_generators(node.generators)
self.generic_visit(node)
def visit_Attribute(self, node): def visit_Attribute(self, node):
if not (isinstance(node.value, ast.Name) and node.value.id == "self"): if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
self.generic_visit(node) self.generic_visit(node)
@ -121,7 +146,8 @@ class MethodChecker(ast.NodeVisitor):
def validate_tool_attributes(cls, check_imports: bool = True) -> None: def validate_tool_attributes(cls, check_imports: bool = True) -> None:
""" """
Validates that a Tool class follows the proper patterns: Validates that a Tool class follows the proper patterns:
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!). 0. Any argument of __init__ should have a default.
Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute.
1. About the class: 1. About the class:
- Class attributes should only be strings or dicts - Class attributes should only be strings or dicts
- Class attributes cannot be complex attributes - Class attributes cannot be complex attributes
@ -140,13 +166,20 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
if not isinstance(tree.body[0], ast.ClassDef): if not isinstance(tree.body[0], ast.ClassDef):
raise ValueError("Source code must define a class") raise ValueError("Source code must define a class")
# Check that __init__ method takes no arguments # Check that __init__ method only has arguments with defaults
if not cls.__init__.__qualname__ == "Tool.__init__": if not cls.__init__.__qualname__ == "Tool.__init__":
sig = inspect.signature(cls.__init__) sig = inspect.signature(cls.__init__)
non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]) non_default_params = [
if len(non_self_params) > 0: arg_name
for arg_name, param in sig.parameters.items()
if arg_name != "self"
and param.default == inspect.Parameter.empty
and param.kind != inspect.Parameter.VAR_KEYWORD # Excludes **kwargs
]
if non_default_params:
errors.append( errors.append(
f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!" f"This tool has required arguments in __init__: {non_default_params}. "
"All parameters of __init__ must have default values!"
) )
class_node = tree.body[0] class_node = tree.body[0]
@ -198,5 +231,5 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
errors += [f"- {node.name}: {error}" for error in method_checker.errors] errors += [f"- {node.name}: {error}" for error in method_checker.errors]
if errors: if errors:
raise ValueError("Tool validation failed:\n" + "\n".join(errors)) raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors))
return return

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import ast import ast
import importlib
import inspect import inspect
import json import json
import logging import logging
@ -23,6 +22,7 @@ import os
import sys import sys
import tempfile import tempfile
import textwrap import textwrap
import types
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
@ -199,24 +199,9 @@ class Tool:
""" """
self.is_initialized = True self.is_initialized = True
def save(self, output_dir): def to_dict(self) -> dict:
""" """Returns a dictionary representing the tool"""
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
tool in `output_dir` as well as autogenerate:
- a `tool.py` file containing the logic for your tool.
- an `app.py` file providing an UI for your tool when it is exported to a Space with `tool.push_to_hub()`
- a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
code)
Args:
output_dir (`str`): The folder in which you want to save your tool.
"""
os.makedirs(output_dir, exist_ok=True)
class_name = self.__class__.__name__ class_name = self.__class__.__name__
tool_file = os.path.join(output_dir, "tool.py")
# Save tool file
if type(self).__name__ == "SimpleTool": if type(self).__name__ == "SimpleTool":
# Check that imports are self-contained # Check that imports are self-contained
source_code = get_source(self.forward).replace("@tool", "") source_code = get_source(self.forward).replace("@tool", "")
@ -232,7 +217,7 @@ class Tool:
tool_code = textwrap.dedent( tool_code = textwrap.dedent(
f""" f"""
from smolagents import Tool from smolagents import Tool
from typing import Optional from typing import Any, Optional
class {class_name}(Tool): class {class_name}(Tool):
name = "{self.name}" name = "{self.name}"
@ -272,11 +257,39 @@ class Tool:
validate_tool_attributes(self.__class__) validate_tool_attributes(self.__class__)
tool_code = instance_to_source(self, base_cls=Tool) tool_code = "from typing import Any, Optional\n" + instance_to_source(self, base_cls=Tool)
requirements = {el for el in get_imports(tool_code) if el not in sys.stdlib_module_names} | {"smolagents"}
return {"name": self.name, "code": tool_code, "requirements": requirements}
def save(self, output_dir: str, tool_file_name: str = "tool", make_gradio_app: bool = True):
"""
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
tool in `output_dir` as well as autogenerate:
- a `{tool_file_name}.py` file containing the logic for your tool.
If you pass `make_gradio_app=True`, this will also write:
- an `app.py` file providing a UI for your tool when it is exported to a Space with `tool.push_to_hub()`
- a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
code)
Args:
output_dir (`str`): The folder in which you want to save your tool.
tool_file_name (`str`, *optional*): The file name in which you want to save your tool.
make_gradio_app (`bool`, *optional*, defaults to True): Whether to also export a `requirements.txt` file and Gradio UI.
"""
os.makedirs(output_dir, exist_ok=True)
class_name = self.__class__.__name__
tool_file = os.path.join(output_dir, f"{tool_file_name}.py")
tool_dict = self.to_dict()
tool_code = tool_dict["code"]
with open(tool_file, "w", encoding="utf-8") as f: with open(tool_file, "w", encoding="utf-8") as f:
f.write(tool_code.replace(":true,", ":True,").replace(":true}", ":True}")) f.write(tool_code.replace(":true,", ":True,").replace(":true}", ":True}"))
if make_gradio_app:
# Save app file # Save app file
app_file = os.path.join(output_dir, "app.py") app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f: with open(app_file, "w", encoding="utf-8") as f:
@ -284,8 +297,7 @@ class Tool:
textwrap.dedent( textwrap.dedent(
f""" f"""
from smolagents import launch_gradio_demo from smolagents import launch_gradio_demo
from typing import Optional from {tool_file_name} import {class_name}
from tool import {class_name}
tool = {class_name}() tool = {class_name}()
@ -295,10 +307,9 @@ class Tool:
) )
# Save requirements file # Save requirements file
imports = {el for el in get_imports(tool_file) if el not in sys.stdlib_module_names} | {"smolagents"}
requirements_file = os.path.join(output_dir, "requirements.txt") requirements_file = os.path.join(output_dir, "requirements.txt")
with open(requirements_file, "w", encoding="utf-8") as f: with open(requirements_file, "w", encoding="utf-8") as f:
f.write("\n".join(imports) + "\n") f.write("\n".join(tool_dict["requirements"]) + "\n")
def push_to_hub( def push_to_hub(
self, self,
@ -311,14 +322,6 @@ class Tool:
""" """
Upload the tool to the Hub. Upload the tool to the Hub.
For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
For instance:
```
from my_tool_module import MyTool
my_tool = MyTool()
my_tool.push_to_hub("my-username/my-space")
```
Parameters: Parameters:
repo_id (`str`): repo_id (`str`):
The name of the repository you want to push your tool to. It should contain your organization name when The name of the repository you want to push your tool to. It should contain your organization name when
@ -342,13 +345,11 @@ class Tool:
space_sdk="gradio", space_sdk="gradio",
) )
repo_id = repo_url.repo_id repo_id = repo_url.repo_id
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space", token=token) metadata_update(repo_id, {"tags": ["smolagents", "tool"]}, repo_type="space", token=token)
with tempfile.TemporaryDirectory() as work_dir: with tempfile.TemporaryDirectory() as work_dir:
# Save all files. # Save all files.
self.save(work_dir) self.save(work_dir)
with open(work_dir + "/tool.py", "r") as f:
print("\n".join(f.readlines()))
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}") logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
return upload_folder( return upload_folder(
repo_id=repo_id, repo_id=repo_id,
@ -394,7 +395,7 @@ class Tool:
""" """
if not trust_remote_code: if not trust_remote_code:
raise ValueError( raise ValueError(
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool." "Loading a tool from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
) )
# Get the tool's tool.py file. # Get the tool's tool.py file.
@ -413,27 +414,23 @@ class Tool:
) )
tool_code = Path(tool_file).read_text() tool_code = Path(tool_file).read_text()
return Tool.from_code(tool_code, **kwargs)
# Find the Tool subclass in the namespace @classmethod
with tempfile.TemporaryDirectory() as temp_dir: def from_code(cls, tool_code: str, **kwargs):
# Save the code to a file module = types.ModuleType("dynamic_tool")
module_path = os.path.join(temp_dir, "tool.py")
with open(module_path, "w") as f:
f.write(tool_code)
print("TOOL CODE:\n", tool_code) exec(tool_code, module.__dict__)
# Load module from file path # Find the Tool subclass
spec = importlib.util.spec_from_file_location("tool", module_path) tool_class = next(
module = importlib.util.module_from_spec(spec) (
spec.loader.exec_module(module) obj
for _, obj in inspect.getmembers(module, inspect.isclass)
# Find and instantiate the Tool class if issubclass(obj, Tool) and obj is not Tool
for item_name in dir(module): ),
item = getattr(module, item_name) None,
if isinstance(item, type) and issubclass(item, Tool) and item != Tool: )
tool_class = item
break
if tool_class is None: if tool_class is None:
raise ValueError("No Tool subclass found in the code.") raise ValueError("No Tool subclass found in the code.")

View File

@ -20,6 +20,7 @@ import importlib.metadata
import importlib.util import importlib.util
import inspect import inspect
import json import json
import os
import re import re
import textwrap import textwrap
import types import types
@ -414,3 +415,10 @@ def encode_image_base64(image):
def make_image_url(base64_image): def make_image_url(base64_image):
return f"data:image/png;base64,{base64_image}" return f"data:image/png;base64,{base64_image}"
def make_init_file(folder: str):
os.makedirs(folder, exist_ok=True)
# Create __init__
with open(os.path.join(folder, "__init__.py"), "w"):
pass

View File

@ -31,12 +31,13 @@ from smolagents.agents import (
ToolCallingAgent, ToolCallingAgent,
populate_template, populate_template,
) )
from smolagents.default_tools import PythonInterpreterTool from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool
from smolagents.memory import PlanningStep from smolagents.memory import PlanningStep
from smolagents.models import ( from smolagents.models import (
ChatMessage, ChatMessage,
ChatMessageToolCall, ChatMessageToolCall,
ChatMessageToolCallDefinition, ChatMessageToolCallDefinition,
HfApiModel,
MessageRole, MessageRole,
TransformersModel, TransformersModel,
) )
@ -436,10 +437,15 @@ class AgentTests(unittest.TestCase):
assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
with pytest.raises(ValueError) as e:
agent = CodeAgent(tools=toolset_2, model=fake_code_model) agent = CodeAgent(tools=toolset_2, model=fake_code_model)
assert ( assert "Each tool or managed_agent should have a unique name!" in str(e)
len(agent.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer with pytest.raises(ValueError) as e:
agent.name = "python_interpreter"
agent.description = "empty"
CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model, managed_agents=[agent])
assert "Each tool or managed_agent should have a unique name!" in str(e)
# check that python_interpreter base tool does not get added to CodeAgent # check that python_interpreter base tool does not get added to CodeAgent
agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True) agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True)
@ -484,132 +490,6 @@ class AgentTests(unittest.TestCase):
str_output = capture.get() str_output = capture.get()
assert "`additional_authorized_imports`" in str_output.replace("\n", "") assert "`additional_authorized_imports`" in str_output.replace("\n", "")
def test_multiagents(self):
class FakeModelMultiagentsManagerAgent:
model_id = "fake_model"
def __call__(
self,
messages,
stop_sequences=None,
grammar=None,
tools_to_call_from=None,
):
if tools_to_call_from is not None:
if len(messages) < 3:
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatMessageToolCallDefinition(
name="search_agent",
arguments="Who is the current US president?",
),
)
],
)
else:
assert "Report on the current US president" in str(messages)
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatMessageToolCallDefinition(
name="final_answer", arguments="Final report."
),
)
],
)
else:
if len(messages) < 3:
return ChatMessage(
role="assistant",
content="""
Thought: Let's call our search agent.
Code:
```py
result = search_agent("Who is the current US president?")
```<end_code>
""",
)
else:
assert "Report on the current US president" in str(messages)
return ChatMessage(
role="assistant",
content="""
Thought: Let's return the report.
Code:
```py
final_answer("Final report.")
```<end_code>
""",
)
manager_model = FakeModelMultiagentsManagerAgent()
class FakeModelMultiagentsManagedAgent:
model_id = "fake_model"
def __call__(
self,
messages,
tools_to_call_from=None,
stop_sequences=None,
grammar=None,
):
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatMessageToolCallDefinition(
name="final_answer",
arguments="Report on the current US president",
),
)
],
)
managed_model = FakeModelMultiagentsManagedAgent()
web_agent = ToolCallingAgent(
tools=[],
model=managed_model,
max_steps=10,
name="search_agent",
description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports",
)
manager_code_agent = CodeAgent(
tools=[],
model=manager_model,
managed_agents=[web_agent],
additional_authorized_imports=["time", "numpy", "pandas"],
)
report = manager_code_agent.run("Fake question.")
assert report == "Final report."
manager_toolcalling_agent = ToolCallingAgent(
tools=[],
model=manager_model,
managed_agents=[web_agent],
)
report = manager_toolcalling_agent.run("Fake question.")
assert report == "Final report."
# Test that visualization works
manager_code_agent.visualize()
def test_code_nontrivial_final_answer_works(self): def test_code_nontrivial_final_answer_works(self):
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None): def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
return ChatMessage( return ChatMessage(
@ -887,6 +767,191 @@ class TestCodeAgent:
assert result == expected_summary assert result == expected_summary
class MultiAgentsTests(unittest.TestCase):
def test_multiagents_save(self):
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=2096, temperature=0.5)
web_agent = ToolCallingAgent(
model=model,
tools=[DuckDuckGoSearchTool(max_results=2), VisitWebpageTool()],
name="web_agent",
description="does web searches",
)
code_agent = CodeAgent(model=model, tools=[], name="useless", description="does nothing in particular")
agent = CodeAgent(
model=model,
tools=[],
additional_authorized_imports=["pandas", "datetime"],
managed_agents=[web_agent, code_agent],
)
agent.save("agent_export")
expected_structure = {
"managed_agents": {
"useless": {"tools": {"files": ["final_answer.py"]}, "files": ["agent.json", "prompts.yaml"]},
"web_agent": {
"tools": {"files": ["final_answer.py", "visit_webpage.py", "web_search.py"]},
"files": ["agent.json", "prompts.yaml"],
},
},
"tools": {"files": ["final_answer.py"]},
"files": ["app.py", "requirements.txt", "agent.json", "prompts.yaml"],
}
def verify_structure(current_path: Path, structure: dict):
for dir_name, contents in structure.items():
if dir_name != "files":
# For directories, verify they exist and recurse into them
dir_path = current_path / dir_name
assert dir_path.exists(), f"Directory {dir_path} does not exist"
assert dir_path.is_dir(), f"{dir_path} is not a directory"
verify_structure(dir_path, contents)
else:
# For files, verify each exists in the current path
for file_name in contents:
file_path = current_path / file_name
assert file_path.exists(), f"File {file_path} does not exist"
assert file_path.is_file(), f"{file_path} is not a file"
verify_structure(Path("agent_export"), expected_structure)
# Test that re-loaded agents work as expected.
agent2 = CodeAgent.from_folder("agent_export", planning_interval=5)
assert agent2.planning_interval == 5 # Check that kwargs are used
assert set(agent2.authorized_imports) == set(["pandas", "datetime"] + BASE_BUILTIN_MODULES)
assert (
agent2.managed_agents["web_agent"].tools["web_search"].max_results == 10
) # For now tool init parameters are forgotten
assert agent2.model.kwargs["temperature"] == pytest.approx(0.5)
def test_multiagents(self):
class FakeModelMultiagentsManagerAgent:
model_id = "fake_model"
def __call__(
self,
messages,
stop_sequences=None,
grammar=None,
tools_to_call_from=None,
):
if tools_to_call_from is not None:
if len(messages) < 3:
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatMessageToolCallDefinition(
name="search_agent",
arguments="Who is the current US president?",
),
)
],
)
else:
assert "Report on the current US president" in str(messages)
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatMessageToolCallDefinition(
name="final_answer", arguments="Final report."
),
)
],
)
else:
if len(messages) < 3:
return ChatMessage(
role="assistant",
content="""
Thought: Let's call our search agent.
Code:
```py
result = search_agent("Who is the current US president?")
```<end_code>
""",
)
else:
assert "Report on the current US president" in str(messages)
return ChatMessage(
role="assistant",
content="""
Thought: Let's return the report.
Code:
```py
final_answer("Final report.")
```<end_code>
""",
)
manager_model = FakeModelMultiagentsManagerAgent()
class FakeModelMultiagentsManagedAgent:
model_id = "fake_model"
def __call__(
self,
messages,
tools_to_call_from=None,
stop_sequences=None,
grammar=None,
):
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatMessageToolCallDefinition(
name="final_answer",
arguments="Report on the current US president",
),
)
],
)
managed_model = FakeModelMultiagentsManagedAgent()
web_agent = ToolCallingAgent(
tools=[],
model=managed_model,
max_steps=10,
name="search_agent",
description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports",
)
manager_code_agent = CodeAgent(
tools=[],
model=manager_model,
managed_agents=[web_agent],
additional_authorized_imports=["time", "numpy", "pandas"],
)
report = manager_code_agent.run("Fake question.")
assert report == "Final report."
manager_toolcalling_agent = ToolCallingAgent(
tools=[],
model=manager_model,
managed_agents=[web_agent],
)
report = manager_toolcalling_agent.run("Fake question.")
assert report == "Final report."
# Test that visualization works
manager_code_agent.visualize()
@pytest.fixture @pytest.fixture
def prompt_templates(): def prompt_templates():
return { return {

View File

@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
from typing import Optional, Tuple from typing import List, Optional, Tuple
from smolagents._function_type_hints_utils import get_json_schema import pytest
from smolagents._function_type_hints_utils import get_imports, get_json_schema
class AgentTextTests(unittest.TestCase): class TestJsonSchema(unittest.TestCase):
def test_return_none(self): def test_get_json_schema(self):
def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None: def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None:
""" """
Test function Test function
@ -52,3 +54,65 @@ class AgentTextTests(unittest.TestCase):
schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"] schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"]
) )
self.assertEqual(schema["function"], expected_schema) self.assertEqual(schema["function"], expected_schema)
class TestGetCode:
@pytest.mark.parametrize(
"code, expected",
[
(
"""
import numpy
import pandas
""",
["numpy", "pandas"],
),
# From imports
(
"""
from torch import nn
from transformers import AutoModel
""",
["torch", "transformers"],
),
# Mixed case with nested imports
(
"""
import numpy as np
from torch.nn import Linear
import os.path
""",
["numpy", "torch", "os"],
),
# Try/except block (should be filtered)
(
"""
try:
import torch
except ImportError:
pass
import numpy
""",
["numpy"],
),
# Flash attention block (should be filtered)
(
"""
if is_flash_attn_2_available():
from flash_attn import flash_attn_func
import transformers
""",
["transformers"],
),
# Relative imports (should be excluded)
(
"""
from .utils import helper
from ..models import transformer
""",
[],
),
],
)
def test_get_imports(self, code: str, expected: List[str]):
assert sorted(get_imports(code)) == sorted(expected)

View File

@ -215,8 +215,9 @@ class ToolTests(unittest.TestCase):
return str(datetime.now()) return str(datetime.now())
def test_saving_tool_allows_no_arg_in_init(self): def test_tool_to_dict_allows_no_arg_in_init(self):
# Test one cannot save tool with additional args in init """Test that a tool cannot be saved with required args in init"""
class FailTool(Tool): class FailTool(Tool):
name = "specific" name = "specific"
description = "test description" description = "test description"
@ -225,15 +226,31 @@ class ToolTests(unittest.TestCase):
def __init__(self, url): def __init__(self, url):
super().__init__(self) super().__init__(self)
self.url = "none" self.url = url
def forward(self, string_input: str) -> str: def forward(self, string_input: str) -> str:
return self.url + string_input return self.url + string_input
fail_tool = FailTool("dummy_url") fail_tool = FailTool("dummy_url")
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
fail_tool.save("output") fail_tool.to_dict()
assert "__init__" in str(e) assert "All parameters of __init__ must have default values!" in str(e)
class PassTool(Tool):
name = "specific"
description = "test description"
inputs = {"string_input": {"type": "string", "description": "input description"}}
output_type = "string"
def __init__(self, url: Optional[str] = "none"):
super().__init__(self)
self.url = url
def forward(self, string_input: str) -> str:
return self.url + string_input
fail_tool = PassTool()
fail_tool.to_dict()
def test_saving_tool_allows_no_imports_from_outside_methods(self): def test_saving_tool_allows_no_imports_from_outside_methods(self):
# Test that using imports from outside functions fails # Test that using imports from outside functions fails

View File

@ -146,11 +146,12 @@ def test_e2e_class_tool_save():
test_tool = TestTool() test_tool = TestTool()
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_tool.save(tmp_dir) test_tool.save(tmp_dir, make_gradio_app=True)
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert ( assert (
pathlib.Path(tmp_dir, "tool.py").read_text() pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents.tools import Tool == """from typing import Any, Optional
from smolagents.tools import Tool
import IPython import IPython
class TestTool(Tool): class TestTool(Tool):
@ -173,7 +174,6 @@ class TestTool(Tool):
assert ( assert (
pathlib.Path(tmp_dir, "app.py").read_text() pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo == """from smolagents import launch_gradio_demo
from typing import Optional
from tool import TestTool from tool import TestTool
tool = TestTool() tool = TestTool()
@ -201,13 +201,14 @@ def test_e2e_ipython_class_tool_save():
import IPython # noqa: F401 import IPython # noqa: F401
return task return task
TestTool().save("{tmp_dir}") TestTool().save("{tmp_dir}", make_gradio_app=True)
""") """)
assert shell.run_cell(code_blob, store_history=True).success assert shell.run_cell(code_blob, store_history=True).success
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert ( assert (
pathlib.Path(tmp_dir, "tool.py").read_text() pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents.tools import Tool == """from typing import Any, Optional
from smolagents.tools import Tool
import IPython import IPython
class TestTool(Tool): class TestTool(Tool):
@ -230,7 +231,6 @@ class TestTool(Tool):
assert ( assert (
pathlib.Path(tmp_dir, "app.py").read_text() pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo == """from smolagents import launch_gradio_demo
from typing import Optional
from tool import TestTool from tool import TestTool
tool = TestTool() tool = TestTool()
@ -254,12 +254,12 @@ def test_e2e_function_tool_save():
return task return task
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
test_tool.save(tmp_dir) test_tool.save(tmp_dir, make_gradio_app=True)
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert ( assert (
pathlib.Path(tmp_dir, "tool.py").read_text() pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents import Tool == """from smolagents import Tool
from typing import Optional from typing import Any, Optional
class SimpleTool(Tool): class SimpleTool(Tool):
name = "test_tool" name = "test_tool"
@ -283,7 +283,6 @@ class SimpleTool(Tool):
assert ( assert (
pathlib.Path(tmp_dir, "app.py").read_text() pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo == """from smolagents import launch_gradio_demo
from typing import Optional
from tool import SimpleTool from tool import SimpleTool
tool = SimpleTool() tool = SimpleTool()
@ -311,14 +310,14 @@ def test_e2e_ipython_function_tool_save():
return task return task
test_tool.save("{tmp_dir}") test_tool.save("{tmp_dir}", make_gradio_app=True)
""") """)
assert shell.run_cell(code_blob, store_history=True).success assert shell.run_cell(code_blob, store_history=True).success
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
assert ( assert (
pathlib.Path(tmp_dir, "tool.py").read_text() pathlib.Path(tmp_dir, "tool.py").read_text()
== """from smolagents import Tool == """from smolagents import Tool
from typing import Optional from typing import Any, Optional
class SimpleTool(Tool): class SimpleTool(Tool):
name = "test_tool" name = "test_tool"
@ -342,7 +341,6 @@ class SimpleTool(Tool):
assert ( assert (
pathlib.Path(tmp_dir, "app.py").read_text() pathlib.Path(tmp_dir, "app.py").read_text()
== """from smolagents import launch_gradio_demo == """from smolagents import launch_gradio_demo
from typing import Optional
from tool import SimpleTool from tool import SimpleTool
tool = SimpleTool() tool = SimpleTool()