Share full agents (#533)
Co-authored-by: Alex <sysradium@users.noreply.github.com> Co-authored-by: Parteek <parteekkamboj112@gmail.com> Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Co-authored-by: David Berenstein <david.m.berenstein@gmail.com>
This commit is contained in:
parent
360e1a8781
commit
1c1418dfa2
|
@ -567,13 +567,13 @@ class WavConverter(MediaConverter):
|
|||
|
||||
class Mp3Converter(WavConverter):
|
||||
"""
|
||||
Converts MP3 files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` AND `pydub` are installed).
|
||||
Converts MP3 and M4A files to markdown via extraction of metadata (if `exiftool` is installed), and speech transcription (if `speech_recognition` AND `pydub` are installed).
|
||||
"""
|
||||
|
||||
def convert(self, local_path, **kwargs) -> Union[None, DocumentConverterResult]:
|
||||
# Bail if not a MP3
|
||||
extension = kwargs.get("file_extension", "")
|
||||
if extension.lower() != ".mp3":
|
||||
if extension.lower() not in [".mp3", ".m4a"]:
|
||||
return None
|
||||
|
||||
md_content = ""
|
||||
|
@ -600,7 +600,10 @@ class Mp3Converter(WavConverter):
|
|||
handle, temp_path = tempfile.mkstemp(suffix=".wav")
|
||||
os.close(handle)
|
||||
try:
|
||||
if extension.lower() == ".mp3":
|
||||
sound = pydub.AudioSegment.from_mp3(local_path)
|
||||
else:
|
||||
sound = pydub.AudioSegment.from_file(local_path, format="m4a")
|
||||
sound.export(temp_path, format="wav")
|
||||
|
||||
_args = dict()
|
||||
|
|
|
@ -10,7 +10,7 @@ class TextInspectorTool(Tool):
|
|||
name = "inspect_file_as_text"
|
||||
description = """
|
||||
You cannot load files yourself: instead call this tool to read a file as markdown text and ask questions about it.
|
||||
This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES."""
|
||||
This tool handles the following file extensions: [".html", ".htm", ".xlsx", ".pptx", ".wav", ".mp3", ".m4a", ".flac", ".pdf", ".docx"], and all other types of text files. IT DOES NOT HANDLE IMAGES."""
|
||||
|
||||
inputs = {
|
||||
"file_path": {
|
||||
|
|
|
@ -410,7 +410,7 @@ class VisitTool(Tool):
|
|||
class DownloadTool(Tool):
|
||||
name = "download_file"
|
||||
description = """
|
||||
Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"]
|
||||
Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".m4a", ".png", ".docx"]
|
||||
After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it.
|
||||
DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead."""
|
||||
inputs = {"url": {"type": "string", "description": "The relative or absolute url of the file to be downloaded."}}
|
||||
|
|
|
@ -24,7 +24,6 @@ TODO: move them to `huggingface_hub` to avoid code duplication.
|
|||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
from copy import copy
|
||||
|
@ -46,34 +45,31 @@ from huggingface_hub.utils import is_torch_available
|
|||
from .utils import _is_pillow_available
|
||||
|
||||
|
||||
def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
||||
def get_imports(code: str) -> List[str]:
|
||||
"""
|
||||
Extracts all the libraries (not relative imports this time) that are imported in a file.
|
||||
Extracts all the libraries (not relative imports) that are imported in a code.
|
||||
|
||||
Args:
|
||||
filename (`str` or `os.PathLike`): The module file to inspect.
|
||||
code (`str`): Code text to inspect.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of all packages required to use the input module.
|
||||
`list[str]`: List of all packages required to use the input code.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# filter out try/except block so in custom code we can have try/except imports
|
||||
content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
|
||||
code = re.sub(r"\s*try\s*:.*?except.*?:", "", code, flags=re.DOTALL)
|
||||
|
||||
# filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
|
||||
content = re.sub(
|
||||
code = re.sub(
|
||||
r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+",
|
||||
"",
|
||||
content,
|
||||
code,
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
# Imports of the form `import xxx`
|
||||
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
||||
# Imports of the form `import xxx` or `import xxx as yyy`
|
||||
imports = re.findall(r"^\s*import\s+(\S+?)(?:\s+as\s+\S+)?\s*$", code, flags=re.MULTILINE)
|
||||
# Imports of the form `from xxx import yyy`
|
||||
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
||||
imports += re.findall(r"^\s*from\s+(\S+)\s+import", code, flags=re.MULTILINE)
|
||||
# Only keep the top-level module
|
||||
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
||||
return list(set(imports))
|
||||
|
|
|
@ -14,44 +14,29 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ["AgentMemory", "CodeAgent", "MultiStepAgent", "ToolCallingAgent"]
|
||||
|
||||
import importlib.resources
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from collections import deque
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import jinja2
|
||||
import yaml
|
||||
from huggingface_hub import create_repo, metadata_update, snapshot_download, upload_folder
|
||||
from jinja2 import StrictUndefined, Template
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.rule import Rule
|
||||
from rich.text import Text
|
||||
|
||||
from smolagents.agent_types import AgentAudio, AgentImage, handle_agent_output_types
|
||||
from smolagents.memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
|
||||
from smolagents.monitoring import (
|
||||
YELLOW_HEX,
|
||||
AgentLogger,
|
||||
LogLevel,
|
||||
)
|
||||
from smolagents.utils import (
|
||||
AgentError,
|
||||
AgentExecutionError,
|
||||
AgentGenerationError,
|
||||
AgentMaxStepsError,
|
||||
AgentParsingError,
|
||||
parse_code_blobs,
|
||||
parse_json_tool_call,
|
||||
truncate_content,
|
||||
)
|
||||
|
||||
from .agent_types import AgentType
|
||||
from .agent_types import AgentAudio, AgentImage, AgentType, handle_agent_output_types
|
||||
from .default_tools import TOOL_MAPPING, FinalAnswerTool
|
||||
from .e2b_executor import E2BExecutor
|
||||
from .local_python_executor import (
|
||||
|
@ -59,12 +44,30 @@ from .local_python_executor import (
|
|||
LocalPythonInterpreter,
|
||||
fix_final_answer_code,
|
||||
)
|
||||
from .memory import ActionStep, AgentMemory, PlanningStep, SystemPromptStep, TaskStep, ToolCall
|
||||
from .models import (
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Model,
|
||||
)
|
||||
from .monitoring import (
|
||||
YELLOW_HEX,
|
||||
AgentLogger,
|
||||
LogLevel,
|
||||
Monitor,
|
||||
)
|
||||
from .monitoring import Monitor
|
||||
from .tools import Tool
|
||||
from .utils import (
|
||||
AgentError,
|
||||
AgentExecutionError,
|
||||
AgentGenerationError,
|
||||
AgentMaxStepsError,
|
||||
AgentParsingError,
|
||||
make_init_file,
|
||||
parse_code_blobs,
|
||||
parse_json_tool_call,
|
||||
truncate_content,
|
||||
)
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -228,9 +231,19 @@ class MultiStepAgent:
|
|||
)
|
||||
self.managed_agents = {agent.name: agent for agent in managed_agents}
|
||||
|
||||
tool_and_managed_agent_names = [tool.name for tool in tools]
|
||||
if managed_agents is not None:
|
||||
tool_and_managed_agent_names += [agent.name for agent in managed_agents]
|
||||
if len(tool_and_managed_agent_names) != len(set(tool_and_managed_agent_names)):
|
||||
raise ValueError(
|
||||
"Each tool or managed_agent should have a unique name! You passed these duplicate names: "
|
||||
f"{[name for name in tool_and_managed_agent_names if tool_and_managed_agent_names.count(name) > 1]}"
|
||||
)
|
||||
|
||||
for tool in tools:
|
||||
assert isinstance(tool, Tool), f"This element is not of class Tool: {str(tool)}"
|
||||
self.tools = {tool.name: tool for tool in tools}
|
||||
|
||||
if add_base_tools:
|
||||
for tool_name, tool_class in TOOL_MAPPING.items():
|
||||
if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent":
|
||||
|
@ -709,6 +722,310 @@ You have been provided with these additional arguments, that you can access usin
|
|||
answer += "\n</summary_of_work>"
|
||||
return answer
|
||||
|
||||
def save(self, output_dir: str, relative_path: Optional[str] = None):
|
||||
"""
|
||||
Saves the relevant code files for your agent. This will copy the code of your agent in `output_dir` as well as autogenerate:
|
||||
|
||||
- a `tools` folder containing the logic for each of the tools under `tools/{tool_name}.py`.
|
||||
- a `managed_agents` folder containing the logic for each of the managed agents.
|
||||
- an `agent.json` file containing a dictionary representing your agent.
|
||||
- a `prompt.yaml` file containing the prompt templates used by your agent.
|
||||
- an `app.py` file providing a UI for your agent when it is exported to a Space with `agent.push_to_hub()`
|
||||
- a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
|
||||
code)
|
||||
|
||||
Args:
|
||||
output_dir (`str`): The folder in which you want to save your tool.
|
||||
"""
|
||||
make_init_file(output_dir)
|
||||
|
||||
# Recursively save managed agents
|
||||
if self.managed_agents:
|
||||
make_init_file(os.path.join(output_dir, "managed_agents"))
|
||||
for agent_name, agent in self.managed_agents.items():
|
||||
agent_suffix = f"managed_agents.{agent_name}"
|
||||
if relative_path:
|
||||
agent_suffix = relative_path + "." + agent_suffix
|
||||
agent.save(os.path.join(output_dir, "managed_agents", agent_name), relative_path=agent_suffix)
|
||||
|
||||
class_name = self.__class__.__name__
|
||||
|
||||
# Save tools to different .py files
|
||||
for tool in self.tools.values():
|
||||
make_init_file(os.path.join(output_dir, "tools"))
|
||||
tool.save(os.path.join(output_dir, "tools"), tool_file_name=tool.name, make_gradio_app=False)
|
||||
|
||||
# Save prompts to yaml
|
||||
yaml_prompts = yaml.safe_dump(
|
||||
self.prompt_templates,
|
||||
default_style="|", # This forces block literals for all strings
|
||||
default_flow_style=False,
|
||||
width=float("inf"),
|
||||
sort_keys=False,
|
||||
allow_unicode=True,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
with open(os.path.join(output_dir, "prompts.yaml"), "w", encoding="utf-8") as f:
|
||||
f.write(yaml_prompts)
|
||||
|
||||
# Save agent dictionary to json
|
||||
agent_dict = self.to_dict()
|
||||
agent_dict["tools"] = [tool.name for tool in self.tools.values()]
|
||||
with open(os.path.join(output_dir, "agent.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(agent_dict, f, indent=4)
|
||||
|
||||
# Save requirements
|
||||
with open(os.path.join(output_dir, "requirements.txt"), "w", encoding="utf-8") as f:
|
||||
f.writelines(f"{r}\n" for r in agent_dict["requirements"])
|
||||
|
||||
# Make agent.py file with Gradio UI
|
||||
agent_name = f"agent_{self.name}" if getattr(self, "name", None) else "agent"
|
||||
managed_agent_relative_path = relative_path + "." if relative_path is not None else ""
|
||||
app_template = textwrap.dedent("""
|
||||
import yaml
|
||||
import os
|
||||
from smolagents import GradioUI, {{ class_name }}, {{ agent_dict['model']['class'] }}
|
||||
|
||||
# Get current directory path
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
{% for tool in tools.values() -%}
|
||||
from {{managed_agent_relative_path}}tools.{{ tool.name }} import {{ tool.__class__.__name__ }} as {{ tool.name | camelcase }}
|
||||
{% endfor %}
|
||||
{% for managed_agent in managed_agents.values() -%}
|
||||
from {{managed_agent_relative_path}}managed_agents.{{ managed_agent.name }}.app import agent_{{ managed_agent.name }}
|
||||
{% endfor %}
|
||||
|
||||
model = {{ agent_dict['model']['class'] }}(
|
||||
{% for key in agent_dict['model']['data'] if key not in ['class', 'last_input_token_count', 'last_output_token_count'] -%}
|
||||
{{ key }}={{ agent_dict['model']['data'][key]|repr }},
|
||||
{% endfor %})
|
||||
|
||||
{% for tool in tools.values() -%}
|
||||
{{ tool.name }} = {{ tool.name | camelcase }}()
|
||||
{% endfor %}
|
||||
|
||||
with open(os.path.join(CURRENT_DIR, "prompts.yaml"), 'r') as stream:
|
||||
prompt_templates = yaml.safe_load(stream)
|
||||
|
||||
{{ agent_name }} = {{ class_name }}(
|
||||
model=model,
|
||||
tools=[{% for tool_name in tools.keys() if tool_name != "final_answer" %}{{ tool_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
|
||||
managed_agents=[{% for subagent_name in managed_agents.keys() %}agent_{{ subagent_name }}{% if not loop.last %}, {% endif %}{% endfor %}],
|
||||
{% for attribute_name, value in agent_dict.items() if attribute_name not in ["model", "tools", "prompt_templates", "authorized_imports", "managed_agents", "requirements"] -%}
|
||||
{{ attribute_name }}={{ value|repr }},
|
||||
{% endfor %}prompt_templates=prompt_templates
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
GradioUI({{ agent_name }}).launch()
|
||||
""").strip()
|
||||
template_env = jinja2.Environment(loader=jinja2.BaseLoader(), undefined=jinja2.StrictUndefined)
|
||||
template_env.filters["repr"] = repr
|
||||
template_env.filters["camelcase"] = lambda value: "".join(word.capitalize() for word in value.split("_"))
|
||||
template = template_env.from_string(app_template)
|
||||
|
||||
# Render the app.py file from Jinja2 template
|
||||
app_text = template.render(
|
||||
{
|
||||
"agent_name": agent_name,
|
||||
"class_name": class_name,
|
||||
"agent_dict": agent_dict,
|
||||
"tools": self.tools,
|
||||
"managed_agents": self.managed_agents,
|
||||
"managed_agent_relative_path": managed_agent_relative_path,
|
||||
}
|
||||
)
|
||||
|
||||
with open(os.path.join(output_dir, "app.py"), "w", encoding="utf-8") as f:
|
||||
f.write(app_text + "\n") # Append newline at the end
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Converts agent into a dictionary."""
|
||||
# TODO: handle serializing step_callbacks and final_answer_checks
|
||||
for attr in ["final_answer_checks", "step_callbacks"]:
|
||||
if getattr(self, attr, None):
|
||||
self.logger.log(f"This agent has {attr}: they will be ignored by this method.", LogLevel.INFO)
|
||||
|
||||
tool_dicts = [tool.to_dict() for tool in self.tools.values()]
|
||||
tool_requirements = {req for tool in self.tools.values() for req in tool.to_dict()["requirements"]}
|
||||
managed_agents_requirements = {
|
||||
req for managed_agent in self.managed_agents.values() for req in managed_agent.to_dict()["requirements"]
|
||||
}
|
||||
requirements = tool_requirements | managed_agents_requirements
|
||||
if hasattr(self, "authorized_imports"):
|
||||
requirements.update(
|
||||
{package.split(".")[0] for package in self.authorized_imports if package not in BASE_BUILTIN_MODULES}
|
||||
)
|
||||
|
||||
agent_dict = {
|
||||
"tools": tool_dicts,
|
||||
"model": {
|
||||
"class": self.model.__class__.__name__,
|
||||
"data": self.model.to_dict(),
|
||||
},
|
||||
"managed_agents": {
|
||||
managed_agent.name: managed_agent.__class__.__name__ for managed_agent in self.managed_agents.values()
|
||||
},
|
||||
"prompt_templates": self.prompt_templates,
|
||||
"max_steps": self.max_steps,
|
||||
"verbosity_level": int(self.logger.level),
|
||||
"grammar": self.grammar,
|
||||
"planning_interval": self.planning_interval,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"requirements": list(requirements),
|
||||
}
|
||||
if hasattr(self, "authorized_imports"):
|
||||
agent_dict["authorized_imports"] = self.authorized_imports
|
||||
return agent_dict
|
||||
|
||||
@classmethod
|
||||
def from_hub(
|
||||
cls,
|
||||
repo_id: str,
|
||||
token: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads an agent defined on the Hub.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Loading a tool from the Hub means that you'll download the tool and execute it locally.
|
||||
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
||||
installing a package using pip/npm/apt.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The name of the repo on the Hub where your tool is defined.
|
||||
token (`str`, *optional*):
|
||||
The token to identify you on hf.co. If unset, will use the token generated when running
|
||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||
trust_remote_code(`bool`, *optional*, defaults to False):
|
||||
This flags marks that you understand the risk of running remote code and that you trust this tool.
|
||||
If not setting this to True, loading the tool from Hub will fail.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your agent, and the
|
||||
others will be passed along to its init.
|
||||
"""
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
"Loading an agent from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
|
||||
)
|
||||
|
||||
# Get the agent's Hub folder.
|
||||
download_kwargs = {"token": token, "repo_type": "space"} | {
|
||||
key: kwargs.pop(key)
|
||||
for key in [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"resume_download",
|
||||
"proxies",
|
||||
"revision",
|
||||
"subfolder",
|
||||
"local_files_only",
|
||||
]
|
||||
if key in kwargs
|
||||
}
|
||||
|
||||
download_folder = Path(snapshot_download(repo_id=repo_id, **download_kwargs))
|
||||
return cls.from_folder(download_folder, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_folder(cls, folder: Union[str, Path], **kwargs):
|
||||
"""Loads an agent from a local folder"""
|
||||
folder = Path(folder)
|
||||
agent_dict = json.loads((folder / "agent.json").read_text())
|
||||
|
||||
# Recursively get managed agents
|
||||
managed_agents = []
|
||||
for managed_agent_name, managed_agent_class in agent_dict["managed_agents"].items():
|
||||
agent_cls = getattr(importlib.import_module("smolagents.agents"), managed_agent_class)
|
||||
managed_agents.append(agent_cls.from_folder(folder / "managed_agents" / managed_agent_name))
|
||||
|
||||
tools = []
|
||||
for tool_name in agent_dict["tools"]:
|
||||
tool_code = (folder / "tools" / f"{tool_name}.py").read_text()
|
||||
tools.append(Tool.from_code(tool_code))
|
||||
|
||||
model_class: Model = getattr(importlib.import_module("smolagents.models"), agent_dict["model"]["class"])
|
||||
model = model_class.from_dict(agent_dict["model"]["data"])
|
||||
|
||||
args = dict(
|
||||
model=model,
|
||||
tools=tools,
|
||||
managed_agents=managed_agents,
|
||||
name=agent_dict["name"],
|
||||
description=agent_dict["description"],
|
||||
max_steps=agent_dict["max_steps"],
|
||||
planning_interval=agent_dict["planning_interval"],
|
||||
grammar=agent_dict["grammar"],
|
||||
verbosity_level=agent_dict["verbosity_level"],
|
||||
)
|
||||
if cls.__name__ == "CodeAgent":
|
||||
args["additional_authorized_imports"] = agent_dict["authorized_imports"]
|
||||
args.update(kwargs)
|
||||
return cls(**args)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
commit_message: str = "Upload agent",
|
||||
private: Optional[bool] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
create_pr: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Upload the agent to the Hub.
|
||||
|
||||
Parameters:
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push to. It should contain your organization name when
|
||||
pushing to a given organization.
|
||||
commit_message (`str`, *optional*, defaults to `"Upload agent"`):
|
||||
Message to commit while pushing.
|
||||
private (`bool`, *optional*, defaults to `None`):
|
||||
Whether to make the repo private. If `None`, the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
|
||||
token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether to create a PR with the uploaded files or directly commit.
|
||||
"""
|
||||
repo_url = create_repo(
|
||||
repo_id=repo_id,
|
||||
token=token,
|
||||
private=private,
|
||||
exist_ok=True,
|
||||
repo_type="space",
|
||||
space_sdk="gradio",
|
||||
)
|
||||
repo_id = repo_url.repo_id
|
||||
metadata_update(
|
||||
repo_id,
|
||||
{"tags": ["smolagents", "agent"]},
|
||||
repo_type="space",
|
||||
token=token,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
self.save(work_dir)
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
||||
return upload_folder(
|
||||
repo_id=repo_id,
|
||||
commit_message=commit_message,
|
||||
folder_path=work_dir,
|
||||
token=token,
|
||||
create_pr=create_pr,
|
||||
repo_type="space",
|
||||
)
|
||||
|
||||
|
||||
class ToolCallingAgent(MultiStepAgent):
|
||||
"""
|
||||
|
@ -863,6 +1180,8 @@ class CodeAgent(MultiStepAgent):
|
|||
):
|
||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
||||
self.use_e2b_executor = use_e2b_executor
|
||||
self.max_print_outputs_length = max_print_outputs_length
|
||||
prompt_templates = prompt_templates or yaml.safe_load(
|
||||
importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
|
||||
)
|
||||
|
|
|
@ -141,7 +141,7 @@ def stream_to_gradio(
|
|||
|
||||
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
|
||||
# Track tokens if model provides them
|
||||
if hasattr(agent.model, "last_input_token_count"):
|
||||
if getattr(agent.model, "last_input_token_count", None):
|
||||
total_input_tokens += agent.model.last_input_token_count
|
||||
total_output_tokens += agent.model.last_output_token_count
|
||||
if isinstance(step_log, ActionStep):
|
||||
|
|
|
@ -331,6 +331,53 @@ class Model:
|
|||
"""
|
||||
pass # To be implemented in child classes!
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""
|
||||
Converts the model into a JSON-compatible dictionary.
|
||||
"""
|
||||
model_dictionary = {
|
||||
**self.kwargs,
|
||||
"last_input_token_count": self.last_input_token_count,
|
||||
"last_output_token_count": self.last_output_token_count,
|
||||
"model_id": self.model_id,
|
||||
}
|
||||
for attribute in [
|
||||
"custom_role_conversion",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"provider",
|
||||
"timeout",
|
||||
"api_base",
|
||||
"torch_dtype",
|
||||
"device_map",
|
||||
"organization",
|
||||
"project",
|
||||
"azure_endpoint",
|
||||
]:
|
||||
if hasattr(self, attribute):
|
||||
model_dictionary[attribute] = getattr(self, attribute)
|
||||
|
||||
dangerous_attributes = ["token", "api_key"]
|
||||
for attribute_name in dangerous_attributes:
|
||||
if hasattr(self, attribute_name):
|
||||
print(
|
||||
f"For security reasons, we do not export the `{attribute_name}` attribute of your model. Please export it manually."
|
||||
)
|
||||
return model_dictionary
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, model_dictionary: Dict[str, Any]) -> "Model":
|
||||
model_instance = cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in model_dictionary.items()
|
||||
if k not in ["last_input_token_count", "last_output_token_count"]
|
||||
}
|
||||
)
|
||||
model_instance.last_input_token_count = model_dictionary.pop("last_input_token_count", None)
|
||||
model_instance.last_output_token_count = model_dictionary.pop("last_output_token_count", None)
|
||||
return model_instance
|
||||
|
||||
|
||||
class HfApiModel(Model):
|
||||
"""A class to interact with Hugging Face's Inference API for language model interaction.
|
||||
|
|
|
@ -83,6 +83,31 @@ class MethodChecker(ast.NodeVisitor):
|
|||
self.assigned_names.add(elt.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def _handle_comprehension_generators(self, generators):
|
||||
"""Helper method to handle generators in all types of comprehensions"""
|
||||
for generator in generators:
|
||||
if isinstance(generator.target, ast.Name):
|
||||
self.assigned_names.add(generator.target.id)
|
||||
elif isinstance(generator.target, ast.Tuple):
|
||||
for elt in generator.target.elts:
|
||||
if isinstance(elt, ast.Name):
|
||||
self.assigned_names.add(elt.id)
|
||||
|
||||
def visit_ListComp(self, node):
|
||||
"""Track variables in list comprehensions"""
|
||||
self._handle_comprehension_generators(node.generators)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_DictComp(self, node):
|
||||
"""Track variables in dictionary comprehensions"""
|
||||
self._handle_comprehension_generators(node.generators)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_SetComp(self, node):
|
||||
"""Track variables in set comprehensions"""
|
||||
self._handle_comprehension_generators(node.generators)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
|
||||
self.generic_visit(node)
|
||||
|
@ -121,7 +146,8 @@ class MethodChecker(ast.NodeVisitor):
|
|||
def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
||||
"""
|
||||
Validates that a Tool class follows the proper patterns:
|
||||
0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
|
||||
0. Any argument of __init__ should have a default.
|
||||
Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute.
|
||||
1. About the class:
|
||||
- Class attributes should only be strings or dicts
|
||||
- Class attributes cannot be complex attributes
|
||||
|
@ -140,13 +166,20 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
|||
if not isinstance(tree.body[0], ast.ClassDef):
|
||||
raise ValueError("Source code must define a class")
|
||||
|
||||
# Check that __init__ method takes no arguments
|
||||
# Check that __init__ method only has arguments with defaults
|
||||
if not cls.__init__.__qualname__ == "Tool.__init__":
|
||||
sig = inspect.signature(cls.__init__)
|
||||
non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
|
||||
if len(non_self_params) > 0:
|
||||
non_default_params = [
|
||||
arg_name
|
||||
for arg_name, param in sig.parameters.items()
|
||||
if arg_name != "self"
|
||||
and param.default == inspect.Parameter.empty
|
||||
and param.kind != inspect.Parameter.VAR_KEYWORD # Excludes **kwargs
|
||||
]
|
||||
if non_default_params:
|
||||
errors.append(
|
||||
f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!"
|
||||
f"This tool has required arguments in __init__: {non_default_params}. "
|
||||
"All parameters of __init__ must have default values!"
|
||||
)
|
||||
|
||||
class_node = tree.body[0]
|
||||
|
@ -198,5 +231,5 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
|||
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
|
||||
|
||||
if errors:
|
||||
raise ValueError("Tool validation failed:\n" + "\n".join(errors))
|
||||
raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors))
|
||||
return
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ast
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
|
@ -23,6 +22,7 @@ import os
|
|||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
@ -199,24 +199,9 @@ class Tool:
|
|||
"""
|
||||
self.is_initialized = True
|
||||
|
||||
def save(self, output_dir):
|
||||
"""
|
||||
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
|
||||
tool in `output_dir` as well as autogenerate:
|
||||
|
||||
- a `tool.py` file containing the logic for your tool.
|
||||
- an `app.py` file providing an UI for your tool when it is exported to a Space with `tool.push_to_hub()`
|
||||
- a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
|
||||
code)
|
||||
|
||||
Args:
|
||||
output_dir (`str`): The folder in which you want to save your tool.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
def to_dict(self) -> dict:
|
||||
"""Returns a dictionary representing the tool"""
|
||||
class_name = self.__class__.__name__
|
||||
tool_file = os.path.join(output_dir, "tool.py")
|
||||
|
||||
# Save tool file
|
||||
if type(self).__name__ == "SimpleTool":
|
||||
# Check that imports are self-contained
|
||||
source_code = get_source(self.forward).replace("@tool", "")
|
||||
|
@ -232,7 +217,7 @@ class Tool:
|
|||
tool_code = textwrap.dedent(
|
||||
f"""
|
||||
from smolagents import Tool
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
class {class_name}(Tool):
|
||||
name = "{self.name}"
|
||||
|
@ -272,11 +257,39 @@ class Tool:
|
|||
|
||||
validate_tool_attributes(self.__class__)
|
||||
|
||||
tool_code = instance_to_source(self, base_cls=Tool)
|
||||
tool_code = "from typing import Any, Optional\n" + instance_to_source(self, base_cls=Tool)
|
||||
|
||||
requirements = {el for el in get_imports(tool_code) if el not in sys.stdlib_module_names} | {"smolagents"}
|
||||
|
||||
return {"name": self.name, "code": tool_code, "requirements": requirements}
|
||||
|
||||
def save(self, output_dir: str, tool_file_name: str = "tool", make_gradio_app: bool = True):
|
||||
"""
|
||||
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
|
||||
tool in `output_dir` as well as autogenerate:
|
||||
|
||||
- a `{tool_file_name}.py` file containing the logic for your tool.
|
||||
If you pass `make_gradio_app=True`, this will also write:
|
||||
- an `app.py` file providing a UI for your tool when it is exported to a Space with `tool.push_to_hub()`
|
||||
- a `requirements.txt` containing the names of the modules used by your tool (as detected when inspecting its
|
||||
code)
|
||||
|
||||
Args:
|
||||
output_dir (`str`): The folder in which you want to save your tool.
|
||||
tool_file_name (`str`, *optional*): The file name in which you want to save your tool.
|
||||
make_gradio_app (`bool`, *optional*, defaults to True): Whether to also export a `requirements.txt` file and Gradio UI.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
class_name = self.__class__.__name__
|
||||
tool_file = os.path.join(output_dir, f"{tool_file_name}.py")
|
||||
|
||||
tool_dict = self.to_dict()
|
||||
tool_code = tool_dict["code"]
|
||||
|
||||
with open(tool_file, "w", encoding="utf-8") as f:
|
||||
f.write(tool_code.replace(":true,", ":True,").replace(":true}", ":True}"))
|
||||
|
||||
if make_gradio_app:
|
||||
# Save app file
|
||||
app_file = os.path.join(output_dir, "app.py")
|
||||
with open(app_file, "w", encoding="utf-8") as f:
|
||||
|
@ -284,8 +297,7 @@ class Tool:
|
|||
textwrap.dedent(
|
||||
f"""
|
||||
from smolagents import launch_gradio_demo
|
||||
from typing import Optional
|
||||
from tool import {class_name}
|
||||
from {tool_file_name} import {class_name}
|
||||
|
||||
tool = {class_name}()
|
||||
|
||||
|
@ -295,10 +307,9 @@ class Tool:
|
|||
)
|
||||
|
||||
# Save requirements file
|
||||
imports = {el for el in get_imports(tool_file) if el not in sys.stdlib_module_names} | {"smolagents"}
|
||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(imports) + "\n")
|
||||
f.write("\n".join(tool_dict["requirements"]) + "\n")
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
|
@ -311,14 +322,6 @@ class Tool:
|
|||
"""
|
||||
Upload the tool to the Hub.
|
||||
|
||||
For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
|
||||
For instance:
|
||||
```
|
||||
from my_tool_module import MyTool
|
||||
my_tool = MyTool()
|
||||
my_tool.push_to_hub("my-username/my-space")
|
||||
```
|
||||
|
||||
Parameters:
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push your tool to. It should contain your organization name when
|
||||
|
@ -342,13 +345,11 @@ class Tool:
|
|||
space_sdk="gradio",
|
||||
)
|
||||
repo_id = repo_url.repo_id
|
||||
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space", token=token)
|
||||
metadata_update(repo_id, {"tags": ["smolagents", "tool"]}, repo_type="space", token=token)
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
# Save all files.
|
||||
self.save(work_dir)
|
||||
with open(work_dir + "/tool.py", "r") as f:
|
||||
print("\n".join(f.readlines()))
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
||||
return upload_folder(
|
||||
repo_id=repo_id,
|
||||
|
@ -394,7 +395,7 @@ class Tool:
|
|||
"""
|
||||
if not trust_remote_code:
|
||||
raise ValueError(
|
||||
"Loading a tool from Hub requires to trust remote code. Make sure you've inspected the repo and pass `trust_remote_code=True` to load the tool."
|
||||
"Loading a tool from Hub requires to acknowledge you trust its code: to do so, pass `trust_remote_code=True`."
|
||||
)
|
||||
|
||||
# Get the tool's tool.py file.
|
||||
|
@ -413,27 +414,23 @@ class Tool:
|
|||
)
|
||||
|
||||
tool_code = Path(tool_file).read_text()
|
||||
return Tool.from_code(tool_code, **kwargs)
|
||||
|
||||
# Find the Tool subclass in the namespace
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save the code to a file
|
||||
module_path = os.path.join(temp_dir, "tool.py")
|
||||
with open(module_path, "w") as f:
|
||||
f.write(tool_code)
|
||||
@classmethod
|
||||
def from_code(cls, tool_code: str, **kwargs):
|
||||
module = types.ModuleType("dynamic_tool")
|
||||
|
||||
print("TOOL CODE:\n", tool_code)
|
||||
exec(tool_code, module.__dict__)
|
||||
|
||||
# Load module from file path
|
||||
spec = importlib.util.spec_from_file_location("tool", module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find and instantiate the Tool class
|
||||
for item_name in dir(module):
|
||||
item = getattr(module, item_name)
|
||||
if isinstance(item, type) and issubclass(item, Tool) and item != Tool:
|
||||
tool_class = item
|
||||
break
|
||||
# Find the Tool subclass
|
||||
tool_class = next(
|
||||
(
|
||||
obj
|
||||
for _, obj in inspect.getmembers(module, inspect.isclass)
|
||||
if issubclass(obj, Tool) and obj is not Tool
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if tool_class is None:
|
||||
raise ValueError("No Tool subclass found in the code.")
|
||||
|
|
|
@ -20,6 +20,7 @@ import importlib.metadata
|
|||
import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import types
|
||||
|
@ -414,3 +415,10 @@ def encode_image_base64(image):
|
|||
|
||||
def make_image_url(base64_image):
|
||||
return f"data:image/png;base64,{base64_image}"
|
||||
|
||||
|
||||
def make_init_file(folder: str):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
# Create __init__
|
||||
with open(os.path.join(folder, "__init__.py"), "w"):
|
||||
pass
|
||||
|
|
|
@ -31,12 +31,13 @@ from smolagents.agents import (
|
|||
ToolCallingAgent,
|
||||
populate_template,
|
||||
)
|
||||
from smolagents.default_tools import PythonInterpreterTool
|
||||
from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool
|
||||
from smolagents.memory import PlanningStep
|
||||
from smolagents.models import (
|
||||
ChatMessage,
|
||||
ChatMessageToolCall,
|
||||
ChatMessageToolCallDefinition,
|
||||
HfApiModel,
|
||||
MessageRole,
|
||||
TransformersModel,
|
||||
)
|
||||
|
@ -436,10 +437,15 @@ class AgentTests(unittest.TestCase):
|
|||
assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default
|
||||
|
||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||
with pytest.raises(ValueError) as e:
|
||||
agent = CodeAgent(tools=toolset_2, model=fake_code_model)
|
||||
assert (
|
||||
len(agent.tools) == 2
|
||||
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
||||
assert "Each tool or managed_agent should have a unique name!" in str(e)
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
agent.name = "python_interpreter"
|
||||
agent.description = "empty"
|
||||
CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model, managed_agents=[agent])
|
||||
assert "Each tool or managed_agent should have a unique name!" in str(e)
|
||||
|
||||
# check that python_interpreter base tool does not get added to CodeAgent
|
||||
agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True)
|
||||
|
@ -484,132 +490,6 @@ class AgentTests(unittest.TestCase):
|
|||
str_output = capture.get()
|
||||
assert "`additional_authorized_imports`" in str_output.replace("\n", "")
|
||||
|
||||
def test_multiagents(self):
|
||||
class FakeModelMultiagentsManagerAgent:
|
||||
model_id = "fake_model"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages,
|
||||
stop_sequences=None,
|
||||
grammar=None,
|
||||
tools_to_call_from=None,
|
||||
):
|
||||
if tools_to_call_from is not None:
|
||||
if len(messages) < 3:
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="search_agent",
|
||||
arguments="Who is the current US president?",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer", arguments="Final report."
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
if len(messages) < 3:
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's call our search agent.
|
||||
Code:
|
||||
```py
|
||||
result = search_agent("Who is the current US president?")
|
||||
```<end_code>
|
||||
""",
|
||||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's return the report.
|
||||
Code:
|
||||
```py
|
||||
final_answer("Final report.")
|
||||
```<end_code>
|
||||
""",
|
||||
)
|
||||
|
||||
manager_model = FakeModelMultiagentsManagerAgent()
|
||||
|
||||
class FakeModelMultiagentsManagedAgent:
|
||||
model_id = "fake_model"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages,
|
||||
tools_to_call_from=None,
|
||||
stop_sequences=None,
|
||||
grammar=None,
|
||||
):
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer",
|
||||
arguments="Report on the current US president",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
managed_model = FakeModelMultiagentsManagedAgent()
|
||||
|
||||
web_agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
model=managed_model,
|
||||
max_steps=10,
|
||||
name="search_agent",
|
||||
description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports",
|
||||
)
|
||||
|
||||
manager_code_agent = CodeAgent(
|
||||
tools=[],
|
||||
model=manager_model,
|
||||
managed_agents=[web_agent],
|
||||
additional_authorized_imports=["time", "numpy", "pandas"],
|
||||
)
|
||||
|
||||
report = manager_code_agent.run("Fake question.")
|
||||
assert report == "Final report."
|
||||
|
||||
manager_toolcalling_agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
model=manager_model,
|
||||
managed_agents=[web_agent],
|
||||
)
|
||||
|
||||
report = manager_toolcalling_agent.run("Fake question.")
|
||||
assert report == "Final report."
|
||||
|
||||
# Test that visualization works
|
||||
manager_code_agent.visualize()
|
||||
|
||||
def test_code_nontrivial_final_answer_works(self):
|
||||
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
|
||||
return ChatMessage(
|
||||
|
@ -887,6 +767,191 @@ class TestCodeAgent:
|
|||
assert result == expected_summary
|
||||
|
||||
|
||||
class MultiAgentsTests(unittest.TestCase):
|
||||
def test_multiagents_save(self):
|
||||
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct", max_tokens=2096, temperature=0.5)
|
||||
|
||||
web_agent = ToolCallingAgent(
|
||||
model=model,
|
||||
tools=[DuckDuckGoSearchTool(max_results=2), VisitWebpageTool()],
|
||||
name="web_agent",
|
||||
description="does web searches",
|
||||
)
|
||||
code_agent = CodeAgent(model=model, tools=[], name="useless", description="does nothing in particular")
|
||||
|
||||
agent = CodeAgent(
|
||||
model=model,
|
||||
tools=[],
|
||||
additional_authorized_imports=["pandas", "datetime"],
|
||||
managed_agents=[web_agent, code_agent],
|
||||
)
|
||||
agent.save("agent_export")
|
||||
|
||||
expected_structure = {
|
||||
"managed_agents": {
|
||||
"useless": {"tools": {"files": ["final_answer.py"]}, "files": ["agent.json", "prompts.yaml"]},
|
||||
"web_agent": {
|
||||
"tools": {"files": ["final_answer.py", "visit_webpage.py", "web_search.py"]},
|
||||
"files": ["agent.json", "prompts.yaml"],
|
||||
},
|
||||
},
|
||||
"tools": {"files": ["final_answer.py"]},
|
||||
"files": ["app.py", "requirements.txt", "agent.json", "prompts.yaml"],
|
||||
}
|
||||
|
||||
def verify_structure(current_path: Path, structure: dict):
|
||||
for dir_name, contents in structure.items():
|
||||
if dir_name != "files":
|
||||
# For directories, verify they exist and recurse into them
|
||||
dir_path = current_path / dir_name
|
||||
assert dir_path.exists(), f"Directory {dir_path} does not exist"
|
||||
assert dir_path.is_dir(), f"{dir_path} is not a directory"
|
||||
verify_structure(dir_path, contents)
|
||||
else:
|
||||
# For files, verify each exists in the current path
|
||||
for file_name in contents:
|
||||
file_path = current_path / file_name
|
||||
assert file_path.exists(), f"File {file_path} does not exist"
|
||||
assert file_path.is_file(), f"{file_path} is not a file"
|
||||
|
||||
verify_structure(Path("agent_export"), expected_structure)
|
||||
|
||||
# Test that re-loaded agents work as expected.
|
||||
agent2 = CodeAgent.from_folder("agent_export", planning_interval=5)
|
||||
assert agent2.planning_interval == 5 # Check that kwargs are used
|
||||
assert set(agent2.authorized_imports) == set(["pandas", "datetime"] + BASE_BUILTIN_MODULES)
|
||||
assert (
|
||||
agent2.managed_agents["web_agent"].tools["web_search"].max_results == 10
|
||||
) # For now tool init parameters are forgotten
|
||||
assert agent2.model.kwargs["temperature"] == pytest.approx(0.5)
|
||||
|
||||
def test_multiagents(self):
|
||||
class FakeModelMultiagentsManagerAgent:
|
||||
model_id = "fake_model"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages,
|
||||
stop_sequences=None,
|
||||
grammar=None,
|
||||
tools_to_call_from=None,
|
||||
):
|
||||
if tools_to_call_from is not None:
|
||||
if len(messages) < 3:
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="search_agent",
|
||||
arguments="Who is the current US president?",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer", arguments="Final report."
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
if len(messages) < 3:
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's call our search agent.
|
||||
Code:
|
||||
```py
|
||||
result = search_agent("Who is the current US president?")
|
||||
```<end_code>
|
||||
""",
|
||||
)
|
||||
else:
|
||||
assert "Report on the current US president" in str(messages)
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="""
|
||||
Thought: Let's return the report.
|
||||
Code:
|
||||
```py
|
||||
final_answer("Final report.")
|
||||
```<end_code>
|
||||
""",
|
||||
)
|
||||
|
||||
manager_model = FakeModelMultiagentsManagerAgent()
|
||||
|
||||
class FakeModelMultiagentsManagedAgent:
|
||||
model_id = "fake_model"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages,
|
||||
tools_to_call_from=None,
|
||||
stop_sequences=None,
|
||||
grammar=None,
|
||||
):
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatMessageToolCall(
|
||||
id="call_0",
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer",
|
||||
arguments="Report on the current US president",
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
managed_model = FakeModelMultiagentsManagedAgent()
|
||||
|
||||
web_agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
model=managed_model,
|
||||
max_steps=10,
|
||||
name="search_agent",
|
||||
description="Runs web searches for you. Give it your request as an argument. Make the request as detailed as needed, you can ask for thorough reports",
|
||||
)
|
||||
|
||||
manager_code_agent = CodeAgent(
|
||||
tools=[],
|
||||
model=manager_model,
|
||||
managed_agents=[web_agent],
|
||||
additional_authorized_imports=["time", "numpy", "pandas"],
|
||||
)
|
||||
|
||||
report = manager_code_agent.run("Fake question.")
|
||||
assert report == "Final report."
|
||||
|
||||
manager_toolcalling_agent = ToolCallingAgent(
|
||||
tools=[],
|
||||
model=manager_model,
|
||||
managed_agents=[web_agent],
|
||||
)
|
||||
|
||||
report = manager_toolcalling_agent.run("Fake question.")
|
||||
assert report == "Final report."
|
||||
|
||||
# Test that visualization works
|
||||
manager_code_agent.visualize()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_templates():
|
||||
return {
|
||||
|
|
|
@ -13,13 +13,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from smolagents._function_type_hints_utils import get_json_schema
|
||||
import pytest
|
||||
|
||||
from smolagents._function_type_hints_utils import get_imports, get_json_schema
|
||||
|
||||
|
||||
class AgentTextTests(unittest.TestCase):
|
||||
def test_return_none(self):
|
||||
class TestJsonSchema(unittest.TestCase):
|
||||
def test_get_json_schema(self):
|
||||
def fn(x: int, y: Optional[Tuple[str, str, float]] = None) -> None:
|
||||
"""
|
||||
Test function
|
||||
|
@ -52,3 +54,65 @@ class AgentTextTests(unittest.TestCase):
|
|||
schema["function"]["parameters"]["properties"]["y"], expected_schema["parameters"]["properties"]["y"]
|
||||
)
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
|
||||
class TestGetCode:
|
||||
@pytest.mark.parametrize(
|
||||
"code, expected",
|
||||
[
|
||||
(
|
||||
"""
|
||||
import numpy
|
||||
import pandas
|
||||
""",
|
||||
["numpy", "pandas"],
|
||||
),
|
||||
# From imports
|
||||
(
|
||||
"""
|
||||
from torch import nn
|
||||
from transformers import AutoModel
|
||||
""",
|
||||
["torch", "transformers"],
|
||||
),
|
||||
# Mixed case with nested imports
|
||||
(
|
||||
"""
|
||||
import numpy as np
|
||||
from torch.nn import Linear
|
||||
import os.path
|
||||
""",
|
||||
["numpy", "torch", "os"],
|
||||
),
|
||||
# Try/except block (should be filtered)
|
||||
(
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pass
|
||||
import numpy
|
||||
""",
|
||||
["numpy"],
|
||||
),
|
||||
# Flash attention block (should be filtered)
|
||||
(
|
||||
"""
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func
|
||||
import transformers
|
||||
""",
|
||||
["transformers"],
|
||||
),
|
||||
# Relative imports (should be excluded)
|
||||
(
|
||||
"""
|
||||
from .utils import helper
|
||||
from ..models import transformer
|
||||
""",
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_imports(self, code: str, expected: List[str]):
|
||||
assert sorted(get_imports(code)) == sorted(expected)
|
||||
|
|
|
@ -215,8 +215,9 @@ class ToolTests(unittest.TestCase):
|
|||
|
||||
return str(datetime.now())
|
||||
|
||||
def test_saving_tool_allows_no_arg_in_init(self):
|
||||
# Test one cannot save tool with additional args in init
|
||||
def test_tool_to_dict_allows_no_arg_in_init(self):
|
||||
"""Test that a tool cannot be saved with required args in init"""
|
||||
|
||||
class FailTool(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
|
@ -225,15 +226,31 @@ class ToolTests(unittest.TestCase):
|
|||
|
||||
def __init__(self, url):
|
||||
super().__init__(self)
|
||||
self.url = "none"
|
||||
self.url = url
|
||||
|
||||
def forward(self, string_input: str) -> str:
|
||||
return self.url + string_input
|
||||
|
||||
fail_tool = FailTool("dummy_url")
|
||||
with pytest.raises(Exception) as e:
|
||||
fail_tool.save("output")
|
||||
assert "__init__" in str(e)
|
||||
fail_tool.to_dict()
|
||||
assert "All parameters of __init__ must have default values!" in str(e)
|
||||
|
||||
class PassTool(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, url: Optional[str] = "none"):
|
||||
super().__init__(self)
|
||||
self.url = url
|
||||
|
||||
def forward(self, string_input: str) -> str:
|
||||
return self.url + string_input
|
||||
|
||||
fail_tool = PassTool()
|
||||
fail_tool.to_dict()
|
||||
|
||||
def test_saving_tool_allows_no_imports_from_outside_methods(self):
|
||||
# Test that using imports from outside functions fails
|
||||
|
|
|
@ -146,11 +146,12 @@ def test_e2e_class_tool_save():
|
|||
|
||||
test_tool = TestTool()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_tool.save(tmp_dir)
|
||||
test_tool.save(tmp_dir, make_gradio_app=True)
|
||||
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
|
||||
assert (
|
||||
pathlib.Path(tmp_dir, "tool.py").read_text()
|
||||
== """from smolagents.tools import Tool
|
||||
== """from typing import Any, Optional
|
||||
from smolagents.tools import Tool
|
||||
import IPython
|
||||
|
||||
class TestTool(Tool):
|
||||
|
@ -173,7 +174,6 @@ class TestTool(Tool):
|
|||
assert (
|
||||
pathlib.Path(tmp_dir, "app.py").read_text()
|
||||
== """from smolagents import launch_gradio_demo
|
||||
from typing import Optional
|
||||
from tool import TestTool
|
||||
|
||||
tool = TestTool()
|
||||
|
@ -201,13 +201,14 @@ def test_e2e_ipython_class_tool_save():
|
|||
import IPython # noqa: F401
|
||||
|
||||
return task
|
||||
TestTool().save("{tmp_dir}")
|
||||
TestTool().save("{tmp_dir}", make_gradio_app=True)
|
||||
""")
|
||||
assert shell.run_cell(code_blob, store_history=True).success
|
||||
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
|
||||
assert (
|
||||
pathlib.Path(tmp_dir, "tool.py").read_text()
|
||||
== """from smolagents.tools import Tool
|
||||
== """from typing import Any, Optional
|
||||
from smolagents.tools import Tool
|
||||
import IPython
|
||||
|
||||
class TestTool(Tool):
|
||||
|
@ -230,7 +231,6 @@ class TestTool(Tool):
|
|||
assert (
|
||||
pathlib.Path(tmp_dir, "app.py").read_text()
|
||||
== """from smolagents import launch_gradio_demo
|
||||
from typing import Optional
|
||||
from tool import TestTool
|
||||
|
||||
tool = TestTool()
|
||||
|
@ -254,12 +254,12 @@ def test_e2e_function_tool_save():
|
|||
return task
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_tool.save(tmp_dir)
|
||||
test_tool.save(tmp_dir, make_gradio_app=True)
|
||||
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
|
||||
assert (
|
||||
pathlib.Path(tmp_dir, "tool.py").read_text()
|
||||
== """from smolagents import Tool
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
class SimpleTool(Tool):
|
||||
name = "test_tool"
|
||||
|
@ -283,7 +283,6 @@ class SimpleTool(Tool):
|
|||
assert (
|
||||
pathlib.Path(tmp_dir, "app.py").read_text()
|
||||
== """from smolagents import launch_gradio_demo
|
||||
from typing import Optional
|
||||
from tool import SimpleTool
|
||||
|
||||
tool = SimpleTool()
|
||||
|
@ -311,14 +310,14 @@ def test_e2e_ipython_function_tool_save():
|
|||
|
||||
return task
|
||||
|
||||
test_tool.save("{tmp_dir}")
|
||||
test_tool.save("{tmp_dir}", make_gradio_app=True)
|
||||
""")
|
||||
assert shell.run_cell(code_blob, store_history=True).success
|
||||
assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"}
|
||||
assert (
|
||||
pathlib.Path(tmp_dir, "tool.py").read_text()
|
||||
== """from smolagents import Tool
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
class SimpleTool(Tool):
|
||||
name = "test_tool"
|
||||
|
@ -342,7 +341,6 @@ class SimpleTool(Tool):
|
|||
assert (
|
||||
pathlib.Path(tmp_dir, "app.py").read_text()
|
||||
== """from smolagents import launch_gradio_demo
|
||||
from typing import Optional
|
||||
from tool import SimpleTool
|
||||
|
||||
tool = SimpleTool()
|
||||
|
|
Loading…
Reference in New Issue