Add linter rules + apply make style (#255)
* Add linter rules + apply make style
This commit is contained in:
parent
5aa0f2b53d
commit
6e1373a324
|
@ -181,6 +181,7 @@
|
|||
"import datasets\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
|
||||
"pd.DataFrame(eval_ds)"
|
||||
]
|
||||
|
@ -199,26 +200,28 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import string\n",
|
||||
"import time\n",
|
||||
"import warnings\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"\n",
|
||||
"from smolagents import (\n",
|
||||
" GoogleSearchTool,\n",
|
||||
" CodeAgent,\n",
|
||||
" ToolCallingAgent,\n",
|
||||
" HfApiModel,\n",
|
||||
" AgentError,\n",
|
||||
" VisitWebpageTool,\n",
|
||||
" CodeAgent,\n",
|
||||
" GoogleSearchTool,\n",
|
||||
" HfApiModel,\n",
|
||||
" PythonInterpreterTool,\n",
|
||||
" ToolCallingAgent,\n",
|
||||
" VisitWebpageTool,\n",
|
||||
")\n",
|
||||
"from smolagents.agents import ActionStep\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"load_dotenv()\n",
|
||||
"os.makedirs(\"output\", exist_ok=True)\n",
|
||||
|
@ -231,9 +234,7 @@
|
|||
" return str(obj)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def answer_questions(\n",
|
||||
" eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False\n",
|
||||
"):\n",
|
||||
"def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n",
|
||||
" answered_questions = []\n",
|
||||
" if os.path.exists(file_name):\n",
|
||||
" with open(file_name, \"r\") as f:\n",
|
||||
|
@ -365,23 +366,18 @@
|
|||
" ma_elems = split_string(model_answer)\n",
|
||||
"\n",
|
||||
" if len(gt_elems) != len(ma_elems): # check length is the same\n",
|
||||
" warnings.warn(\n",
|
||||
" \"Answer lists have different lengths, returning False.\", UserWarning\n",
|
||||
" )\n",
|
||||
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" comparisons = []\n",
|
||||
" for ma_elem, gt_elem in zip(\n",
|
||||
" ma_elems, gt_elems\n",
|
||||
" ): # compare each element as float or str\n",
|
||||
" for ma_elem, gt_elem in zip(ma_elems, gt_elems): # compare each element as float or str\n",
|
||||
" if is_float(gt_elem):\n",
|
||||
" normalized_ma_elem = normalize_number_str(ma_elem)\n",
|
||||
" comparisons.append(normalized_ma_elem == float(gt_elem))\n",
|
||||
" else:\n",
|
||||
" # we do not remove punct since comparisons can include punct\n",
|
||||
" comparisons.append(\n",
|
||||
" normalize_str(ma_elem, remove_punct=False)\n",
|
||||
" == normalize_str(gt_elem, remove_punct=False)\n",
|
||||
" normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)\n",
|
||||
" )\n",
|
||||
" return all(comparisons)\n",
|
||||
"\n",
|
||||
|
@ -441,9 +437,7 @@
|
|||
" action_type = \"vanilla\"\n",
|
||||
" llm = HfApiModel(model_id)\n",
|
||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||
" answer_questions(\n",
|
||||
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
|
||||
" )"
|
||||
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -461,6 +455,7 @@
|
|||
"source": [
|
||||
"from smolagents import LiteLLMModel\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
|
||||
"\n",
|
||||
"for model_id in litellm_model_ids:\n",
|
||||
|
@ -492,9 +487,7 @@
|
|||
" action_type = \"vanilla\"\n",
|
||||
" llm = LiteLLMModel(model_id)\n",
|
||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
||||
" answer_questions(\n",
|
||||
" eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True\n",
|
||||
" )"
|
||||
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -556,9 +549,11 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import glob\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"res = []\n",
|
||||
"for file_path in glob.glob(\"output/*.jsonl\"):\n",
|
||||
" data = []\n",
|
||||
|
@ -595,11 +590,7 @@
|
|||
"\n",
|
||||
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
|
||||
"\n",
|
||||
"result_df = (\n",
|
||||
" (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100)\n",
|
||||
" .round(1)\n",
|
||||
" .reset_index()\n",
|
||||
")"
|
||||
"result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -895,6 +886,7 @@
|
|||
"import pandas as pd\n",
|
||||
"from matplotlib.legend_handler import HandlerTuple # Added import\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Assuming pivot_df is your original dataframe\n",
|
||||
"models = pivot_df[\"model_id\"].unique()\n",
|
||||
"sources = pivot_df[\"source\"].unique()\n",
|
||||
|
@ -961,14 +953,10 @@
|
|||
"handles, labels = ax.get_legend_handles_labels()\n",
|
||||
"unique_sources = sources\n",
|
||||
"legend_elements = [\n",
|
||||
" (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\"))\n",
|
||||
" for i in range(len(unique_sources))\n",
|
||||
" (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\")) for i in range(len(unique_sources))\n",
|
||||
"]\n",
|
||||
"custom_legend = ax.legend(\n",
|
||||
" [\n",
|
||||
" (agent_handle, vanilla_handle)\n",
|
||||
" for agent_handle, vanilla_handle, _ in legend_elements\n",
|
||||
" ],\n",
|
||||
" [(agent_handle, vanilla_handle) for agent_handle, vanilla_handle, _ in legend_elements],\n",
|
||||
" [label for _, _, label in legend_elements],\n",
|
||||
" handler_map={tuple: HandlerTuple(ndivide=None)},\n",
|
||||
" bbox_to_anchor=(1.05, 1),\n",
|
||||
|
@ -1006,9 +994,7 @@
|
|||
" # Start the matrix environment with 4 columns\n",
|
||||
" # l for left-aligned model and task, c for centered numbers\n",
|
||||
" mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n",
|
||||
" mathjax_table += (\n",
|
||||
" \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
|
||||
" )\n",
|
||||
" mathjax_table += \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
|
||||
" mathjax_table += \"\\\\hline\\n\"\n",
|
||||
"\n",
|
||||
" # Sort the DataFrame by model_id and source\n",
|
||||
|
@ -1033,9 +1019,7 @@
|
|||
" model_display = \"\\\\;\"\n",
|
||||
"\n",
|
||||
" # Add the data row\n",
|
||||
" mathjax_table += (\n",
|
||||
" f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
|
||||
" )\n",
|
||||
" mathjax_table += f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
|
||||
"\n",
|
||||
" current_model = model\n",
|
||||
"\n",
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from smolagents import Tool, CodeAgent, HfApiModel
|
||||
from smolagents.default_tools import VisitWebpageTool
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from smolagents import CodeAgent, HfApiModel, Tool
|
||||
from smolagents.default_tools import VisitWebpageTool
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
|
@ -16,10 +18,11 @@ class GetCatImageTool(Tool):
|
|||
self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
|
||||
|
||||
def forward(self):
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
response = requests.get(self.url)
|
||||
|
||||
return Image.open(BytesIO(response.content))
|
||||
|
@ -46,4 +49,5 @@ agent.run(
|
|||
# Try the agent in a Gradio UI
|
||||
from smolagents import GradioUI
|
||||
|
||||
|
||||
GradioUI(agent).launch()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from smolagents import CodeAgent, HfApiModel, GradioUI
|
||||
from smolagents import CodeAgent, GradioUI, HfApiModel
|
||||
|
||||
|
||||
agent = CodeAgent(tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1)
|
||||
|
||||
|
|
|
@ -1,24 +1,22 @@
|
|||
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
|
||||
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
||||
|
||||
from smolagents import (
|
||||
CodeAgent,
|
||||
DuckDuckGoSearchTool,
|
||||
VisitWebpageTool,
|
||||
HfApiModel,
|
||||
ManagedAgent,
|
||||
ToolCallingAgent,
|
||||
HfApiModel,
|
||||
VisitWebpageTool,
|
||||
)
|
||||
|
||||
|
||||
# Let's setup the instrumentation first
|
||||
|
||||
trace_provider = TracerProvider()
|
||||
trace_provider.add_span_processor(
|
||||
SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces"))
|
||||
)
|
||||
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces")))
|
||||
|
||||
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)
|
||||
|
||||
|
@ -39,6 +37,4 @@ manager_agent = CodeAgent(
|
|||
model=model,
|
||||
managed_agents=[managed_agent],
|
||||
)
|
||||
manager_agent.run(
|
||||
"If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?"
|
||||
)
|
||||
manager_agent.run("If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?")
|
||||
|
|
|
@ -8,13 +8,10 @@ from langchain_community.retrievers import BM25Retriever
|
|||
|
||||
|
||||
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
|
||||
knowledge_base = knowledge_base.filter(
|
||||
lambda row: row["source"].startswith("huggingface/transformers")
|
||||
)
|
||||
knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
|
||||
|
||||
source_docs = [
|
||||
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
|
||||
for doc in knowledge_base
|
||||
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base
|
||||
]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
|
@ -51,14 +48,12 @@ class RetrieverTool(Tool):
|
|||
query,
|
||||
)
|
||||
return "\nRetrieved documents:\n" + "".join(
|
||||
[
|
||||
f"\n\n===== Document {str(i)} =====\n" + doc.page_content
|
||||
for i, doc in enumerate(docs)
|
||||
]
|
||||
[f"\n\n===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)]
|
||||
)
|
||||
|
||||
|
||||
from smolagents import HfApiModel, CodeAgent
|
||||
from smolagents import CodeAgent, HfApiModel
|
||||
|
||||
|
||||
retriever_tool = RetrieverTool(docs_processed)
|
||||
agent = CodeAgent(
|
||||
|
@ -68,9 +63,7 @@ agent = CodeAgent(
|
|||
verbosity_level=2,
|
||||
)
|
||||
|
||||
agent_output = agent.run(
|
||||
"For a transformers model training, which is slower, the forward or the backward pass?"
|
||||
)
|
||||
agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
|
||||
|
||||
print("Final output:")
|
||||
print(agent_output)
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from sqlalchemy import (
|
||||
create_engine,
|
||||
MetaData,
|
||||
Table,
|
||||
Column,
|
||||
String,
|
||||
Integer,
|
||||
Float,
|
||||
Integer,
|
||||
MetaData,
|
||||
String,
|
||||
Table,
|
||||
create_engine,
|
||||
insert,
|
||||
inspect,
|
||||
text,
|
||||
)
|
||||
|
||||
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
metadata_obj = MetaData()
|
||||
|
||||
|
@ -40,9 +41,7 @@ for row in rows:
|
|||
inspector = inspect(engine)
|
||||
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
|
||||
|
||||
table_description = "Columns:\n" + "\n".join(
|
||||
[f" - {name}: {col_type}" for name, col_type in columns_info]
|
||||
)
|
||||
table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
|
||||
print(table_description)
|
||||
|
||||
from smolagents import tool
|
||||
|
@ -72,6 +71,7 @@ def sql_engine(query: str) -> str:
|
|||
|
||||
from smolagents import CodeAgent, HfApiModel
|
||||
|
||||
|
||||
agent = CodeAgent(
|
||||
tools=[sql_engine],
|
||||
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from smolagents.agents import ToolCallingAgent
|
||||
from smolagents import tool, LiteLLMModel
|
||||
from typing import Optional
|
||||
|
||||
from smolagents import LiteLLMModel, tool
|
||||
from smolagents.agents import ToolCallingAgent
|
||||
|
||||
|
||||
# Choose which LLM engine to use!
|
||||
# model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
|
||||
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
|
||||
|
|
|
@ -13,8 +13,10 @@ Usage:
|
|||
import os
|
||||
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
from smolagents import CodeAgent, HfApiModel, ToolCollection
|
||||
|
||||
|
||||
mcp_server_params = StdioServerParameters(
|
||||
command="uvx",
|
||||
args=["--quiet", "pubmedmcp@0.1.3"],
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from smolagents.agents import ToolCallingAgent
|
||||
from smolagents import tool, LiteLLMModel
|
||||
from typing import Optional
|
||||
|
||||
from smolagents import LiteLLMModel, tool
|
||||
from smolagents.agents import ToolCallingAgent
|
||||
|
||||
|
||||
model = LiteLLMModel(
|
||||
model_id="ollama_chat/llama3.2",
|
||||
api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
|
||||
|
|
|
@ -60,9 +60,18 @@ dev = [
|
|||
addopts = "-sv --durations=0"
|
||||
|
||||
[tool.ruff]
|
||||
lint.ignore = ["F403"]
|
||||
line-length = 119
|
||||
lint.ignore = [
|
||||
"F403", # undefined-local-with-import-star
|
||||
"E501", # line-too-long
|
||||
]
|
||||
lint.select = ["E", "F", "I", "W"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"examples/*" = [
|
||||
"E402", # module-import-not-at-top-of-file
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["smolagents"]
|
||||
lines-after-imports = 2
|
||||
|
|
|
@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
|||
from transformers.utils import _LazyModule
|
||||
from transformers.utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import *
|
||||
from .default_tools import *
|
||||
|
|
|
@ -16,18 +16,17 @@
|
|||
# limitations under the License.
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from enum import IntEnum
|
||||
from rich import box
|
||||
from rich.console import Group
|
||||
from rich.console import Console, Group
|
||||
from rich.panel import Panel
|
||||
from rich.rule import Rule
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
from rich.console import Console
|
||||
|
||||
from .default_tools import FinalAnswerTool, TOOL_MAPPING
|
||||
from .default_tools import TOOL_MAPPING, FinalAnswerTool
|
||||
from .e2b_executor import E2BExecutor
|
||||
from .local_python_executor import (
|
||||
BASE_BUILTIN_MODULES,
|
||||
|
@ -112,20 +111,11 @@ class SystemPromptStep(AgentStepLog):
|
|||
system_prompt: str
|
||||
|
||||
|
||||
def get_tool_descriptions(
|
||||
tools: Dict[str, Tool], tool_description_template: str
|
||||
) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
get_tool_description_with_args(tool, tool_description_template)
|
||||
for tool in tools.values()
|
||||
]
|
||||
)
|
||||
def get_tool_descriptions(tools: Dict[str, Tool], tool_description_template: str) -> str:
|
||||
return "\n".join([get_tool_description_with_args(tool, tool_description_template) for tool in tools.values()])
|
||||
|
||||
|
||||
def format_prompt_with_tools(
|
||||
tools: Dict[str, Tool], prompt_template: str, tool_description_template: str
|
||||
) -> str:
|
||||
def format_prompt_with_tools(tools: Dict[str, Tool], prompt_template: str, tool_description_template: str) -> str:
|
||||
tool_descriptions = get_tool_descriptions(tools, tool_description_template)
|
||||
prompt = prompt_template.replace("{{tool_descriptions}}", tool_descriptions)
|
||||
if "{{tool_names}}" in prompt:
|
||||
|
@ -159,9 +149,7 @@ def format_prompt_with_managed_agents_descriptions(
|
|||
f"Provided prompt template does not contain the managed agents descriptions placeholder '{agent_descriptions_placeholder}'"
|
||||
)
|
||||
if len(managed_agents.keys()) > 0:
|
||||
return prompt_template.replace(
|
||||
agent_descriptions_placeholder, show_agents_descriptions(managed_agents)
|
||||
)
|
||||
return prompt_template.replace(agent_descriptions_placeholder, show_agents_descriptions(managed_agents))
|
||||
else:
|
||||
return prompt_template.replace(agent_descriptions_placeholder, "")
|
||||
|
||||
|
@ -214,9 +202,7 @@ class MultiStepAgent:
|
|||
self.model = model
|
||||
self.system_prompt_template = system_prompt
|
||||
self.tool_description_template = (
|
||||
tool_description_template
|
||||
if tool_description_template
|
||||
else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
)
|
||||
self.max_steps = max_steps
|
||||
self.tool_parser = tool_parser
|
||||
|
@ -231,10 +217,7 @@ class MultiStepAgent:
|
|||
self.tools = {tool.name: tool for tool in tools}
|
||||
if add_base_tools:
|
||||
for tool_name, tool_class in TOOL_MAPPING.items():
|
||||
if (
|
||||
tool_name != "python_interpreter"
|
||||
or self.__class__.__name__ == "ToolCallingAgent"
|
||||
):
|
||||
if tool_name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent":
|
||||
self.tools[tool_name] = tool_class()
|
||||
self.tools["final_answer"] = FinalAnswerTool()
|
||||
|
||||
|
@ -253,15 +236,11 @@ class MultiStepAgent:
|
|||
self.system_prompt_template,
|
||||
self.tool_description_template,
|
||||
)
|
||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(
|
||||
self.system_prompt, self.managed_agents
|
||||
)
|
||||
self.system_prompt = format_prompt_with_managed_agents_descriptions(self.system_prompt, self.managed_agents)
|
||||
|
||||
return self.system_prompt
|
||||
|
||||
def write_inner_memory_from_logs(
|
||||
self, summary_mode: Optional[bool] = False
|
||||
) -> List[Dict[str, str]]:
|
||||
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||
that can be used as input to the LLM.
|
||||
|
@ -355,10 +334,7 @@ class MultiStepAgent:
|
|||
return memory
|
||||
|
||||
def get_succinct_logs(self):
|
||||
return [
|
||||
{key: value for key, value in log.items() if key != "agent_memory"}
|
||||
for log in self.logs
|
||||
]
|
||||
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
|
||||
|
||||
def extract_action(self, llm_output: str, split_token: str) -> Tuple[str, str]:
|
||||
"""
|
||||
|
@ -402,9 +378,7 @@ class MultiStepAgent:
|
|||
except Exception as e:
|
||||
return f"Error in generating final LLM output:\n{e}"
|
||||
|
||||
def execute_tool_call(
|
||||
self, tool_name: str, arguments: Union[Dict[str, str], str]
|
||||
) -> Any:
|
||||
def execute_tool_call(self, tool_name: str, arguments: Union[Dict[str, str], str]) -> Any:
|
||||
"""
|
||||
Execute tool with the provided input and returns the result.
|
||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||
|
@ -423,9 +397,7 @@ class MultiStepAgent:
|
|||
if tool_name in self.managed_agents:
|
||||
observation = available_tools[tool_name].__call__(arguments)
|
||||
else:
|
||||
observation = available_tools[tool_name].__call__(
|
||||
arguments, sanitize_inputs_outputs=True
|
||||
)
|
||||
observation = available_tools[tool_name].__call__(arguments, sanitize_inputs_outputs=True)
|
||||
elif isinstance(arguments, dict):
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, str) and value in self.state:
|
||||
|
@ -433,18 +405,14 @@ class MultiStepAgent:
|
|||
if tool_name in self.managed_agents:
|
||||
observation = available_tools[tool_name].__call__(**arguments)
|
||||
else:
|
||||
observation = available_tools[tool_name].__call__(
|
||||
**arguments, sanitize_inputs_outputs=True
|
||||
)
|
||||
observation = available_tools[tool_name].__call__(**arguments, sanitize_inputs_outputs=True)
|
||||
else:
|
||||
error_msg = f"Arguments passed to tool should be a dict or string: got a {type(arguments)}."
|
||||
raise AgentExecutionError(error_msg)
|
||||
return observation
|
||||
except Exception as e:
|
||||
if tool_name in self.tools:
|
||||
tool_description = get_tool_description_with_args(
|
||||
available_tools[tool_name]
|
||||
)
|
||||
tool_description = get_tool_description_with_args(available_tools[tool_name])
|
||||
error_msg = (
|
||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||
f"As a reminder, this tool's description is the following:\n{tool_description}"
|
||||
|
@ -544,10 +512,7 @@ You have been provided with these additional arguments, that you can access usin
|
|||
step_start_time = time.time()
|
||||
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
||||
try:
|
||||
if (
|
||||
self.planning_interval is not None
|
||||
and self.step_number % self.planning_interval == 0
|
||||
):
|
||||
if self.planning_interval is not None and self.step_number % self.planning_interval == 0:
|
||||
self.planning_step(
|
||||
task,
|
||||
is_first_step=(self.step_number == 0),
|
||||
|
@ -600,10 +565,7 @@ You have been provided with these additional arguments, that you can access usin
|
|||
step_start_time = time.time()
|
||||
step_log = ActionStep(step=self.step_number, start_time=step_start_time)
|
||||
try:
|
||||
if (
|
||||
self.planning_interval is not None
|
||||
and self.step_number % self.planning_interval == 0
|
||||
):
|
||||
if self.planning_interval is not None and self.step_number % self.planning_interval == 0:
|
||||
self.planning_step(
|
||||
task,
|
||||
is_first_step=(self.step_number == 0),
|
||||
|
@ -668,9 +630,7 @@ You have been provided with these additional arguments, that you can access usin
|
|||
Now begin!""",
|
||||
}
|
||||
|
||||
answer_facts = self.model(
|
||||
[message_prompt_facts, message_prompt_task]
|
||||
).content
|
||||
answer_facts = self.model([message_prompt_facts, message_prompt_task]).content
|
||||
|
||||
message_system_prompt_plan = {
|
||||
"role": MessageRole.SYSTEM,
|
||||
|
@ -680,12 +640,8 @@ Now begin!""",
|
|||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN.format(
|
||||
task=task,
|
||||
tool_descriptions=get_tool_descriptions(
|
||||
self.tools, self.tool_description_template
|
||||
),
|
||||
managed_agents_descriptions=(
|
||||
show_agents_descriptions(self.managed_agents)
|
||||
),
|
||||
tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
|
||||
managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
|
||||
answer_facts=answer_facts,
|
||||
),
|
||||
}
|
||||
|
@ -702,9 +658,7 @@ Now begin!""",
|
|||
```
|
||||
{answer_facts}
|
||||
```""".strip()
|
||||
self.logs.append(
|
||||
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
|
||||
)
|
||||
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
|
||||
self.logger.log(
|
||||
Rule("[bold]Initial plan", style="orange"),
|
||||
Text(final_plan_redaction),
|
||||
|
@ -724,9 +678,7 @@ Now begin!""",
|
|||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_FACTS_UPDATE,
|
||||
}
|
||||
facts_update = self.model(
|
||||
[facts_update_system_prompt] + agent_memory + [facts_update_message]
|
||||
).content
|
||||
facts_update = self.model([facts_update_system_prompt] + agent_memory + [facts_update_message]).content
|
||||
|
||||
# Redact updated plan
|
||||
plan_update_message = {
|
||||
|
@ -737,12 +689,8 @@ Now begin!""",
|
|||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN_UPDATE.format(
|
||||
task=task,
|
||||
tool_descriptions=get_tool_descriptions(
|
||||
self.tools, self.tool_description_template
|
||||
),
|
||||
managed_agents_descriptions=(
|
||||
show_agents_descriptions(self.managed_agents)
|
||||
),
|
||||
tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
|
||||
managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
|
||||
facts_update=facts_update,
|
||||
remaining_steps=(self.max_steps - step),
|
||||
),
|
||||
|
@ -753,16 +701,12 @@ Now begin!""",
|
|||
).content
|
||||
|
||||
# Log final facts and plan
|
||||
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
|
||||
task=task, plan_update=plan_update
|
||||
)
|
||||
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
|
||||
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
||||
```
|
||||
{facts_update}
|
||||
```"""
|
||||
self.logs.append(
|
||||
PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction)
|
||||
)
|
||||
self.logs.append(PlanningStep(plan=final_plan_redaction, facts=final_facts_redaction))
|
||||
self.logger.log(
|
||||
Rule("[bold]Updated plan", style="orange"),
|
||||
Text(final_plan_redaction),
|
||||
|
@ -816,19 +760,13 @@ class ToolCallingAgent(MultiStepAgent):
|
|||
tool_arguments = tool_call.function.arguments
|
||||
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(
|
||||
f"Error in generating tool call with model:\n{e}"
|
||||
)
|
||||
raise AgentGenerationError(f"Error in generating tool call with model:\n{e}")
|
||||
|
||||
log_entry.tool_calls = [
|
||||
ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)
|
||||
]
|
||||
log_entry.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)]
|
||||
|
||||
# Execute
|
||||
self.logger.log(
|
||||
Panel(
|
||||
Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")
|
||||
),
|
||||
Panel(Text(f"Calling tool: '{tool_name}' with arguments: {tool_arguments}")),
|
||||
level=LogLevel.INFO,
|
||||
)
|
||||
if tool_name == "final_answer":
|
||||
|
@ -900,16 +838,10 @@ class CodeAgent(MultiStepAgent):
|
|||
if system_prompt is None:
|
||||
system_prompt = CODE_SYSTEM_PROMPT
|
||||
|
||||
self.additional_authorized_imports = (
|
||||
additional_authorized_imports if additional_authorized_imports else []
|
||||
)
|
||||
self.authorized_imports = list(
|
||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||
)
|
||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
||||
if "{{authorized_imports}}" not in system_prompt:
|
||||
raise AgentError(
|
||||
"Tag '{{authorized_imports}}' should be provided in the prompt."
|
||||
)
|
||||
raise AgentError("Tag '{{authorized_imports}}' should be provided in the prompt.")
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
model=model,
|
||||
|
@ -966,9 +898,7 @@ class CodeAgent(MultiStepAgent):
|
|||
log_entry.agent_memory = agent_memory.copy()
|
||||
|
||||
try:
|
||||
additional_args = (
|
||||
{"grammar": self.grammar} if self.grammar is not None else {}
|
||||
)
|
||||
additional_args = {"grammar": self.grammar} if self.grammar is not None else {}
|
||||
llm_output = self.model(
|
||||
self.input_messages,
|
||||
stop_sequences=["<end_code>", "Observation:"],
|
||||
|
@ -999,9 +929,7 @@ class CodeAgent(MultiStepAgent):
|
|||
try:
|
||||
code_action = fix_final_answer_code(parse_code_blobs(llm_output))
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
||||
)
|
||||
error_msg = f"Error in code parsing:\n{e}\nMake sure to provide correct code blobs."
|
||||
raise AgentParsingError(error_msg)
|
||||
|
||||
log_entry.tool_calls = [
|
||||
|
@ -1088,17 +1016,13 @@ class ManagedAgent:
|
|||
self.description = description
|
||||
self.additional_prompting = additional_prompting
|
||||
self.provide_run_summary = provide_run_summary
|
||||
self.managed_agent_prompt = (
|
||||
managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT
|
||||
)
|
||||
self.managed_agent_prompt = managed_agent_prompt if managed_agent_prompt else MANAGED_AGENT_PROMPT
|
||||
|
||||
def write_full_task(self, task):
|
||||
"""Adds additional prompting for the managed agent, like 'add more detail in your answer'."""
|
||||
full_task = self.managed_agent_prompt.format(name=self.name, task=task)
|
||||
if self.additional_prompting:
|
||||
full_task = full_task.replace(
|
||||
"\n{{additional_prompting}}", self.additional_prompting
|
||||
).strip()
|
||||
full_task = full_task.replace("\n{{additional_prompting}}", self.additional_prompting).strip()
|
||||
else:
|
||||
full_task = full_task.replace("\n{{additional_prompting}}", "").strip()
|
||||
return full_task
|
||||
|
@ -1107,9 +1031,7 @@ class ManagedAgent:
|
|||
full_task = self.write_full_task(request)
|
||||
output = self.agent.run(full_task, **kwargs)
|
||||
if self.provide_run_summary:
|
||||
answer = (
|
||||
f"Here is the final answer from your managed agent '{self.name}':\n"
|
||||
)
|
||||
answer = f"Here is the final answer from your managed agent '{self.name}':\n"
|
||||
answer += str(output)
|
||||
answer += f"\n\nFor more detail, find below a summary of this agent's work:\nSUMMARY OF WORK FROM AGENT '{self.name}':\n"
|
||||
for message in self.agent.write_inner_memory_from_logs(summary_mode=True):
|
||||
|
|
|
@ -20,8 +20,6 @@ from dataclasses import dataclass
|
|||
from typing import Dict, Optional
|
||||
|
||||
from huggingface_hub import hf_hub_download, list_spaces
|
||||
|
||||
|
||||
from transformers.utils import is_offline_mode, is_torch_available
|
||||
|
||||
from .local_python_executor import (
|
||||
|
@ -32,6 +30,7 @@ from .local_python_executor import (
|
|||
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
|
||||
from .types import AgentAudio
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.models.whisper import (
|
||||
WhisperForConditionalGeneration,
|
||||
|
@ -61,9 +60,7 @@ def get_remote_tools(logger, organization="huggingface-tools"):
|
|||
tools = {}
|
||||
for space_info in spaces:
|
||||
repo_id = space_info.id
|
||||
resolved_config_file = hf_hub_download(
|
||||
repo_id, TOOL_CONFIG_FILE, repo_type="space"
|
||||
)
|
||||
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
task = repo_id.split("/")[-1]
|
||||
|
@ -94,9 +91,7 @@ class PythonInterpreterTool(Tool):
|
|||
if authorized_imports is None:
|
||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
|
||||
else:
|
||||
self.authorized_imports = list(
|
||||
set(BASE_BUILTIN_MODULES) | set(authorized_imports)
|
||||
)
|
||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(authorized_imports))
|
||||
self.inputs = {
|
||||
"code": {
|
||||
"type": "string",
|
||||
|
@ -126,9 +121,7 @@ class PythonInterpreterTool(Tool):
|
|||
class FinalAnswerTool(Tool):
|
||||
name = "final_answer"
|
||||
description = "Provides a final answer to the given problem."
|
||||
inputs = {
|
||||
"answer": {"type": "any", "description": "The final answer to the problem"}
|
||||
}
|
||||
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
||||
output_type = "any"
|
||||
|
||||
def forward(self, answer):
|
||||
|
@ -138,9 +131,7 @@ class FinalAnswerTool(Tool):
|
|||
class UserInputTool(Tool):
|
||||
name = "user_input"
|
||||
description = "Asks for user's input on a specific question"
|
||||
inputs = {
|
||||
"question": {"type": "string", "description": "The question to ask the user"}
|
||||
}
|
||||
inputs = {"question": {"type": "string", "description": "The question to ask the user"}}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, question):
|
||||
|
@ -151,9 +142,7 @@ class UserInputTool(Tool):
|
|||
class DuckDuckGoSearchTool(Tool):
|
||||
name = "web_search"
|
||||
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."""
|
||||
inputs = {
|
||||
"query": {"type": "string", "description": "The search query to perform."}
|
||||
}
|
||||
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, *args, max_results=10, **kwargs):
|
||||
|
@ -169,10 +158,7 @@ class DuckDuckGoSearchTool(Tool):
|
|||
|
||||
def forward(self, query: str) -> str:
|
||||
results = self.ddgs.text(query, max_results=self.max_results)
|
||||
postprocessed_results = [
|
||||
f"[{result['title']}]({result['href']})\n{result['body']}"
|
||||
for result in results
|
||||
]
|
||||
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
|
||||
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
||||
|
||||
|
||||
|
@ -199,9 +185,7 @@ class GoogleSearchTool(Tool):
|
|||
import requests
|
||||
|
||||
if self.serpapi_key is None:
|
||||
raise ValueError(
|
||||
"Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables."
|
||||
)
|
||||
raise ValueError("Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables.")
|
||||
|
||||
params = {
|
||||
"engine": "google",
|
||||
|
@ -210,9 +194,7 @@ class GoogleSearchTool(Tool):
|
|||
"google_domain": "google.com",
|
||||
}
|
||||
if filter_year is not None:
|
||||
params["tbs"] = (
|
||||
f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
|
||||
)
|
||||
params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
|
||||
|
||||
response = requests.get("https://serpapi.com/search.json", params=params)
|
||||
|
||||
|
@ -227,13 +209,9 @@ class GoogleSearchTool(Tool):
|
|||
f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
|
||||
)
|
||||
raise Exception(f"'organic_results' key not found for query: '{query}'. Use a less restrictive query.")
|
||||
if len(results["organic_results"]) == 0:
|
||||
year_filter_message = (
|
||||
f" with filter year={filter_year}" if filter_year is not None else ""
|
||||
)
|
||||
year_filter_message = f" with filter year={filter_year}" if filter_year is not None else ""
|
||||
return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
|
||||
|
||||
web_snippets = []
|
||||
|
@ -253,9 +231,7 @@ class GoogleSearchTool(Tool):
|
|||
|
||||
redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
|
||||
|
||||
redacted_version = redacted_version.replace(
|
||||
"Your browser can't play this video.", ""
|
||||
)
|
||||
redacted_version = redacted_version.replace("Your browser can't play this video.", "")
|
||||
web_snippets.append(redacted_version)
|
||||
|
||||
return "## Search Results\n" + "\n\n".join(web_snippets)
|
||||
|
@ -263,7 +239,9 @@ class GoogleSearchTool(Tool):
|
|||
|
||||
class VisitWebpageTool(Tool):
|
||||
name = "visit_webpage"
|
||||
description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
|
||||
description = (
|
||||
"Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
|
||||
)
|
||||
inputs = {
|
||||
"url": {
|
||||
"type": "string",
|
||||
|
@ -277,6 +255,7 @@ class VisitWebpageTool(Tool):
|
|||
import requests
|
||||
from markdownify import markdownify
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from smolagents.utils import truncate_content
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
|
|
|
@ -28,6 +28,7 @@ from .tool_validation import validate_tool_attributes
|
|||
from .tools import Tool
|
||||
from .utils import BASE_BUILTIN_MODULES, instance_to_source
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
|
@ -45,9 +46,7 @@ class E2BExecutor:
|
|||
self.logger = logger
|
||||
additional_imports = additional_imports + ["pickle5", "smolagents"]
|
||||
if len(additional_imports) > 0:
|
||||
execution = self.sbx.commands.run(
|
||||
"pip install " + " ".join(additional_imports)
|
||||
)
|
||||
execution = self.sbx.commands.run("pip install " + " ".join(additional_imports))
|
||||
if execution.error:
|
||||
raise Exception(f"Error installing dependencies: {execution.error}")
|
||||
else:
|
||||
|
@ -61,9 +60,7 @@ class E2BExecutor:
|
|||
tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
|
||||
tool_codes.append(tool_code)
|
||||
|
||||
tool_definition_code = "\n".join(
|
||||
[f"import {module}" for module in BASE_BUILTIN_MODULES]
|
||||
)
|
||||
tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES])
|
||||
tool_definition_code += textwrap.dedent("""
|
||||
class Tool:
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
@ -122,9 +119,7 @@ locals().update({key: value for key, value in pickle_dict.items()})
|
|||
for attribute_name in ["jpeg", "png"]:
|
||||
if getattr(result, attribute_name) is not None:
|
||||
image_output = getattr(result, attribute_name)
|
||||
decoded_bytes = base64.b64decode(
|
||||
image_output.encode("utf-8")
|
||||
)
|
||||
decoded_bytes = base64.b64decode(image_output.encode("utf-8"))
|
||||
return Image.open(BytesIO(decoded_bytes)), execution_logs
|
||||
for attribute_name in [
|
||||
"chart",
|
||||
|
|
|
@ -13,14 +13,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gradio as gr
|
||||
import shutil
|
||||
import os
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
|
||||
import shutil
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .agents import ActionStep, AgentStepLog, MultiStepAgent
|
||||
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
|
||||
|
||||
|
@ -59,9 +59,7 @@ def stream_to_gradio(
|
|||
):
|
||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||
|
||||
for step_log in agent.run(
|
||||
task, stream=True, reset=reset_agent_memory, additional_args=additional_args
|
||||
):
|
||||
for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
|
||||
for message in pull_messages_from_step(step_log, test_mode=test_mode):
|
||||
yield message
|
||||
|
||||
|
@ -147,14 +145,10 @@ class GradioUI:
|
|||
sanitized_name = "".join(sanitized_name)
|
||||
|
||||
# Save the uploaded file to the specified folder
|
||||
file_path = os.path.join(
|
||||
self.file_upload_folder, os.path.basename(sanitized_name)
|
||||
)
|
||||
file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
|
||||
shutil.copy(file.name, file_path)
|
||||
|
||||
return gr.Textbox(
|
||||
f"File uploaded: {file_path}", visible=True
|
||||
), file_uploads_log + [file_path]
|
||||
return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
|
||||
|
||||
def log_user_message(self, text_input, file_uploads_log):
|
||||
return (
|
||||
|
@ -183,9 +177,7 @@ class GradioUI:
|
|||
# If an upload folder is provided, enable the upload feature
|
||||
if self.file_upload_folder is not None:
|
||||
upload_file = gr.File(label="Upload a file")
|
||||
upload_status = gr.Textbox(
|
||||
label="Upload Status", interactive=False, visible=False
|
||||
)
|
||||
upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
|
||||
upload_file.change(
|
||||
self.upload_file,
|
||||
[upload_file, file_uploads_log],
|
||||
|
|
|
@ -42,8 +42,7 @@ class InterpreterError(ValueError):
|
|||
ERRORS = {
|
||||
name: getattr(builtins, name)
|
||||
for name in dir(builtins)
|
||||
if isinstance(getattr(builtins, name), type)
|
||||
and issubclass(getattr(builtins, name), BaseException)
|
||||
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
|
||||
}
|
||||
|
||||
PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000
|
||||
|
@ -167,9 +166,7 @@ def evaluate_unaryop(
|
|||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str],
|
||||
) -> Any:
|
||||
operand = evaluate_ast(
|
||||
expression.operand, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools, authorized_imports)
|
||||
if isinstance(expression.op, ast.USub):
|
||||
return -operand
|
||||
elif isinstance(expression.op, ast.UAdd):
|
||||
|
@ -179,9 +176,7 @@ def evaluate_unaryop(
|
|||
elif isinstance(expression.op, ast.Invert):
|
||||
return ~operand
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"Unary operation {expression.op.__class__.__name__} is not supported."
|
||||
)
|
||||
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_lambda(
|
||||
|
@ -217,23 +212,17 @@ def evaluate_while(
|
|||
) -> None:
|
||||
max_iterations = 1000
|
||||
iterations = 0
|
||||
while evaluate_ast(
|
||||
while_loop.test, state, static_tools, custom_tools, authorized_imports
|
||||
):
|
||||
while evaluate_ast(while_loop.test, state, static_tools, custom_tools, authorized_imports):
|
||||
for node in while_loop.body:
|
||||
try:
|
||||
evaluate_ast(
|
||||
node, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||
except BreakException:
|
||||
return None
|
||||
except ContinueException:
|
||||
break
|
||||
iterations += 1
|
||||
if iterations > max_iterations:
|
||||
raise InterpreterError(
|
||||
f"Maximum number of {max_iterations} iterations in While loop exceeded"
|
||||
)
|
||||
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
|
||||
return None
|
||||
|
||||
|
||||
|
@ -248,8 +237,7 @@ def create_function(
|
|||
func_state = state.copy()
|
||||
arg_names = [arg.arg for arg in func_def.args.args]
|
||||
default_values = [
|
||||
evaluate_ast(d, state, static_tools, custom_tools, authorized_imports)
|
||||
for d in func_def.args.defaults
|
||||
evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) for d in func_def.args.defaults
|
||||
]
|
||||
|
||||
# Apply default values
|
||||
|
@ -286,9 +274,7 @@ def create_function(
|
|||
result = None
|
||||
try:
|
||||
for stmt in func_def.body:
|
||||
result = evaluate_ast(
|
||||
stmt, func_state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
result = evaluate_ast(stmt, func_state, static_tools, custom_tools, authorized_imports)
|
||||
except ReturnException as e:
|
||||
result = e.value
|
||||
|
||||
|
@ -307,9 +293,7 @@ def evaluate_function_def(
|
|||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str],
|
||||
) -> Callable:
|
||||
custom_tools[func_def.name] = create_function(
|
||||
func_def, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools, authorized_imports)
|
||||
return custom_tools[func_def.name]
|
||||
|
||||
|
||||
|
@ -321,17 +305,12 @@ def evaluate_class_def(
|
|||
authorized_imports: List[str],
|
||||
) -> type:
|
||||
class_name = class_def.name
|
||||
bases = [
|
||||
evaluate_ast(base, state, static_tools, custom_tools, authorized_imports)
|
||||
for base in class_def.bases
|
||||
]
|
||||
bases = [evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) for base in class_def.bases]
|
||||
class_dict = {}
|
||||
|
||||
for stmt in class_def.body:
|
||||
if isinstance(stmt, ast.FunctionDef):
|
||||
class_dict[stmt.name] = evaluate_function_def(
|
||||
stmt, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
for target in stmt.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
|
@ -351,9 +330,7 @@ def evaluate_class_def(
|
|||
authorized_imports,
|
||||
)
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"Unsupported statement in class body: {stmt.__class__.__name__}"
|
||||
)
|
||||
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
||||
|
||||
new_class = type(class_name, tuple(bases), class_dict)
|
||||
state[class_name] = new_class
|
||||
|
@ -371,38 +348,26 @@ def evaluate_augassign(
|
|||
if isinstance(target, ast.Name):
|
||||
return state.get(target.id, 0)
|
||||
elif isinstance(target, ast.Subscript):
|
||||
obj = evaluate_ast(
|
||||
target.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
key = evaluate_ast(
|
||||
target.slice, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
||||
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
|
||||
return obj[key]
|
||||
elif isinstance(target, ast.Attribute):
|
||||
obj = evaluate_ast(
|
||||
target.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
||||
return getattr(obj, target.attr)
|
||||
elif isinstance(target, ast.Tuple):
|
||||
return tuple(get_current_value(elt) for elt in target.elts)
|
||||
elif isinstance(target, ast.List):
|
||||
return [get_current_value(elt) for elt in target.elts]
|
||||
else:
|
||||
raise InterpreterError(
|
||||
"AugAssign not supported for {type(target)} targets."
|
||||
)
|
||||
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
||||
|
||||
current_value = get_current_value(expression.target)
|
||||
value_to_add = evaluate_ast(
|
||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||
|
||||
if isinstance(expression.op, ast.Add):
|
||||
if isinstance(current_value, list):
|
||||
if not isinstance(value_to_add, list):
|
||||
raise InterpreterError(
|
||||
f"Cannot add non-list value {value_to_add} to a list."
|
||||
)
|
||||
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
|
||||
updated_value = current_value + value_to_add
|
||||
else:
|
||||
updated_value = current_value + value_to_add
|
||||
|
@ -429,9 +394,7 @@ def evaluate_augassign(
|
|||
elif isinstance(expression.op, ast.RShift):
|
||||
updated_value = current_value >> value_to_add
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"Operation {type(expression.op).__name__} is not supported."
|
||||
)
|
||||
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
||||
|
||||
# Update the state
|
||||
set_value(
|
||||
|
@ -455,16 +418,12 @@ def evaluate_boolop(
|
|||
) -> bool:
|
||||
if isinstance(node.op, ast.And):
|
||||
for value in node.values:
|
||||
if not evaluate_ast(
|
||||
value, state, static_tools, custom_tools, authorized_imports
|
||||
):
|
||||
if not evaluate_ast(value, state, static_tools, custom_tools, authorized_imports):
|
||||
return False
|
||||
return True
|
||||
elif isinstance(node.op, ast.Or):
|
||||
for value in node.values:
|
||||
if evaluate_ast(
|
||||
value, state, static_tools, custom_tools, authorized_imports
|
||||
):
|
||||
if evaluate_ast(value, state, static_tools, custom_tools, authorized_imports):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -477,12 +436,8 @@ def evaluate_binop(
|
|||
authorized_imports: List[str],
|
||||
) -> Any:
|
||||
# Recursively evaluate the left and right operands
|
||||
left_val = evaluate_ast(
|
||||
binop.left, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
right_val = evaluate_ast(
|
||||
binop.right, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools, authorized_imports)
|
||||
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools, authorized_imports)
|
||||
|
||||
# Determine the operation based on the type of the operator in the BinOp
|
||||
if isinstance(binop.op, ast.Add):
|
||||
|
@ -510,9 +465,7 @@ def evaluate_binop(
|
|||
elif isinstance(binop.op, ast.RShift):
|
||||
return left_val >> right_val
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Binary operation {type(binop.op).__name__} is not implemented."
|
||||
)
|
||||
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
||||
|
||||
|
||||
def evaluate_assign(
|
||||
|
@ -522,17 +475,13 @@ def evaluate_assign(
|
|||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str],
|
||||
) -> Any:
|
||||
result = evaluate_ast(
|
||||
assign.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
result = evaluate_ast(assign.value, state, static_tools, custom_tools, authorized_imports)
|
||||
if len(assign.targets) == 1:
|
||||
target = assign.targets[0]
|
||||
set_value(target, result, state, static_tools, custom_tools, authorized_imports)
|
||||
else:
|
||||
if len(assign.targets) != len(result):
|
||||
raise InterpreterError(
|
||||
f"Assign failed: expected {len(result)} values but got {len(assign.targets)}."
|
||||
)
|
||||
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
||||
expanded_values = []
|
||||
for tgt in assign.targets:
|
||||
if isinstance(tgt, ast.Starred):
|
||||
|
@ -554,9 +503,7 @@ def set_value(
|
|||
) -> None:
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in static_tools:
|
||||
raise InterpreterError(
|
||||
f"Cannot assign to name '{target.id}': doing this would erase the existing tool!"
|
||||
)
|
||||
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
||||
state[target.id] = value
|
||||
elif isinstance(target, ast.Tuple):
|
||||
if not isinstance(value, tuple):
|
||||
|
@ -567,21 +514,13 @@ def set_value(
|
|||
if len(target.elts) != len(value):
|
||||
raise InterpreterError("Cannot unpack tuple of wrong size")
|
||||
for i, elem in enumerate(target.elts):
|
||||
set_value(
|
||||
elem, value[i], state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
set_value(elem, value[i], state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(target, ast.Subscript):
|
||||
obj = evaluate_ast(
|
||||
target.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
key = evaluate_ast(
|
||||
target.slice, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
||||
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports)
|
||||
obj[key] = value
|
||||
elif isinstance(target, ast.Attribute):
|
||||
obj = evaluate_ast(
|
||||
target.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports)
|
||||
setattr(obj, target.attr, value)
|
||||
|
||||
|
||||
|
@ -593,15 +532,11 @@ def evaluate_call(
|
|||
authorized_imports: List[str],
|
||||
) -> Any:
|
||||
if not (
|
||||
isinstance(call.func, ast.Attribute)
|
||||
or isinstance(call.func, ast.Name)
|
||||
or isinstance(call.func, ast.Subscript)
|
||||
isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name) or isinstance(call.func, ast.Subscript)
|
||||
):
|
||||
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||
if isinstance(call.func, ast.Attribute):
|
||||
obj = evaluate_ast(
|
||||
call.func.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
|
||||
func_name = call.func.attr
|
||||
if not hasattr(obj, func_name):
|
||||
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
||||
|
@ -623,18 +558,12 @@ def evaluate_call(
|
|||
)
|
||||
|
||||
elif isinstance(call.func, ast.Subscript):
|
||||
value = evaluate_ast(
|
||||
call.func.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
index = evaluate_ast(
|
||||
call.func.slice, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
value = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
|
||||
index = evaluate_ast(call.func.slice, state, static_tools, custom_tools, authorized_imports)
|
||||
if isinstance(value, (list, tuple)):
|
||||
func = value[index]
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"Cannot subscript object of type {type(value).__name__}"
|
||||
)
|
||||
raise InterpreterError(f"Cannot subscript object of type {type(value).__name__}")
|
||||
|
||||
if not callable(func):
|
||||
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||
|
@ -642,20 +571,12 @@ def evaluate_call(
|
|||
args = []
|
||||
for arg in call.args:
|
||||
if isinstance(arg, ast.Starred):
|
||||
args.extend(
|
||||
evaluate_ast(
|
||||
arg.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
)
|
||||
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools, authorized_imports))
|
||||
else:
|
||||
args.append(
|
||||
evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports)
|
||||
)
|
||||
args.append(evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports))
|
||||
|
||||
kwargs = {
|
||||
keyword.arg: evaluate_ast(
|
||||
keyword.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools, authorized_imports)
|
||||
for keyword in call.keywords
|
||||
}
|
||||
|
||||
|
@ -693,17 +614,11 @@ def evaluate_subscript(
|
|||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str],
|
||||
) -> Any:
|
||||
index = evaluate_ast(
|
||||
subscript.slice, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
value = evaluate_ast(
|
||||
subscript.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools, authorized_imports)
|
||||
value = evaluate_ast(subscript.value, state, static_tools, custom_tools, authorized_imports)
|
||||
|
||||
if isinstance(value, str) and isinstance(index, str):
|
||||
raise InterpreterError(
|
||||
"You're trying to subscript a string with a string index, which is impossible"
|
||||
)
|
||||
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
|
||||
if isinstance(value, pd.core.indexing._LocIndexer):
|
||||
parent_object = value.obj
|
||||
return parent_object.loc[index]
|
||||
|
@ -718,15 +633,11 @@ def evaluate_subscript(
|
|||
return value[index]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if not (-len(value) <= index < len(value)):
|
||||
raise InterpreterError(
|
||||
f"Index {index} out of bounds for list of length {len(value)}"
|
||||
)
|
||||
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
||||
return value[int(index)]
|
||||
elif isinstance(value, str):
|
||||
if not (-len(value) <= index < len(value)):
|
||||
raise InterpreterError(
|
||||
f"Index {index} out of bounds for string of length {len(value)}"
|
||||
)
|
||||
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
||||
return value[index]
|
||||
elif index in value:
|
||||
return value[index]
|
||||
|
@ -765,12 +676,9 @@ def evaluate_condition(
|
|||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str],
|
||||
) -> bool:
|
||||
left = evaluate_ast(
|
||||
condition.left, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
|
||||
comparators = [
|
||||
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports)
|
||||
for c in condition.comparators
|
||||
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators
|
||||
]
|
||||
ops = [type(op) for op in condition.ops]
|
||||
|
||||
|
@ -818,21 +726,15 @@ def evaluate_if(
|
|||
authorized_imports: List[str],
|
||||
) -> Any:
|
||||
result = None
|
||||
test_result = evaluate_ast(
|
||||
if_statement.test, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools, authorized_imports)
|
||||
if test_result:
|
||||
for line in if_statement.body:
|
||||
line_result = evaluate_ast(
|
||||
line, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
line_result = evaluate_ast(
|
||||
line, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
return result
|
||||
|
@ -846,9 +748,7 @@ def evaluate_for(
|
|||
authorized_imports: List[str],
|
||||
) -> Any:
|
||||
result = None
|
||||
iterator = evaluate_ast(
|
||||
for_loop.iter, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools, authorized_imports)
|
||||
for counter in iterator:
|
||||
set_value(
|
||||
for_loop.target,
|
||||
|
@ -860,9 +760,7 @@ def evaluate_for(
|
|||
)
|
||||
for node in for_loop.body:
|
||||
try:
|
||||
line_result = evaluate_ast(
|
||||
node, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
line_result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
except BreakException:
|
||||
|
@ -882,9 +780,7 @@ def evaluate_listcomp(
|
|||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str],
|
||||
) -> List[Any]:
|
||||
def inner_evaluate(
|
||||
generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]
|
||||
) -> List[Any]:
|
||||
def inner_evaluate(generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]) -> List[Any]:
|
||||
if index >= len(generators):
|
||||
return [
|
||||
evaluate_ast(
|
||||
|
@ -912,9 +808,7 @@ def evaluate_listcomp(
|
|||
else:
|
||||
new_state[generator.target.id] = value
|
||||
if all(
|
||||
evaluate_ast(
|
||||
if_clause, new_state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
|
||||
for if_clause in generator.ifs
|
||||
):
|
||||
result.extend(inner_evaluate(generators, index + 1, new_state))
|
||||
|
@ -938,32 +832,24 @@ def evaluate_try(
|
|||
for handler in try_node.handlers:
|
||||
if handler.type is None or isinstance(
|
||||
e,
|
||||
evaluate_ast(
|
||||
handler.type, state, static_tools, custom_tools, authorized_imports
|
||||
),
|
||||
evaluate_ast(handler.type, state, static_tools, custom_tools, authorized_imports),
|
||||
):
|
||||
matched = True
|
||||
if handler.name:
|
||||
state[handler.name] = e
|
||||
for stmt in handler.body:
|
||||
evaluate_ast(
|
||||
stmt, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
||||
break
|
||||
if not matched:
|
||||
raise e
|
||||
else:
|
||||
if try_node.orelse:
|
||||
for stmt in try_node.orelse:
|
||||
evaluate_ast(
|
||||
stmt, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
||||
finally:
|
||||
if try_node.finalbody:
|
||||
for stmt in try_node.finalbody:
|
||||
evaluate_ast(
|
||||
stmt, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
|
||||
|
||||
|
||||
def evaluate_raise(
|
||||
|
@ -974,15 +860,11 @@ def evaluate_raise(
|
|||
authorized_imports: List[str],
|
||||
) -> None:
|
||||
if raise_node.exc is not None:
|
||||
exc = evaluate_ast(
|
||||
raise_node.exc, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools, authorized_imports)
|
||||
else:
|
||||
exc = None
|
||||
if raise_node.cause is not None:
|
||||
cause = evaluate_ast(
|
||||
raise_node.cause, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools, authorized_imports)
|
||||
else:
|
||||
cause = None
|
||||
if exc is not None:
|
||||
|
@ -1001,14 +883,10 @@ def evaluate_assert(
|
|||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str],
|
||||
) -> None:
|
||||
test_result = evaluate_ast(
|
||||
assert_node.test, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools, authorized_imports)
|
||||
if not test_result:
|
||||
if assert_node.msg:
|
||||
msg = evaluate_ast(
|
||||
assert_node.msg, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools, authorized_imports)
|
||||
raise AssertionError(msg)
|
||||
else:
|
||||
# Include the failing condition in the assertion message
|
||||
|
@ -1025,9 +903,7 @@ def evaluate_with(
|
|||
) -> None:
|
||||
contexts = []
|
||||
for item in with_node.items:
|
||||
context_expr = evaluate_ast(
|
||||
item.context_expr, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools, authorized_imports)
|
||||
if item.optional_vars:
|
||||
state[item.optional_vars.id] = context_expr.__enter__()
|
||||
contexts.append(state[item.optional_vars.id])
|
||||
|
@ -1069,19 +945,14 @@ def get_safe_module(unsafe_module, dangerous_patterns, visited=None):
|
|||
# Copy all attributes by reference, recursively checking modules
|
||||
for attr_name in dir(unsafe_module):
|
||||
# Skip dangerous patterns at any level
|
||||
if any(
|
||||
pattern in f"{unsafe_module.__name__}.{attr_name}"
|
||||
for pattern in dangerous_patterns
|
||||
):
|
||||
if any(pattern in f"{unsafe_module.__name__}.{attr_name}" for pattern in dangerous_patterns):
|
||||
continue
|
||||
|
||||
attr_value = getattr(unsafe_module, attr_name)
|
||||
|
||||
# Recursively process nested modules, passing visited set
|
||||
if isinstance(attr_value, ModuleType):
|
||||
attr_value = get_safe_module(
|
||||
attr_value, dangerous_patterns, visited=visited
|
||||
)
|
||||
attr_value = get_safe_module(attr_value, dangerous_patterns, visited=visited)
|
||||
|
||||
setattr(safe_module, attr_name, attr_value)
|
||||
|
||||
|
@ -1116,18 +987,14 @@ def import_modules(expression, state, authorized_imports):
|
|||
module_path = module_name.split(".")
|
||||
if any([module in dangerous_patterns for module in module_path]):
|
||||
return False
|
||||
module_subpaths = [
|
||||
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
|
||||
]
|
||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||
|
||||
if isinstance(expression, ast.Import):
|
||||
for alias in expression.names:
|
||||
if check_module_authorized(alias.name):
|
||||
raw_module = import_module(alias.name)
|
||||
state[alias.asname or alias.name] = get_safe_module(
|
||||
raw_module, dangerous_patterns
|
||||
)
|
||||
state[alias.asname or alias.name] = get_safe_module(raw_module, dangerous_patterns)
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
||||
|
@ -1135,9 +1002,7 @@ def import_modules(expression, state, authorized_imports):
|
|||
return None
|
||||
elif isinstance(expression, ast.ImportFrom):
|
||||
if check_module_authorized(expression.module):
|
||||
raw_module = __import__(
|
||||
expression.module, fromlist=[alias.name for alias in expression.names]
|
||||
)
|
||||
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||
for alias in expression.names:
|
||||
state[alias.asname or alias.name] = get_safe_module(
|
||||
getattr(raw_module, alias.name), dangerous_patterns
|
||||
|
@ -1156,9 +1021,7 @@ def evaluate_dictcomp(
|
|||
) -> Dict[Any, Any]:
|
||||
result = {}
|
||||
for gen in dictcomp.generators:
|
||||
iter_value = evaluate_ast(
|
||||
gen.iter, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports)
|
||||
for value in iter_value:
|
||||
new_state = state.copy()
|
||||
set_value(
|
||||
|
@ -1170,9 +1033,7 @@ def evaluate_dictcomp(
|
|||
authorized_imports,
|
||||
)
|
||||
if all(
|
||||
evaluate_ast(
|
||||
if_clause, new_state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports)
|
||||
for if_clause in gen.ifs
|
||||
):
|
||||
key = evaluate_ast(
|
||||
|
@ -1229,202 +1090,116 @@ def evaluate_ast(
|
|||
if isinstance(expression, ast.Assign):
|
||||
# Assignment -> we evaluate the assignment which should update the state
|
||||
# We return the variable assigned as it may be used to determine the final result.
|
||||
return evaluate_assign(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_assign(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.AugAssign):
|
||||
return evaluate_augassign(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_augassign(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> we return the value of the function call
|
||||
return evaluate_call(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_call(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Tuple):
|
||||
return tuple(
|
||||
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
|
||||
for elt in expression.elts
|
||||
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts
|
||||
)
|
||||
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
||||
return evaluate_listcomp(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_listcomp(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.UnaryOp):
|
||||
return evaluate_unaryop(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_unaryop(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Starred):
|
||||
return evaluate_ast(
|
||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.BoolOp):
|
||||
# Boolean operation -> evaluate the operation
|
||||
return evaluate_boolop(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_boolop(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Break):
|
||||
raise BreakException()
|
||||
elif isinstance(expression, ast.Continue):
|
||||
raise ContinueException()
|
||||
elif isinstance(expression, ast.BinOp):
|
||||
# Binary operation -> execute operation
|
||||
return evaluate_binop(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_binop(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Compare):
|
||||
# Comparison -> evaluate the comparison
|
||||
return evaluate_condition(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_condition(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Lambda):
|
||||
return evaluate_lambda(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_lambda(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.FunctionDef):
|
||||
return evaluate_function_def(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_function_def(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
keys = [
|
||||
evaluate_ast(k, state, static_tools, custom_tools, authorized_imports)
|
||||
for k in expression.keys
|
||||
]
|
||||
values = [
|
||||
evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)
|
||||
for v in expression.values
|
||||
]
|
||||
keys = [evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) for v in expression.values]
|
||||
return dict(zip(keys, values))
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return evaluate_ast(
|
||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.For):
|
||||
# For loop -> execute the loop
|
||||
return evaluate_for(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_for(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content and return
|
||||
return evaluate_ast(
|
||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
return evaluate_if(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_if(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||
return evaluate_ast(
|
||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join(
|
||||
[
|
||||
str(
|
||||
evaluate_ast(
|
||||
v, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
)
|
||||
for v in expression.values
|
||||
]
|
||||
[str(evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)) for v in expression.values]
|
||||
)
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [
|
||||
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
|
||||
for elt in expression.elts
|
||||
]
|
||||
return [evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_name(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return evaluate_subscript(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_subscript(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.IfExp):
|
||||
test_val = evaluate_ast(
|
||||
expression.test, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools, authorized_imports)
|
||||
if test_val:
|
||||
return evaluate_ast(
|
||||
expression.body, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_ast(expression.body, state, static_tools, custom_tools, authorized_imports)
|
||||
else:
|
||||
return evaluate_ast(
|
||||
expression.orelse, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_ast(expression.orelse, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Attribute):
|
||||
value = evaluate_ast(
|
||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||
return getattr(value, expression.attr)
|
||||
elif isinstance(expression, ast.Slice):
|
||||
return slice(
|
||||
evaluate_ast(
|
||||
expression.lower, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(expression.lower, state, static_tools, custom_tools, authorized_imports)
|
||||
if expression.lower is not None
|
||||
else None,
|
||||
evaluate_ast(
|
||||
expression.upper, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(expression.upper, state, static_tools, custom_tools, authorized_imports)
|
||||
if expression.upper is not None
|
||||
else None,
|
||||
evaluate_ast(
|
||||
expression.step, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(expression.step, state, static_tools, custom_tools, authorized_imports)
|
||||
if expression.step is not None
|
||||
else None,
|
||||
)
|
||||
elif isinstance(expression, ast.DictComp):
|
||||
return evaluate_dictcomp(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_dictcomp(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.While):
|
||||
return evaluate_while(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_while(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
|
||||
return import_modules(expression, state, authorized_imports)
|
||||
elif isinstance(expression, ast.ClassDef):
|
||||
return evaluate_class_def(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_class_def(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Try):
|
||||
return evaluate_try(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_try(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Raise):
|
||||
return evaluate_raise(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_raise(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Assert):
|
||||
return evaluate_assert(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_assert(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.With):
|
||||
return evaluate_with(
|
||||
expression, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
return evaluate_with(expression, state, static_tools, custom_tools, authorized_imports)
|
||||
elif isinstance(expression, ast.Set):
|
||||
return {
|
||||
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
|
||||
for elt in expression.elts
|
||||
}
|
||||
return {evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts}
|
||||
elif isinstance(expression, ast.Return):
|
||||
raise ReturnException(
|
||||
evaluate_ast(
|
||||
expression.value, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
|
||||
if expression.value
|
||||
else None
|
||||
)
|
||||
|
@ -1488,18 +1263,12 @@ def evaluate_python_code(
|
|||
|
||||
try:
|
||||
for node in expression.body:
|
||||
result = evaluate_ast(
|
||||
node, state, static_tools, custom_tools, authorized_imports
|
||||
)
|
||||
state["print_outputs"] = truncate_content(
|
||||
PRINT_OUTPUTS, max_length=max_print_outputs_length
|
||||
)
|
||||
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
|
||||
is_final_answer = False
|
||||
return result, is_final_answer
|
||||
except FinalAnswerException as e:
|
||||
state["print_outputs"] = truncate_content(
|
||||
PRINT_OUTPUTS, max_length=max_print_outputs_length
|
||||
)
|
||||
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
|
||||
is_final_answer = True
|
||||
return e.value, is_final_answer
|
||||
except InterpreterError as e:
|
||||
|
@ -1521,9 +1290,7 @@ class LocalPythonInterpreter:
|
|||
if max_print_outputs_length is None:
|
||||
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
|
||||
self.additional_authorized_imports = additional_authorized_imports
|
||||
self.authorized_imports = list(
|
||||
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
|
||||
)
|
||||
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports))
|
||||
# Add base trusted tools to list
|
||||
self.static_tools = {
|
||||
**tools,
|
||||
|
@ -1531,9 +1298,7 @@ class LocalPythonInterpreter:
|
|||
}
|
||||
# TODO: assert self.authorized imports are all installed locally
|
||||
|
||||
def __call__(
|
||||
self, code_action: str, additional_variables: Dict
|
||||
) -> Tuple[Any, str, bool]:
|
||||
def __call__(self, code_action: str, additional_variables: Dict) -> Tuple[Any, str, bool]:
|
||||
self.state.update(additional_variables)
|
||||
output, is_final_answer = evaluate_python_code(
|
||||
code_action,
|
||||
|
|
|
@ -14,17 +14,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, asdict
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
|
@ -35,6 +34,7 @@ from transformers import (
|
|||
|
||||
from .tools import Tool
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||
|
@ -100,10 +100,7 @@ class ChatMessage:
|
|||
def from_hf_api(cls, message) -> "ChatMessage":
|
||||
tool_calls = None
|
||||
if getattr(message, "tool_calls", None) is not None:
|
||||
tool_calls = [
|
||||
ChatMessageToolCall.from_hf_api(tool_call)
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
tool_calls = [ChatMessageToolCall.from_hf_api(tool_call) for tool_call in message.tool_calls]
|
||||
return cls(role=message.role, content=message.content, tool_calls=tool_calls)
|
||||
|
||||
|
||||
|
@ -172,17 +169,12 @@ def get_clean_message_list(
|
|||
|
||||
role = message["role"]
|
||||
if role not in MessageRole.roles():
|
||||
raise ValueError(
|
||||
f"Incorrect role {role}, only {MessageRole.roles()} are supported for now."
|
||||
)
|
||||
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
|
||||
|
||||
if role in role_conversions:
|
||||
message["role"] = role_conversions[role]
|
||||
|
||||
if (
|
||||
len(final_message_list) > 0
|
||||
and message["role"] == final_message_list[-1]["role"]
|
||||
):
|
||||
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
|
||||
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
|
||||
else:
|
||||
final_message_list.append(message)
|
||||
|
@ -292,9 +284,7 @@ class HfApiModel(Model):
|
|||
Gets an LLM output message for the given list of input messages.
|
||||
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
|
||||
"""
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
||||
if tools_to_call_from:
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
|
@ -367,9 +357,7 @@ class TransformersModel(Model):
|
|||
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||
if model_id is None:
|
||||
model_id = default_model_id
|
||||
logger.warning(
|
||||
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
|
||||
)
|
||||
logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
|
||||
self.model_id = model_id
|
||||
self.kwargs = kwargs
|
||||
if device_map is None:
|
||||
|
@ -389,9 +377,7 @@ class TransformersModel(Model):
|
|||
)
|
||||
self.model_id = default_model_id
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, device_map=device_map, torch_dtype=torch_dtype
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, torch_dtype=torch_dtype)
|
||||
|
||||
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
|
||||
class StopOnStrings(StoppingCriteria):
|
||||
|
@ -404,16 +390,9 @@ class TransformersModel(Model):
|
|||
self.stream = ""
|
||||
|
||||
def __call__(self, input_ids, scores, **kwargs):
|
||||
generated = self.tokenizer.decode(
|
||||
input_ids[0][-1], skip_special_tokens=True
|
||||
)
|
||||
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
|
||||
self.stream += generated
|
||||
if any(
|
||||
[
|
||||
self.stream.endswith(stop_string)
|
||||
for stop_string in self.stop_strings
|
||||
]
|
||||
):
|
||||
if any([self.stream.endswith(stop_string) for stop_string in self.stop_strings]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -426,9 +405,7 @@ class TransformersModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatMessage:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
||||
if tools_to_call_from is not None:
|
||||
prompt_tensor = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
|
@ -448,9 +425,7 @@ class TransformersModel(Model):
|
|||
|
||||
out = self.model.generate(
|
||||
**prompt_tensor,
|
||||
stopping_criteria=(
|
||||
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
|
||||
),
|
||||
stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None),
|
||||
**self.kwargs,
|
||||
)
|
||||
generated_tokens = out[0, count_prompt_tokens:]
|
||||
|
@ -475,9 +450,7 @@ class TransformersModel(Model):
|
|||
ChatMessageToolCall(
|
||||
id="".join(random.choices("0123456789", k=5)),
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name=tool_name, arguments=tool_arguments
|
||||
),
|
||||
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -525,9 +498,7 @@ class LiteLLMModel(Model):
|
|||
grammar: Optional[str] = None,
|
||||
tools_to_call_from: Optional[List[Tool]] = None,
|
||||
) -> ChatMessage:
|
||||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
messages = get_clean_message_list(messages, role_conversions=tool_role_conversions)
|
||||
import litellm
|
||||
|
||||
if tools_to_call_from:
|
||||
|
@ -604,11 +575,7 @@ class OpenAIServerModel(Model):
|
|||
) -> ChatMessage:
|
||||
messages = get_clean_message_list(
|
||||
messages,
|
||||
role_conversions=(
|
||||
self.custom_role_conversions
|
||||
if self.custom_role_conversions
|
||||
else tool_role_conversions
|
||||
),
|
||||
role_conversions=(self.custom_role_conversions if self.custom_role_conversions else tool_role_conversions),
|
||||
)
|
||||
if tools_to_call_from:
|
||||
response = self.client.chat.completions.create(
|
||||
|
|
|
@ -22,10 +22,7 @@ class Monitor:
|
|||
self.step_durations = []
|
||||
self.tracked_model = tracked_model
|
||||
self.logger = logger
|
||||
if (
|
||||
getattr(self.tracked_model, "last_input_token_count", "Not found")
|
||||
!= "Not found"
|
||||
):
|
||||
if getattr(self.tracked_model, "last_input_token_count", "Not found") != "Not found":
|
||||
self.total_input_token_count = 0
|
||||
self.total_output_token_count = 0
|
||||
|
||||
|
@ -48,7 +45,9 @@ class Monitor:
|
|||
if getattr(self.tracked_model, "last_input_token_count", None) is not None:
|
||||
self.total_input_token_count += self.tracked_model.last_input_token_count
|
||||
self.total_output_token_count += self.tracked_model.last_output_token_count
|
||||
console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
|
||||
console_outputs += (
|
||||
f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
|
||||
)
|
||||
console_outputs += "]"
|
||||
self.logger.log(Text(console_outputs, style="dim"), level=1)
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Set
|
|||
|
||||
from .utils import BASE_BUILTIN_MODULES
|
||||
|
||||
|
||||
_BUILTIN_NAMES = set(vars(builtins))
|
||||
|
||||
|
||||
|
@ -141,9 +142,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
|||
# Check that __init__ method takes no arguments
|
||||
if not cls.__init__.__qualname__ == "Tool.__init__":
|
||||
sig = inspect.signature(cls.__init__)
|
||||
non_self_params = list(
|
||||
[arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]
|
||||
)
|
||||
non_self_params = list([arg_name for arg_name in sig.parameters.keys() if arg_name != "self"])
|
||||
if len(non_self_params) > 0:
|
||||
errors.append(
|
||||
f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!"
|
||||
|
@ -174,9 +173,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
|||
|
||||
# Check if the assignment is more complex than simple literals
|
||||
if not all(
|
||||
isinstance(
|
||||
val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)
|
||||
)
|
||||
isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set))
|
||||
for val in ast.walk(node.value)
|
||||
):
|
||||
for target in node.targets:
|
||||
|
@ -195,9 +192,7 @@ def validate_tool_attributes(cls, check_imports: bool = True) -> None:
|
|||
# Run checks on all methods
|
||||
for node in class_node.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
method_checker = MethodChecker(
|
||||
class_level_checker.class_attributes, check_imports=check_imports
|
||||
)
|
||||
method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports)
|
||||
method_checker.visit(node)
|
||||
errors += [f"- {node.name}: {error}" for error in method_checker.errors]
|
||||
|
||||
|
|
|
@ -36,7 +36,6 @@ from huggingface_hub import (
|
|||
upload_folder,
|
||||
)
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from packaging import version
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from transformers.utils import (
|
||||
|
@ -52,6 +51,7 @@ from .tool_validation import MethodChecker, validate_tool_attributes
|
|||
from .types import ImageType, handle_agent_input_types, handle_agent_output_types
|
||||
from .utils import instance_to_source
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_accelerate_available():
|
||||
|
@ -77,9 +77,7 @@ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
|||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
||||
return "model"
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
|
||||
)
|
||||
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
|
||||
except Exception:
|
||||
return "model"
|
||||
except Exception:
|
||||
|
@ -109,9 +107,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
|||
properties[param_name]["nullable"] = True
|
||||
for param_name in signature.parameters.keys():
|
||||
if signature.parameters[param_name].default != inspect.Parameter.empty:
|
||||
if (
|
||||
param_name not in properties
|
||||
): # this can happen if the param has no type hint but a default value
|
||||
if param_name not in properties: # this can happen if the param has no type hint but a default value
|
||||
properties[param_name] = {"nullable": True}
|
||||
return properties
|
||||
|
||||
|
@ -181,9 +177,7 @@ class Tool:
|
|||
f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
|
||||
)
|
||||
for input_name, input_content in self.inputs.items():
|
||||
assert isinstance(input_content, dict), (
|
||||
f"Input '{input_name}' should be a dictionary."
|
||||
)
|
||||
assert isinstance(input_content, dict), f"Input '{input_name}' should be a dictionary."
|
||||
assert "type" in input_content and "description" in input_content, (
|
||||
f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
|
||||
)
|
||||
|
@ -348,15 +342,7 @@ class Tool:
|
|||
imports = []
|
||||
for module in [tool_file]:
|
||||
imports.extend(get_imports(module))
|
||||
imports = list(
|
||||
set(
|
||||
[
|
||||
el
|
||||
for el in imports + ["smolagents"]
|
||||
if el not in sys.stdlib_module_names
|
||||
]
|
||||
)
|
||||
)
|
||||
imports = list(set([el for el in imports + ["smolagents"] if el not in sys.stdlib_module_names]))
|
||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(imports) + "\n")
|
||||
|
||||
|
@ -410,9 +396,7 @@ class Tool:
|
|||
print(work_dir)
|
||||
with open(work_dir + "/tool.py", "r") as f:
|
||||
print("\n".join(f.readlines()))
|
||||
logger.info(
|
||||
f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
|
||||
)
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
||||
return upload_folder(
|
||||
repo_id=repo_id,
|
||||
commit_message=commit_message,
|
||||
|
@ -592,9 +576,7 @@ class Tool:
|
|||
self.name = name
|
||||
self.description = description
|
||||
self.client = Client(space_id, hf_token=token)
|
||||
space_description = self.client.view_api(
|
||||
return_format="dict", print_info=False
|
||||
)["named_endpoints"]
|
||||
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]
|
||||
|
||||
# If api_name is not defined, take the first of the available APIs for this space
|
||||
if api_name is None:
|
||||
|
@ -607,9 +589,7 @@ class Tool:
|
|||
try:
|
||||
space_description_api = space_description[api_name]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"Could not find specified {api_name=} among available api names."
|
||||
)
|
||||
raise KeyError(f"Could not find specified {api_name=} among available api names.")
|
||||
|
||||
self.inputs = {}
|
||||
for parameter in space_description_api["parameters"]:
|
||||
|
@ -683,8 +663,7 @@ class Tool:
|
|||
self._gradio_tool = _gradio_tool
|
||||
func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
|
||||
self.inputs = {
|
||||
key: {"type": CONVERSION_DICT[value.annotation], "description": ""}
|
||||
for key, value in func_args
|
||||
key: {"type": CONVERSION_DICT[value.annotation], "description": ""} for key, value in func_args
|
||||
}
|
||||
self.forward = self._gradio_tool.run
|
||||
|
||||
|
@ -726,9 +705,7 @@ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
|
|||
"""
|
||||
|
||||
|
||||
def get_tool_description_with_args(
|
||||
tool: Tool, description_template: Optional[str] = None
|
||||
) -> str:
|
||||
def get_tool_description_with_args(tool: Tool, description_template: Optional[str] = None) -> str:
|
||||
if description_template is None:
|
||||
description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
compiled_template = compile_jinja_template(description_template)
|
||||
|
@ -748,10 +725,7 @@ def compile_jinja_template(template):
|
|||
raise ImportError("template requires jinja2 to be installed.")
|
||||
|
||||
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
||||
raise ImportError(
|
||||
"template requires jinja2>=3.1.0 to be installed. Your version is "
|
||||
f"{jinja2.__version__}."
|
||||
)
|
||||
raise ImportError(f"template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}.")
|
||||
|
||||
def raise_exception(message):
|
||||
raise TemplateError(message)
|
||||
|
@ -772,9 +746,7 @@ def launch_gradio_demo(tool: Tool):
|
|||
try:
|
||||
import gradio as gr
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Gradio should be installed in order to launch a gradio demo."
|
||||
)
|
||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||
|
||||
TYPE_TO_COMPONENT_CLASS_MAPPING = {
|
||||
"image": gr.Image,
|
||||
|
@ -791,9 +763,7 @@ def launch_gradio_demo(tool: Tool):
|
|||
|
||||
gradio_inputs = []
|
||||
for input_name, input_details in tool.inputs.items():
|
||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
|
||||
input_details["type"]
|
||||
]
|
||||
input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[input_details["type"]]
|
||||
new_component = input_gradio_component_class(label=input_name)
|
||||
gradio_inputs.append(new_component)
|
||||
|
||||
|
@ -922,14 +892,9 @@ class ToolCollection:
|
|||
```
|
||||
"""
|
||||
_collection = get_collection(collection_slug, token=token)
|
||||
_hub_repo_ids = {
|
||||
item.item_id for item in _collection.items if item.item_type == "space"
|
||||
}
|
||||
_hub_repo_ids = {item.item_id for item in _collection.items if item.item_type == "space"}
|
||||
|
||||
tools = {
|
||||
Tool.from_hub(repo_id, token, trust_remote_code)
|
||||
for repo_id in _hub_repo_ids
|
||||
}
|
||||
tools = {Tool.from_hub(repo_id, token, trust_remote_code) for repo_id in _hub_repo_ids}
|
||||
|
||||
return cls(tools)
|
||||
|
||||
|
@ -986,9 +951,7 @@ def tool(tool_function: Callable) -> Tool:
|
|||
"""
|
||||
parameters = get_json_schema(tool_function)["function"]
|
||||
if "return" not in parameters:
|
||||
raise TypeHintParsingException(
|
||||
"Tool return type not found: make sure your function has a return type hint!"
|
||||
)
|
||||
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
||||
|
||||
class SimpleTool(Tool):
|
||||
def __init__(self, name, description, inputs, output_type, function):
|
||||
|
@ -1007,9 +970,9 @@ def tool(tool_function: Callable) -> Tool:
|
|||
function=tool_function,
|
||||
)
|
||||
original_signature = inspect.signature(tool_function)
|
||||
new_parameters = [
|
||||
inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)
|
||||
] + list(original_signature.parameters.values())
|
||||
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY)] + list(
|
||||
original_signature.parameters.values()
|
||||
)
|
||||
new_signature = original_signature.replace(parameters=new_parameters)
|
||||
simple_tool.forward.__signature__ = new_signature
|
||||
return simple_tool
|
||||
|
@ -1082,9 +1045,7 @@ class PipelineTool(Tool):
|
|||
|
||||
if model is None:
|
||||
if self.default_checkpoint is None:
|
||||
raise ValueError(
|
||||
"This tool does not implement a default checkpoint, you need to pass one."
|
||||
)
|
||||
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
||||
model = self.default_checkpoint
|
||||
if pre_processor is None:
|
||||
pre_processor = model
|
||||
|
@ -1107,21 +1068,15 @@ class PipelineTool(Tool):
|
|||
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
||||
"""
|
||||
if isinstance(self.pre_processor, str):
|
||||
self.pre_processor = self.pre_processor_class.from_pretrained(
|
||||
self.pre_processor, **self.hub_kwargs
|
||||
)
|
||||
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
||||
|
||||
if isinstance(self.model, str):
|
||||
self.model = self.model_class.from_pretrained(
|
||||
self.model, **self.model_kwargs, **self.hub_kwargs
|
||||
)
|
||||
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
||||
|
||||
if self.post_processor is None:
|
||||
self.post_processor = self.pre_processor
|
||||
elif isinstance(self.post_processor, str):
|
||||
self.post_processor = self.post_processor_class.from_pretrained(
|
||||
self.post_processor, **self.hub_kwargs
|
||||
)
|
||||
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
||||
|
||||
if self.device is None:
|
||||
if self.device_map is not None:
|
||||
|
@ -1165,12 +1120,8 @@ class PipelineTool(Tool):
|
|||
|
||||
encoded_inputs = self.encode(*args, **kwargs)
|
||||
|
||||
tensor_inputs = {
|
||||
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
|
||||
}
|
||||
non_tensor_inputs = {
|
||||
k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)
|
||||
}
|
||||
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
|
||||
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
|
||||
|
||||
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
||||
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
||||
|
|
|
@ -27,6 +27,7 @@ from transformers.utils import (
|
|||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_vision_available():
|
||||
|
@ -113,9 +114,7 @@ class AgentImage(AgentType, ImageType):
|
|||
elif isinstance(value, np.ndarray):
|
||||
self._tensor = torch.from_numpy(value)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported type for {self.__class__.__name__}: {type(value)}"
|
||||
)
|
||||
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
||||
|
||||
def _ipython_display_(self, include=None, exclude=None):
|
||||
"""
|
||||
|
@ -264,9 +263,7 @@ if is_torch_available():
|
|||
|
||||
def handle_agent_input_types(*args, **kwargs):
|
||||
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
||||
kwargs = {
|
||||
k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
|
||||
}
|
||||
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
|
||||
return args, kwargs
|
||||
|
||||
|
||||
|
@ -279,9 +276,7 @@ def handle_agent_output_types(output, output_type=None):
|
|||
# If the class does not have defined output, then we map according to the type
|
||||
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||
if isinstance(output, _k):
|
||||
if (
|
||||
_k is not object
|
||||
): # avoid converting to audio if torch is not installed
|
||||
if _k is not object: # avoid converting to audio if torch is not installed
|
||||
return _v(output)
|
||||
return output
|
||||
|
||||
|
|
|
@ -83,9 +83,7 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
|||
try:
|
||||
first_accolade_index = json_blob.find("{")
|
||||
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
|
||||
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace(
|
||||
'\\"', "'"
|
||||
)
|
||||
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
|
||||
json_data = json.loads(json_blob, strict=False)
|
||||
return json_data
|
||||
except json.JSONDecodeError as e:
|
||||
|
@ -162,9 +160,7 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
|
|||
MAX_LENGTH_TRUNCATE_CONTENT = 20000
|
||||
|
||||
|
||||
def truncate_content(
|
||||
content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT
|
||||
) -> str:
|
||||
def truncate_content(content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT) -> str:
|
||||
if len(content) <= max_length:
|
||||
return content
|
||||
else:
|
||||
|
@ -206,12 +202,8 @@ def is_same_method(method1, method2):
|
|||
source2 = get_method_source(method2)
|
||||
|
||||
# Remove method decorators if any
|
||||
source1 = "\n".join(
|
||||
line for line in source1.split("\n") if not line.strip().startswith("@")
|
||||
)
|
||||
source2 = "\n".join(
|
||||
line for line in source2.split("\n") if not line.strip().startswith("@")
|
||||
)
|
||||
source1 = "\n".join(line for line in source1.split("\n") if not line.strip().startswith("@"))
|
||||
source2 = "\n".join(line for line in source2.split("\n") if not line.strip().startswith("@"))
|
||||
|
||||
return source1 == source2
|
||||
except (TypeError, OSError):
|
||||
|
@ -248,9 +240,7 @@ def instance_to_source(instance, base_cls=None):
|
|||
for name, value in cls.__dict__.items()
|
||||
if not name.startswith("__")
|
||||
and not callable(value)
|
||||
and not (
|
||||
base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value
|
||||
)
|
||||
and not (base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value)
|
||||
}
|
||||
|
||||
for name, value in class_attrs.items():
|
||||
|
@ -271,9 +261,7 @@ def instance_to_source(instance, base_cls=None):
|
|||
for name, func in cls.__dict__.items()
|
||||
if callable(func)
|
||||
and not (
|
||||
base_cls
|
||||
and hasattr(base_cls, name)
|
||||
and getattr(base_cls, name).__code__.co_code == func.__code__.co_code
|
||||
base_cls and hasattr(base_cls, name) and getattr(base_cls, name).__code__.co_code == func.__code__.co_code
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -284,9 +272,7 @@ def instance_to_source(instance, base_cls=None):
|
|||
first_line = method_lines[0]
|
||||
indent = len(first_line) - len(first_line.lstrip())
|
||||
method_lines = [line[indent:] for line in method_lines]
|
||||
method_source = "\n".join(
|
||||
[" " + line if line.strip() else line for line in method_lines]
|
||||
)
|
||||
method_source = "\n".join([" " + line if line.strip() else line for line in method_lines])
|
||||
class_lines.append(method_source)
|
||||
class_lines.append("")
|
||||
|
||||
|
|
|
@ -28,13 +28,13 @@ from smolagents.agents import (
|
|||
ToolCallingAgent,
|
||||
)
|
||||
from smolagents.default_tools import PythonInterpreterTool
|
||||
from smolagents.tools import tool
|
||||
from smolagents.types import AgentImage, AgentText
|
||||
from smolagents.models import (
|
||||
ChatMessage,
|
||||
ChatMessageToolCall,
|
||||
ChatMessageToolCallDefinition,
|
||||
)
|
||||
from smolagents.tools import tool
|
||||
from smolagents.types import AgentImage, AgentText
|
||||
from smolagents.utils import BASE_BUILTIN_MODULES
|
||||
|
||||
|
||||
|
@ -44,9 +44,7 @@ def get_new_path(suffix="") -> str:
|
|||
|
||||
|
||||
class FakeToolCallModel:
|
||||
def __call__(
|
||||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||
):
|
||||
def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None):
|
||||
if len(messages) < 3:
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
|
@ -69,18 +67,14 @@ class FakeToolCallModel:
|
|||
ChatMessageToolCall(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer", arguments={"answer": "7.2904"}
|
||||
),
|
||||
function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "7.2904"}),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FakeToolCallModelImage:
|
||||
def __call__(
|
||||
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
|
||||
):
|
||||
def __call__(self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None):
|
||||
if len(messages) < 3:
|
||||
return ChatMessage(
|
||||
role="assistant",
|
||||
|
@ -104,9 +98,7 @@ class FakeToolCallModelImage:
|
|||
ChatMessageToolCall(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer", arguments="image.png"
|
||||
),
|
||||
function=ChatMessageToolCallDefinition(name="final_answer", arguments="image.png"),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -271,17 +263,13 @@ print(result)
|
|||
|
||||
class AgentTests(unittest.TestCase):
|
||||
def test_fake_single_step_code_agent(self):
|
||||
agent = CodeAgent(
|
||||
tools=[PythonInterpreterTool()], model=fake_code_model_single_step
|
||||
)
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_single_step)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
|
||||
assert isinstance(output, str)
|
||||
assert "7.2904" in output
|
||||
|
||||
def test_fake_toolcalling_agent(self):
|
||||
agent = ToolCallingAgent(
|
||||
tools=[PythonInterpreterTool()], model=FakeToolCallModel()
|
||||
)
|
||||
agent = ToolCallingAgent(tools=[PythonInterpreterTool()], model=FakeToolCallModel())
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert "7.2904" in output
|
||||
|
@ -301,9 +289,7 @@ class AgentTests(unittest.TestCase):
|
|||
"""
|
||||
return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
|
||||
|
||||
agent = ToolCallingAgent(
|
||||
tools=[fake_image_generation_tool], model=FakeToolCallModelImage()
|
||||
)
|
||||
agent = ToolCallingAgent(tools=[fake_image_generation_tool], model=FakeToolCallModelImage())
|
||||
output = agent.run("Make me an image.")
|
||||
assert isinstance(output, AgentImage)
|
||||
assert isinstance(agent.state["image.png"], Image.Image)
|
||||
|
@ -315,9 +301,7 @@ class AgentTests(unittest.TestCase):
|
|||
assert output == 7.2904
|
||||
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
|
||||
assert agent.logs[3].tool_calls == [
|
||||
ToolCall(
|
||||
name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
|
||||
)
|
||||
ToolCall(name="python_interpreter", arguments="final_answer(7.2904)", id="call_3")
|
||||
]
|
||||
|
||||
def test_additional_args_added_to_task(self):
|
||||
|
@ -351,9 +335,7 @@ class AgentTests(unittest.TestCase):
|
|||
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
|
||||
|
||||
def test_code_agent_syntax_error_show_offending_lines(self):
|
||||
agent = CodeAgent(
|
||||
tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error
|
||||
)
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, AgentText)
|
||||
assert output == "got an error"
|
||||
|
@ -391,9 +373,7 @@ class AgentTests(unittest.TestCase):
|
|||
def test_init_agent_with_different_toolsets(self):
|
||||
toolset_1 = []
|
||||
agent = CodeAgent(tools=toolset_1, model=fake_code_model)
|
||||
assert (
|
||||
len(agent.tools) == 1
|
||||
) # when no tools are provided, only the final_answer tool is added by default
|
||||
assert len(agent.tools) == 1 # when no tools are provided, only the final_answer tool is added by default
|
||||
|
||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||
agent = CodeAgent(tools=toolset_2, model=fake_code_model)
|
||||
|
@ -436,9 +416,7 @@ class AgentTests(unittest.TestCase):
|
|||
assert "You can also give requests to team members." not in agent.system_prompt
|
||||
print("ok1")
|
||||
assert "{{managed_agents_descriptions}}" not in agent.system_prompt
|
||||
assert (
|
||||
"You can also give requests to team members." in manager_agent.system_prompt
|
||||
)
|
||||
assert "You can also give requests to team members." in manager_agent.system_prompt
|
||||
|
||||
def test_code_agent_missing_import_triggers_advice_in_error_log(self):
|
||||
agent = CodeAgent(tools=[], model=fake_code_model_import)
|
||||
|
|
|
@ -136,9 +136,7 @@ class TestDocs:
|
|||
try:
|
||||
code_blocks = [
|
||||
(
|
||||
block.replace(
|
||||
"<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN")
|
||||
)
|
||||
block.replace("<YOUR_HUGGINGFACEHUB_API_TOKEN>", os.getenv("HF_TOKEN"))
|
||||
.replace("YOUR_ANTHROPIC_API_KEY", os.getenv("ANTHROPIC_API_KEY"))
|
||||
.replace("{your_username}", "m-ric")
|
||||
)
|
||||
|
@ -150,9 +148,7 @@ class TestDocs:
|
|||
except SubprocessCallException as e:
|
||||
pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
|
||||
except Exception:
|
||||
pytest.fail(
|
||||
f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}"
|
||||
)
|
||||
pytest.fail(f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self):
|
||||
|
@ -174,6 +170,4 @@ def pytest_generate_tests(metafunc):
|
|||
test_class.setup_class()
|
||||
|
||||
# Parameterize with the markdown files
|
||||
metafunc.parametrize(
|
||||
"doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files]
|
||||
)
|
||||
metafunc.parametrize("doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files])
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
|
||||
|
@ -23,14 +24,10 @@ from .test_tools import ToolTesterMixin
|
|||
|
||||
class DefaultToolTests(unittest.TestCase):
|
||||
def test_visit_webpage(self):
|
||||
arguments = {
|
||||
"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"
|
||||
}
|
||||
arguments = {"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"}
|
||||
result = VisitWebpageTool()(arguments)
|
||||
assert isinstance(result, str)
|
||||
assert (
|
||||
"* [About Wikipedia](/wiki/Wikipedia:About)" in result
|
||||
) # Proper wikipedia pages have an About
|
||||
assert "* [About Wikipedia](/wiki/Wikipedia:About)" in result # Proper wikipedia pages have an About
|
||||
|
||||
|
||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
|
@ -59,12 +56,7 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append(
|
||||
[
|
||||
AGENT_TYPE_MAPPING[_input_type](_input)
|
||||
for _input_type in input_type
|
||||
]
|
||||
)
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ from smolagents.types import AGENT_TYPE_MAPPING
|
|||
|
||||
from .test_tools import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
@ -45,11 +46,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
|||
|
||||
def create_inputs(self):
|
||||
inputs_text = {"answer": "Text input"}
|
||||
inputs_image = {
|
||||
"answer": Image.open(
|
||||
Path(get_tests_dir("fixtures")) / "000000039769.png"
|
||||
).resize((512, 512))
|
||||
}
|
||||
inputs_image = {"answer": Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512))}
|
||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||
|
||||
|
|
|
@ -12,11 +12,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
import json
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel
|
||||
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
||||
|
||||
|
||||
class ModelTests(unittest.TestCase):
|
||||
|
@ -33,12 +33,7 @@ class ModelTests(unittest.TestCase):
|
|||
"""
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
assert (
|
||||
"nullable"
|
||||
in models.get_json_schema(get_weather)["function"]["parameters"][
|
||||
"properties"
|
||||
]["celsius"]
|
||||
)
|
||||
assert "nullable" in models.get_json_schema(get_weather)["function"]["parameters"]["properties"]["celsius"]
|
||||
|
||||
def test_chatmessage_has_model_dumps_json(self):
|
||||
message = ChatMessage("user", "Hello!")
|
||||
|
|
|
@ -43,9 +43,7 @@ class FakeLLMModel:
|
|||
ChatMessageToolCall(
|
||||
id="fake_id",
|
||||
type="function",
|
||||
function=ChatMessageToolCallDefinition(
|
||||
name="final_answer", arguments={"answer": "image"}
|
||||
),
|
||||
function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "image"}),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -122,9 +120,7 @@ class MonitoringTester(unittest.TestCase):
|
|||
)
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(
|
||||
agent.monitor.total_input_token_count, 20
|
||||
) # Should have done two monitoring callbacks
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 20) # Should have done two monitoring callbacks
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 0)
|
||||
|
||||
def test_streaming_agent_text_output(self):
|
||||
|
|
|
@ -55,10 +55,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||
code = "print = '3'"
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, {"print": print}, state={})
|
||||
assert (
|
||||
"Cannot assign to name 'print': doing this would erase the existing tool!"
|
||||
in str(e)
|
||||
)
|
||||
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
|
||||
|
||||
def test_subscript_call(self):
|
||||
code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
|
||||
|
@ -92,9 +89,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||
state = {"x": 3}
|
||||
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(
|
||||
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
||||
)
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
|
@ -110,9 +105,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||
result, _ = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(
|
||||
state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""}
|
||||
)
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
|
@ -153,15 +146,11 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||
state = {"x": 3}
|
||||
result, _ = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(
|
||||
state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}
|
||||
)
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
||||
state = {}
|
||||
evaluate_python_code(
|
||||
code, {"min": min, "print": print, "round": round}, state=state
|
||||
)
|
||||
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
||||
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
||||
|
||||
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
||||
|
@ -317,9 +306,7 @@ print(check_digits)
|
|||
assert result == {0: 0, 1: 1, 2: 4}
|
||||
|
||||
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
||||
result, _ = evaluate_python_code(
|
||||
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
||||
)
|
||||
result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||
assert result == {102: "b"}
|
||||
|
||||
code = """
|
||||
|
@ -367,9 +354,7 @@ else:
|
|||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result, _ = evaluate_python_code(
|
||||
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
||||
)
|
||||
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Brooklyn"
|
||||
|
||||
code = """if d > e and a < b:
|
||||
|
@ -380,9 +365,7 @@ else:
|
|||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result, _ = evaluate_python_code(
|
||||
code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
|
||||
)
|
||||
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Sacramento"
|
||||
|
||||
def test_if_conditions(self):
|
||||
|
@ -398,9 +381,7 @@ if char.isalpha():
|
|||
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.0
|
||||
|
||||
code = (
|
||||
"from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||
)
|
||||
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "lose"
|
||||
|
||||
|
@ -434,14 +415,10 @@ if char.isalpha():
|
|||
|
||||
# Test submodules are handled properly, thus not raising error
|
||||
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
||||
result, _ = evaluate_python_code(
|
||||
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
||||
)
|
||||
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||
|
||||
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
||||
result, _ = evaluate_python_code(
|
||||
code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"]
|
||||
)
|
||||
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||
|
||||
def test_additional_imports(self):
|
||||
code = "import numpy as np"
|
||||
|
@ -613,9 +590,7 @@ except ValueError as e:
|
|||
def test_types_as_objects(self):
|
||||
code = "type_a = float(2); type_b = str; type_c = int"
|
||||
state = {}
|
||||
result, is_final_answer = evaluate_python_code(
|
||||
code, {"float": float, "str": str, "int": int}, state=state
|
||||
)
|
||||
result, is_final_answer = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
||||
assert result is int
|
||||
|
||||
def test_tuple_id(self):
|
||||
|
@ -733,9 +708,7 @@ while True:
|
|||
break
|
||||
|
||||
i"""
|
||||
result, is_final_answer = evaluate_python_code(
|
||||
code, {"print": print, "round": round}, state={}
|
||||
)
|
||||
result, is_final_answer = evaluate_python_code(code, {"print": print, "round": round}, state={})
|
||||
assert result == 3
|
||||
assert not is_final_answer
|
||||
|
||||
|
@ -781,9 +754,7 @@ out = [i for sublist in all_res for i in sublist]
|
|||
out[:10]
|
||||
"""
|
||||
state = {}
|
||||
result, is_final_answer = evaluate_python_code(
|
||||
code, {"print": print, "range": range}, state=state
|
||||
)
|
||||
result, is_final_answer = evaluate_python_code(code, {"print": print, "range": range}, state=state)
|
||||
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||
|
||||
def test_pandas(self):
|
||||
|
@ -798,9 +769,7 @@ parts_with_5_set_count = df[df['SetCount'] == 5.0]
|
|||
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
||||
"""
|
||||
state = {}
|
||||
result, _ = evaluate_python_code(
|
||||
code, {}, state=state, authorized_imports=["pandas"]
|
||||
)
|
||||
result, _ = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
|
||||
assert np.array_equal(result, [-1, 5])
|
||||
|
||||
code = """
|
||||
|
@ -811,9 +780,7 @@ df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
|
|||
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
||||
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
||||
"""
|
||||
result, _ = evaluate_python_code(
|
||||
code, {"print": print}, state={}, authorized_imports=["pandas"]
|
||||
)
|
||||
result, _ = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||
assert np.array_equal(result.values[0], [104, 1])
|
||||
|
||||
# Test groupby
|
||||
|
@ -825,9 +792,7 @@ data = pd.DataFrame.from_dict([
|
|||
])
|
||||
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
|
||||
"""
|
||||
result, _ = evaluate_python_code(
|
||||
code, {}, state={}, authorized_imports=["pandas"]
|
||||
)
|
||||
result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
||||
assert result.values[1] == 0.5
|
||||
|
||||
# Test loc and iloc
|
||||
|
@ -839,11 +804,9 @@ data = pd.DataFrame.from_dict([
|
|||
])
|
||||
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
|
||||
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
|
||||
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
|
||||
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
|
||||
"""
|
||||
result, _ = evaluate_python_code(
|
||||
code, {}, state={}, authorized_imports=["pandas"]
|
||||
)
|
||||
result, _ = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
||||
|
||||
def test_starred(self):
|
||||
code = """
|
||||
|
@ -864,9 +827,7 @@ coords_barcelona = (41.3869, 2.1660)
|
|||
|
||||
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
||||
"""
|
||||
result, _ = evaluate_python_code(
|
||||
code, {"print": print, "map": map}, state={}, authorized_imports=["math"]
|
||||
)
|
||||
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
|
||||
assert round(result, 1) == 622395.4
|
||||
|
||||
def test_for(self):
|
||||
|
|
|
@ -16,7 +16,7 @@ import unittest
|
|||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Dict, Optional, Union
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mcp
|
||||
import numpy as np
|
||||
|
@ -32,6 +32,7 @@ from smolagents.types import (
|
|||
AgentText,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
@ -48,9 +49,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
|||
if input_type == "string":
|
||||
inputs[input_name] = "Text input"
|
||||
elif input_type == "image":
|
||||
inputs[input_name] = Image.open(
|
||||
Path(get_tests_dir("fixtures")) / "000000039769.png"
|
||||
).resize((512, 512))
|
||||
inputs[input_name] = Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png").resize((512, 512))
|
||||
elif input_type == "audio":
|
||||
inputs[input_name] = np.ones(3000)
|
||||
else:
|
||||
|
@ -224,9 +223,7 @@ class ToolTests(unittest.TestCase):
|
|||
class FailTool(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {
|
||||
"string_input": {"type": "string", "description": "input description"}
|
||||
}
|
||||
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, url):
|
||||
|
@ -248,9 +245,7 @@ class ToolTests(unittest.TestCase):
|
|||
class FailTool(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {
|
||||
"string_input": {"type": "string", "description": "input description"}
|
||||
}
|
||||
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
||||
output_type = "string"
|
||||
|
||||
def useless_method(self):
|
||||
|
@ -269,9 +264,7 @@ class ToolTests(unittest.TestCase):
|
|||
class SuccessTool(Tool):
|
||||
name = "specific"
|
||||
description = "test description"
|
||||
inputs = {
|
||||
"string_input": {"type": "string", "description": "input description"}
|
||||
}
|
||||
inputs = {"string_input": {"type": "string", "description": "input description"}}
|
||||
output_type = "string"
|
||||
|
||||
def useless_method(self):
|
||||
|
@ -300,9 +293,7 @@ class ToolTests(unittest.TestCase):
|
|||
},
|
||||
}
|
||||
|
||||
def forward(
|
||||
self, location: str, celsius: Optional[bool] = False
|
||||
) -> str:
|
||||
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
GetWeatherTool()
|
||||
|
@ -340,9 +331,7 @@ class ToolTests(unittest.TestCase):
|
|||
}
|
||||
output_type = "string"
|
||||
|
||||
def forward(
|
||||
self, location: str, celsius: Optional[bool] = False
|
||||
) -> str:
|
||||
def forward(self, location: str, celsius: Optional[bool] = False) -> str:
|
||||
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
|
||||
|
||||
GetWeatherTool()
|
||||
|
@ -410,9 +399,7 @@ def mock_smolagents_adapter():
|
|||
|
||||
|
||||
class TestToolCollection:
|
||||
def test_from_mcp(
|
||||
self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter
|
||||
):
|
||||
def test_from_mcp(self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter):
|
||||
with ToolCollection.from_mcp(mock_server_parameters) as tool_collection:
|
||||
assert isinstance(tool_collection, ToolCollection)
|
||||
assert len(tool_collection.tools) == 2
|
||||
|
@ -440,9 +427,5 @@ class TestToolCollection:
|
|||
|
||||
with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
|
||||
assert len(tool_collection.tools) == 1, "Expected 1 tool"
|
||||
assert tool_collection.tools[0].name == "echo_tool", (
|
||||
"Expected tool name to be 'echo_tool'"
|
||||
)
|
||||
assert tool_collection.tools[0](text="Hello") == "Hello", (
|
||||
"Expected tool to echo the input text"
|
||||
)
|
||||
assert tool_collection.tools[0].name == "echo_tool", "Expected tool name to be 'echo_tool'"
|
||||
assert tool_collection.tools[0](text="Hello") == "Hello", "Expected tool to echo the input text"
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from smolagents.utils import parse_code_blobs
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
|
||||
TESTS_FOLDER = ROOT / "tests"
|
||||
|
@ -37,11 +38,7 @@ def check_tests_in_ci():
|
|||
if path.name.startswith("test_")
|
||||
]
|
||||
ci_workflow_file_content = CI_WORKFLOW_FILE.read_text()
|
||||
missing_test_files = [
|
||||
test_file
|
||||
for test_file in test_files
|
||||
if test_file not in ci_workflow_file_content
|
||||
]
|
||||
missing_test_files = [test_file for test_file in test_files if test_file not in ci_workflow_file_content]
|
||||
if missing_test_files:
|
||||
print(
|
||||
"❌ Some test files seem to be ignored in the CI:\n"
|
||||
|
|
Loading…
Reference in New Issue