Add linter rules + apply make style (#255)

* Add linter rules + apply make style
This commit is contained in:
Lucain 2025-01-18 19:01:15 +01:00 committed by GitHub
parent 5aa0f2b53d
commit 6e1373a324
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 378 additions and 944 deletions

View File

@ -181,6 +181,7 @@
"import datasets\n", "import datasets\n",
"import pandas as pd\n", "import pandas as pd\n",
"\n", "\n",
"\n",
"eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n", "eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
"pd.DataFrame(eval_ds)" "pd.DataFrame(eval_ds)"
] ]
@ -199,26 +200,28 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import time\n",
"import json\n", "import json\n",
"import os\n", "import os\n",
"import re\n", "import re\n",
"import string\n", "import string\n",
"import time\n",
"import warnings\n", "import warnings\n",
"from tqdm import tqdm\n",
"from typing import List\n", "from typing import List\n",
"\n", "\n",
"from dotenv import load_dotenv\n",
"from tqdm import tqdm\n",
"\n",
"from smolagents import (\n", "from smolagents import (\n",
" GoogleSearchTool,\n",
" CodeAgent,\n",
" ToolCallingAgent,\n",
" HfApiModel,\n",
" AgentError,\n", " AgentError,\n",
" VisitWebpageTool,\n", " CodeAgent,\n",
" GoogleSearchTool,\n",
" HfApiModel,\n",
" PythonInterpreterTool,\n", " PythonInterpreterTool,\n",
" ToolCallingAgent,\n",
" VisitWebpageTool,\n",
")\n", ")\n",
"from smolagents.agents import ActionStep\n", "from smolagents.agents import ActionStep\n",
"from dotenv import load_dotenv\n", "\n",
"\n", "\n",
"load_dotenv()\n", "load_dotenv()\n",
"os.makedirs(\"output\", exist_ok=True)\n", "os.makedirs(\"output\", exist_ok=True)\n",
@ -231,9 +234,7 @@
" return str(obj)\n", " return str(obj)\n",
"\n", "\n",
"\n", "\n",
"def answer_questions(\n", "def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n",
" eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False\n",
"):\n",
" answered_questions = []\n", " answered_questions = []\n",
" if os.path.exists(file_name):\n", " if os.path.exists(file_name):\n",
" with open(file_name, \"r\") as f:\n", " with open(file_name, \"r\") as f:\n",
@ -365,23 +366,18 @@
" ma_elems = split_string(model_answer)\n", " ma_elems = split_string(model_answer)\n",
"\n", "\n",
" if len(gt_elems) != len(ma_elems): # check length is the same\n", " if len(gt_elems) != len(ma_elems): # check length is the same\n",
" warnings.warn(\n", " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
" \"Answer lists have different lengths, returning False.\", UserWarning\n",
" )\n",
" return False\n", " return False\n",
"\n", "\n",
" comparisons = []\n", " comparisons = []\n",
" for ma_elem, gt_elem in zip(\n", " for ma_elem, gt_elem in zip(ma_elems, gt_elems): # compare each element as float or str\n",
" ma_elems, gt_elems\n",
" ): # compare each element as float or str\n",
" if is_float(gt_elem):\n", " if is_float(gt_elem):\n",
" normalized_ma_elem = normalize_number_str(ma_elem)\n", " normalized_ma_elem = normalize_number_str(ma_elem)\n",
" comparisons.append(normalized_ma_elem == float(gt_elem))\n", " comparisons.append(normalized_ma_elem == float(gt_elem))\n",
" else:\n", " else:\n",
" # we do not remove punct since comparisons can include punct\n", " # we do not remove punct since comparisons can include punct\n",
" comparisons.append(\n", " comparisons.append(\n",
" normalize_str(ma_elem, remove_punct=False)\n", " normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)\n",
" == normalize_str(gt_elem, remove_punct=False)\n",
" )\n", " )\n",
" return all(comparisons)\n", " return all(comparisons)\n",
"\n", "\n",
@ -441,9 +437,7 @@
" action_type = \"vanilla\"\n", " action_type = \"vanilla\"\n",
" llm = HfApiModel(model_id)\n", " llm = HfApiModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(\n", " answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
" )"
] ]
}, },
{ {
@ -461,6 +455,7 @@
"source": [ "source": [
"from smolagents import LiteLLMModel\n", "from smolagents import LiteLLMModel\n",
"\n", "\n",
"\n",
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n", "litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
"\n", "\n",
"for model_id in litellm_model_ids:\n", "for model_id in litellm_model_ids:\n",
@ -492,9 +487,7 @@
" action_type = \"vanilla\"\n", " action_type = \"vanilla\"\n",
" llm = LiteLLMModel(model_id)\n", " llm = LiteLLMModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(\n", " answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
" )"
] ]
}, },
{ {
@ -556,9 +549,11 @@
} }
], ],
"source": [ "source": [
"import pandas as pd\n",
"import glob\n", "import glob\n",
"\n", "\n",
"import pandas as pd\n",
"\n",
"\n",
"res = []\n", "res = []\n",
"for file_path in glob.glob(\"output/*.jsonl\"):\n", "for file_path in glob.glob(\"output/*.jsonl\"):\n",
" data = []\n", " data = []\n",
@ -595,11 +590,7 @@
"\n", "\n",
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n", "result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
"\n", "\n",
"result_df = (\n", "result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()"
" (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100)\n",
" .round(1)\n",
" .reset_index()\n",
")"
] ]
}, },
{ {
@ -895,6 +886,7 @@
"import pandas as pd\n", "import pandas as pd\n",
"from matplotlib.legend_handler import HandlerTuple # Added import\n", "from matplotlib.legend_handler import HandlerTuple # Added import\n",
"\n", "\n",
"\n",
"# Assuming pivot_df is your original dataframe\n", "# Assuming pivot_df is your original dataframe\n",
"models = pivot_df[\"model_id\"].unique()\n", "models = pivot_df[\"model_id\"].unique()\n",
"sources = pivot_df[\"source\"].unique()\n", "sources = pivot_df[\"source\"].unique()\n",
@ -961,14 +953,10 @@
"handles, labels = ax.get_legend_handles_labels()\n", "handles, labels = ax.get_legend_handles_labels()\n",
"unique_sources = sources\n", "unique_sources = sources\n",
"legend_elements = [\n", "legend_elements = [\n",
" (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\"))\n", " (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\")) for i in range(len(unique_sources))\n",
" for i in range(len(unique_sources))\n",
"]\n", "]\n",
"custom_legend = ax.legend(\n", "custom_legend = ax.legend(\n",
" [\n", " [(agent_handle, vanilla_handle) for agent_handle, vanilla_handle, _ in legend_elements],\n",
" (agent_handle, vanilla_handle)\n",
" for agent_handle, vanilla_handle, _ in legend_elements\n",
" ],\n",
" [label for _, _, label in legend_elements],\n", " [label for _, _, label in legend_elements],\n",
" handler_map={tuple: HandlerTuple(ndivide=None)},\n", " handler_map={tuple: HandlerTuple(ndivide=None)},\n",
" bbox_to_anchor=(1.05, 1),\n", " bbox_to_anchor=(1.05, 1),\n",
@ -1006,9 +994,7 @@
" # Start the matrix environment with 4 columns\n", " # Start the matrix environment with 4 columns\n",
" # l for left-aligned model and task, c for centered numbers\n", " # l for left-aligned model and task, c for centered numbers\n",
" mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n", " mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n",
" mathjax_table += (\n", " mathjax_table += \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
" \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
" )\n",
" mathjax_table += \"\\\\hline\\n\"\n", " mathjax_table += \"\\\\hline\\n\"\n",
"\n", "\n",
" # Sort the DataFrame by model_id and source\n", " # Sort the DataFrame by model_id and source\n",
@ -1033,9 +1019,7 @@
" model_display = \"\\\\;\"\n", " model_display = \"\\\\;\"\n",
"\n", "\n",
" # Add the data row\n", " # Add the data row\n",
" mathjax_table += (\n", " mathjax_table += f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
" f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
" )\n",
"\n", "\n",
" current_model = model\n", " current_model = model\n",
"\n", "\n",

View File

@ -1,7 +1,9 @@
from smolagents import Tool, CodeAgent, HfApiModel
from smolagents.default_tools import VisitWebpageTool
from dotenv import load_dotenv from dotenv import load_dotenv
from smolagents import CodeAgent, HfApiModel, Tool
from smolagents.default_tools import VisitWebpageTool
load_dotenv() load_dotenv()
@ -16,10 +18,11 @@ class GetCatImageTool(Tool):
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png" self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
def forward(self): def forward(self):
from PIL import Image
import requests
from io import BytesIO from io import BytesIO
import requests
from PIL import Image
response = requests.get(self.url) response = requests.get(self.url)
return Image.open(BytesIO(response.content)) return Image.open(BytesIO(response.content))
@ -46,4 +49,5 @@ agent.run(
# Try the agent in a Gradio UI # Try the agent in a Gradio UI
from smolagents import GradioUI from smolagents import GradioUI
GradioUI(agent).launch() GradioUI(agent).launch()

View File

@ -1,4 +1,5 @@
from smolagents import CodeAgent, HfApiModel, GradioUI from smolagents import CodeAgent, GradioUI, HfApiModel
agent = CodeAgent(tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1) agent = CodeAgent(tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1)

View File

@ -1,24 +1,22 @@
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from smolagents import ( from smolagents import (
CodeAgent, CodeAgent,
DuckDuckGoSearchTool, DuckDuckGoSearchTool,
VisitWebpageTool, HfApiModel,
ManagedAgent, ManagedAgent,
ToolCallingAgent, ToolCallingAgent,
HfApiModel, VisitWebpageTool,
) )
# Let's setup the instrumentation first # Let's setup the instrumentation first
trace_provider = TracerProvider() trace_provider = TracerProvider()
trace_provider.add_span_processor( trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces")))
SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces"))
)
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True) SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)
@ -39,6 +37,4 @@ manager_agent = CodeAgent(
model=model, model=model,
managed_agents=[managed_agent], managed_agents=[managed_agent],
) )
manager_agent.run( manager_agent.run("If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?")
"If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?"
)

View File

