Add linter rules + apply make style (#255)

* Add linter rules + apply make style
This commit is contained in:
Lucain 2025-01-18 19:01:15 +01:00 committed by GitHub
parent 5aa0f2b53d
commit 6e1373a324
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 378 additions and 944 deletions

View File

@ -181,6 +181,7 @@
"import datasets\n",
"import pandas as pd\n",
"\n",
"\n",
"eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
"pd.DataFrame(eval_ds)"
]
@ -199,26 +200,28 @@
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import json\n",
"import os\n",
"import re\n",
"import string\n",
"import time\n",
"import warnings\n",
"from tqdm import tqdm\n",
"from typing import List\n",
"\n",
"from dotenv import load_dotenv\n",
"from tqdm import tqdm\n",
"\n",
"from smolagents import (\n",
" GoogleSearchTool,\n",
" CodeAgent,\n",
" ToolCallingAgent,\n",
" HfApiModel,\n",
" AgentError,\n",
" VisitWebpageTool,\n",
" CodeAgent,\n",
" GoogleSearchTool,\n",
" HfApiModel,\n",
" PythonInterpreterTool,\n",
" ToolCallingAgent,\n",
" VisitWebpageTool,\n",
")\n",
"from smolagents.agents import ActionStep\n",
"from dotenv import load_dotenv\n",
"\n",
"\n",
"load_dotenv()\n",
"os.makedirs(\"output\", exist_ok=True)\n",
@ -231,9 +234,7 @@
" return str(obj)\n",
"\n",
"\n",
"def answer_questions(\n",
" eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False\n",
"):\n",
"def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n",
" answered_questions = []\n",
" if os.path.exists(file_name):\n",
" with open(file_name, \"r\") as f:\n",
@ -365,23 +366,18 @@
" ma_elems = split_string(model_answer)\n",
"\n",
" if len(gt_elems) != len(ma_elems): # check length is the same\n",
" warnings.warn(\n",
" \"Answer lists have different lengths, returning False.\", UserWarning\n",
" )\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
" return False\n",
"\n",
" comparisons = []\n",
" for ma_elem, gt_elem in zip(\n",
" ma_elems, gt_elems\n",
" ): # compare each element as float or str\n",
" for ma_elem, gt_elem in zip(ma_elems, gt_elems): # compare each element as float or str\n",
" if is_float(gt_elem):\n",
" normalized_ma_elem = normalize_number_str(ma_elem)\n",
" comparisons.append(normalized_ma_elem == float(gt_elem))\n",
" else:\n",
" # we do not remove punct since comparisons can include punct\n",
" comparisons.append(\n",
" normalize_str(ma_elem, remove_punct=False)\n",
" == normalize_str(gt_elem, remove_punct=False)\n",
" normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)\n",
" )\n",
" return all(comparisons)\n",
"\n",
@ -441,9 +437,7 @@
" action_type = \"vanilla\"\n",
" llm = HfApiModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(\n",
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
" )"
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
]
},
{
@ -461,6 +455,7 @@
"source": [
"from smolagents import LiteLLMModel\n",
"\n",
"\n",
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
"\n",
"for model_id in litellm_model_ids:\n",
@ -492,9 +487,7 @@
" action_type = \"vanilla\"\n",
" llm = LiteLLMModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(\n",
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
" )"
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
]
},
{
@ -556,9 +549,11 @@
}
],
"source": [
"import pandas as pd\n",
"import glob\n",
"\n",
"import pandas as pd\n",
"\n",
"\n",
"res = []\n",
"for file_path in glob.glob(\"output/*.jsonl\"):\n",
" data = []\n",
@ -595,11 +590,7 @@
"\n",
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
"\n",
"result_df = (\n",
" (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100)\n",
" .round(1)\n",
" .reset_index()\n",
")"
"result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()"
]
},
{
@ -895,6 +886,7 @@
"import pandas as pd\n",
"from matplotlib.legend_handler import HandlerTuple # Added import\n",
"\n",
"\n",
"# Assuming pivot_df is your original dataframe\n",
"models = pivot_df[\"model_id\"].unique()\n",
"sources = pivot_df[\"source\"].unique()\n",
@ -961,14 +953,10 @@
"handles, labels = ax.get_legend_handles_labels()\n",
"unique_sources = sources\n",
"legend_elements = [\n",
" (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\"))\n",
" for i in range(len(unique_sources))\n",
" (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\")) for i in range(len(unique_sources))\n",
"]\n",
"custom_legend = ax.legend(\n",
" [\n",
" (agent_handle, vanilla_handle)\n",
" for agent_handle, vanilla_handle, _ in legend_elements\n",
" ],\n",
" [(agent_handle, vanilla_handle) for agent_handle, vanilla_handle, _ in legend_elements],\n",
" [label for _, _, label in legend_elements],\n",
" handler_map={tuple: HandlerTuple(ndivide=None)},\n",
" bbox_to_anchor=(1.05, 1),\n",
@ -1006,9 +994,7 @@
" # Start the matrix environment with 4 columns\n",
" # l for left-aligned model and task, c for centered numbers\n",
" mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n",
" mathjax_table += (\n",
" \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
" )\n",
" mathjax_table += \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
" mathjax_table += \"\\\\hline\\n\"\n",
"\n",
" # Sort the DataFrame by model_id and source\n",
@ -1033,9 +1019,7 @@
" model_display = \"\\\\;\"\n",
"\n",
" # Add the data row\n",
" mathjax_table += (\n",
" f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
" )\n",
" mathjax_table += f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
"\n",
" current_model = model\n",
"\n",

View File

@ -1,7 +1,9 @@
from smolagents import Tool, CodeAgent, HfApiModel
from smolagents.default_tools import VisitWebpageTool
from dotenv import load_dotenv
from smolagents import CodeAgent, HfApiModel, Tool
from smolagents.default_tools import VisitWebpageTool
load_dotenv()
@ -16,10 +18,11 @@ class GetCatImageTool(Tool):
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
def forward(self):
from PIL import Image
import requests
from io import BytesIO
import requests
from PIL import Image
response = requests.get(self.url)
return Image.open(BytesIO(response.content))
@ -46,4 +49,5 @@ agent.run(
# Try the agent in a Gradio UI
from smolagents import GradioUI
GradioUI(agent).launch()

View File

@ -1,4 +1,5 @@
from smolagents import CodeAgent, HfApiModel, GradioUI
from smolagents import CodeAgent, GradioUI, HfApiModel
agent = CodeAgent(tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1)

View File

@ -1,24 +1,22 @@
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
VisitWebpageTool,
HfApiModel,
ManagedAgent,
ToolCallingAgent,
HfApiModel,
VisitWebpageTool,
)
# Let's setup the instrumentation first
trace_provider = TracerProvider()
trace_provider.add_span_processor(
SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces"))
)
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces")))
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)
@ -39,6 +37,4 @@ manager_agent = CodeAgent(
model=model,
managed_agents=[managed_agent],
)
manager_agent.run(
"If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?"
)
manager_agent.run("If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?")

View File

@ -8,13 +8,10 @@ from langchain_community.retrievers import BM25Retriever
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
knowledge_base = knowledge_base.filter(
lambda row: row["source"].startswith("huggingface/transformers")
)
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
for doc in knowledge_base
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
]
text_splitter = RecursiveCharacterTextSplitter(
@ -51,14 +48,12 @@ class RetrieverTool(Tool):
query,
)
return "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + doc.page_content
for i, doc in enumerate(docs)
]
[f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
)
from smolagents import HfApiModel, CodeAgent
from smolagents import CodeAgent, HfApiModel
retriever_tool = RetrieverTool(docs_processed)
agent = CodeAgent(
@ -68,9 +63,7 @@ agent = CodeAgent(
verbosity_level=2,
)
agent_output = agent.run(
"For a transformers model training, which is slower, the forward or the backward pass?"
)
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
print("Final output:")
print(agent_output)

View File

@ -1,16 +1,17 @@
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
Float,
Integer,
MetaData,
String,
Table,
create_engine,
insert,
inspect,
text,
)
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
@ -40,9 +41,7 @@ for row in rows:
inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
table_description = "Columns:\n" + "\n".join(
[f" - {name}: {col_type}" for name, col_type in columns_info]
)
table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
print(table_description)
from smolagents import tool
@ -72,6 +71,7 @@ def sql_engine(query: str) -> str:
from smolagents import CodeAgent, HfApiModel
agent = CodeAgent(
tools=[sql_engine],
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),

View File

@ -1,7 +1,9 @@
from smolagents.agents import ToolCallingAgent
from smolagents import tool, LiteLLMModel
from typing import Optional
from smolagents import LiteLLMModel, tool
from smolagents.agents import ToolCallingAgent
# Choose which LLM engine to use!
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")

View File