@ -8,13 +8,10 @@ from langchain_community.retrievers import BM25Retriever
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter( knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
lambda row: row["source"].startswith("huggingface/transformers")
)
source_docs = [ source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
for doc in knowledge_base
] ]
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
@ -51,14 +48,12 @@ class RetrieverTool(Tool):
query, query,
) )
return "\nRetrieved documents:\n" + "".join( return "\nRetrieved documents:\n" + "".join(
[ [f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
f"\n\n===== Document {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs)
]
) )
from smolagents import HfApiModel, CodeAgent from smolagents import CodeAgent, HfApiModel
retriever_tool = RetrieverTool(docs_processed) retriever_tool = RetrieverTool(docs_processed)
agent = CodeAgent( agent = CodeAgent(
@ -68,9 +63,7 @@ agent = CodeAgent(
verbosity_level=2, verbosity_level=2,
) )
agent_output = agent.run( agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
"For a transformers model training, which is slower, the forward or the backward pass?"
)
print("Final output:") print("Final output:")
print(agent_output) print(agent_output)

View File

@ -1,16 +1,17 @@
from sqlalchemy import ( from sqlalchemy import (
create_engine,
MetaData,
Table,
Column, Column,
String,
Integer,
Float, Float,
Integer,
MetaData,
String,
Table,
create_engine,
insert, insert,
inspect, inspect,
text, text,
) )
engine = create_engine("sqlite:///:memory:") engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData() metadata_obj = MetaData()
@ -40,9 +41,7 @@ for row in rows:
inspector = inspect(engine) inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")] columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
table_description = "Columns:\n" + "\n".join( table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
[f" - {name}: {col_type}" for name, col_type in columns_info]
)
print(table_description) print(table_description)
from smolagents import tool from smolagents import tool
@ -72,6 +71,7 @@ def sql_engine(query: str) -> str:
from smolagents import CodeAgent, HfApiModel from smolagents import CodeAgent, HfApiModel
agent = CodeAgent( agent = CodeAgent(
tools=[sql_engine], tools=[sql_engine],
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"), model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),

View File

@ -1,7 +1,9 @@
from smolagents.agents import ToolCallingAgent
from smolagents import tool, LiteLLMModel
from typing import Optional from typing import Optional
from smolagents import LiteLLMModel, tool
from smolagents.agents import ToolCallingAgent
# Choose which LLM engine to use! # Choose which LLM engine to use!
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct") # model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct") # model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")

View File

@ -13,8 +13,10 @@ Usage:
import os import os
from mcp import StdioServerParameters from mcp import StdioServerParameters
from smolagents import CodeAgent, HfApiModel, ToolCollection from smolagents import CodeAgent, HfApiModel, ToolCollection
mcp_server_params = StdioServerParameters( mcp_server_params = StdioServerParameters(
command="uvx", command="uvx",
args=["--quiet", "pubmedmcp@0.1.3"], args=["--quiet", "pubmedmcp@0.1.3"],

View File

@ -1,7 +1,9 @@
from smolagents.agents import ToolCallingAgent
from smolagents import tool, LiteLLMModel
from typing import Optional from typing import Optional
from smolagents import LiteLLMModel, tool
from smolagents.agents import ToolCallingAgent
model = LiteLLMModel( model = LiteLLMModel(
model_id="ollama_chat/llama3.2", model_id="ollama_chat/llama3.2",
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary

View File

@ -60,9 +60,18 @@ dev = [
addopts = "-sv --durations=0" addopts = "-sv --durations=0"
[tool.ruff] [tool.ruff]
lint.ignore = ["F403"] line-length = 119
lint.ignore = [
"F403", # undefined-local-with-import-star
"E501", # line-too-long
]
lint.select = ["E", "F", "I", "W"]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"examples/*" = [ "examples/*" = [
"E402", # module-import-not-at-top-of-file "E402", # module-import-not-at-top-of-file
] ]
[tool.ruff.lint.isort]
known-first-party = ["smolagents"]
lines-after-imports = 2

View File

@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
from transformers.utils import _LazyModule from transformers.utils import _LazyModule
from transformers.utils.import_utils import define_import_structure from transformers.utils.import_utils import define_import_structure
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import * from .agents import *
from .default_tools import * from .default_tools import *

View File

@ -16,18 +16,17 @@
# limitations under the License. # limitations under the License.
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from enum import IntEnum
from rich import box from rich import box
from rich.console import Group from rich.console import Console, Group
from rich.panel import Panel from rich.panel import Panel
from rich.rule import Rule from rich.rule import Rule
from rich.syntax import Syntax from rich.syntax import Syntax
from rich.text import Text from rich.text import Text
from rich.console import Console
from .default_tools import FinalAnswerTool, TOOL_MAPPING from .default_tools import TOOL_MAPPING, FinalAnswerTool
from .e2b_executor import E2BExecutor from .e2b_executor import E2BExecutor
from .local_python_executor import ( from .local_python_executor import (
BASE_BUILTIN_MODULES, BASE_BUILTIN_MODULES,
@ -112,20 +111,11 @@ class SystemPromptStep(AgentStepLog):
system_prompt: str system_prompt: str
def get_tool_descriptions( def get_tool_descriptions(tools: Dict[str, Tool], tool_description_template: str) -> str:
tools: Dict[str, Tool], tool_description_template: str return "\n".join([get_tool_description_with_args(tool, tool_description_template) for tool in tools.values()])
) -> str:
return "\n".join(
[
get_tool_description_with_args(tool, tool_description_template)
for tool in tools.values()
]
)
def format_prompt_with_tools( def format_prompt_with_tools(tools: Dict[str, Tool], prompt_template: str, tool_description_template: str) -> str:
tools: Dict[str, Tool], prompt_template: str, tool_description_template: str
) -> str:
tool_descriptions = get_tool_descriptions(tools, tool_description_template) tool_descriptions = get_tool_descriptions(tools, tool_description_template)
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions) prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
if "{{tool_names}}" in prompt: if "{{tool_names}}" in prompt:
@ -159,9 +149,7 @@ def format_prompt_with_managed_agents_descriptions(
f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'" f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'"
) )
if len(managed_agents.keys()) > 0: if len(managed_agents.keys()) > 0:
return prompt_template.replace( return prompt_template.replace(agent_descriptions_placeholder, show_agents_descriptions(managed_agents))
agent_descriptions_placeholder, show_agents_descriptions(managed_agents)
)
else: else:
return prompt_template.replace(agent_descriptions_placeholder, "") return prompt_template.replace(agent_descriptions_placeholder, "")
@ -214,9 +202,7 @@ class MultiStepAgent:
self.model = model self.model = model
self.system_prompt_template = system_prompt self.system_prompt_template = system_prompt
self.tool_description_template = ( self.tool_description_template = (
tool_description_template tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
if tool_description_template
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
) )
self.max_steps = max_steps self.max_steps = max_steps
self.tool_parser = tool_parser self.tool_parser = tool_parser
@ -231,10 +217,7 @@ class MultiStepAgent:
self.tools = {tool.name: tool for tool in tools} self.tools = {tool.name: tool for tool in tools}
if add_base_tools: if add_base_tools:
for tool_name, tool_class in TOOL_MAPPING.items(): for tool_name, tool_class in TOOL_MAPPING.items():
if ( if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent":
tool_name != "python_interpreter"
or self.__class__.__name__ == "ToolCallingAgent"
):
self.tools[tool_name] = tool_class() self.tools[tool_name] = tool_class()
self.tools["final_answer"] = FinalAnswerTool() self.tools["final_answer"] = FinalAnswerTool()
@ -253,15 +236,11 @@ class MultiStepAgent:
self.system_prompt_template, self.system_prompt_template,
self.tool_description_template, self.tool_description_template,
) )
self.system_prompt = format_prompt_with_managed_agents_descriptions( self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
self.system_prompt, self.managed_agents
)
return self.system_prompt return self.system_prompt
def write_inner_memory_from_logs( def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
self, summary_mode: Optional[bool] = False
) -> List[Dict[str, str]]:
""" """
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
that can be used as input to the LLM. that can be used as input to the LLM.
@ -355,10 +334,7 @@ class MultiStepAgent:
return memory return memory
def get_succinct_logs(self): def get_succinct_logs(self):
return [ return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
{key: value for key, value in log.items() if key != "agent_memory"}
for log in self.logs
]
def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]: def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
""" """
@ -402,9 +378,7 @@ class MultiStepAgent:
except Exception as e: except Exception as e:
return f"Error in generating final LLM output:\n{e}" return f"Error in generating final LLM output:\n{e}"
def execute_tool_call( def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any:
self, tool_name: str, arguments: Union[Dict[str, str], str]
) -> Any:
""" """
Execute tool with the provided input and returns the result. Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables. This method replaces arguments with the actual values from the state if they refer to state variables.
@ -423,9 +397,7 @@ class MultiStepAgent:
if tool_name in self.managed_agents: if tool_name in self.managed_agents:
observation = available_tools[tool_name].__call__(arguments) observation = available_tools[tool_name].__call__(arguments)
else: else:
observation = available_tools[tool_name].__call__( observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True)
arguments, sanitize_inputs_outputs=True
)
elif isinstance(arguments, dict): elif isinstance(arguments, dict):
for key, value in arguments.items(): for key, value in arguments.items():
if isinstance(value, str) and value in self.state: if isinstance(value, str) and value in self.state:
@ -433,18 +405,14 @@ class MultiStepAgent:
if tool_name in self.managed_agents: if tool_name in self.managed_agents:
observation = available_tools[tool_name].__call__(**arguments) observation = available_tools[tool_name].__call__(**arguments)
else: else:
observation = available_tools[tool_name].__call__( observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
**arguments, sanitize_inputs_outputs=True
)
else: else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}." error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
return observation return observation
except Exception as e: except Exception as e:
if tool_name in self.tools: if tool_name in self.tools:
tool_description = get_tool_description_with_args( tool_description = get_tool_description_with_args(available_tools[tool_name])
available_tools[tool_name]
)
error_msg = ( error_msg = (
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}" f"As a reminder, this tool's description is the following:\n{tool_description}"
@ -544,10 +512,7 @@ You have been provided with these additional arguments, that you can access usin
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(step=self.step_number, start_time=step_start_time) step_log = ActionStep(step=self.step_number, start_time=step_start_time)
try: try:
if ( if self.planning_interval is not None and self.step_number % self.planning_interval == 0:
self.planning_interval is not None
and self.step_number % self.planning_interval == 0
):
self.planning_step( self.planning_step(
task, task,
is_first_step=(self.step_number == 0), is_first_step=(self.step_number == 0),
@ -600,10 +565,7 @@ You have been provided with these additional arguments, that you can access usin
step_start_time = time.time() step_start_time = time.time()
step_log = ActionStep(step=self.step_number, start_time=step_start_time) step_log = ActionStep(step=self.step_number, start_time=step_start_time)
try: try:
if ( if self.planning_interval is not None and self.step_number % self.planning_interval == 0:
self.planning_interval is not None
and self.step_number % self.planning_interval == 0
):
self.planning_step( self.planning_step(
task, task,
is_first_step=(self.step_number == 0), is_first_step=(self.step_number == 0),
@ -668,9 +630,7 @@ You have been provided with these additional arguments, that you can access usin
Now begin!""", Now begin!""",
} }
answer_facts = self.model( answer_facts = self.model([message_prompt_facts, message_prompt_task]).content
[message_prompt_facts, message_prompt_task]
).content
message_system_prompt_plan = { message_system_prompt_plan = {
"role": MessageRole.SYSTEM, "role": MessageRole.SYSTEM,
@ -680,12 +640,8 @@ Now begin!""",
"role": MessageRole.USER, "role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format( "content": USER_PROMPT_PLAN.format(
task=task, task=task,
tool_descriptions=get_tool_descriptions( tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
self.tools, self.tool_description_template managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents)
),
answer_facts=answer_facts, answer_facts=answer_facts,
), ),
} }
@ -702,9 +658,7 @@ Now begin!""",
``` ```
{answer_facts} {answer_facts}
```""".strip() ```""".strip()
self.logs.append( self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
self.logger.log( self.logger.log(
Rule("[bold]Initial plan", style="orange"), Rule("[bold]Initial plan", style="orange"),
Text(final_plan_redaction), Text(final_plan_redaction),
@ -724,9 +678,7 @@ Now begin!""",
"role": MessageRole.USER, "role": MessageRole.USER,
"content": USER_PROMPT_FACTS_UPDATE, "content": USER_PROMPT_FACTS_UPDATE,
} }
facts_update = self.model( facts_update = self.model([facts_update_system_prompt] + agent_memory + [facts_update_message]).content
[facts_update_system_prompt] + agent_memory + [facts_update_message]
).content
# Redact updated plan # Redact updated plan
plan_update_message = { plan_update_message = {
@ -737,12 +689,8 @@ Now begin!""",
"role": MessageRole.USER, "role": MessageRole.USER,
"content": USER_PROMPT_PLAN_UPDATE.format( "content": USER_PROMPT_PLAN_UPDATE.format(
task=task, task=task,
tool_descriptions=get_tool_descriptions( tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
self.tools, self.tool_description_template managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents)
),
facts_update=facts_update, facts_update=facts_update,
remaining_steps=(self.max_steps - step), remaining_steps=(self.max_steps - step),
), ),
@ -753,16 +701,12 @@ Now begin!""",
).content ).content
# Log final facts and plan # Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format( final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
task=task, plan_update=plan_update
)
final_facts_redaction = f"""Here is the updated list of the facts that I know: final_facts_redaction = f"""Here is the updated list of the facts that I know:
``` ```
{facts_update} {facts_update}
```""" ```"""
self.logs.append( self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
self.logger.log( self.logger.log(
Rule("[bold]Updated plan", style="orange"), Rule("[bold]Updated plan", style="orange"),
Text(final_plan_redaction), Text(final_plan_redaction),
@ -816,19 +760,13 @@ class ToolCallingAgent(MultiStepAgent):
tool_arguments = tool_call.function.arguments tool_arguments = tool_call.function.arguments
except Exception as e: except Exception as e:
raise AgentGenerationError( raise AgentGenerationError(f"Error in generating tool call with model:\n{e}")
f"Error in generating tool call with model:\n{e}"
)
log_entry.tool_calls = [ log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)
]
# Execute # Execute
self.logger.log( self.logger.log(
Panel( Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")
),
level=LogLevel.INFO, level=LogLevel.INFO,
) )
if tool_name == "final_answer": if tool_name == "final_answer":
@ -900,16 +838,10 @@ class CodeAgent(MultiStepAgent):
if system_prompt is None: if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT system_prompt = CODE_SYSTEM_PROMPT
self.additional_authorized_imports = ( self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
additional_authorized_imports if additional_authorized_imports else [] self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
)
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
if "{{authorized_imports}}" not in system_prompt: if "{{authorized_imports}}" not in system_prompt:
raise AgentError( raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.")
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
super().__init__( super().__init__(
tools=tools, tools=tools,
model=model, model=model,
@ -966,9 +898,7 @@ class CodeAgent(MultiStepAgent):
log_entry.agent_memory = agent_memory.copy() log_entry.agent_memory = agent_memory.copy()
try: try:
additional_args = ( additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
{"grammar": self.grammar} if self.grammar is not None else {}
)
llm_output = self.model( llm_output = self.model(
self.input_messages, self.input_messages,
stop_sequences=["<end_code>", "Observation:"], stop_sequences=["<end_code>", "Observation:"],
@ -999,9 +929,7 @@ class CodeAgent(MultiStepAgent):
try: try:
code_action = fix_final_answer_code(parse_code_blobs(llm_output)) code_action = fix_final_answer_code(parse_code_blobs(llm_output))
except Exception as e: except Exception as e:
error_msg = ( error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
)
raise AgentParsingError(error_msg) raise AgentParsingError(error_msg)
log_entry.tool_calls = [ log_entry.tool_calls = [
@ -1088,17 +1016,13 @@ class ManagedAgent:
self.description = description self.description = description
self.additional_prompting = additional_prompting self.additional_prompting = additional_prompting
self.provide_run_summary = provide_run_summary self.provide_run_summary = provide_run_summary
self.managed_agent_prompt = ( self.managed_agent_prompt = managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT
managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT
)
def write_full_task(self, task): def write_full_task(self, task):
"""Adds additional prompting for the managed agent, like 'add more detail in your answer'.""" """Adds additional prompting for the managed agent, like 'add more detail in your answer'."""
full_task = self.managed_agent_prompt.format(name=self.name, task=task) full_task = self.managed_agent_prompt.format(name=self.name, task=task)
if self.additional_prompting: if self.additional_prompting:
full_task = full_task.replace( full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip()
"\n{{additional_prompting}}", self.additional_prompting
).strip()
else: else:
full_task = full_task.replace("\n{{additional_prompting}}", "").strip() full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
return full_task return full_task
@ -1107,9 +1031,7 @@ class ManagedAgent:
full_task = self.write_full_task(request) full_task = self.write_full_task(request)
output = self.agent.run(full_task, **kwargs) output = self.agent.run(full_task, **kwargs)
if self.provide_run_summary: if self.provide_run_summary:
answer = ( answer = f"Here is the final answer from your managed agent '{self.name}':\n"
f"Here is the final answer from your managed agent '{self.name}':\n"
)
answer += str(output) answer += str(output)
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n" answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True): for message in self.agent.write_inner_memory_from_logs(summary_mode=True):

View File

@ -20,8 +20,6 @@ from dataclasses import dataclass
from typing import Dict, Optional from typing import Dict, Optional
from huggingface_hub import hf_hub_download, list_spaces from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode, is_torch_available from transformers.utils import is_offline_mode, is_torch_available
from .local_python_executor import ( from .local_python_executor import (
@ -32,6 +30,7 @@ from .local_python_executor import (
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio from .types import AgentAudio
if is_torch_available(): if is_torch_available():
from transformers.models.whisper import ( from transformers.models.whisper import (
WhisperForConditionalGeneration, WhisperForConditionalGeneration,
@ -61,9 +60,7 @@ def get_remote_tools(logger, organization="huggingface-tools"):
tools = {} tools = {}
for space_info in spaces: for space_info in spaces:
repo_id = space_info.id repo_id = space_info.id
resolved_config_file = hf_hub_download( resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
repo_id, TOOL_CONFIG_FILE, repo_type="space"
)
with open(resolved_config_file, encoding="utf-8") as reader: with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader) config = json.load(reader)
task = repo_id.split("/")[-1] task = repo_id.split("/")[-1]
@ -94,9 +91,7 @@ class PythonInterpreterTool(Tool):
if authorized_imports is None: if authorized_imports is None:
self.authorized_imports = list(set(BASE_BUILTIN_MODULES)) self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
else: else:
self.authorized_imports = list( self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports))
set(BASE_BUILTIN_MODULES) | set(authorized_imports)
)
self.inputs = { self.inputs = {
"code": { "code": {
"type": "string", "type": "string",
@ -126,9 +121,7 @@ class PythonInterpreterTool(Tool):
class FinalAnswerTool(Tool): class FinalAnswerTool(Tool):
name = "final_answer" name = "final_answer"
description = "Provides a final answer to the given problem." description = "Provides a final answer to the given problem."
inputs = { inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
"answer": {"type": "any", "description": "The final answer to the problem"}
}
output_type = "any" output_type = "any"
def forward(self, answer): def forward(self, answer):
@ -138,9 +131,7 @@ class FinalAnswerTool(Tool):
class UserInputTool(Tool): class UserInputTool(Tool):
name = "user_input" name = "user_input"
description = "Asks for user's input on a specific question" description = "Asks for user's input on a specific question"
inputs = { inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
"question": {"type": "string", "description": "The question to ask the user"}
}
output_type = "string" output_type = "string"
def forward(self, question): def forward(self, question):
@ -151,9 +142,7 @@ class UserInputTool(Tool):
class DuckDuckGoSearchTool(Tool): class DuckDuckGoSearchTool(Tool):
name = "web_search" name = "web_search"
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.""" description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."""
inputs = { inputs = {"query": {"type": "string", "description": "The search query to perform."}}
"query": {"type": "string", "description": "The search query to perform."}
}
output_type = "string" output_type = "string"
def __init__(self, *args, max_results=10, **kwargs): def __init__(self, *args, max_results=10, **kwargs):
@ -169,10 +158,7 @@ class DuckDuckGoSearchTool(Tool):
def forward(self, query: str) -> str: def forward(self, query: str) -> str:
results = self.ddgs.text(query, max_results=self.max_results) results = self.ddgs.text(query, max_results=self.max_results)
postprocessed_results = [ postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
f"[{result['title']}]({result['href']})\n{result['body']}"
for result in results
]
return "## Search Results\n\n" + "\n\n".join(postprocessed_results) return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
@ -199,9 +185,7 @@ class GoogleSearchTool(Tool):
import requests import requests
if self.serpapi_key is None: if self.serpapi_key is None:
raise ValueError( raise ValueError("Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables.")
"Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables."
)
params = { params = {
"engine": "google", "engine": "google",
@ -210,9 +194,7 @@ class GoogleSearchTool(Tool):
"google_domain": "google.com", "google_domain": "google.com",
} }
if filter_year is not None: if filter_year is not None:
params["tbs"] = ( params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
)
response = requests.get("https://serpapi.com/search.json", params=params) response = requests.get("https://serpapi.com/search.json", params=params)
@ -227,13 +209,9 @@ class GoogleSearchTool(Tool):
f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year." f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
) )
else: else:
raise Exception( raise Exception(f"'organic_results' key not found for query: '{query}'. Use a less restrictive query.")
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
)
if len(results["organic_results"]) == 0: if len(results["organic_results"]) == 0:
year_filter_message = ( year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
f" with filter year={filter_year}" if filter_year is not None else ""
)
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter." return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
web_snippets = [] web_snippets = []
@ -253,9 +231,7 @@ class GoogleSearchTool(Tool):
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}" redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
redacted_version = redacted_version.replace( redacted_version = redacted_version.replace("Your browser can't play this video.", "")
"Your browser can't play this video.", ""
)
web_snippets.append(redacted_version) web_snippets.append(redacted_version)
return "## Search Results\n" + "\n\n".join(web_snippets) return "## Search Results\n" + "\n\n".join(web_snippets)
@ -263,7 +239,9 @@ class GoogleSearchTool(Tool):
class VisitWebpageTool(Tool): class VisitWebpageTool(Tool):
name = "visit_webpage" name = "visit_webpage"
description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages." description = (
"Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
)
inputs = { inputs = {
"url": { "url": {
"type": "string", "type": "string",
@ -277,6 +255,7 @@ class VisitWebpageTool(Tool):
import requests import requests
from markdownify import markdownify from markdownify import markdownify
from requests.exceptions import RequestException from requests.exceptions import RequestException
from smolagents.utils import truncate_content from smolagents.utils import truncate_content
except ImportError: except ImportError:
raise ImportError( raise ImportError(

View File

@ -28,6 +28,7 @@ from .tool_validation import validate_tool_attributes
from .tools import Tool from .tools import Tool
from .utils import BASE_BUILTIN_MODULES, instance_to_source from .utils import BASE_BUILTIN_MODULES, instance_to_source
load_dotenv() load_dotenv()
@ -45,9 +46,7 @@ class E2BExecutor:
self.logger = logger self.logger = logger
additional_imports = additional_imports + ["pickle5", "smolagents"] additional_imports = additional_imports + ["pickle5", "smolagents"]
if len(additional_imports) > 0: if len(additional_imports) > 0:
execution = self.sbx.commands.run( execution = self.sbx.commands.run("pip install " + " ".join(additional_imports))
"pip install " + " ".join(additional_imports)
)
if execution.error: if execution.error:
raise Exception(f"Error installing dependencies: {execution.error}") raise Exception(f"Error installing dependencies: {execution.error}")
else: else:
@ -61,9 +60,7 @@ class E2BExecutor:
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n" tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
tool_codes.append(tool_code) tool_codes.append(tool_code)
tool_definition_code = "\n".join( tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
[f"import {module}" for module in BASE_BUILTIN_MODULES]
)
tool_definition_code += textwrap.dedent(""" tool_definition_code += textwrap.dedent("""
class Tool: class Tool:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
@ -122,9 +119,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
for attribute_name in ["jpeg", "png"]: for attribute_name in ["jpeg", "png"]:
if getattr(result, attribute_name) is not None: if getattr(result, attribute_name) is not None:
image_output = getattr(result, attribute_name) image_output = getattr(result, attribute_name)
decoded_bytes = base64.b64decode( decoded_bytes = base64.b64decode(image_output.encode("utf-8"))
image_output.encode("utf-8")
)
return Image.open(BytesIO(decoded_bytes)), execution_logs return Image.open(BytesIO(decoded_bytes)), execution_logs
for attribute_name in [ for attribute_name in [
"chart", "chart",

View File

@ -13,14 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gradio as gr
import shutil
import os
import mimetypes import mimetypes
import os
import re import re
import shutil
from typing import Optional from typing import Optional
import gradio as gr
from .agents import ActionStep, AgentStepLog, MultiStepAgent from .agents import ActionStep, AgentStepLog, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
@ -59,9 +59,7 @@ def stream_to_gradio(
): ):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
for step_log in agent.run( for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
):
for message in pull_messages_from_step(step_log, test_mode=test_mode): for message in pull_messages_from_step(step_log, test_mode=test_mode):
yield message yield message
@ -147,14 +145,10 @@ class GradioUI:
sanitized_name = "".join(sanitized_name) sanitized_name = "".join(sanitized_name)
# Save the uploaded file to the specified folder # Save the uploaded file to the specified folder
file_path = os.path.join( file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
self.file_upload_folder, os.path.basename(sanitized_name)
)
shutil.copy(file.name, file_path) shutil.copy(file.name, file_path)
return gr.Textbox( return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
f"File uploaded: {file_path}", visible=True
), file_uploads_log + [file_path]
def log_user_message(self, text_input, file_uploads_log): def log_user_message(self, text_input, file_uploads_log):
return ( return (
@ -183,9 +177,7 @@ class GradioUI:
# If an upload folder is provided, enable the upload feature # If an upload folder is provided, enable the upload feature
if self.file_upload_folder is not None: if self.file_upload_folder is not None:
upload_file = gr.File(label="Upload a file") upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox( upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
label="Upload Status", interactive=False, visible=False
)
upload_file.change( upload_file.change(
self.upload_file, self.upload_file,
[upload_file, file_uploads_log], [upload_file, file_uploads_log],

View File

@ -42,8 +42,7 @@ class InterpreterError(ValueError):
ERRORS = { ERRORS = {
name: getattr(builtins, name) name: getattr(builtins, name)
for name in dir(builtins) for name in dir(builtins)
if isinstance(getattr(builtins, name), type) if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
and issubclass(getattr(builtins, name), BaseException)
} }
PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000 PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000
@ -167,9 +166,7 @@ def evaluate_unaryop(
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str], authorized_imports: List[str],
) -> Any: ) -> Any:
operand = evaluate_ast( operand = evaluate_ast(expression.operand, state, static_tools, custom_tools, authorized_imports)
expression.operand, state, static_tools, custom_tools, authorized_imports
)
if isinstance(expression.op, ast.USub): if isinstance(expression.op, ast.USub):
return -operand return -operand
elif isinstance(expression.op, ast.UAdd): elif isinstance(expression.op, ast.UAdd):
@ -179,9 +176,7 @@ def evaluate_unaryop(
elif isinstance(expression.op, ast.Invert): elif isinstance(expression.op, ast.Invert):
return ~operand return ~operand
else: else:
raise InterpreterError( raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
f"Unary operation {expression.op.__class__.__name__} is not supported."
)
def evaluate_lambda( def evaluate_lambda(
@ -217,23 +212,17 @@ def evaluate_while(
) -> None: ) -> None:
max_iterations = 1000 max_iterations = 1000
iterations = 0 iterations = 0
while evaluate_ast( while evaluate_ast(while_loop.test, state, static_tools, custom_tools, authorized_imports):
while_loop.test, state, static_tools, custom_tools, authorized_imports
):
for node in while_loop.body: for node in while_loop.body:
try: try:
evaluate_ast( evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
node, state, static_tools, custom_tools, authorized_imports
)
except BreakException: except BreakException:
return None return None
except ContinueException: except ContinueException:
break break
iterations += 1 iterations += 1
if iterations > max_iterations: if iterations > max_iterations:
raise InterpreterError( raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
f"Maximum number of {max_iterations} iterations in While loop exceeded"
)
return None return None
@ -248,8 +237,7 @@ def create_function(
func_state = state.copy() func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args] arg_names = [arg.arg for arg in func_def.args.args]
default_values = [ default_values = [
evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) for d in func_def.args.defaults
for d in func_def.args.defaults
] ]
# Apply default values # Apply default values
@ -286,9 +274,7 @@ def create_function(
result = None result = None
try: try:
for stmt in func_def.body: for stmt in func_def.body:
result = evaluate_ast( result = evaluate_ast(stmt, func_state, static_tools, custom_tools, authorized_imports)
stmt, func_state, static_tools, custom_tools, authorized_imports
)
except ReturnException as e: except ReturnException as e:
result = e.value result = e.value
@ -307,9 +293,7 @@ def evaluate_function_def(
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str], authorized_imports: List[str],
) -> Callable: ) -> Callable:
custom_tools[func_def.name] = create_function( custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools, authorized_imports)
func_def, state, static_tools, custom_tools, authorized_imports
)
return custom_tools[func_def.name] return custom_tools[func_def.name]
@ -321,17 +305,12 @@ def evaluate_class_def(
authorized_imports: List[str], authorized_imports: List[str],
) -> type: ) -> type:
class_name = class_def.name class_name = class_def.name
bases = [ bases = [evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) for base in class_def.bases]
evaluate_ast(base, state, static_tools, custom_tools, authorized_imports)
for base in class_def.bases
]
class_dict = {} class_dict = {}
for stmt in class_def.body: for stmt in class_def.body:
if isinstance(stmt, ast.FunctionDef): if isinstance(stmt, ast.FunctionDef):
class_dict[stmt.name] = evaluate_function_def( class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools, authorized_imports)
stmt, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(stmt, ast.Assign): elif isinstance(stmt, ast.Assign):
for target in stmt.targets: for target in stmt.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
@ -351,9 +330,7 @@ def evaluate_class_def(
authorized_imports, authorized_imports,
) )
else: else:
raise InterpreterError( raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
f"Unsupported statement in class body: {stmt.__class__.__name__}"
)
new_class = type(class_name, tuple(bases), class_dict) new_class = type(class_name, tuple(bases), class_dict)
state[class_name] = new_class state[class_name] = new_class
@ -371,38 +348,26 @@ def evaluate_augassign(
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
return state.get(target.id, 0) return state.get(target.id, 0)
elif isinstance(target, ast.Subscript): elif isinstance(target, ast.Subscript):
obj = evaluate_ast( obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
target.value, state, static_tools, custom_tools, authorized_imports key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
)
key = evaluate_ast(
target.slice, state, static_tools, custom_tools, authorized_imports
)
return obj[key] return obj[key]
elif isinstance(target, ast.Attribute): elif isinstance(target, ast.Attribute):
obj = evaluate_ast( obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
target.value, state, static_tools, custom_tools, authorized_imports
)
return getattr(obj, target.attr) return getattr(obj, target.attr)
elif isinstance(target, ast.Tuple): elif isinstance(target, ast.Tuple):
return tuple(get_current_value(elt) for elt in target.elts) return tuple(get_current_value(elt) for elt in target.elts)
elif isinstance(target, ast.List): elif isinstance(target, ast.List):
return [get_current_value(elt) for elt in target.elts] return [get_current_value(elt) for elt in target.elts]
else: else:
raise InterpreterError( raise InterpreterError("AugAssign not supported for {type(target)} targets.")
"AugAssign not supported for {type(target)} targets."
)
current_value = get_current_value(expression.target) current_value = get_current_value(expression.target)
value_to_add = evaluate_ast( value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
expression.value, state, static_tools, custom_tools, authorized_imports
)
if isinstance(expression.op, ast.Add): if isinstance(expression.op, ast.Add):
if isinstance(current_value, list): if isinstance(current_value, list):
if not isinstance(value_to_add, list): if not isinstance(value_to_add, list):
raise InterpreterError( raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
f"Cannot add non-list value {value_to_add} to a list."
)
updated_value = current_value + value_to_add updated_value = current_value + value_to_add
else: else:
updated_value = current_value + value_to_add updated_value = current_value + value_to_add
@ -429,9 +394,7 @@ def evaluate_augassign(
elif isinstance(expression.op, ast.RShift): elif isinstance(expression.op, ast.RShift):
updated_value = current_value >> value_to_add updated_value = current_value >> value_to_add
else: else:
raise InterpreterError( raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
f"Operation {type(expression.op).__name__} is not supported."
)
# Update the state # Update the state
set_value( set_value(
@ -455,16 +418,12 @@ def evaluate_boolop(
) -> bool: ) -> bool:
if isinstance(node.op, ast.And): if isinstance(node.op, ast.And):
for value in node.values: for value in node.values:
if not evaluate_ast( if not evaluate_ast(value, state, static_tools, custom_tools, authorized_imports):
value, state, static_tools, custom_tools, authorized_imports
):
return False return False
return True return True
elif isinstance(node.op, ast.Or): elif isinstance(node.op, ast.Or):
for value in node.values: for value in node.values:
if evaluate_ast( if evaluate_ast(value, state, static_tools, custom_tools, authorized_imports):
value, state, static_tools, custom_tools, authorized_imports
):
return True return True
return False return False
@ -477,12 +436,8 @@ def evaluate_binop(
authorized_imports: List[str], authorized_imports: List[str],
) -> Any: ) -> Any:
# Recursively evaluate the left and right operands # Recursively evaluate the left and right operands
left_val = evaluate_ast( left_val = evaluate_ast(binop.left, state, static_tools, custom_tools, authorized_imports)
binop.left, state, static_tools, custom_tools, authorized_imports right_val = evaluate_ast(binop.right, state, static_tools, custom_tools, authorized_imports)
)
right_val = evaluate_ast(
binop.right, state, static_tools, custom_tools, authorized_imports
)
# Determine the operation based on the type of the operator in the BinOp # Determine the operation based on the type of the operator in the BinOp
if isinstance(binop.op, ast.Add): if isinstance(binop.op, ast.Add):
@ -510,9 +465,7 @@ def evaluate_binop(
elif isinstance(binop.op, ast.RShift): elif isinstance(binop.op, ast.RShift):
return left_val >> right_val return left_val >> right_val
else: else:
raise NotImplementedError( raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
f"Binary operation {type(binop.op).__name__} is not implemented."
)
def evaluate_assign( def evaluate_assign(
@ -522,17 +475,13 @@ def evaluate_assign(
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str], authorized_imports: List[str],
) -> Any: ) -> Any:
result = evaluate_ast( result = evaluate_ast(assign.value, state, static_tools, custom_tools, authorized_imports)
assign.value, state, static_tools, custom_tools, authorized_imports
)
if len(assign.targets) == 1: if len(assign.targets) == 1:
target = assign.targets[0] target = assign.targets[0]
set_value(target, result, state, static_tools, custom_tools, authorized_imports) set_value(target, result, state, static_tools, custom_tools, authorized_imports)
else: else:
if len(assign.targets) != len(result): if len(assign.targets) != len(result):
raise InterpreterError( raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
f"Assign failed: expected {len(result)} values but got {len(assign.targets)}."
)
expanded_values = [] expanded_values = []
for tgt in assign.targets: for tgt in assign.targets:
if isinstance(tgt, ast.Starred): if isinstance(tgt, ast.Starred):
@ -554,9 +503,7 @@ def set_value(
) -> None: ) -> None:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
if target.id in static_tools: if target.id in static_tools:
raise InterpreterError( raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
f"Cannot assign to name '{target.id}': doing this would erase the existing tool!"
)
state[target.id] = value state[target.id] = value
elif isinstance(target, ast.Tuple): elif isinstance(target, ast.Tuple):
if not isinstance(value, tuple): if not isinstance(value, tuple):
@ -567,21 +514,13 @@ def set_value(
if len(target.elts) != len(value): if len(target.elts) != len(value):
raise InterpreterError("Cannot unpack tuple of wrong size") raise InterpreterError("Cannot unpack tuple of wrong size")
for i, elem in enumerate(target.elts): for i, elem in enumerate(target.elts):
set_value( set_value(elem, value[i], state, static_tools, custom_tools, authorized_imports)
elem, value[i], state, static_tools, custom_tools, authorized_imports
)
elif isinstance(target, ast.Subscript): elif isinstance(target, ast.Subscript):
obj = evaluate_ast( obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
target.value, state, static_tools, custom_tools, authorized_imports key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
)
key = evaluate_ast(
target.slice, state, static_tools, custom_tools, authorized_imports
)
obj[key] = value obj[key] = value
elif isinstance(target, ast.Attribute): elif isinstance(target, ast.Attribute):
obj = evaluate_ast( obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
target.value, state, static_tools, custom_tools, authorized_imports
)
setattr(obj, target.attr, value) setattr(obj, target.attr, value)
@ -593,15 +532,11 @@ def evaluate_call(
authorized_imports: List[str], authorized_imports: List[str],
) -> Any: ) -> Any:
if not ( if not (
isinstance(call.func, ast.Attribute) isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name) or isinstance(call.func, ast.Subscript)
or isinstance(call.func, ast.Name)
or isinstance(call.func, ast.Subscript)
): ):
raise InterpreterError(f"This is not a correct function: {call.func}).") raise InterpreterError(f"This is not a correct function: {call.func}).")
if isinstance(call.func, ast.Attribute): if isinstance(call.func, ast.Attribute):
obj = evaluate_ast( obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
call.func.value, state, static_tools, custom_tools, authorized_imports
)
func_name = call.func.attr func_name = call.func.attr
if not hasattr(obj, func_name): if not hasattr(obj, func_name):
raise InterpreterError(f"Object {obj} has no attribute {func_name}") raise InterpreterError(f"Object {obj} has no attribute {func_name}")
@ -623,18 +558,12 @@ def evaluate_call(
) )
elif isinstance(call.func, ast.Subscript): elif isinstance(call.func, ast.Subscript):
value = evaluate_ast( value = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
call.func.value, state, static_tools, custom_tools, authorized_imports index = evaluate_ast(call.func.slice, state, static_tools, custom_tools, authorized_imports)
)
index = evaluate_ast(
call.func.slice, state, static_tools, custom_tools, authorized_imports
)
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
func = value[index] func = value[index]
else: else:
raise InterpreterError( raise InterpreterError(f"Cannot subscript object of type {type(value).__name__}")
f"Cannot subscript object of type {type(value).__name__}"
)
if not callable(func): if not callable(func):
raise InterpreterError(f"This is not a correct function: {call.func}).") raise InterpreterError(f"This is not a correct function: {call.func}).")
@ -642,20 +571,12 @@ def evaluate_call(
args = [] args = []
for arg in call.args: for arg in call.args:
if isinstance(arg, ast.Starred): if isinstance(arg, ast.Starred):
args.extend( args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools, authorized_imports))
evaluate_ast(
arg.value, state, static_tools, custom_tools, authorized_imports
)
)
else: else:
args.append( args.append(evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports))
evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports)
)
kwargs = { kwargs = {
keyword.arg: evaluate_ast( keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools, authorized_imports)
keyword.value, state, static_tools, custom_tools, authorized_imports
)
for keyword in call.keywords for keyword in call.keywords
} }
@ -693,17 +614,11 @@ def evaluate_subscript(
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str], authorized_imports: List[str],
) -> Any: ) -> Any:
index = evaluate_ast( index = evaluate_ast(subscript.slice, state, static_tools, custom_tools, authorized_imports)
subscript.slice, state, static_tools, custom_tools, authorized_imports value = evaluate_ast(subscript.value, state, static_tools, custom_tools, authorized_imports)
)
value = evaluate_ast(
subscript.value, state, static_tools, custom_tools, authorized_imports
)
if isinstance(value, str) and isinstance(index, str): if isinstance(value, str) and isinstance(index, str):
raise InterpreterError( raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
"You're trying to subscript a string with a string index, which is impossible"
)
if isinstance(value, pd.core.indexing._LocIndexer): if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj parent_object = value.obj
return parent_object.loc[index] return parent_object.loc[index]
@ -718,15 +633,11 @@ def evaluate_subscript(
return value[index] return value[index]
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
if not (-len(value) <= index < len(value)): if not (-len(value) <= index < len(value)):
raise InterpreterError( raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
f"Index {index} out of bounds for list of length {len(value)}"
)
return value[int(index)] return value[int(index)]
elif isinstance(value, str): elif isinstance(value, str):
if not (-len(value) <= index < len(value)): if not (-len(value) <= index < len(value)):
raise InterpreterError( raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
f"Index {index} out of bounds for string of length {len(value)}"
)
return value[index] return value[index]
elif index in value: elif index in value:
return value[index] return value[index]
@ -765,12 +676,9 @@ def evaluate_condition(
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str], authorized_imports: List[str],
) -> bool: ) -> bool:
left = evaluate_ast( left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
condition.left, state, static_tools, custom_tools, authorized_imports
)
comparators = [ comparators = [
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators
for c in condition.comparators
] ]
ops = [type(op) for op in condition.ops] ops = [type(op) for op in condition.ops]
@ -818,21 +726,15 @@ def evaluate_if(
authorized_imports: List[str], authorized_imports: List[str],
) -> Any: ) -> Any:
result = None result = None
test_result = evaluate_ast( test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools, authorized_imports)
if_statement.test, state, static_tools, custom_tools, authorized_imports
)
if test_result: if test_result:
for line in if_statement.body: for line in if_statement.body:
line_result = evaluate_ast( line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
line, state, static_tools, custom_tools, authorized_imports
)
if line_result is not None: if line_result is not None:
result = line_result result = line_result
else: else:
for line in if_statement.orelse: for line in if_statement.orelse:
line_result = evaluate_ast( line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
line, state, static_tools, custom_tools, authorized_imports
)
if line_result is not None: if line_result is not None:
result = line_result result = line_result
return result return result
@ -846,9 +748,7 @@ def evaluate_for(
authorized_imports: List[str], authorized_imports: List[str],
) -> Any: ) -> Any:
result = None result = None
iterator = evaluate_ast( iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools, authorized_imports)
for_loop.iter, state, static_tools, custom_tools, authorized_imports
)
for counter in iterator: for counter in iterator:
set_value( set_value(
for_loop.target, for_loop.target,
@ -860,9 +760,7 @@ def evaluate_for(
) )
for node in for_loop.body: for node in for_loop.body:
try: try:
line_result = evaluate_ast( line_result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
node, state, static_tools, custom_tools, authorized_imports
)
if line_result is not None: if line_result is not None:
result = line_result result = line_result
except BreakException: except BreakException:
@ -882,9 +780,7 @@ def evaluate_listcomp(
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str], authorized_imports: List[str],
) -> List[Any]: ) -> List[Any]:
def inner_evaluate( def inner_evaluate(generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]) -> List[Any]:
generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]
) -> List[Any]:
if index >= len(generators): if index >= len(generators):
return [ return [
evaluate_ast( evaluate_ast(
@ -912,9 +808,7 @@ def evaluate_listcomp(
else: else:
new_state[generator.target.id] = value new_state[generator.target.id] = value
if all( if all(
evaluate_ast( evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
if_clause, new_state, static_tools, custom_tools, authorized_imports
)
for if_clause in generator.ifs for if_clause in generator.ifs
): ):
result.extend(inner_evaluate(generators, index + 1, new_state)) result.extend(inner_evaluate(generators, index + 1, new_state))
@ -938,32 +832,24 @@ def evaluate_try(
for handler in try_node.handlers: for handler in try_node.handlers:
if handler.type is None or isinstance( if handler.type is None or isinstance(
e, e,
evaluate_ast( evaluate_ast(handler.type, state, static_tools, custom_tools, authorized_imports),
handler.type, state, static_tools, custom_tools, authorized_imports
),
): ):
matched = True matched = True
if handler.name: if handler.name:
state[handler.name] = e state[handler.name] = e
for stmt in handler.body: for stmt in handler.body:
evaluate_ast( evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
stmt, state, static_tools, custom_tools, authorized_imports
)
break break
if not matched: if not matched:
raise e raise e
else: else:
if try_node.orelse: if try_node.orelse:
for stmt in try_node.orelse: for stmt in try_node.orelse:
evaluate_ast( evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
stmt, state, static_tools, custom_tools, authorized_imports
)
finally: finally:
if try_node.finalbody: if try_node.finalbody:
for stmt in try_node.finalbody: for stmt in try_node.finalbody:
evaluate_ast( evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
stmt, state, static_tools, custom_tools, authorized_imports
)
def evaluate_raise( def evaluate_raise(
@ -974,15 +860,11 @@ def evaluate_raise(
authorized_imports: List[str], authorized_imports: List[str],
) -> None: ) -> None:
if raise_node.exc is not None: if raise_node.exc is not None:
exc = evaluate_ast( exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools, authorized_imports)
raise_node.exc, state, static_tools, custom_tools, authorized_imports
)
else: else:
exc = None exc = None
if raise_node.cause is not None: if raise_node.cause is not None:
cause = evaluate_ast( cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools, authorized_imports)
raise_node.cause, state, static_tools, custom_tools, authorized_imports
)
else: else:
cause = None cause = None
if exc is not None: if exc is not None:
@ -1001,14 +883,10 @@ def evaluate_assert(
custom_tools: Dict[str, Callable], custom_tools: Dict[str, Callable],
authorized_imports: List[str], authorized_imports: List[str],
) -> None: ) -> None:
test_result = evaluate_ast( test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools, authorized_imports)
assert_node.test, state, static_tools, custom_tools, authorized_imports
)
if not test_result: if not test_result:
if assert_node.msg: if assert_node.msg:
msg = evaluate_ast( msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools, authorized_imports)
assert_node.msg, state, static_tools, custom_tools, authorized_imports
)
raise AssertionError(msg) raise AssertionError(msg)
else: else:
# Include the failing condition in the assertion message # Include the failing condition in the assertion message
@ -1025,9 +903,7 @@ def evaluate_with(
) -> None: ) -> None:
contexts = [] contexts = []
for item in with_node.items: for item in with_node.items:
context_expr = evaluate_ast( context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools, authorized_imports)
item.context_expr, state, static_tools, custom_tools, authorized_imports
)
if item.optional_vars: if item.optional_vars:
state[item.optional_vars.id] = context_expr.__enter__() state[item.optional_vars.id] = context_expr.__enter__()
contexts.append(state[item.optional_vars.id]) contexts.append(state[item.optional_vars.id])
@ -1069,19 +945,14 @@ def get_safe_module(unsafe_module, dangerous_patterns, visited=None):
# Copy all attributes by reference, recursively checking modules # Copy all attributes by reference, recursively checking modules
for attr_name in dir(unsafe_module): for attr_name in dir(unsafe_module):
# Skip dangerous patterns at any level # Skip dangerous patterns at any level
if any( if any(pattern in f"{unsafe_module.__name__}.{attr_name}" for pattern in dangerous_patterns):
pattern in f"{unsafe_module.__name__}.{attr_name}"
for pattern in dangerous_patterns
):
continue continue
attr_value = getattr(unsafe_module, attr_name) attr_value = getattr(unsafe_module, attr_name)
# Recursively process nested modules, passing visited set # Recursively process nested modules, passing visited set
if isinstance(attr_value, ModuleType): if isinstance(attr_value, ModuleType):
attr_value = get_safe_module( attr_value = get_safe_module(attr_value, dangerous_patterns, visited=visited)
attr_value, dangerous_patterns, visited=visited
)
setattr(safe_module, attr_name, attr_value) setattr(safe_module, attr_name, attr_value)
@ -1116,18 +987,14 @@ def import_modules(expression, state, authorized_imports):
module_path = module_name.split(".") module_path = module_name.split(".")
if any([module in dangerous_patterns for module in module_path]): if any([module in dangerous_patterns for module in module_path]):
return False return False
module_subpaths = [ module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
]
return any(subpath in authorized_imports for subpath in module_subpaths) return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import): if isinstance(expression, ast.Import):
for alias in expression.names: for alias in expression.names:
if check_module_authorized(alias.name): if check_module_authorized(alias.name):
raw_module = import_module(alias.name) raw_module = import_module(alias.name)
state[alias.asname or alias.name] = get_safe_module( state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns)
raw_module, dangerous_patterns
)
else: else:
raise InterpreterError( raise InterpreterError(
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
@ -1135,9 +1002,7 @@ def import_modules(expression, state, authorized_imports):
return None return None
elif isinstance(expression, ast.ImportFrom): elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module): if check_module_authorized(expression.module):
raw_module = __import__( raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
expression.module, fromlist=[alias.name for alias in expression.names]
)
for alias in expression.names: for alias in expression.names:
state[alias.asname or alias.name] = get_safe_module( state[alias.asname or alias.name] = get_safe_module(
getattr(raw_module, alias.name), dangerous_patterns getattr(raw_module, alias.name), dangerous_patterns
@ -1156,9 +1021,7 @@ def evaluate_dictcomp(
) -> Dict[Any, Any]: ) -> Dict[Any, Any]:
result = {} result = {}
for gen in dictcomp.generators: for gen in dictcomp.generators:
iter_value = evaluate_ast( iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports)
gen.iter, state, static_tools, custom_tools, authorized_imports
)
for value in iter_value: for value in iter_value:
new_state = state.copy() new_state = state.copy()
set_value( set_value(
@ -1170,9 +1033,7 @@ def evaluate_dictcomp(
authorized_imports, authorized_imports,
) )
if all( if all(
evaluate_ast( evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
if_clause, new_state, static_tools, custom_tools, authorized_imports
)
for if_clause in gen.ifs for if_clause in gen.ifs
): ):
key = evaluate_ast( key = evaluate_ast(
@ -1229,202 +1090,116 @@ def evaluate_ast(
if isinstance(expression, ast.Assign): if isinstance(expression, ast.Assign):
# Assignment -> we evaluate the assignment which should update the state # Assignment -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result. # We return the variable assigned as it may be used to determine the final result.
return evaluate_assign( return evaluate_assign(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.AugAssign): elif isinstance(expression, ast.AugAssign):
return evaluate_augassign( return evaluate_augassign(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Call): elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call # Function call -> we return the value of the function call
return evaluate_call( return evaluate_call(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Constant): elif isinstance(expression, ast.Constant):
# Constant -> just return the value # Constant -> just return the value
return expression.value return expression.value
elif isinstance(expression, ast.Tuple): elif isinstance(expression, ast.Tuple):
return tuple( return tuple(
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts
for elt in expression.elts
) )
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp( return evaluate_listcomp(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.UnaryOp): elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop( return evaluate_unaryop(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Starred): elif isinstance(expression, ast.Starred):
return evaluate_ast( return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
expression.value, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.BoolOp): elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation # Boolean operation -> evaluate the operation
return evaluate_boolop( return evaluate_boolop(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Break): elif isinstance(expression, ast.Break):
raise BreakException() raise BreakException()
elif isinstance(expression, ast.Continue): elif isinstance(expression, ast.Continue):
raise ContinueException() raise ContinueException()
elif isinstance(expression, ast.BinOp): elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation # Binary operation -> execute operation
return evaluate_binop( return evaluate_binop(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Compare): elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison # Comparison -> evaluate the comparison
return evaluate_condition( return evaluate_condition(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Lambda): elif isinstance(expression, ast.Lambda):
return evaluate_lambda( return evaluate_lambda(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.FunctionDef): elif isinstance(expression, ast.FunctionDef):
return evaluate_function_def( return evaluate_function_def(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Dict): elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values # Dict -> evaluate all keys and values
keys = [ keys = [evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) for k in expression.keys]
evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) values = [evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) for v in expression.values]
for k in expression.keys
]
values = [
evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)
for v in expression.values
]
return dict(zip(keys, values)) return dict(zip(keys, values))
elif isinstance(expression, ast.Expr): elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content # Expression -> evaluate the content
return evaluate_ast( return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
expression.value, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.For): elif isinstance(expression, ast.For):
# For loop -> execute the loop # For loop -> execute the loop
return evaluate_for( return evaluate_for(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.FormattedValue): elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return # Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast( return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
expression.value, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.If): elif isinstance(expression, ast.If):
# If -> execute the right branch # If -> execute the right branch
return evaluate_if( return evaluate_if(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index): elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast( return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
expression.value, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.JoinedStr): elif isinstance(expression, ast.JoinedStr):
return "".join( return "".join(
[ [str(evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)) for v in expression.values]
str(
evaluate_ast(
v, state, static_tools, custom_tools, authorized_imports
)
)
for v in expression.values
]
) )
elif isinstance(expression, ast.List): elif isinstance(expression, ast.List):
# List -> evaluate all elements # List -> evaluate all elements
return [ return [evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts]
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
for elt in expression.elts
]
elif isinstance(expression, ast.Name): elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state # Name -> pick up the value in the state
return evaluate_name( return evaluate_name(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Subscript): elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing # Subscript -> return the value of the indexing
return evaluate_subscript( return evaluate_subscript(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.IfExp): elif isinstance(expression, ast.IfExp):
test_val = evaluate_ast( test_val = evaluate_ast(expression.test, state, static_tools, custom_tools, authorized_imports)
expression.test, state, static_tools, custom_tools, authorized_imports
)
if test_val: if test_val:
return evaluate_ast( return evaluate_ast(expression.body, state, static_tools, custom_tools, authorized_imports)
expression.body, state, static_tools, custom_tools, authorized_imports
)
else: else:
return evaluate_ast( return evaluate_ast(expression.orelse, state, static_tools, custom_tools, authorized_imports)
expression.orelse, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Attribute): elif isinstance(expression, ast.Attribute):
value = evaluate_ast( value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
expression.value, state, static_tools, custom_tools, authorized_imports
)
return getattr(value, expression.attr) return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice): elif isinstance(expression, ast.Slice):
return slice( return slice(
evaluate_ast( evaluate_ast(expression.lower, state, static_tools, custom_tools, authorized_imports)
expression.lower, state, static_tools, custom_tools, authorized_imports
)
if expression.lower is not None if expression.lower is not None
else None, else None,
evaluate_ast( evaluate_ast(expression.upper, state, static_tools, custom_tools, authorized_imports)
expression.upper, state, static_tools, custom_tools, authorized_imports
)
if expression.upper is not None if expression.upper is not None
else None, else None,
evaluate_ast( evaluate_ast(expression.step, state, static_tools, custom_tools, authorized_imports)
expression.step, state, static_tools, custom_tools, authorized_imports
)
if expression.step is not None if expression.step is not None
else None, else None,
) )
elif isinstance(expression, ast.DictComp): elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp( return evaluate_dictcomp(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.While): elif isinstance(expression, ast.While):
return evaluate_while( return evaluate_while(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, (ast.Import, ast.ImportFrom)): elif isinstance(expression, (ast.Import, ast.ImportFrom)):
return import_modules(expression, state, authorized_imports) return import_modules(expression, state, authorized_imports)
elif isinstance(expression, ast.ClassDef): elif isinstance(expression, ast.ClassDef):
return evaluate_class_def( return evaluate_class_def(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Try): elif isinstance(expression, ast.Try):
return evaluate_try( return evaluate_try(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Raise): elif isinstance(expression, ast.Raise):
return evaluate_raise( return evaluate_raise(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Assert): elif isinstance(expression, ast.Assert):
return evaluate_assert( return evaluate_assert(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.With): elif isinstance(expression, ast.With):
return evaluate_with( return evaluate_with(expression, state, static_tools, custom_tools, authorized_imports)
expression, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(expression, ast.Set): elif isinstance(expression, ast.Set):
return { return {evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts}
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
for elt in expression.elts
}
elif isinstance(expression, ast.Return): elif isinstance(expression, ast.Return):
raise ReturnException( raise ReturnException(
evaluate_ast( evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
expression.value, state, static_tools, custom_tools, authorized_imports
)
if expression.value if expression.value
else None else None
) )
@ -1488,18 +1263,12 @@ def evaluate_python_code(
try: try:
for node in expression.body: for node in expression.body:
result = evaluate_ast( result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
node, state, static_tools, custom_tools, authorized_imports state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
)
state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=max_print_outputs_length
)
is_final_answer = False is_final_answer = False
return result, is_final_answer return result, is_final_answer
except FinalAnswerException as e: except FinalAnswerException as e:
state["print_outputs"] = truncate_content( state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
PRINT_OUTPUTS, max_length=max_print_outputs_length
)
is_final_answer = True is_final_answer = True
return e.value, is_final_answer return e.value, is_final_answer
except InterpreterError as e: except InterpreterError as e:
@ -1521,9 +1290,7 @@ class LocalPythonInterpreter:
if max_print_outputs_length is None: if max_print_outputs_length is None:
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
self.additional_authorized_imports = additional_authorized_imports self.additional_authorized_imports = additional_authorized_imports
self.authorized_imports = list( self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
# Add base trusted tools to list # Add base trusted tools to list
self.static_tools = { self.static_tools = {
**tools, **tools,
@ -1531,9 +1298,7 @@ class LocalPythonInterpreter:
} }
# TODO: assert self.authorized imports are all installed locally # TODO: assert self.authorized imports are all installed locally
def __call__( def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str, bool]:
self, code_action: str, additional_variables: Dict
) -> Tuple[Any, str, bool]:
self.state.update(additional_variables) self.state.update(additional_variables)
output, is_final_answer = evaluate_python_code( output, is_final_answer = evaluate_python_code(
code_action, code_action,

View File

@ -14,17 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, asdict
import json import json
import logging import logging
import os import os
import random import random
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union, Any from typing import Any, Dict, List, Optional, Union
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
@ -35,6 +34,7 @@ from transformers import (
from .tools import Tool from .tools import Tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_JSONAGENT_REGEX_GRAMMAR = { DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
@ -100,10 +100,7 @@ class ChatMessage:
def from_hf_api(cls, message) -> "ChatMessage": def from_hf_api(cls, message) -> "ChatMessage":
tool_calls = None tool_calls = None
if getattr(message, "tool_calls", None) is not None: if getattr(message, "tool_calls", None) is not None:
tool_calls = [ tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
ChatMessageToolCall.from_hf_api(tool_call)
for tool_call in message.tool_calls
]
return cls(role=message.role, content=message.content, tool_calls=tool_calls) return cls(role=message.role, content=message.content, tool_calls=tool_calls)
@ -172,17 +169,12 @@ def get_clean_message_list(
role = message["role"] role = message["role"]
if role not in MessageRole.roles(): if role not in MessageRole.roles():
raise ValueError( raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
)
if role in role_conversions: if role in role_conversions:
message["role"] = role_conversions[role] message["role"] = role_conversions[role]
if ( if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
len(final_message_list) > 0
and message["role"] == final_message_list[-1]["role"]
):
final_message_list[-1]["content"] += "\n=======\n" + message["content"] final_message_list[-1]["content"] += "\n=======\n" + message["content"]
else: else:
final_message_list.append(message) final_message_list.append(message)
@ -292,9 +284,7 @@ class HfApiModel(Model):
Gets an LLM output message for the given list of input messages. Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
""" """
messages = get_clean_message_list( messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
messages, role_conversions=tool_role_conversions
)
if tools_to_call_from: if tools_to_call_from:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
messages=messages, messages=messages,
@ -367,9 +357,7 @@ class TransformersModel(Model):
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
if model_id is None: if model_id is None:
model_id = default_model_id model_id = default_model_id
logger.warning( logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
)
self.model_id = model_id self.model_id = model_id
self.kwargs = kwargs self.kwargs = kwargs
if device_map is None: if device_map is None:
@ -389,9 +377,7 @@ class TransformersModel(Model):
) )
self.model_id = default_model_id self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
model_id, device_map=device_map, torch_dtype=torch_dtype
)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria): class StopOnStrings(StoppingCriteria):
@ -404,16 +390,9 @@ class TransformersModel(Model):
self.stream = "" self.stream = ""
def __call__(self, input_ids, scores, **kwargs): def __call__(self, input_ids, scores, **kwargs):
generated = self.tokenizer.decode( generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
input_ids[0][-1], skip_special_tokens=True
)
self.stream += generated self.stream += generated
if any( if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
[
self.stream.endswith(stop_string)
for stop_string in self.stop_strings
]
):
return True return True
return False return False
@ -426,9 +405,7 @@ class TransformersModel(Model):
grammar: Optional[str] = None, grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage: ) -> ChatMessage:
messages = get_clean_message_list( messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
messages, role_conversions=tool_role_conversions
)
if tools_to_call_from is not None: if tools_to_call_from is not None:
prompt_tensor = self.tokenizer.apply_chat_template( prompt_tensor = self.tokenizer.apply_chat_template(
messages, messages,
@ -448,9 +425,7 @@ class TransformersModel(Model):
out = self.model.generate( out = self.model.generate(
**prompt_tensor, **prompt_tensor,
stopping_criteria=( stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
),
**self.kwargs, **self.kwargs,
) )
generated_tokens = out[0, count_prompt_tokens:] generated_tokens = out[0, count_prompt_tokens:]
@ -475,9 +450,7 @@ class TransformersModel(Model):
ChatMessageToolCall( ChatMessageToolCall(
id="".join(random.choices("0123456789", k=5)), id="".join(random.choices("0123456789", k=5)),
type="function", type="function",
function=ChatMessageToolCallDefinition( function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
name=tool_name, arguments=tool_arguments
),
) )
], ],
) )
@ -525,9 +498,7 @@ class LiteLLMModel(Model):
grammar: Optional[str] = None, grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage: ) -> ChatMessage:
messages = get_clean_message_list( messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
messages, role_conversions=tool_role_conversions
)
import litellm import litellm
if tools_to_call_from: if tools_to_call_from:
@ -604,11 +575,7 @@ class OpenAIServerModel(Model):
) -> ChatMessage: ) -> ChatMessage:
messages = get_clean_message_list( messages = get_clean_message_list(
messages, messages,
role_conversions=( role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions),
self.custom_role_conversions
if self.custom_role_conversions
else tool_role_conversions
),
) )
if tools_to_call_from: if tools_to_call_from:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(

View File

@ -22,10 +22,7 @@ class Monitor:
self.step_durations = [] self.step_durations = []
self.tracked_model = tracked_model self.tracked_model = tracked_model
self.logger = logger self.logger = logger
if ( if getattr(self.tracked_model, "last_input_token_count", "Not found") != "Not found":
getattr(self.tracked_model, "last_input_token_count", "Not found")
!= "Not found"
):
self.total_input_token_count = 0 self.total_input_token_count = 0
self.total_output_token_count = 0 self.total_output_token_count = 0
@ -48,7 +45,9 @@ class Monitor:
if getattr(self.tracked_model, "last_input_token_count", None) is not None: if getattr(self.tracked_model, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_model.last_input_token_count self.total_input_token_count += self.tracked_model.last_input_token_count
self.total_output_token_count += self.tracked_model.last_output_token_count self.total_output_token_count += self.tracked_model.last_output_token_count
console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}" console_outputs += (
f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
)
console_outputs += "]" console_outputs += "]"
self.logger.log(Text(console_outputs, style="dim"), level=1) self.logger.log(Text(console_outputs, style="dim"), level=1)

View File

@ -6,6 +6,7 @@ from typing import Set
from .utils import BASE_BUILTIN_MODULES from .utils import BASE_BUILTIN_MODULES
_BUILTIN_NAMES = set(vars(builtins)) _BUILTIN_NAMES = set(vars(builtins))
@ -141,9 +142,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
# Check that __init__ method takes no arguments # Check that __init__ method takes no arguments
if not cls.__init__.__qualname__ == "Tool.__init__": if not cls.__init__.__qualname__ == "Tool.__init__":
sig = inspect.signature(cls.__init__) sig = inspect.signature(cls.__init__)
non_self_params = list( non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
[arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]
)
if len(non_self_params) > 0: if len(non_self_params) > 0:
errors.append( errors.append(
f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!" f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!"
@ -174,9 +173,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
# Check if the assignment is more complex than simple literals # Check if the assignment is more complex than simple literals
if not all( if not all(
isinstance( isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)
)
for val in ast.walk(node.value) for val in ast.walk(node.value)
): ):
for target in node.targets: for target in node.targets:
@ -195,9 +192,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
# Run checks on all methods # Run checks on all methods
for node in class_node.body: for node in class_node.body:
if isinstance(node, ast.FunctionDef): if isinstance(node, ast.FunctionDef):
method_checker = MethodChecker( method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
class_level_checker.class_attributes, check_imports=check_imports
)
method_checker.visit(node) method_checker.visit(node)
errors += [f"- {node.name}: {error}" for error in method_checker.errors] errors += [f"- {node.name}: {error}" for error in method_checker.errors]

View File

@ -36,7 +36,6 @@ from huggingface_hub import (
upload_folder, upload_folder,
) )
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from packaging import version from packaging import version
from transformers.dynamic_module_utils import get_imports from transformers.dynamic_module_utils import get_imports
from transformers.utils import ( from transformers.utils import (
@ -52,6 +51,7 @@ from .tool_validation import MethodChecker, validate_tool_attributes
from .types import ImageType, handle_agent_input_types, handle_agent_output_types from .types import ImageType, handle_agent_input_types, handle_agent_output_types
from .utils import instance_to_source from .utils import instance_to_source
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_accelerate_available(): if is_accelerate_available():
@ -77,9 +77,7 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs) hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
return "model" return "model"
except RepositoryNotFoundError: except RepositoryNotFoundError:
raise EnvironmentError( raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
)
except Exception: except Exception:
return "model" return "model"
except Exception: except Exception:
@ -109,9 +107,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
properties[param_name]["nullable"] = True properties[param_name]["nullable"] = True
for param_name in signature.parameters.keys(): for param_name in signature.parameters.keys():
if signature.parameters[param_name].default != inspect.Parameter.empty: if signature.parameters[param_name].default != inspect.Parameter.empty:
if ( if param_name not in properties: # this can happen if the param has no type hint but a default value
param_name not in properties
): # this can happen if the param has no type hint but a default value
properties[param_name] = {"nullable": True} properties[param_name] = {"nullable": True}
return properties return properties
@ -181,9 +177,7 @@ class Tool:
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead." f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
) )
for input_name, input_content in self.inputs.items(): for input_name, input_content in self.inputs.items():
assert isinstance(input_content, dict), ( assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
f"Input '{input_name}' should be a dictionary."
)
assert "type" in input_content and "description" in input_content, ( assert "type" in input_content and "description" in input_content, (
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}." f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
) )
@ -348,15 +342,7 @@ class Tool:
imports = [] imports = []
for module in [tool_file]: for module in [tool_file]:
imports.extend(get_imports(module)) imports.extend(get_imports(module))
imports = list( imports = list(set([el for el in imports + ["smolagents"] if el not in sys.stdlib_module_names]))
set(
[
el
for el in imports + ["smolagents"]
if el not in sys.stdlib_module_names
]
)
)
with open(requirements_file, "w", encoding="utf-8") as f: with open(requirements_file, "w", encoding="utf-8") as f:
f.write("\n".join(imports) + "\n") f.write("\n".join(imports) + "\n")
@ -410,9 +396,7 @@ class Tool:
print(work_dir) print(work_dir)
with open(work_dir + "/tool.py", "r") as f: with open(work_dir + "/tool.py", "r") as f:
print("\n".join(f.readlines())) print("\n".join(f.readlines()))
logger.info( logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
)
return upload_folder( return upload_folder(
repo_id=repo_id, repo_id=repo_id,
commit_message=commit_message, commit_message=commit_message,
@ -592,9 +576,7 @@ class Tool:
self.name = name self.name = name
self.description = description self.description = description
self.client = Client(space_id, hf_token=token) self.client = Client(space_id, hf_token=token)
space_description = self.client.view_api( space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
return_format="dict", print_info=False
)["named_endpoints"]
# If api_name is not defined, take the first of the available APIs for this space # If api_name is not defined, take the first of the available APIs for this space
if api_name is None: if api_name is None:
@ -607,9 +589,7 @@ class Tool:
try: try:
space_description_api = space_description[api_name] space_description_api = space_description[api_name]
except KeyError: except KeyError:
raise KeyError( raise KeyError(f"Could not find specified {api_name=} among available api names.")
f"Could not find specified {api_name=} among available api names."
)
self.inputs = {} self.inputs = {}
for parameter in space_description_api["parameters"]: for parameter in space_description_api["parameters"]:
@ -683,8 +663,7 @@ class Tool:
self._gradio_tool = _gradio_tool self._gradio_tool = _gradio_tool
func_args = list(inspect.signature(_gradio_tool.run).parameters.items()) func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
self.inputs = { self.inputs = {
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
for key, value in func_args
} }
self.forward = self._gradio_tool.run self.forward = self._gradio_tool.run
@ -726,9 +705,7 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
""" """
def get_tool_description_with_args( def get_tool_description_with_args(tool: Tool, description_template: Optional[str] = None) -> str:
tool: Tool, description_template: Optional[str] = None
) -> str:
if description_template is None: if description_template is None:
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
compiled_template = compile_jinja_template(description_template) compiled_template = compile_jinja_template(description_template)
@ -748,10 +725,7 @@ def compile_jinja_template(template):
raise ImportError("template requires jinja2 to be installed.") raise ImportError("template requires jinja2 to be installed.")
if version.parse(jinja2.__version__) < version.parse("3.1.0"): if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError( raise ImportError(f"template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}.")
"template requires jinja2>=3.1.0 to be installed. Your version is "
f"{jinja2.__version__}."
)
def raise_exception(message): def raise_exception(message):
raise TemplateError(message) raise TemplateError(message)
@ -772,9 +746,7 @@ def launch_gradio_demo(tool: Tool):
try: try:
import gradio as gr import gradio as gr
except ImportError: except ImportError:
raise ImportError( raise ImportError("Gradio should be installed in order to launch a gradio demo.")
"Gradio should be installed in order to launch a gradio demo."
)
TYPE_TO_COMPONENT_CLASS_MAPPING = { TYPE_TO_COMPONENT_CLASS_MAPPING = {
"image": gr.Image, "image": gr.Image,
@ -791,9 +763,7 @@ def launch_gradio_demo(tool: Tool):
gradio_inputs = [] gradio_inputs = []
for input_name, input_details in tool.inputs.items(): for input_name, input_details in tool.inputs.items():
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[ input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
input_details["type"]
]
new_component = input_gradio_component_class(label=input_name) new_component = input_gradio_component_class(label=input_name)
gradio_inputs.append(new_component) gradio_inputs.append(new_component)
@ -922,14 +892,9 @@ class ToolCollection:
``` ```
""" """
_collection = get_collection(collection_slug, token=token) _collection = get_collection(collection_slug, token=token)
_hub_repo_ids = { _hub_repo_ids = {item.item_id for item in _collection.items if item.item_type == "space"}
item.item_id for item in _collection.items if item.item_type == "space"
}
tools = { tools = {Tool.from_hub(repo_id, token, trust_remote_code) for repo_id in _hub_repo_ids}
Tool.from_hub(repo_id, token, trust_remote_code)
for repo_id in _hub_repo_ids
}
return cls(tools) return cls(tools)
@ -986,9 +951,7 @@ def tool(tool_function: Callable) -> Tool:
""" """
parameters = get_json_schema(tool_function)["function"] parameters = get_json_schema(tool_function)["function"]
if "return" not in parameters: if "return" not in parameters:
raise TypeHintParsingException( raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
"Tool return type not found: make sure your function has a return type hint!"
)
class SimpleTool(Tool): class SimpleTool(Tool):
def __init__(self, name, description, inputs, output_type, function): def __init__(self, name, description, inputs, output_type, function):
@ -1007,9 +970,9 @@ def tool(tool_function: Callable) -> Tool:
function=tool_function, function=tool_function,
) )
original_signature = inspect.signature(tool_function) original_signature = inspect.signature(tool_function)
new_parameters = [ new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)] + list(
inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY) original_signature.parameters.values()
] + list(original_signature.parameters.values()) )
new_signature = original_signature.replace(parameters=new_parameters) new_signature = original_signature.replace(parameters=new_parameters)
simple_tool.forward.__signature__ = new_signature simple_tool.forward.__signature__ = new_signature
return simple_tool return simple_tool
@ -1082,9 +1045,7 @@ class PipelineTool(Tool):
if model is None: if model is None:
if self.default_checkpoint is None: if self.default_checkpoint is None:
raise ValueError( raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
"This tool does not implement a default checkpoint, you need to pass one."
)
model = self.default_checkpoint model = self.default_checkpoint
if pre_processor is None: if pre_processor is None:
pre_processor = model pre_processor = model
@ -1107,21 +1068,15 @@ class PipelineTool(Tool):
Instantiates the `pre_processor`, `model` and `post_processor` if necessary. Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
""" """
if isinstance(self.pre_processor, str): if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained( self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
self.pre_processor, **self.hub_kwargs
)
if isinstance(self.model, str): if isinstance(self.model, str):
self.model = self.model_class.from_pretrained( self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
self.model, **self.model_kwargs, **self.hub_kwargs
)
if self.post_processor is None: if self.post_processor is None:
self.post_processor = self.pre_processor self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str): elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained( self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
self.post_processor, **self.hub_kwargs
)
if self.device is None: if self.device is None:
if self.device_map is not None: if self.device_map is not None:
@ -1165,12 +1120,8 @@ class PipelineTool(Tool):
encoded_inputs = self.encode(*args, **kwargs) encoded_inputs = self.encode(*args, **kwargs)
tensor_inputs = { tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor) non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
}
non_tensor_inputs = {
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
}
encoded_inputs = send_to_device(tensor_inputs, self.device) encoded_inputs = send_to_device(tensor_inputs, self.device)
outputs = self.forward({**encoded_inputs, **non_tensor_inputs}) outputs = self.forward({**encoded_inputs, **non_tensor_inputs})

View File

@ -27,6 +27,7 @@ from transformers.utils import (
is_vision_available, is_vision_available,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_vision_available(): if is_vision_available():
@ -113,9 +114,7 @@ class AgentImage(AgentType, ImageType):
elif isinstance(value, np.ndarray): elif isinstance(value, np.ndarray):
self._tensor = torch.from_numpy(value) self._tensor = torch.from_numpy(value)
else: else:
raise TypeError( raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
f"Unsupported type for {self.__class__.__name__}: {type(value)}"
)
def _ipython_display_(self, include=None, exclude=None): def _ipython_display_(self, include=None, exclude=None):
""" """
@ -264,9 +263,7 @@ if is_torch_available():
def handle_agent_input_types(*args, **kwargs): def handle_agent_input_types(*args, **kwargs):
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args] args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
kwargs = { kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
}
return args, kwargs return args, kwargs
@ -279,9 +276,7 @@ def handle_agent_output_types(output, output_type=None):
# If the class does not have defined output, then we map according to the type # If the class does not have defined output, then we map according to the type
for _k, _v in INSTANCE_TYPE_MAPPING.items(): for _k, _v in INSTANCE_TYPE_MAPPING.items():
if isinstance(output, _k): if isinstance(output, _k):
if ( if _k is not object: # avoid converting to audio if torch is not installed
_k is not object
): # avoid converting to audio if torch is not installed
return _v(output) return _v(output)
return output return output

View File

@ -83,9 +83,7 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
try: try:
first_accolade_index = json_blob.find("{") first_accolade_index = json_blob.find("{")
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1] last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace( json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
'\\"', "'"
)
json_data = json.loads(json_blob, strict=False) json_data = json.loads(json_blob, strict=False)
return json_data return json_data
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
@ -162,9 +160,7 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
MAX_LENGTH_TRUNCATE_CONTENT = 20000 MAX_LENGTH_TRUNCATE_CONTENT = 20000
def truncate_content( def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str:
content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT
) -> str:
if len(content) <= max_length: if len(content) <= max_length:
return content return content
else: else:
@ -206,12 +202,8 @@ def is_same_method(method1, method2):
source2 = get_method_source(method2) source2 = get_method_source(method2)
# Remove method decorators if any # Remove method decorators if any
source1 = "\n".join( source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@"))
line for line in source1.split("\n") if not line.strip().startswith("@") source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@"))
)
source2 = "\n".join(
line for line in source2.split("\n") if not line.strip().startswith("@")
)
return source1 == source2 return source1 == source2
except (TypeError, OSError): except (TypeError, OSError):
@ -248,9 +240,7 @@ def instance_to_source(instance, base_cls=None):
for name, value in cls.__dict__.items() for name, value in cls.__dict__.items()
if not name.startswith("__") if not name.startswith("__")
and not callable(value) and not callable(value)
and not ( and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value
)
} }
for name, value in class_attrs.items(): for name, value in class_attrs.items():
@ -271,9 +261,7 @@ def instance_to_source(instance, base_cls=None):
for name, func in cls.__dict__.items() for name, func in cls.__dict__.items()
if callable(func) if callable(func)
and not ( and not (
base_cls base_cls and hasattr(base_cls, name) and getattr(base_cls, name).__code__.co_code == func.__code__.co_code
and hasattr(base_cls, name)
and getattr(base_cls, name).__code__.co_code == func.__code__.co_code
) )
} }
@ -284,9 +272,7 @@ def instance_to_source(instance, base_cls=None):
first_line = method_lines[0] first_line = method_lines[0]
indent = len(first_line) - len(first_line.lstrip()) indent = len(first_line) - len(first_line.lstrip())
method_lines = [line[indent:] for line in method_lines] method_lines = [line[indent:] for line in method_lines]
method_source = "\n".join( method_source = "\n".join([" " + line if line.strip() else line for line in method_lines])
[" " + line if line.strip() else line for line in method_lines]
)
class_lines.append(method_source) class_lines.append(method_source)
class_lines.append("") class_lines.append("")

View File

@ -28,13 +28,13 @@ from smolagents.agents import (
ToolCallingAgent, ToolCallingAgent,
) )
from smolagents.default_tools import PythonInterpreterTool from smolagents.default_tools import PythonInterpreterTool
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
from smolagents.models import ( from smolagents.models import (
ChatMessage, ChatMessage,
ChatMessageToolCall, ChatMessageToolCall,
ChatMessageToolCallDefinition, ChatMessageToolCallDefinition,
) )
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
from smolagents.utils import BASE_BUILTIN_MODULES from smolagents.utils import BASE_BUILTIN_MODULES
@ -44,9 +44,7 @@ def get_new_path(suffix="") -> str:
class FakeToolCallModel: class FakeToolCallModel:
def __call__( def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None):
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3: if len(messages) < 3:
return ChatMessage( return ChatMessage(
role="assistant", role="assistant",
@ -69,18 +67,14 @@ class FakeToolCallModel:
ChatMessageToolCall( ChatMessageToolCall(
id="call_1", id="call_1",
type="function", type="function",
function=ChatMessageToolCallDefinition( function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "7.2904"}),
name="final_answer", arguments={"answer": "7.2904"}
),
) )
], ],
) )
class FakeToolCallModelImage: class FakeToolCallModelImage:
def __call__( def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None):
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3: if len(messages) < 3:
return ChatMessage( return ChatMessage(
role="assistant", role="assistant",
@ -104,9 +98,7 @@ class FakeToolCallModelImage:
ChatMessageToolCall( ChatMessageToolCall(
id="call_1", id="call_1",
type="function", type="function",
function=ChatMessageToolCallDefinition( function=ChatMessageToolCallDefinition(name="final_answer", arguments="image.png"),
name="final_answer", arguments="image.png"
),
) )
], ],
) )
@ -271,17 +263,13 @@ print(result)
class AgentTests(unittest.TestCase): class AgentTests(unittest.TestCase):
def test_fake_single_step_code_agent(self): def test_fake_single_step_code_agent(self):
agent = CodeAgent( agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_single_step)
tools=[PythonInterpreterTool()], model=fake_code_model_single_step
)
output = agent.run("What is 2 multiplied by 3.6452?", single_step=True) output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
assert isinstance(output, str) assert isinstance(output, str)
assert "7.2904" in output assert "7.2904" in output
def test_fake_toolcalling_agent(self): def test_fake_toolcalling_agent(self):
agent = ToolCallingAgent( agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel())
tools=[PythonInterpreterTool()], model=FakeToolCallModel()
)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str) assert isinstance(output, str)
assert "7.2904" in output assert "7.2904" in output
@ -301,9 +289,7 @@ class AgentTests(unittest.TestCase):
""" """
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png") return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
agent = ToolCallingAgent( agent = ToolCallingAgent(tools=[fake_image_generation_tool], model=FakeToolCallModelImage())
tools=[fake_image_generation_tool], model=FakeToolCallModelImage()
)
output = agent.run("Make me an image.") output = agent.run("Make me an image.")
assert isinstance(output, AgentImage) assert isinstance(output, AgentImage)
assert isinstance(agent.state["image.png"], Image.Image) assert isinstance(agent.state["image.png"], Image.Image)
@ -315,9 +301,7 @@ class AgentTests(unittest.TestCase):
assert output == 7.2904 assert output == 7.2904
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?" assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[3].tool_calls == [ assert agent.logs[3].tool_calls == [
ToolCall( ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_3")
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
)
] ]
def test_additional_args_added_to_task(self): def test_additional_args_added_to_task(self):
@ -351,9 +335,7 @@ class AgentTests(unittest.TestCase):
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs) assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
def test_code_agent_syntax_error_show_offending_lines(self): def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent( agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error
)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText) assert isinstance(output, AgentText)
assert output == "got an error" assert output == "got an error"
@ -391,9 +373,7 @@ class AgentTests(unittest.TestCase):
def test_init_agent_with_different_toolsets(self): def test_init_agent_with_different_toolsets(self):
toolset_1 = [] toolset_1 = []
agent = CodeAgent(tools=toolset_1, model=fake_code_model) agent = CodeAgent(tools=toolset_1, model=fake_code_model)
assert ( assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default
len(agent.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = CodeAgent(tools=toolset_2, model=fake_code_model) agent = CodeAgent(tools=toolset_2, model=fake_code_model)
@ -436,9 +416,7 @@ class AgentTests(unittest.TestCase):
assert "You can also give requests to team members." not in agent.system_prompt assert "You can also give requests to team members." not in agent.system_prompt
print("ok1") print("ok1")
assert "{{managed_agents_descriptions}}" not in agent.system_prompt assert "{{managed_agents_descriptions}}" not in agent.system_prompt
assert ( assert "You can also give requests to team members." in manager_agent.system_prompt
"You can also give requests to team members." in manager_agent.system_prompt
)
def test_code_agent_missing_import_triggers_advice_in_error_log(self): def test_code_agent_missing_import_triggers_advice_in_error_log(self):
agent = CodeAgent(tools=[], model=fake_code_model_import) agent = CodeAgent(tools=[], model=fake_code_model_import)

View File

@ -136,9 +136,7 @@ class TestDocs:
try: try:
code_blocks = [ code_blocks = [
( (
block.replace( block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN"))
"<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN")
)
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY")) .replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
.replace("{your_username}", "m-ric") .replace("{your_username}", "m-ric")
) )
@ -150,9 +148,7 @@ class TestDocs:
except SubprocessCallException as e: except SubprocessCallException as e:
pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}") pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
except Exception: except Exception:
pytest.fail( pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}")
f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}"
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _setup(self): def _setup(self):
@ -174,6 +170,4 @@ def pytest_generate_tests(metafunc):
test_class.setup_class() test_class.setup_class()
# Parameterize with the markdown files # Parameterize with the markdown files
metafunc.parametrize( metafunc.parametrize("doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files])
"doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files]
)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
import pytest import pytest
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
@ -23,14 +24,10 @@ from .test_tools import ToolTesterMixin
class DefaultToolTests(unittest.TestCase): class DefaultToolTests(unittest.TestCase):
def test_visit_webpage(self): def test_visit_webpage(self):
arguments = { arguments = {"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"}
"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"
}
result = VisitWebpageTool()(arguments) result = VisitWebpageTool()(arguments)
assert isinstance(result, str) assert isinstance(result, str)
assert ( assert "* [About Wikipedia](/wiki/Wikipedia:About)" in result # Proper wikipedia pages have an About
"* [About Wikipedia](/wiki/Wikipedia:About)" in result
) # Proper wikipedia pages have an About
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
@ -59,12 +56,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
for _input, expected_input in zip(inputs, self.tool.inputs.values()): for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"] input_type = expected_input["type"]
if isinstance(input_type, list): if isinstance(input_type, list):
_inputs.append( _inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
[
AGENT_TYPE_MAPPING[_input_type](_input)
for _input_type in input_type
]
)
else: else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))

View File

@ -26,6 +26,7 @@ from smolagents.types import AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin from .test_tools import ToolTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
@ -45,11 +46,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
def create_inputs(self): def create_inputs(self):
inputs_text = {"answer": "Text input"} inputs_text = {"answer": "Text input"}
inputs_image = { inputs_image = {"answer": Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512))}
"answer": Image.open(
Path(get_tests_dir("fixtures")) / "000000039769.png"
).resize((512, 512))
}
inputs_audio = {"answer": torch.Tensor(np.ones(3000))} inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio} return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}

View File

@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
import json import json
import unittest
from typing import Optional from typing import Optional
from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
class ModelTests(unittest.TestCase): class ModelTests(unittest.TestCase):
@ -33,12 +33,7 @@ class ModelTests(unittest.TestCase):
""" """
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert ( assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
"nullable"
in models.get_json_schema(get_weather)["function"]["parameters"][
"properties"
]["celsius"]
)
def test_chatmessage_has_model_dumps_json(self): def test_chatmessage_has_model_dumps_json(self):
message = ChatMessage("user", "Hello!") message = ChatMessage("user", "Hello!")

View File

@ -43,9 +43,7 @@ class FakeLLMModel:
ChatMessageToolCall( ChatMessageToolCall(
id="fake_id", id="fake_id",
type="function", type="function",
function=ChatMessageToolCallDefinition( function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "image"}),
name="final_answer", arguments={"answer": "image"}
),
) )
], ],
) )
@ -122,9 +120,7 @@ class MonitoringTester(unittest.TestCase):
) )
agent.run("Fake task") agent.run("Fake task")
self.assertEqual( self.assertEqual(agent.monitor.total_input_token_count, 20) # Should have done two monitoring callbacks
agent.monitor.total_input_token_count, 20
) # Should have done two monitoring callbacks
self.assertEqual(agent.monitor.total_output_token_count, 0) self.assertEqual(agent.monitor.total_output_token_count, 0)
def test_streaming_agent_text_output(self): def test_streaming_agent_text_output(self):

View File

@ -55,10 +55,7 @@ class PythonInterpreterTester(unittest.TestCase):
code = "print = '3'" code = "print = '3'"
with pytest.raises(InterpreterError) as e: with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, {"print": print}, state={}) evaluate_python_code(code, {"print": print}, state={})
assert ( assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
"Cannot assign to name 'print': doing this would erase the existing tool!"
in str(e)
)
def test_subscript_call(self): def test_subscript_call(self):
code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)""" code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
@ -92,9 +89,7 @@ class PythonInterpreterTester(unittest.TestCase):
state = {"x": 3} state = {"x": 3}
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5}) self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqual( self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
)
def test_evaluate_expression(self): def test_evaluate_expression(self):
code = "x = 3\ny = 5" code = "x = 3\ny = 5"
@ -110,9 +105,7 @@ class PythonInterpreterTester(unittest.TestCase):
result, _ = evaluate_python_code(code, {}, state=state) result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment. # evaluate returns the value of the last assignment.
assert result == "This is x: 3." assert result == "This is x: 3."
self.assertDictEqual( self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}
)
def test_evaluate_if(self): def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5" code = "if x <= 3:\n y = 2\nelse:\n y = 5"
@ -153,15 +146,11 @@ class PythonInterpreterTester(unittest.TestCase):
state = {"x": 3} state = {"x": 3}
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state) result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5 assert result == 5
self.assertDictEqual( self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
)
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
state = {} state = {}
evaluate_python_code( evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
code, {"min": min, "print": print, "round": round}, state=state
)
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62} assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
def test_subscript_string_with_string_index_raises_appropriate_error(self): def test_subscript_string_with_string_index_raises_appropriate_error(self):
@ -317,9 +306,7 @@ print(check_digits)
assert result == {0: 0, 1: 1, 2: 4} assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}" code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert result == {102: "b"} assert result == {102: "b"}
code = """ code = """
@ -367,9 +354,7 @@ else:
best_city = "Manhattan" best_city = "Manhattan"
best_city best_city
""" """
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Brooklyn" assert result == "Brooklyn"
code = """if d > e and a < b: code = """if d > e and a < b:
@ -380,9 +365,7 @@ else:
best_city = "Manhattan" best_city = "Manhattan"
best_city best_city
""" """
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
assert result == "Sacramento" assert result == "Sacramento"
def test_if_conditions(self): def test_if_conditions(self):
@ -398,9 +381,7 @@ if char.isalpha():
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.0 assert result == 2.0
code = ( code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
)
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose" assert result == "lose"
@ -434,14 +415,10 @@ if char.isalpha():
# Test submodules are handled properly, thus not raising error # Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()" code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()" code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
def test_additional_imports(self): def test_additional_imports(self):
code = "import numpy as np" code = "import numpy as np"
@ -613,9 +590,7 @@ except ValueError as e:
def test_types_as_objects(self): def test_types_as_objects(self):
code = "type_a = float(2); type_b = str; type_c = int" code = "type_a = float(2); type_b = str; type_c = int"
state = {} state = {}
result, is_final_answer = evaluate_python_code( result, is_final_answer = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
code, {"float": float, "str": str, "int": int}, state=state
)
assert result is int assert result is int
def test_tuple_id(self): def test_tuple_id(self):
@ -733,9 +708,7 @@ while True:
break break
i""" i"""
result, is_final_answer = evaluate_python_code( result, is_final_answer = evaluate_python_code(code, {"print": print, "round": round}, state={})
code, {"print": print, "round": round}, state={}
)
assert result == 3 assert result == 3
assert not is_final_answer assert not is_final_answer
@ -781,9 +754,7 @@ out = [i for sublist in all_res for i in sublist]
out[:10] out[:10]
""" """
state = {} state = {}
result, is_final_answer = evaluate_python_code( result, is_final_answer = evaluate_python_code(code, {"print": print, "range": range}, state=state)
code, {"print": print, "range": range}, state=state
)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3] assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
def test_pandas(self): def test_pandas(self):
@ -798,9 +769,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1] parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
""" """
state = {} state = {}
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
code, {}, state=state, authorized_imports=["pandas"]
)
assert np.array_equal(result, [-1, 5]) assert np.array_equal(result, [-1, 5])
code = """ code = """
@ -811,9 +780,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
# Filter the DataFrame to get only the rows with outdated atomic numbers # Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])] filtered_df = df.loc[df['AtomicNumber'].isin([104])]
""" """
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
assert np.array_equal(result.values[0], [104, 1]) assert np.array_equal(result.values[0], [104, 1])
# Test groupby # Test groupby
@ -825,9 +792,7 @@ data = pd.DataFrame.from_dict([
]) ])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean() survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
""" """
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
code, {}, state={}, authorized_imports=["pandas"]
)
assert result.values[1] == 0.5 assert result.values[1] == 0.5
# Test loc and iloc # Test loc and iloc
@ -839,11 +804,9 @@ data = pd.DataFrame.from_dict([
]) ])
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean() survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean() survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0] survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
""" """
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
code, {}, state={}, authorized_imports=["pandas"]
)
def test_starred(self): def test_starred(self):
code = """ code = """
@ -864,9 +827,7 @@ coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona) distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
""" """
result, _ = evaluate_python_code( result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
)
assert round(result, 1) == 622395.4 assert round(result, 1) == 622395.4
def test_for(self): def test_for(self):

View File

@ -16,7 +16,7 @@ import unittest
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from unittest.mock import patch, MagicMock from unittest.mock import MagicMock, patch
import mcp import mcp
import numpy as np import numpy as np
@ -32,6 +32,7 @@ from smolagents.types import (
AgentText, AgentText,
) )
if is_torch_available(): if is_torch_available():
import torch import torch
@ -48,9 +49,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
if input_type == "string": if input_type == "string":
inputs[input_name] = "Text input" inputs[input_name] = "Text input"
elif input_type == "image": elif input_type == "image":
inputs[input_name] = Image.open( inputs[input_name] = Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512))
Path(get_tests_dir("fixtures")) / "000000039769.png"
).resize((512, 512))
elif input_type == "audio": elif input_type == "audio":
inputs[input_name] = np.ones(3000) inputs[input_name] = np.ones(3000)
else: else:
@ -224,9 +223,7 @@ class ToolTests(unittest.TestCase):
class FailTool(Tool): class FailTool(Tool):
name = "specific" name = "specific"
description = "test description" description = "test description"
inputs = { inputs = {"string_input": {"type": "string", "description": "input description"}}
"string_input": {"type": "string", "description": "input description"}
}
output_type = "string" output_type = "string"
def __init__(self, url): def __init__(self, url):
@ -248,9 +245,7 @@ class ToolTests(unittest.TestCase):
class FailTool(Tool): class FailTool(Tool):
name = "specific" name = "specific"
description = "test description" description = "test description"
inputs = { inputs = {"string_input": {"type": "string", "description": "input description"}}
"string_input": {"type": "string", "description": "input description"}
}
output_type = "string" output_type = "string"
def useless_method(self): def useless_method(self):
@ -269,9 +264,7 @@ class ToolTests(unittest.TestCase):
class SuccessTool(Tool): class SuccessTool(Tool):
name = "specific" name = "specific"
description = "test description" description = "test description"
inputs = { inputs = {"string_input": {"type": "string", "description": "input description"}}
"string_input": {"type": "string", "description": "input description"}
}
output_type = "string" output_type = "string"
def useless_method(self): def useless_method(self):
@ -300,9 +293,7 @@ class ToolTests(unittest.TestCase):
}, },
} }
def forward( def forward(self, location: str, celsius: Optional[bool] = False) -> str:
self, location: str, celsius: Optional[bool] = False
) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
GetWeatherTool() GetWeatherTool()
@ -340,9 +331,7 @@ class ToolTests(unittest.TestCase):
} }
output_type = "string" output_type = "string"
def forward( def forward(self, location: str, celsius: Optional[bool] = False) -> str:
self, location: str, celsius: Optional[bool] = False
) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C" return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
GetWeatherTool() GetWeatherTool()
@ -410,9 +399,7 @@ def mock_smolagents_adapter():
class TestToolCollection: class TestToolCollection:
def test_from_mcp( def test_from_mcp(self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter):
self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter
):
with ToolCollection.from_mcp(mock_server_parameters) as tool_collection: with ToolCollection.from_mcp(mock_server_parameters) as tool_collection:
assert isinstance(tool_collection, ToolCollection) assert isinstance(tool_collection, ToolCollection)
assert len(tool_collection.tools) == 2 assert len(tool_collection.tools) == 2
@ -440,9 +427,5 @@ class TestToolCollection:
with ToolCollection.from_mcp(mcp_server_params) as tool_collection: with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
assert len(tool_collection.tools) == 1, "Expected 1 tool" assert len(tool_collection.tools) == 1, "Expected 1 tool"
assert tool_collection.tools[0].name == "echo_tool", ( assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
"Expected tool name to be 'echo_tool'" assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"
)
assert tool_collection.tools[0](text="Hello") == "Hello", (
"Expected tool to echo the input text"
)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest import unittest
import pytest import pytest
from smolagents.utils import parse_code_blobs from smolagents.utils import parse_code_blobs

View File

@ -16,6 +16,7 @@
from pathlib import Path from pathlib import Path
ROOT = Path(__file__).parent.parent ROOT = Path(__file__).parent.parent
TESTS_FOLDER = ROOT / "tests" TESTS_FOLDER = ROOT / "tests"
@ -37,11 +38,7 @@ def check_tests_in_ci():
if path.name.startswith("test_") if path.name.startswith("test_")
] ]
ci_workflow_file_content = CI_WORKFLOW_FILE.read_text() ci_workflow_file_content = CI_WORKFLOW_FILE.read_text()
missing_test_files = [ missing_test_files = [test_file for test_file in test_files if test_file not in ci_workflow_file_content]
test_file
for test_file in test_files
if test_file not in ci_workflow_file_content
]
if missing_test_files: if missing_test_files:
print( print(
"❌ Some test files seem to be ignored in the CI:\n" "❌ Some test files seem to be ignored in the CI:\n"