@ -13,8 +13,10 @@ Usage:
import os
from mcp import StdioServerParameters
from smolagents import CodeAgent, HfApiModel, ToolCollection
mcp_server_params = StdioServerParameters(
command="uvx",
args=["--quiet", "pubmedmcp@0.1.3"],

View File

@ -1,7 +1,9 @@
from smolagents.agents import ToolCallingAgent
from smolagents import tool, LiteLLMModel
from typing import Optional
from smolagents import LiteLLMModel, tool
from smolagents.agents import ToolCallingAgent
model = LiteLLMModel(
model_id="ollama_chat/llama3.2",
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary

View File

@ -60,9 +60,18 @@ dev = [
addopts = "-sv --durations=0"
[tool.ruff]
lint.ignore = ["F403"]
line-length = 119
lint.ignore = [
"F403", # undefined-local-with-import-star
"E501", # line-too-long
]
lint.select = ["E", "F", "I", "W"]
[tool.ruff.lint.per-file-ignores]
"examples/*" = [
"E402", # module-import-not-at-top-of-file
]
[tool.ruff.lint.isort]
known-first-party = ["smolagents"]
lines-after-imports = 2

View File

@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
from transformers.utils import _LazyModule
from transformers.utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .agents import *
from .default_tools import *

View File

@ -16,18 +16,17 @@
# limitations under the License.
import time
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from enum import IntEnum
from rich import box
from rich.console import Group
from rich.console import Console, Group
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.text import Text
from rich.console import Console
from .default_tools import FinalAnswerTool, TOOL_MAPPING
from .default_tools import TOOL_MAPPING, FinalAnswerTool
from .e2b_executor import E2BExecutor
from .local_python_executor import (
BASE_BUILTIN_MODULES,
@ -112,20 +111,11 @@ class SystemPromptStep(AgentStepLog):
system_prompt: str
def get_tool_descriptions(
tools: Dict[str, Tool], tool_description_template: str
) -> str:
return "\n".join(
[
get_tool_description_with_args(tool, tool_description_template)
for tool in tools.values()
]
)
def get_tool_descriptions(tools: Dict[str, Tool], tool_description_template: str) -> str:
return "\n".join([get_tool_description_with_args(tool, tool_description_template) for tool in tools.values()])
def format_prompt_with_tools(
tools: Dict[str, Tool], prompt_template: str, tool_description_template: str
) -> str:
def format_prompt_with_tools(tools: Dict[str, Tool], prompt_template: str, tool_description_template: str) -> str:
tool_descriptions = get_tool_descriptions(tools, tool_description_template)
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
if "{{tool_names}}" in prompt:
@ -159,9 +149,7 @@ def format_prompt_with_managed_agents_descriptions(
f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'"
)
if len(managed_agents.keys()) > 0:
return prompt_template.replace(
agent_descriptions_placeholder, show_agents_descriptions(managed_agents)
)
return prompt_template.replace(agent_descriptions_placeholder, show_agents_descriptions(managed_agents))
else:
return prompt_template.replace(agent_descriptions_placeholder, "")
@ -214,9 +202,7 @@ class MultiStepAgent:
self.model = model
self.system_prompt_template = system_prompt
self.tool_description_template = (
tool_description_template
if tool_description_template
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
)
self.max_steps = max_steps
self.tool_parser = tool_parser
@ -231,10 +217,7 @@ class MultiStepAgent:
self.tools = {tool.name: tool for tool in tools}
if add_base_tools:
for tool_name, tool_class in TOOL_MAPPING.items():
if (
tool_name != "python_interpreter"
or self.__class__.__name__ == "ToolCallingAgent"
):
if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent":
self.tools[tool_name] = tool_class()
self.tools["final_answer"] = FinalAnswerTool()
@ -253,15 +236,11 @@ class MultiStepAgent:
self.system_prompt_template,
self.tool_description_template,
)
self.system_prompt = format_prompt_with_managed_agents_descriptions(
self.system_prompt, self.managed_agents
)
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
return self.system_prompt
def write_inner_memory_from_logs(
self, summary_mode: Optional[bool] = False
) -> List[Dict[str, str]]:
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
"""
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
that can be used as input to the LLM.
@ -355,10 +334,7 @@ class MultiStepAgent:
return memory
def get_succinct_logs(self):
return [
{key: value for key, value in log.items() if key != "agent_memory"}
for log in self.logs
]
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
"""
@ -402,9 +378,7 @@ class MultiStepAgent:
except Exception as e:
return f"Error in generating final LLM output:\n{e}"
def execute_tool_call(
self, tool_name: str, arguments: Union[Dict[str, str], str]
) -> Any:
def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any:
"""
Execute tool with the provided input and returns the result.
This method replaces arguments with the actual values from the state if they refer to state variables.
@ -423,9 +397,7 @@ class MultiStepAgent:
if tool_name in self.managed_agents:
observation = available_tools[tool_name].__call__(arguments)
else:
observation = available_tools[tool_name].__call__(
arguments, sanitize_inputs_outputs=True
)
observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True)
elif isinstance(arguments, dict):
for key, value in arguments.items():
if isinstance(value, str) and value in self.state:
@ -433,18 +405,14 @@ class MultiStepAgent:
if tool_name in self.managed_agents:
observation = available_tools[tool_name].__call__(**arguments)
else:
observation = available_tools[tool_name].__call__(
**arguments, sanitize_inputs_outputs=True
)
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
else:
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
raise AgentExecutionError(error_msg)
return observation
except Exception as e:
if tool_name in self.tools:
tool_description = get_tool_description_with_args(
available_tools[tool_name]
)
tool_description = get_tool_description_with_args(available_tools[tool_name])
error_msg = (
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
f"As a reminder, this tool's description is the following:\n{tool_description}"
@ -544,10 +512,7 @@ You have been provided with these additional arguments, that you can access usin
step_start_time = time.time()
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
try:
if (
self.planning_interval is not None
and self.step_number % self.planning_interval == 0
):
if self.planning_interval is not None and self.step_number % self.planning_interval == 0:
self.planning_step(
task,
is_first_step=(self.step_number == 0),
@ -600,10 +565,7 @@ You have been provided with these additional arguments, that you can access usin
step_start_time = time.time()
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
try:
if (
self.planning_interval is not None
and self.step_number % self.planning_interval == 0
):
if self.planning_interval is not None and self.step_number % self.planning_interval == 0:
self.planning_step(
task,
is_first_step=(self.step_number == 0),
@ -668,9 +630,7 @@ You have been provided with these additional arguments, that you can access usin
Now begin!""",
}
answer_facts = self.model(
[message_prompt_facts, message_prompt_task]
).content
answer_facts = self.model([message_prompt_facts, message_prompt_task]).content
message_system_prompt_plan = {
"role": MessageRole.SYSTEM,
@ -680,12 +640,8 @@ Now begin!""",
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format(
task=task,
tool_descriptions=get_tool_descriptions(
self.tools, self.tool_description_template
),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents)
),
tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
answer_facts=answer_facts,
),
}
@ -702,9 +658,7 @@ Now begin!""",
```
{answer_facts}
```""".strip()
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
self.logger.log(
Rule("[bold]Initial plan", style="orange"),
Text(final_plan_redaction),
@ -724,9 +678,7 @@ Now begin!""",
"role": MessageRole.USER,
"content": USER_PROMPT_FACTS_UPDATE,
}
facts_update = self.model(
[facts_update_system_prompt] + agent_memory + [facts_update_message]
).content
facts_update = self.model([facts_update_system_prompt] + agent_memory + [facts_update_message]).content
# Redact updated plan
plan_update_message = {
@ -737,12 +689,8 @@ Now begin!""",
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN_UPDATE.format(
task=task,
tool_descriptions=get_tool_descriptions(
self.tools, self.tool_description_template
),
managed_agents_descriptions=(
show_agents_descriptions(self.managed_agents)
),
tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
facts_update=facts_update,
remaining_steps=(self.max_steps - step),
),
@ -753,16 +701,12 @@ Now begin!""",
).content
# Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
task=task, plan_update=plan_update
)
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
final_facts_redaction = f"""Here is the updated list of the facts that I know:
```
{facts_update}
```"""
self.logs.append(
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
)
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
self.logger.log(
Rule("[bold]Updated plan", style="orange"),
Text(final_plan_redaction),
@ -816,19 +760,13 @@ class ToolCallingAgent(MultiStepAgent):
tool_arguments = tool_call.function.arguments
except Exception as e:
raise AgentGenerationError(
f"Error in generating tool call with model:\n{e}"
)
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}")
log_entry.tool_calls = [
ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)
]
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
# Execute
self.logger.log(
Panel(
Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")
),
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
level=LogLevel.INFO,
)
if tool_name == "final_answer":
@ -900,16 +838,10 @@ class CodeAgent(MultiStepAgent):
if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT
self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else []
)
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
if "{{authorized_imports}}" not in system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.")
super().__init__(
tools=tools,
model=model,
@ -966,9 +898,7 @@ class CodeAgent(MultiStepAgent):
log_entry.agent_memory = agent_memory.copy()
try:
additional_args = (
{"grammar": self.grammar} if self.grammar is not None else {}
)
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
llm_output = self.model(
self.input_messages,
stop_sequences=["<end_code>", "Observation:"],
@ -999,9 +929,7 @@ class CodeAgent(MultiStepAgent):
try:
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
except Exception as e:
error_msg = (
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
)
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
raise AgentParsingError(error_msg)
log_entry.tool_calls = [
@ -1088,17 +1016,13 @@ class ManagedAgent:
self.description = description
self.additional_prompting = additional_prompting
self.provide_run_summary = provide_run_summary
self.managed_agent_prompt = (
managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT
)
self.managed_agent_prompt = managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT
def write_full_task(self, task):
"""Adds additional prompting for the managed agent, like 'add more detail in your answer'."""
full_task = self.managed_agent_prompt.format(name=self.name, task=task)
if self.additional_prompting:
full_task = full_task.replace(
"\n{{additional_prompting}}", self.additional_prompting
).strip()
full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip()
else:
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
return full_task
@ -1107,9 +1031,7 @@ class ManagedAgent:
full_task = self.write_full_task(request)
output = self.agent.run(full_task, **kwargs)
if self.provide_run_summary:
answer = (
f"Here is the final answer from your managed agent '{self.name}':\n"
)
answer = f"Here is the final answer from your managed agent '{self.name}':\n"
answer += str(output)
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):

View File

@ -20,8 +20,6 @@ from dataclasses import dataclass
from typing import Dict, Optional
from huggingface_hub import hf_hub_download, list_spaces
from transformers.utils import is_offline_mode, is_torch_available
from .local_python_executor import (
@ -32,6 +30,7 @@ from .local_python_executor import (
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio
if is_torch_available():
from transformers.models.whisper import (
WhisperForConditionalGeneration,
@ -61,9 +60,7 @@ def get_remote_tools(logger, organization="huggingface-tools"):
tools = {}
for space_info in spaces:
repo_id = space_info.id
resolved_config_file = hf_hub_download(
repo_id, TOOL_CONFIG_FILE, repo_type="space"
)
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
task = repo_id.split("/")[-1]
@ -94,9 +91,7 @@ class PythonInterpreterTool(Tool):
if authorized_imports is None:
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
else:
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(authorized_imports)
)
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports))
self.inputs = {
"code": {
"type": "string",
@ -126,9 +121,7 @@ class PythonInterpreterTool(Tool):
class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {
"answer": {"type": "any", "description": "The final answer to the problem"}
}
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
output_type = "any"
def forward(self, answer):
@ -138,9 +131,7 @@ class FinalAnswerTool(Tool):
class UserInputTool(Tool):
name = "user_input"
description = "Asks for user's input on a specific question"
inputs = {
"question": {"type": "string", "description": "The question to ask the user"}
}
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
output_type = "string"
def forward(self, question):
@ -151,9 +142,7 @@ class UserInputTool(Tool):
class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."""
inputs = {
"query": {"type": "string", "description": "The search query to perform."}
}
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "string"
def __init__(self, *args, max_results=10, **kwargs):
@ -169,10 +158,7 @@ class DuckDuckGoSearchTool(Tool):
def forward(self, query: str) -> str:
results = self.ddgs.text(query, max_results=self.max_results)
postprocessed_results = [
f"[{result['title']}]({result['href']})\n{result['body']}"
for result in results
]
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
@ -199,9 +185,7 @@ class GoogleSearchTool(Tool):
import requests
if self.serpapi_key is None:
raise ValueError(
"Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables."
)
raise ValueError("Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables.")
params = {
"engine": "google",
@ -210,9 +194,7 @@ class GoogleSearchTool(Tool):
"google_domain": "google.com",
}
if filter_year is not None:
params["tbs"] = (
f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
)
params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
response = requests.get("https://serpapi.com/search.json", params=params)
@ -227,13 +209,9 @@ class GoogleSearchTool(Tool):
f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
)
else:
raise Exception(
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
)
raise Exception(f"'organic_results' key not found for query: '{query}'. Use a less restrictive query.")
if len(results["organic_results"]) == 0:
year_filter_message = (
f" with filter year={filter_year}" if filter_year is not None else ""
)
year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
web_snippets = []
@ -253,9 +231,7 @@ class GoogleSearchTool(Tool):
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
redacted_version = redacted_version.replace(
"Your browser can't play this video.", ""
)
redacted_version = redacted_version.replace("Your browser can't play this video.", "")
web_snippets.append(redacted_version)
return "## Search Results\n" + "\n\n".join(web_snippets)
@ -263,7 +239,9 @@ class GoogleSearchTool(Tool):
class VisitWebpageTool(Tool):
name = "visit_webpage"
description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
description = (
"Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
)
inputs = {
"url": {
"type": "string",
@ -277,6 +255,7 @@ class VisitWebpageTool(Tool):
import requests
from markdownify import markdownify
from requests.exceptions import RequestException
from smolagents.utils import truncate_content
except ImportError:
raise ImportError(

View File

@ -28,6 +28,7 @@ from .tool_validation import validate_tool_attributes
from .tools import Tool
from .utils import BASE_BUILTIN_MODULES, instance_to_source
load_dotenv()
@ -45,9 +46,7 @@ class E2BExecutor:
self.logger = logger
additional_imports = additional_imports + ["pickle5", "smolagents"]
if len(additional_imports) > 0:
execution = self.sbx.commands.run(
"pip install " + " ".join(additional_imports)
)
execution = self.sbx.commands.run("pip install " + " ".join(additional_imports))
if execution.error:
raise Exception(f"Error installing dependencies: {execution.error}")
else:
@ -61,9 +60,7 @@ class E2BExecutor:
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
tool_codes.append(tool_code)
tool_definition_code = "\n".join(
[f"import {module}" for module in BASE_BUILTIN_MODULES]
)
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
tool_definition_code += textwrap.dedent("""
class Tool:
def __call__(self, *args, **kwargs):
@ -122,9 +119,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
for attribute_name in ["jpeg", "png"]:
if getattr(result, attribute_name) is not None:
image_output = getattr(result, attribute_name)
decoded_bytes = base64.b64decode(
image_output.encode("utf-8")
)
decoded_bytes = base64.b64decode(image_output.encode("utf-8"))
return Image.open(BytesIO(decoded_bytes)), execution_logs
for attribute_name in [
"chart",

View File

@ -13,14 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gradio as gr
import shutil
import os
import mimetypes
import os
import re
import shutil
from typing import Optional
import gradio as gr
from .agents import ActionStep, AgentStepLog, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
@ -59,9 +59,7 @@ def stream_to_gradio(
):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
for step_log in agent.run(
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
):
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
for message in pull_messages_from_step(step_log, test_mode=test_mode):
yield message
@ -147,14 +145,10 @@ class GradioUI:
sanitized_name = "".join(sanitized_name)
# Save the uploaded file to the specified folder
file_path = os.path.join(
self.file_upload_folder, os.path.basename(sanitized_name)
)
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
shutil.copy(file.name, file_path)
return gr.Textbox(
f"File uploaded: {file_path}", visible=True
), file_uploads_log + [file_path]
return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
def log_user_message(self, text_input, file_uploads_log):
return (
@ -183,9 +177,7 @@ class GradioUI:
# If an upload folder is provided, enable the upload feature
if self.file_upload_folder is not None:
upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(
label="Upload Status", interactive=False, visible=False
)
upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
upload_file.change(
self.upload_file,
[upload_file, file_uploads_log],

View File

@ -42,8 +42,7 @@ class InterpreterError(ValueError):
ERRORS = {
name: getattr(builtins, name)
for name in dir(builtins)
if isinstance(getattr(builtins, name), type)
and issubclass(getattr(builtins, name), BaseException)
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
}
PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000
@ -167,9 +166,7 @@ def evaluate_unaryop(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> Any:
operand = evaluate_ast(
expression.operand, state, static_tools, custom_tools, authorized_imports
)
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools, authorized_imports)
if isinstance(expression.op, ast.USub):
return -operand
elif isinstance(expression.op, ast.UAdd):
@ -179,9 +176,7 @@ def evaluate_unaryop(
elif isinstance(expression.op, ast.Invert):
return ~operand
else:
raise InterpreterError(
f"Unary operation {expression.op.__class__.__name__} is not supported."
)
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
def evaluate_lambda(
@ -217,23 +212,17 @@ def evaluate_while(
) -> None:
max_iterations = 1000
iterations = 0
while evaluate_ast(
while_loop.test, state, static_tools, custom_tools, authorized_imports
):
while evaluate_ast(while_loop.test, state, static_tools, custom_tools, authorized_imports):
for node in while_loop.body:
try:
evaluate_ast(
node, state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
except BreakException:
return None
except ContinueException:
break
iterations += 1
if iterations > max_iterations:
raise InterpreterError(
f"Maximum number of {max_iterations} iterations in While loop exceeded"
)
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
return None
@ -248,8 +237,7 @@ def create_function(
func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args]
default_values = [
evaluate_ast(d, state, static_tools, custom_tools, authorized_imports)
for d in func_def.args.defaults
evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) for d in func_def.args.defaults
]
# Apply default values
@ -286,9 +274,7 @@ def create_function(
result = None
try:
for stmt in func_def.body:
result = evaluate_ast(
stmt, func_state, static_tools, custom_tools, authorized_imports
)
result = evaluate_ast(stmt, func_state, static_tools, custom_tools, authorized_imports)
except ReturnException as e:
result = e.value
@ -307,9 +293,7 @@ def evaluate_function_def(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> Callable:
custom_tools[func_def.name] = create_function(
func_def, state, static_tools, custom_tools, authorized_imports
)
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools, authorized_imports)
return custom_tools[func_def.name]
@ -321,17 +305,12 @@ def evaluate_class_def(
authorized_imports: List[str],
) -> type:
class_name = class_def.name
bases = [
evaluate_ast(base, state, static_tools, custom_tools, authorized_imports)
for base in class_def.bases
]
bases = [evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) for base in class_def.bases]
class_dict = {}
for stmt in class_def.body:
if isinstance(stmt, ast.FunctionDef):
class_dict[stmt.name] = evaluate_function_def(
stmt, state, static_tools, custom_tools, authorized_imports
)
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools, authorized_imports)
elif isinstance(stmt, ast.Assign):
for target in stmt.targets:
if isinstance(target, ast.Name):
@ -351,9 +330,7 @@ def evaluate_class_def(
authorized_imports,
)
else:
raise InterpreterError(
f"Unsupported statement in class body: {stmt.__class__.__name__}"
)
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
new_class = type(class_name, tuple(bases), class_dict)
state[class_name] = new_class
@ -371,38 +348,26 @@ def evaluate_augassign(
if isinstance(target, ast.Name):
return state.get(target.id, 0)
elif isinstance(target, ast.Subscript):
obj = evaluate_ast(
target.value, state, static_tools, custom_tools, authorized_imports
)
key = evaluate_ast(
target.slice, state, static_tools, custom_tools, authorized_imports
)
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
return obj[key]
elif isinstance(target, ast.Attribute):
obj = evaluate_ast(
target.value, state, static_tools, custom_tools, authorized_imports
)
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
return getattr(obj, target.attr)
elif isinstance(target, ast.Tuple):
return tuple(get_current_value(elt) for elt in target.elts)
elif isinstance(target, ast.List):
return [get_current_value(elt) for elt in target.elts]
else:
raise InterpreterError(
"AugAssign not supported for {type(target)} targets."
)
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
current_value = get_current_value(expression.target)
value_to_add = evaluate_ast(
expression.value, state, static_tools, custom_tools, authorized_imports
)
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
if isinstance(expression.op, ast.Add):
if isinstance(current_value, list):
if not isinstance(value_to_add, list):
raise InterpreterError(
f"Cannot add non-list value {value_to_add} to a list."
)
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
updated_value = current_value + value_to_add
else:
updated_value = current_value + value_to_add
@ -429,9 +394,7 @@ def evaluate_augassign(
elif isinstance(expression.op, ast.RShift):
updated_value = current_value >> value_to_add
else:
raise InterpreterError(
f"Operation {type(expression.op).__name__} is not supported."
)
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
# Update the state
set_value(
@ -455,16 +418,12 @@ def evaluate_boolop(
) -> bool:
if isinstance(node.op, ast.And):
for value in node.values:
if not evaluate_ast(
value, state, static_tools, custom_tools, authorized_imports
):
if not evaluate_ast(value, state, static_tools, custom_tools, authorized_imports):
return False
return True
elif isinstance(node.op, ast.Or):
for value in node.values:
if evaluate_ast(
value, state, static_tools, custom_tools, authorized_imports
):
if evaluate_ast(value, state, static_tools, custom_tools, authorized_imports):
return True
return False
@ -477,12 +436,8 @@ def evaluate_binop(
authorized_imports: List[str],
) -> Any:
# Recursively evaluate the left and right operands
left_val = evaluate_ast(
binop.left, state, static_tools, custom_tools, authorized_imports
)
right_val = evaluate_ast(
binop.right, state, static_tools, custom_tools, authorized_imports
)
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools, authorized_imports)
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools, authorized_imports)
# Determine the operation based on the type of the operator in the BinOp
if isinstance(binop.op, ast.Add):
@ -510,9 +465,7 @@ def evaluate_binop(
elif isinstance(binop.op, ast.RShift):
return left_val >> right_val
else:
raise NotImplementedError(
f"Binary operation {type(binop.op).__name__} is not implemented."
)
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
def evaluate_assign(
@ -522,17 +475,13 @@ def evaluate_assign(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> Any:
result = evaluate_ast(
assign.value, state, static_tools, custom_tools, authorized_imports
)
result = evaluate_ast(assign.value, state, static_tools, custom_tools, authorized_imports)
if len(assign.targets) == 1:
target = assign.targets[0]
set_value(target, result, state, static_tools, custom_tools, authorized_imports)
else:
if len(assign.targets) != len(result):
raise InterpreterError(
f"Assign failed: expected {len(result)} values but got {len(assign.targets)}."
)
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
expanded_values = []
for tgt in assign.targets:
if isinstance(tgt, ast.Starred):
@ -554,9 +503,7 @@ def set_value(
) -> None:
if isinstance(target, ast.Name):
if target.id in static_tools:
raise InterpreterError(
f"Cannot assign to name '{target.id}': doing this would erase the existing tool!"
)
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
state[target.id] = value
elif isinstance(target, ast.Tuple):
if not isinstance(value, tuple):
@ -567,21 +514,13 @@ def set_value(
if len(target.elts) != len(value):
raise InterpreterError("Cannot unpack tuple of wrong size")
for i, elem in enumerate(target.elts):
set_value(
elem, value[i], state, static_tools, custom_tools, authorized_imports
)
set_value(elem, value[i], state, static_tools, custom_tools, authorized_imports)
elif isinstance(target, ast.Subscript):
obj = evaluate_ast(
target.value, state, static_tools, custom_tools, authorized_imports
)
key = evaluate_ast(
target.slice, state, static_tools, custom_tools, authorized_imports
)
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
obj[key] = value
elif isinstance(target, ast.Attribute):
obj = evaluate_ast(
target.value, state, static_tools, custom_tools, authorized_imports
)
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
setattr(obj, target.attr, value)
@ -593,15 +532,11 @@ def evaluate_call(
authorized_imports: List[str],
) -> Any:
if not (
isinstance(call.func, ast.Attribute)
or isinstance(call.func, ast.Name)
or isinstance(call.func, ast.Subscript)
isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name) or isinstance(call.func, ast.Subscript)
):
raise InterpreterError(f"This is not a correct function: {call.func}).")
if isinstance(call.func, ast.Attribute):
obj = evaluate_ast(
call.func.value, state, static_tools, custom_tools, authorized_imports
)
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
func_name = call.func.attr
if not hasattr(obj, func_name):
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
@ -623,18 +558,12 @@ def evaluate_call(
)
elif isinstance(call.func, ast.Subscript):
value = evaluate_ast(
call.func.value, state, static_tools, custom_tools, authorized_imports
)
index = evaluate_ast(
call.func.slice, state, static_tools, custom_tools, authorized_imports
)
value = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
index = evaluate_ast(call.func.slice, state, static_tools, custom_tools, authorized_imports)
if isinstance(value, (list, tuple)):
func = value[index]
else:
raise InterpreterError(
f"Cannot subscript object of type {type(value).__name__}"
)
raise InterpreterError(f"Cannot subscript object of type {type(value).__name__}")
if not callable(func):
raise InterpreterError(f"This is not a correct function: {call.func}).")
@ -642,20 +571,12 @@ def evaluate_call(
args = []
for arg in call.args:
if isinstance(arg, ast.Starred):
args.extend(
evaluate_ast(
arg.value, state, static_tools, custom_tools, authorized_imports
)
)
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools, authorized_imports))
else:
args.append(
evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports)
)
args.append(evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports))
kwargs = {
keyword.arg: evaluate_ast(
keyword.value, state, static_tools, custom_tools, authorized_imports
)
keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools, authorized_imports)
for keyword in call.keywords
}
@ -693,17 +614,11 @@ def evaluate_subscript(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> Any:
index = evaluate_ast(
subscript.slice, state, static_tools, custom_tools, authorized_imports
)
value = evaluate_ast(
subscript.value, state, static_tools, custom_tools, authorized_imports
)
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools, authorized_imports)
value = evaluate_ast(subscript.value, state, static_tools, custom_tools, authorized_imports)
if isinstance(value, str) and isinstance(index, str):
raise InterpreterError(
"You're trying to subscript a string with a string index, which is impossible"
)
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj
return parent_object.loc[index]
@ -718,15 +633,11 @@ def evaluate_subscript(
return value[index]
elif isinstance(value, (list, tuple)):
if not (-len(value) <= index < len(value)):
raise InterpreterError(
f"Index {index} out of bounds for list of length {len(value)}"
)
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
return value[int(index)]
elif isinstance(value, str):
if not (-len(value) <= index < len(value)):
raise InterpreterError(
f"Index {index} out of bounds for string of length {len(value)}"
)
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
return value[index]
elif index in value:
return value[index]
@ -765,12 +676,9 @@ def evaluate_condition(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> bool:
left = evaluate_ast(
condition.left, state, static_tools, custom_tools, authorized_imports
)
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
comparators = [
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports)
for c in condition.comparators
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators
]
ops = [type(op) for op in condition.ops]
@ -818,21 +726,15 @@ def evaluate_if(
authorized_imports: List[str],
) -> Any:
result = None
test_result = evaluate_ast(
if_statement.test, state, static_tools, custom_tools, authorized_imports
)
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools, authorized_imports)
if test_result:
for line in if_statement.body:
line_result = evaluate_ast(
line, state, static_tools, custom_tools, authorized_imports
)
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
if line_result is not None:
result = line_result
else:
for line in if_statement.orelse:
line_result = evaluate_ast(
line, state, static_tools, custom_tools, authorized_imports
)
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
if line_result is not None:
result = line_result
return result
@ -846,9 +748,7 @@ def evaluate_for(
authorized_imports: List[str],
) -> Any:
result = None
iterator = evaluate_ast(
for_loop.iter, state, static_tools, custom_tools, authorized_imports
)
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools, authorized_imports)
for counter in iterator:
set_value(
for_loop.target,
@ -860,9 +760,7 @@ def evaluate_for(
)
for node in for_loop.body:
try:
line_result = evaluate_ast(
node, state, static_tools, custom_tools, authorized_imports
)
line_result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
if line_result is not None:
result = line_result
except BreakException:
@ -882,9 +780,7 @@ def evaluate_listcomp(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> List[Any]:
def inner_evaluate(
generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]
) -> List[Any]:
def inner_evaluate(generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]) -> List[Any]:
if index >= len(generators):
return [
evaluate_ast(
@ -912,9 +808,7 @@ def evaluate_listcomp(
else:
new_state[generator.target.id] = value
if all(
evaluate_ast(
if_clause, new_state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
for if_clause in generator.ifs
):
result.extend(inner_evaluate(generators, index + 1, new_state))
@ -938,32 +832,24 @@ def evaluate_try(
for handler in try_node.handlers:
if handler.type is None or isinstance(
e,
evaluate_ast(
handler.type, state, static_tools, custom_tools, authorized_imports
),
evaluate_ast(handler.type, state, static_tools, custom_tools, authorized_imports),
):
matched = True
if handler.name:
state[handler.name] = e
for stmt in handler.body:
evaluate_ast(
stmt, state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
break
if not matched:
raise e
else:
if try_node.orelse:
for stmt in try_node.orelse:
evaluate_ast(
stmt, state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
finally:
if try_node.finalbody:
for stmt in try_node.finalbody:
evaluate_ast(
stmt, state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
def evaluate_raise(
@ -974,15 +860,11 @@ def evaluate_raise(
authorized_imports: List[str],
) -> None:
if raise_node.exc is not None:
exc = evaluate_ast(
raise_node.exc, state, static_tools, custom_tools, authorized_imports
)
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools, authorized_imports)
else:
exc = None
if raise_node.cause is not None:
cause = evaluate_ast(
raise_node.cause, state, static_tools, custom_tools, authorized_imports
)
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools, authorized_imports)
else:
cause = None
if exc is not None:
@ -1001,14 +883,10 @@ def evaluate_assert(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> None:
test_result = evaluate_ast(
assert_node.test, state, static_tools, custom_tools, authorized_imports
)
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools, authorized_imports)
if not test_result:
if assert_node.msg:
msg = evaluate_ast(
assert_node.msg, state, static_tools, custom_tools, authorized_imports
)
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools, authorized_imports)
raise AssertionError(msg)
else:
# Include the failing condition in the assertion message
@ -1025,9 +903,7 @@ def evaluate_with(
) -> None:
contexts = []
for item in with_node.items:
context_expr = evaluate_ast(
item.context_expr, state, static_tools, custom_tools, authorized_imports
)
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools, authorized_imports)
if item.optional_vars:
state[item.optional_vars.id] = context_expr.__enter__()
contexts.append(state[item.optional_vars.id])
@ -1069,19 +945,14 @@ def get_safe_module(unsafe_module, dangerous_patterns, visited=None):
# Copy all attributes by reference, recursively checking modules
for attr_name in dir(unsafe_module):
# Skip dangerous patterns at any level
if any(
pattern in f"{unsafe_module.__name__}.{attr_name}"
for pattern in dangerous_patterns
):
if any(pattern in f"{unsafe_module.__name__}.{attr_name}" for pattern in dangerous_patterns):
continue
attr_value = getattr(unsafe_module, attr_name)
# Recursively process nested modules, passing visited set
if isinstance(attr_value, ModuleType):
attr_value = get_safe_module(
attr_value, dangerous_patterns, visited=visited
)
attr_value = get_safe_module(attr_value, dangerous_patterns, visited=visited)
setattr(safe_module, attr_name, attr_value)
@ -1116,18 +987,14 @@ def import_modules(expression, state, authorized_imports):
module_path = module_name.split(".")
if any([module in dangerous_patterns for module in module_path]):
return False
module_subpaths = [
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
]
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import):
for alias in expression.names:
if check_module_authorized(alias.name):
raw_module = import_module(alias.name)
state[alias.asname or alias.name] = get_safe_module(
raw_module, dangerous_patterns
)
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns)
else:
raise InterpreterError(
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
@ -1135,9 +1002,7 @@ def import_modules(expression, state, authorized_imports):
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module):
raw_module = __import__(
expression.module, fromlist=[alias.name for alias in expression.names]
)
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
for alias in expression.names:
state[alias.asname or alias.name] = get_safe_module(
getattr(raw_module, alias.name), dangerous_patterns
@ -1156,9 +1021,7 @@ def evaluate_dictcomp(
) -> Dict[Any, Any]:
result = {}
for gen in dictcomp.generators:
iter_value = evaluate_ast(
gen.iter, state, static_tools, custom_tools, authorized_imports
)
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports)
for value in iter_value:
new_state = state.copy()
set_value(
@ -1170,9 +1033,7 @@ def evaluate_dictcomp(
authorized_imports,
)
if all(
evaluate_ast(
if_clause, new_state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
for if_clause in gen.ifs
):
key = evaluate_ast(
@ -1229,202 +1090,116 @@ def evaluate_ast(
if isinstance(expression, ast.Assign):
# Assignment -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_assign(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.AugAssign):
return evaluate_augassign(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_augassign(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call
return evaluate_call(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_call(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Tuple):
return tuple(
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
for elt in expression.elts
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts
)
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_listcomp(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_unaryop(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Starred):
return evaluate_ast(
expression.value, state, static_tools, custom_tools, authorized_imports
)
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation
return evaluate_boolop(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_boolop(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Break):
raise BreakException()
elif isinstance(expression, ast.Continue):
raise ContinueException()
elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation
return evaluate_binop(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_binop(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison
return evaluate_condition(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_condition(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Lambda):
return evaluate_lambda(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_lambda(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.FunctionDef):
return evaluate_function_def(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_function_def(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = [
evaluate_ast(k, state, static_tools, custom_tools, authorized_imports)
for k in expression.keys
]
values = [
evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)
for v in expression.values
]
keys = [evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) for k in expression.keys]
values = [evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) for v in expression.values]
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
return evaluate_ast(
expression.value, state, static_tools, custom_tools, authorized_imports
)
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.For):
# For loop -> execute the loop
return evaluate_for(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_for(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(
expression.value, state, static_tools, custom_tools, authorized_imports
)
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.If):
# If -> execute the right branch
return evaluate_if(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_if(expression, state, static_tools, custom_tools, authorized_imports)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(
expression.value, state, static_tools, custom_tools, authorized_imports
)
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.JoinedStr):
return "".join(
[
str(
evaluate_ast(
v, state, static_tools, custom_tools, authorized_imports
)
)
for v in expression.values
]
[str(evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)) for v in expression.values]
)
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
for elt in expression.elts
]
return [evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
return evaluate_name(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_name(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
return evaluate_subscript(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_subscript(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.IfExp):
test_val = evaluate_ast(
expression.test, state, static_tools, custom_tools, authorized_imports
)
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools, authorized_imports)
if test_val:
return evaluate_ast(
expression.body, state, static_tools, custom_tools, authorized_imports
)
return evaluate_ast(expression.body, state, static_tools, custom_tools, authorized_imports)
else:
return evaluate_ast(
expression.orelse, state, static_tools, custom_tools, authorized_imports
)
return evaluate_ast(expression.orelse, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Attribute):
value = evaluate_ast(
expression.value, state, static_tools, custom_tools, authorized_imports
)
value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice):
return slice(
evaluate_ast(
expression.lower, state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(expression.lower, state, static_tools, custom_tools, authorized_imports)
if expression.lower is not None
else None,
evaluate_ast(
expression.upper, state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(expression.upper, state, static_tools, custom_tools, authorized_imports)
if expression.upper is not None
else None,
evaluate_ast(
expression.step, state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(expression.step, state, static_tools, custom_tools, authorized_imports)
if expression.step is not None
else None,
)
elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_dictcomp(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.While):
return evaluate_while(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_while(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
return import_modules(expression, state, authorized_imports)
elif isinstance(expression, ast.ClassDef):
return evaluate_class_def(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_class_def(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Try):
return evaluate_try(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_try(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Raise):
return evaluate_raise(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_raise(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Assert):
return evaluate_assert(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_assert(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.With):
return evaluate_with(
expression, state, static_tools, custom_tools, authorized_imports
)
return evaluate_with(expression, state, static_tools, custom_tools, authorized_imports)
elif isinstance(expression, ast.Set):
return {
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
for elt in expression.elts
}
return {evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts}
elif isinstance(expression, ast.Return):
raise ReturnException(
evaluate_ast(
expression.value, state, static_tools, custom_tools, authorized_imports
)
evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
if expression.value
else None
)
@ -1488,18 +1263,12 @@ def evaluate_python_code(
try:
for node in expression.body:
result = evaluate_ast(
node, state, static_tools, custom_tools, authorized_imports
)
state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=max_print_outputs_length
)
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
is_final_answer = False
return result, is_final_answer
except FinalAnswerException as e:
state["print_outputs"] = truncate_content(
PRINT_OUTPUTS, max_length=max_print_outputs_length
)
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
is_final_answer = True
return e.value, is_final_answer
except InterpreterError as e:
@ -1521,9 +1290,7 @@ class LocalPythonInterpreter:
if max_print_outputs_length is None:
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
self.additional_authorized_imports = additional_authorized_imports
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
# Add base trusted tools to list
self.static_tools = {
**tools,
@ -1531,9 +1298,7 @@ class LocalPythonInterpreter:
}
# TODO: assert self.authorized imports are all installed locally
def __call__(
self, code_action: str, additional_variables: Dict
) -> Tuple[Any, str, bool]:
def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str, bool]:
self.state.update(additional_variables)
output, is_final_answer = evaluate_python_code(
code_action,

View File

@ -14,17 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, asdict
import json
import logging
import os
import random
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Dict, List, Optional, Union, Any
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import InferenceClient
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@ -35,6 +34,7 @@ from transformers import (
from .tools import Tool
logger = logging.getLogger(__name__)
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
@ -100,10 +100,7 @@ class ChatMessage:
def from_hf_api(cls, message) -> "ChatMessage":
tool_calls = None
if getattr(message, "tool_calls", None) is not None:
tool_calls = [
ChatMessageToolCall.from_hf_api(tool_call)
for tool_call in message.tool_calls
]
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
@ -172,17 +169,12 @@ def get_clean_message_list(
role = message["role"]
if role not in MessageRole.roles():
raise ValueError(
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
)
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
if role in role_conversions:
message["role"] = role_conversions[role]
if (
len(final_message_list) > 0
and message["role"] == final_message_list[-1]["role"]
):
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
else:
final_message_list.append(message)
@ -292,9 +284,7 @@ class HfApiModel(Model):
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
"""
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
if tools_to_call_from:
response = self.client.chat.completions.create(
messages=messages,
@ -367,9 +357,7 @@ class TransformersModel(Model):
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
if model_id is None:
model_id = default_model_id
logger.warning(
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
)
logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
self.model_id = model_id
self.kwargs = kwargs
if device_map is None:
@ -389,9 +377,7 @@ class TransformersModel(Model):
)
self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id, device_map=device_map, torch_dtype=torch_dtype
)
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria):
@ -404,16 +390,9 @@ class TransformersModel(Model):
self.stream = ""
def __call__(self, input_ids, scores, **kwargs):
generated = self.tokenizer.decode(
input_ids[0][-1], skip_special_tokens=True
)
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
self.stream += generated
if any(
[
self.stream.endswith(stop_string)
for stop_string in self.stop_strings
]
):
if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
return True
return False
@ -426,9 +405,7 @@ class TransformersModel(Model):
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
if tools_to_call_from is not None:
prompt_tensor = self.tokenizer.apply_chat_template(
messages,
@ -448,9 +425,7 @@ class TransformersModel(Model):
out = self.model.generate(
**prompt_tensor,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
),
stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
**self.kwargs,
)
generated_tokens = out[0, count_prompt_tokens:]
@ -475,9 +450,7 @@ class TransformersModel(Model):
ChatMessageToolCall(
id="".join(random.choices("0123456789", k=5)),
type="function",
function=ChatMessageToolCallDefinition(
name=tool_name, arguments=tool_arguments
),
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
)
],
)
@ -525,9 +498,7 @@ class LiteLLMModel(Model):
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
import litellm
if tools_to_call_from:
@ -604,11 +575,7 @@ class OpenAIServerModel(Model):
) -> ChatMessage:
messages = get_clean_message_list(
messages,
role_conversions=(
self.custom_role_conversions
if self.custom_role_conversions
else tool_role_conversions
),
role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions),
)
if tools_to_call_from:
response = self.client.chat.completions.create(

View File

@ -22,10 +22,7 @@ class Monitor:
self.step_durations = []
self.tracked_model = tracked_model
self.logger = logger
if (
getattr(self.tracked_model, "last_input_token_count", "Not found")
!= "Not found"
):
if getattr(self.tracked_model, "last_input_token_count", "Not found") != "Not found":
self.total_input_token_count = 0
self.total_output_token_count = 0
@ -48,7 +45,9 @@ class Monitor:
if getattr(self.tracked_model, "last_input_token_count", None) is not None:
self.total_input_token_count += self.tracked_model.last_input_token_count
self.total_output_token_count += self.tracked_model.last_output_token_count
console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
console_outputs += (
f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
)
console_outputs += "]"
self.logger.log(Text(console_outputs, style="dim"), level=1)

View File

@ -6,6 +6,7 @@ from typing import Set
from .utils import BASE_BUILTIN_MODULES
_BUILTIN_NAMES = set(vars(builtins))
@ -141,9 +142,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
# Check that __init__ method takes no arguments
if not cls.__init__.__qualname__ == "Tool.__init__":
sig = inspect.signature(cls.__init__)
non_self_params = list(
[arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]
)
non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
if len(non_self_params) > 0:
errors.append(
f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!"
@ -174,9 +173,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
# Check if the assignment is more complex than simple literals
if not all(
isinstance(
val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)
)
isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
for val in ast.walk(node.value)
):
for target in node.targets:
@ -195,9 +192,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
# Run checks on all methods
for node in class_node.body:
if isinstance(node, ast.FunctionDef):
method_checker = MethodChecker(
class_level_checker.class_attributes, check_imports=check_imports
)
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
method_checker.visit(node)
errors += [f"- {node.name}: {error}" for error in method_checker.errors]

View File

@ -36,7 +36,6 @@ from huggingface_hub import (
upload_folder,
)
from huggingface_hub.utils import RepositoryNotFoundError
from packaging import version
from transformers.dynamic_module_utils import get_imports
from transformers.utils import (
@ -52,6 +51,7 @@ from .tool_validation import MethodChecker, validate_tool_attributes
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
from .utils import instance_to_source
logger = logging.getLogger(__name__)
if is_accelerate_available():
@ -77,9 +77,7 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
return "model"
except RepositoryNotFoundError:
raise EnvironmentError(
f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
)
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
except Exception:
return "model"
except Exception:
@ -109,9 +107,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
properties[param_name]["nullable"] = True
for param_name in signature.parameters.keys():
if signature.parameters[param_name].default != inspect.Parameter.empty:
if (
param_name not in properties
): # this can happen if the param has no type hint but a default value
if param_name not in properties: # this can happen if the param has no type hint but a default value
properties[param_name] = {"nullable": True}
return properties
@ -181,9 +177,7 @@ class Tool:
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
)
for input_name, input_content in self.inputs.items():
assert isinstance(input_content, dict), (
f"Input '{input_name}' should be a dictionary."
)
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
assert "type" in input_content and "description" in input_content, (
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
)
@ -348,15 +342,7 @@ class Tool:
imports = []
for module in [tool_file]:
imports.extend(get_imports(module))
imports = list(
set(
[
el
for el in imports + ["smolagents"]
if el not in sys.stdlib_module_names
]
)
)
imports = list(set([el for el in imports + ["smolagents"] if el not in sys.stdlib_module_names]))
with open(requirements_file, "w", encoding="utf-8") as f:
f.write("\n".join(imports) + "\n")
@ -410,9 +396,7 @@ class Tool:
print(work_dir)
with open(work_dir + "/tool.py", "r") as f:
print("\n".join(f.readlines()))
logger.info(
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
)
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
return upload_folder(
repo_id=repo_id,
commit_message=commit_message,
@ -592,9 +576,7 @@ class Tool:
self.name = name
self.description = description
self.client = Client(space_id, hf_token=token)
space_description = self.client.view_api(
return_format="dict", print_info=False
)["named_endpoints"]
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
# If api_name is not defined, take the first of the available APIs for this space
if api_name is None:
@ -607,9 +589,7 @@ class Tool:
try:
space_description_api = space_description[api_name]
except KeyError:
raise KeyError(
f"Could not find specified {api_name=} among available api names."
)
raise KeyError(f"Could not find specified {api_name=} among available api names.")
self.inputs = {}
for parameter in space_description_api["parameters"]:
@ -683,8 +663,7 @@ class Tool:
self._gradio_tool = _gradio_tool
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
self.inputs = {
key: {"type": CONVERSION_DICT[value.annotation], "description": ""}
for key, value in func_args
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
}
self.forward = self._gradio_tool.run
@ -726,9 +705,7 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
"""
def get_tool_description_with_args(
tool: Tool, description_template: Optional[str] = None
) -> str:
def get_tool_description_with_args(tool: Tool, description_template: Optional[str] = None) -> str:
if description_template is None:
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
compiled_template = compile_jinja_template(description_template)
@ -748,10 +725,7 @@ def compile_jinja_template(template):
raise ImportError("template requires jinja2 to be installed.")
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"template requires jinja2>=3.1.0 to be installed. Your version is "
f"{jinja2.__version__}."
)
raise ImportError(f"template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}.")
def raise_exception(message):
raise TemplateError(message)
@ -772,9 +746,7 @@ def launch_gradio_demo(tool: Tool):
try:
import gradio as gr
except ImportError:
raise ImportError(
"Gradio should be installed in order to launch a gradio demo."
)
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
TYPE_TO_COMPONENT_CLASS_MAPPING = {
"image": gr.Image,
@ -791,9 +763,7 @@ def launch_gradio_demo(tool: Tool):
gradio_inputs = []
for input_name, input_details in tool.inputs.items():
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
input_details["type"]
]
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
new_component = input_gradio_component_class(label=input_name)
gradio_inputs.append(new_component)
@ -922,14 +892,9 @@ class ToolCollection:
```
"""
_collection = get_collection(collection_slug, token=token)
_hub_repo_ids = {
item.item_id for item in _collection.items if item.item_type == "space"
}
_hub_repo_ids = {item.item_id for item in _collection.items if item.item_type == "space"}
tools = {
Tool.from_hub(repo_id, token, trust_remote_code)
for repo_id in _hub_repo_ids
}
tools = {Tool.from_hub(repo_id, token, trust_remote_code) for repo_id in _hub_repo_ids}
return cls(tools)
@ -986,9 +951,7 @@ def tool(tool_function: Callable) -> Tool:
"""
parameters = get_json_schema(tool_function)["function"]
if "return" not in parameters:
raise TypeHintParsingException(
"Tool return type not found: make sure your function has a return type hint!"
)
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
class SimpleTool(Tool):
def __init__(self, name, description, inputs, output_type, function):
@ -1007,9 +970,9 @@ def tool(tool_function: Callable) -> Tool:
function=tool_function,
)
original_signature = inspect.signature(tool_function)
new_parameters = [
inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)
] + list(original_signature.parameters.values())
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)] + list(
original_signature.parameters.values()
)
new_signature = original_signature.replace(parameters=new_parameters)
simple_tool.forward.__signature__ = new_signature
return simple_tool
@ -1082,9 +1045,7 @@ class PipelineTool(Tool):
if model is None:
if self.default_checkpoint is None:
raise ValueError(
"This tool does not implement a default checkpoint, you need to pass one."
)
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
model = self.default_checkpoint
if pre_processor is None:
pre_processor = model
@ -1107,21 +1068,15 @@ class PipelineTool(Tool):
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
"""
if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(
self.pre_processor, **self.hub_kwargs
)
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(
self.model, **self.model_kwargs, **self.hub_kwargs
)
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
if self.post_processor is None:
self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(
self.post_processor, **self.hub_kwargs
)
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
if self.device is None:
if self.device_map is not None:
@ -1165,12 +1120,8 @@ class PipelineTool(Tool):
encoded_inputs = self.encode(*args, **kwargs)
tensor_inputs = {
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
}
non_tensor_inputs = {
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
}
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
encoded_inputs = send_to_device(tensor_inputs, self.device)
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})

View File

@ -27,6 +27,7 @@ from transformers.utils import (
is_vision_available,
)
logger = logging.getLogger(__name__)
if is_vision_available():
@ -113,9 +114,7 @@ class AgentImage(AgentType, ImageType):
elif isinstance(value, np.ndarray):
self._tensor = torch.from_numpy(value)
else:
raise TypeError(
f"Unsupported type for {self.__class__.__name__}: {type(value)}"
)
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
def _ipython_display_(self, include=None, exclude=None):
"""
@ -264,9 +263,7 @@ if is_torch_available():
def handle_agent_input_types(*args, **kwargs):
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
kwargs = {
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
}
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
return args, kwargs
@ -279,9 +276,7 @@ def handle_agent_output_types(output, output_type=None):
# If the class does not have defined output, then we map according to the type
for _k, _v in INSTANCE_TYPE_MAPPING.items():
if isinstance(output, _k):
if (
_k is not object
): # avoid converting to audio if torch is not installed
if _k is not object: # avoid converting to audio if torch is not installed
return _v(output)
return output

View File

@ -83,9 +83,7 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
try:
first_accolade_index = json_blob.find("{")
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace(
'\\"', "'"
)
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
json_data = json.loads(json_blob, strict=False)
return json_data
except json.JSONDecodeError as e:
@ -162,9 +160,7 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
MAX_LENGTH_TRUNCATE_CONTENT = 20000
def truncate_content(
content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT
) -> str:
def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str:
if len(content) <= max_length:
return content
else:
@ -206,12 +202,8 @@ def is_same_method(method1, method2):
source2 = get_method_source(method2)
# Remove method decorators if any
source1 = "\n".join(
line for line in source1.split("\n") if not line.strip().startswith("@")
)
source2 = "\n".join(
line for line in source2.split("\n") if not line.strip().startswith("@")
)
source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@"))
source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@"))
return source1 == source2
except (TypeError, OSError):
@ -248,9 +240,7 @@ def instance_to_source(instance, base_cls=None):
for name, value in cls.__dict__.items()
if not name.startswith("__")
and not callable(value)
and not (
base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value
)
and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
}
for name, value in class_attrs.items():
@ -271,9 +261,7 @@ def instance_to_source(instance, base_cls=None):
for name, func in cls.__dict__.items()
if callable(func)
and not (
base_cls
and hasattr(base_cls, name)
and getattr(base_cls, name).__code__.co_code == func.__code__.co_code
base_cls and hasattr(base_cls, name) and getattr(base_cls, name).__code__.co_code == func.__code__.co_code
)
}
@ -284,9 +272,7 @@ def instance_to_source(instance, base_cls=None):
first_line = method_lines[0]
indent = len(first_line) - len(first_line.lstrip())
method_lines = [line[indent:] for line in method_lines]
method_source = "\n".join(
[" " + line if line.strip() else line for line in method_lines]
)
method_source = "\n".join([" " + line if line.strip() else line for line in method_lines])
class_lines.append(method_source)
class_lines.append("")

View File

@ -28,13 +28,13 @@ from smolagents.agents import (
ToolCallingAgent,
)
from smolagents.default_tools import PythonInterpreterTool
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
from smolagents.models import (
ChatMessage,
ChatMessageToolCall,
ChatMessageToolCallDefinition,
)
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
from smolagents.utils import BASE_BUILTIN_MODULES
@ -44,9 +44,7 @@ def get_new_path(suffix="") -> str:
class FakeToolCallModel:
def __call__(
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None):
if len(messages) < 3:
return ChatMessage(
role="assistant",
@ -69,18 +67,14 @@ class FakeToolCallModel:
ChatMessageToolCall(
id="call_1",
type="function",
function=ChatMessageToolCallDefinition(
name="final_answer", arguments={"answer": "7.2904"}
),
function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "7.2904"}),
)
],
)
class FakeToolCallModelImage:
def __call__(
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None):
if len(messages) < 3:
return ChatMessage(
role="assistant",
@ -104,9 +98,7 @@ class FakeToolCallModelImage:
ChatMessageToolCall(
id="call_1",
type="function",
function=ChatMessageToolCallDefinition(
name="final_answer", arguments="image.png"
),
function=ChatMessageToolCallDefinition(name="final_answer", arguments="image.png"),
)
],
)
@ -271,17 +263,13 @@ print(result)
class AgentTests(unittest.TestCase):
def test_fake_single_step_code_agent(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], model=fake_code_model_single_step
)
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_single_step)
output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
assert isinstance(output, str)
assert "7.2904" in output
def test_fake_toolcalling_agent(self):
agent = ToolCallingAgent(
tools=[PythonInterpreterTool()], model=FakeToolCallModel()
)
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel())
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str)
assert "7.2904" in output
@ -301,9 +289,7 @@ class AgentTests(unittest.TestCase):
"""
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
agent = ToolCallingAgent(
tools=[fake_image_generation_tool], model=FakeToolCallModelImage()
)
agent = ToolCallingAgent(tools=[fake_image_generation_tool], model=FakeToolCallModelImage())
output = agent.run("Make me an image.")
assert isinstance(output, AgentImage)
assert isinstance(agent.state["image.png"], Image.Image)
@ -315,9 +301,7 @@ class AgentTests(unittest.TestCase):
assert output == 7.2904
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
assert agent.logs[3].tool_calls == [
ToolCall(
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
)
ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_3")
]
def test_additional_args_added_to_task(self):
@ -351,9 +335,7 @@ class AgentTests(unittest.TestCase):
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent(
tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error
)
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "got an error"
@ -391,9 +373,7 @@ class AgentTests(unittest.TestCase):
def test_init_agent_with_different_toolsets(self):
toolset_1 = []
agent = CodeAgent(tools=toolset_1, model=fake_code_model)
assert (
len(agent.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default
assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = CodeAgent(tools=toolset_2, model=fake_code_model)
@ -436,9 +416,7 @@ class AgentTests(unittest.TestCase):
assert "You can also give requests to team members." not in agent.system_prompt
print("ok1")
assert "{{managed_agents_descriptions}}" not in agent.system_prompt
assert (
"You can also give requests to team members." in manager_agent.system_prompt
)
assert "You can also give requests to team members." in manager_agent.system_prompt
def test_code_agent_missing_import_triggers_advice_in_error_log(self):
agent = CodeAgent(tools=[], model=fake_code_model_import)

View File

@ -136,9 +136,7 @@ class TestDocs:
try:
code_blocks = [
(
block.replace(
"<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN")
)
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN"))
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
.replace("{your_username}", "m-ric")
)
@ -150,9 +148,7 @@ class TestDocs:
except SubprocessCallException as e:
pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
except Exception:
pytest.fail(
f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}"
)
pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}")
@pytest.fixture(autouse=True)
def _setup(self):
@ -174,6 +170,4 @@ def pytest_generate_tests(metafunc):
test_class.setup_class()
# Parameterize with the markdown files
metafunc.parametrize(
"doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files]
)
metafunc.parametrize("doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files])

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
@ -23,14 +24,10 @@ from .test_tools import ToolTesterMixin
class DefaultToolTests(unittest.TestCase):
def test_visit_webpage(self):
arguments = {
"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"
}
arguments = {"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"}
result = VisitWebpageTool()(arguments)
assert isinstance(result, str)
assert (
"* [About Wikipedia](/wiki/Wikipedia:About)" in result
) # Proper wikipedia pages have an About
assert "* [About Wikipedia](/wiki/Wikipedia:About)" in result # Proper wikipedia pages have an About
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
@ -59,12 +56,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"]
if isinstance(input_type, list):
_inputs.append(
[
AGENT_TYPE_MAPPING[_input_type](_input)
for _input_type in input_type
]
)
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))

View File

@ -26,6 +26,7 @@ from smolagents.types import AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin
if is_torch_available():
import torch
@ -45,11 +46,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
def create_inputs(self):
inputs_text = {"answer": "Text input"}
inputs_image = {
"answer": Image.open(
Path(get_tests_dir("fixtures")) / "000000039769.png"
).resize((512, 512))
}
inputs_image = {"answer": Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512))}
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}

View File

@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import json
import unittest
from typing import Optional
from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
class ModelTests(unittest.TestCase):
@ -33,12 +33,7 @@ class ModelTests(unittest.TestCase):
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
assert (
"nullable"
in models.get_json_schema(get_weather)["function"]["parameters"][
"properties"
]["celsius"]
)
assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
def test_chatmessage_has_model_dumps_json(self):
message = ChatMessage("user", "Hello!")

View File

@ -43,9 +43,7 @@ class FakeLLMModel:
ChatMessageToolCall(
id="fake_id",
type="function",
function=ChatMessageToolCallDefinition(
name="final_answer", arguments={"answer": "image"}
),
function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "image"}),
)
],
)
@ -122,9 +120,7 @@ class MonitoringTester(unittest.TestCase):
)
agent.run("Fake task")
self.assertEqual(
agent.monitor.total_input_token_count, 20
) # Should have done two monitoring callbacks
self.assertEqual(agent.monitor.total_input_token_count, 20) # Should have done two monitoring callbacks
self.assertEqual(agent.monitor.total_output_token_count, 0)
def test_streaming_agent_text_output(self):

View File

@ -55,10 +55,7 @@ class PythonInterpreterTester(unittest.TestCase):
code = "print = '3'"
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, {"print": print}, state={})
assert (
"Cannot assign to name 'print': doing this would erase the existing tool!"
in str(e)
)
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
def test_subscript_call(self):
code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
@ -92,9 +89,7 @@ class PythonInterpreterTester(unittest.TestCase):
state = {"x": 3}
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
)
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
def test_evaluate_expression(self):
code = "x = 3\ny = 5"
@ -110,9 +105,7 @@ class PythonInterpreterTester(unittest.TestCase):
result, _ = evaluate_python_code(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == "This is x: 3."
self.assertDictEqual(
state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}
)
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
@ -153,15 +146,11 @@ class PythonInterpreterTester(unittest.TestCase):
state = {"x": 3}
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
)
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
state = {}
evaluate_python_code(
code, {"min": min, "print": print, "round": round}, state=state
)
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
def test_subscript_string_with_string_index_raises_appropriate_error(self):
@ -317,9 +306,7 @@ print(check_digits)
assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result, _ = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert result == {102: "b"}
code = """
@ -367,9 +354,7 @@ else:
best_city = "Manhattan"
best_city
"""
result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
assert result == "Brooklyn"
code = """if d > e and a < b:
@ -380,9 +365,7 @@ else:
best_city = "Manhattan"
best_city
"""
result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
)
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
assert result == "Sacramento"
def test_if_conditions(self):
@ -398,9 +381,7 @@ if char.isalpha():
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == 2.0
code = (
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
)
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose"
@ -434,14 +415,10 @@ if char.isalpha():
# Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result, _ = evaluate_python_code(
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
)
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
def test_additional_imports(self):
code = "import numpy as np"
@ -613,9 +590,7 @@ except ValueError as e:
def test_types_as_objects(self):
code = "type_a = float(2); type_b = str; type_c = int"
state = {}
result, is_final_answer = evaluate_python_code(
code, {"float": float, "str": str, "int": int}, state=state
)
result, is_final_answer = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
assert result is int
def test_tuple_id(self):
@ -733,9 +708,7 @@ while True:
break
i"""
result, is_final_answer = evaluate_python_code(
code, {"print": print, "round": round}, state={}
)
result, is_final_answer = evaluate_python_code(code, {"print": print, "round": round}, state={})
assert result == 3
assert not is_final_answer
@ -781,9 +754,7 @@ out = [i for sublist in all_res for i in sublist]
out[:10]
"""
state = {}
result, is_final_answer = evaluate_python_code(
code, {"print": print, "range": range}, state=state
)
result, is_final_answer = evaluate_python_code(code, {"print": print, "range": range}, state=state)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
def test_pandas(self):
@ -798,9 +769,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
"""
state = {}
result, _ = evaluate_python_code(
code, {}, state=state, authorized_imports=["pandas"]
)
result, _ = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
assert np.array_equal(result, [-1, 5])
code = """
@ -811,9 +780,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
# Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
"""
result, _ = evaluate_python_code(
code, {"print": print}, state={}, authorized_imports=["pandas"]
)
result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert np.array_equal(result.values[0], [104, 1])
# Test groupby
@ -825,9 +792,7 @@ data = pd.DataFrame.from_dict([
])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
"""
result, _ = evaluate_python_code(
code, {}, state={}, authorized_imports=["pandas"]
)
result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
assert result.values[1] == 0.5
# Test loc and iloc
@ -839,11 +804,9 @@ data = pd.DataFrame.from_dict([
])
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
"""
result, _ = evaluate_python_code(
code, {}, state={}, authorized_imports=["pandas"]
)
result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
def test_starred(self):
code = """
@ -864,9 +827,7 @@ coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
"""
result, _ = evaluate_python_code(
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
)
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
assert round(result, 1) == 622395.4
def test_for(self):

View File

@ -16,7 +16,7 @@ import unittest
from pathlib import Path
from textwrap import dedent
from typing import Dict, Optional, Union
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
import mcp
import numpy as np
@ -32,6 +32,7 @@ from smolagents.types import (
AgentText,
)
if is_torch_available():
import torch
@ -48,9 +49,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
if input_type == "string":
inputs[input_name] = "Text input"
elif input_type == "image":
inputs[input_name] = Image.open(
Path(get_tests_dir("fixtures")) / "000000039769.png"
).resize((512, 512))
inputs[input_name] = Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512))
elif input_type == "audio":
inputs[input_name] = np.ones(3000)
else:
@ -224,9 +223,7 @@ class ToolTests(unittest.TestCase):
class FailTool(Tool):
name = "specific"
description = "test description"
inputs = {
"string_input": {"type": "string", "description": "input description"}
}
inputs = {"string_input": {"type": "string", "description": "input description"}}
output_type = "string"
def __init__(self, url):
@ -248,9 +245,7 @@ class ToolTests(unittest.TestCase):
class FailTool(Tool):
name = "specific"
description = "test description"
inputs = {
"string_input": {"type": "string", "description": "input description"}
}
inputs = {"string_input": {"type": "string", "description": "input description"}}
output_type = "string"
def useless_method(self):
@ -269,9 +264,7 @@ class ToolTests(unittest.TestCase):
class SuccessTool(Tool):
name = "specific"
description = "test description"
inputs = {
"string_input": {"type": "string", "description": "input description"}
}
inputs = {"string_input": {"type": "string", "description": "input description"}}
output_type = "string"
def useless_method(self):
@ -300,9 +293,7 @@ class ToolTests(unittest.TestCase):
},
}
def forward(
self, location: str, celsius: Optional[bool] = False
) -> str:
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
GetWeatherTool()
@ -340,9 +331,7 @@ class ToolTests(unittest.TestCase):
}
output_type = "string"
def forward(
self, location: str, celsius: Optional[bool] = False
) -> str:
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
GetWeatherTool()
@ -410,9 +399,7 @@ def mock_smolagents_adapter():
class TestToolCollection:
def test_from_mcp(
self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter
):
def test_from_mcp(self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter):
with ToolCollection.from_mcp(mock_server_parameters) as tool_collection:
assert isinstance(tool_collection, ToolCollection)
assert len(tool_collection.tools) == 2
@ -440,9 +427,5 @@ class TestToolCollection:
with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
assert len(tool_collection.tools) == 1, "Expected 1 tool"
assert tool_collection.tools[0].name == "echo_tool", (
"Expected tool name to be 'echo_tool'"
)
assert tool_collection.tools[0](text="Hello") == "Hello", (
"Expected tool to echo the input text"
)
assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
from smolagents.utils import parse_code_blobs

View File

@ -16,6 +16,7 @@
from pathlib import Path
ROOT = Path(__file__).parent.parent
TESTS_FOLDER = ROOT / "tests"
@ -37,11 +38,7 @@ def check_tests_in_ci():
if path.name.startswith("test_")
]
ci_workflow_file_content = CI_WORKFLOW_FILE.read_text()
missing_test_files = [
test_file
for test_file in test_files
if test_file not in ci_workflow_file_content
]
missing_test_files = [test_file for test_file in test_files if test_file not in ci_workflow_file_content]
if missing_test_files:
print(
"❌ Some test files seem to be ignored in the CI:\n"