Make default tools more robust (#186)

This commit is contained in:
Aymeric Roucher 2025-01-14 14:57:11 +01:00 committed by GitHub
parent 12a2e6f4b4
commit 5f32373551
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 296 additions and 356 deletions

View File

@ -36,6 +36,9 @@ jobs:
- name: Agent tests
run: |
uv run pytest -sv ./tests/test_agents.py
- name: Default tools tests
run: |
uv run pytest -sv ./tests/test_default_tools.py
- name: Final answer tests
run: |
uv run pytest -sv ./tests/test_final_answer.py

View File

@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -29,8 +29,7 @@
"output_type": "stream",
"text": [
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Generating train split: 100%|██████████| 132/132 [00:00<00:00, 17393.36 examples/s]\n"
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
@ -173,7 +172,7 @@
"[132 rows x 4 columns]"
]
},
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -196,19 +195,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n",
"* 'fields' has been removed\n",
" warnings.warn(message, UserWarning)\n"
]
}
],
"outputs": [],
"source": [
"import time\n",
"import json\n",
@ -408,100 +397,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 27061.35it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 34618.15it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'Qwen/Qwen2.5-72B-Instruct'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 33008.29it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 36292.90it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'Qwen/Qwen2.5-Coder-32B-Instruct'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 29165.47it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 30378.50it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'meta-llama/Llama-3.2-3B-Instruct'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 33453.06it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 34763.79it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'meta-llama/Llama-3.1-8B-Instruct'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 35246.25it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 28551.81it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'mistralai/Mistral-Nemo-Instruct-2407'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 32441.59it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 35542.67it/s]\n"
]
}
],
"outputs": [],
"source": [
"open_model_ids = [\n",
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
@ -554,42 +452,9 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'gpt-4o'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 36136.55it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 33451.04it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 39146.44it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating 'anthropic/claude-3-5-sonnet-latest'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 132/132 [00:00<00:00, 31512.79it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 33576.82it/s]\n",
"100%|██████████| 132/132 [00:00<00:00, 36075.33it/s]\n"
]
}
],
"outputs": [],
"source": [
"from smolagents import LiteLLMModel\n",
"\n",
@ -614,7 +479,7 @@
" agent = CodeAgent(\n",
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
" model=LiteLLMModel(model_id),\n",
" additional_authorized_imports=[\"numpy\"],\n",
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
@ -631,34 +496,39 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# import glob\n",
"# import json\n",
"\n",
"# jsonl_files = glob.glob(f\"output/*.jsonl\")\n",
"\n",
"# for file_path in jsonl_files:\n",
"# print(file_path)\n",
"# # Read all lines and filter out SimpleQA sources\n",
"# filtered_lines = []\n",
"# removed = 0\n",
"# with open(file_path, 'r', encoding='utf-8') as f:\n",
"# for line in f:\n",
"# try:\n",
"# data = json.loads(line.strip())\n",
"# if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
"# removed +=1\n",
"# else:\n",
"# filtered_lines.append(line)\n",
"# except json.JSONDecodeError:\n",
"# print(\"Invalid line:\", line)\n",
"# continue # Skip invalid JSON lines\n",
"# print(f\"Removed {removed} lines.\")\n",
"# # Write filtered content back to the same file\n",
"# with open(file_path, 'w', encoding='utf-8') as f:\n",
"# f.writelines(filtered_lines)"
"# if \"-Nemo-\" in file_path and \"-vanilla-\" in file_path:\n",
"# print(file_path)\n",
"# # Read all lines and filter out SimpleQA sources\n",
"# filtered_lines = []\n",
"# removed = 0\n",
"# with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
"# for line in f:\n",
"# try:\n",
"# data = json.loads(line.strip())\n",
"# data[\"answer\"] = data[\"answer\"][\"content\"]\n",
"# # if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
"# # removed +=1\n",
"# # else:\n",
"# filtered_lines.append(json.dumps(data) + \"\\n\")\n",
"# except json.JSONDecodeError:\n",
"# print(\"Invalid line:\", line)\n",
"# continue # Skip invalid JSON lines\n",
"# print(f\"Removed {removed} lines.\")\n",
"# # Write filtered content back to the same file\n",
"# with open(\n",
"# str(file_path).replace(\"-vanilla-\", \"-vanilla2-\"), \"w\", encoding=\"utf-8\"\n",
"# ) as f:\n",
"# f.writelines(filtered_lines)"
]
},
{
@ -670,14 +540,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_6037/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\n"
]
}
@ -731,7 +601,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -752,7 +622,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 11,
"metadata": {},
"outputs": [
{
@ -794,28 +664,28 @@
" <th>1</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>74.0</td>\n",
" <td>76.0</td>\n",
" <td>30.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>70.0</td>\n",
" <td>88.0</td>\n",
" <td>10.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>18.8</td>\n",
" <td>25.0</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>76.0</td>\n",
" <td>86.0</td>\n",
" <td>60.0</td>\n",
" </tr>\n",
" <tr>\n",
@ -829,63 +699,63 @@
" <th>6</th>\n",
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
" <td>GAIA</td>\n",
" <td>40.6</td>\n",
" <td>NaN</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
" <td>MATH</td>\n",
" <td>67.0</td>\n",
" <td>NaN</td>\n",
" <td>50.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
" <td>SimpleQA</td>\n",
" <td>90.0</td>\n",
" <td>NaN</td>\n",
" <td>34.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>gpt-4o</td>\n",
" <td>GAIA</td>\n",
" <td>28.1</td>\n",
" <td>25.6</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>gpt-4o</td>\n",
" <td>MATH</td>\n",
" <td>70.0</td>\n",
" <td>58.0</td>\n",
" <td>40.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>gpt-4o</td>\n",
" <td>SimpleQA</td>\n",
" <td>88.0</td>\n",
" <td>86.0</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>0.0</td>\n",
" <td>3.1</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>42.0</td>\n",
" <td>14.0</td>\n",
" <td>18.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>54.0</td>\n",
" <td>2.0</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
@ -899,49 +769,49 @@
" <th>16</th>\n",
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>32.0</td>\n",
" <td>40.0</td>\n",
" <td>12.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>4.0</td>\n",
" <td>20.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>34.4</td>\n",
" <td>31.2</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>82.0</td>\n",
" <td>72.0</td>\n",
" <td>40.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>84.0</td>\n",
" <td>78.0</td>\n",
" <td>12.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
" <td>GAIA</td>\n",
" <td>3.1</td>\n",
" <td>0.0</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
" <td>MATH</td>\n",
" <td>20.0</td>\n",
" <td>30.0</td>\n",
" <td>22.0</td>\n",
" </tr>\n",
" <tr>\n",
@ -949,7 +819,7 @@
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
" <td>SimpleQA</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
@ -958,29 +828,29 @@
"text/plain": [
"action_type model_id source code vanilla\n",
"0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
"1 Qwen/Qwen2.5-72B-Instruct MATH 74.0 30.0\n",
"2 Qwen/Qwen2.5-72B-Instruct SimpleQA 70.0 10.0\n",
"3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 18.8 3.1\n",
"4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 76.0 60.0\n",
"1 Qwen/Qwen2.5-72B-Instruct MATH 76.0 30.0\n",
"2 Qwen/Qwen2.5-72B-Instruct SimpleQA 88.0 10.0\n",
"3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 25.0 3.1\n",
"4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 86.0 60.0\n",
"5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
"6 anthropic/claude-3-5-sonnet-latest GAIA 40.6 3.1\n",
"7 anthropic/claude-3-5-sonnet-latest MATH 67.0 50.0\n",
"8 anthropic/claude-3-5-sonnet-latest SimpleQA 90.0 34.0\n",
"9 gpt-4o GAIA 28.1 3.1\n",
"10 gpt-4o MATH 70.0 40.0\n",
"11 gpt-4o SimpleQA 88.0 6.0\n",
"12 meta-llama/Llama-3.1-8B-Instruct GAIA 0.0 0.0\n",
"13 meta-llama/Llama-3.1-8B-Instruct MATH 42.0 18.0\n",
"14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 54.0 6.0\n",
"6 anthropic/claude-3-5-sonnet-latest GAIA NaN 3.1\n",
"7 anthropic/claude-3-5-sonnet-latest MATH NaN 50.0\n",
"8 anthropic/claude-3-5-sonnet-latest SimpleQA NaN 34.0\n",
"9 gpt-4o GAIA 25.6 3.1\n",
"10 gpt-4o MATH 58.0 40.0\n",
"11 gpt-4o SimpleQA 86.0 6.0\n",
"12 meta-llama/Llama-3.1-8B-Instruct GAIA 3.1 0.0\n",
"13 meta-llama/Llama-3.1-8B-Instruct MATH 14.0 18.0\n",
"14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 2.0 6.0\n",
"15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
"16 meta-llama/Llama-3.2-3B-Instruct MATH 32.0 12.0\n",
"17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 4.0 0.0\n",
"18 meta-llama/Llama-3.3-70B-Instruct GAIA 34.4 3.1\n",
"19 meta-llama/Llama-3.3-70B-Instruct MATH 82.0 40.0\n",
"20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 84.0 12.0\n",
"21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 3.1 0.0\n",
"22 mistralai/Mistral-Nemo-Instruct-2407 MATH 20.0 22.0\n",
"23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 0.0"
"16 meta-llama/Llama-3.2-3B-Instruct MATH 40.0 12.0\n",
"17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 20.0 0.0\n",
"18 meta-llama/Llama-3.3-70B-Instruct GAIA 31.2 3.1\n",
"19 meta-llama/Llama-3.3-70B-Instruct MATH 72.0 40.0\n",
"20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 78.0 12.0\n",
"21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 0.0 3.1\n",
"22 mistralai/Mistral-Nemo-Instruct-2407 MATH 30.0 22.0\n",
"23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 6.0"
]
},
"metadata": {},
@ -1005,6 +875,15 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mnotebook controller is DISPOSED. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import json
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@ -809,26 +808,9 @@ class ToolCallingAgent(MultiStepAgent):
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
# Extract tool call from model output
if (
type(model_message.tool_calls) is list
and len(model_message.tool_calls) > 0
):
tool_calls = model_message.tool_calls[0]
tool_arguments = tool_calls.function.arguments
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
else:
start, end = (
model_message.content.find("{"),
model_message.content.rfind("}") + 1,
)
tool_calls = json.loads(model_message.content[start:end])
tool_arguments = tool_calls["tool_arguments"]
tool_name, tool_call_id = (
tool_calls["tool_name"],
f"call_{len(self.logs)}",
)
tool_call = model_message.tool_calls[0]
tool_name, tool_call_id = tool_call.function.name, tool_call.id
tool_arguments = tool_call.function.arguments
except Exception as e:
raise AgentGenerationError(
@ -887,7 +869,10 @@ class ToolCallingAgent(MultiStepAgent):
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO)
self.logger.log(
f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
level=LogLevel.INFO,
)
log_entry.observations = updated_information
return None

View File

@ -31,6 +31,7 @@ from .local_python_executor import (
)
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio
from .utils import truncate_content
if is_torch_available():
from transformers.models.whisper import (
@ -112,18 +113,15 @@ class PythonInterpreterTool(Tool):
def forward(self, code: str) -> str:
state = {}
try:
output = str(
self.python_evaluator(
code,
state=state,
static_tools=self.base_python_tools,
authorized_imports=self.authorized_imports,
)[0] # The second element is boolean is_final_answer
)
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
except Exception as e:
return f"Error: {str(e)}"
output = str(
self.python_evaluator(
code,
state=state,
static_tools=self.base_python_tools,
authorized_imports=self.authorized_imports,
)[0] # The second element is boolean is_final_answer
)
return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
class FinalAnswerTool(Tool):
@ -295,7 +293,7 @@ class VisitWebpageTool(Tool):
# Remove multiple line breaks
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
return markdown_content
return truncate_content(markdown_content)
except RequestException as e:
return f"Error fetching the webpage: {str(e)}"

View File

@ -14,20 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import json
import logging
import os
import random
from copy import deepcopy
from enum import Enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union, Any
from huggingface_hub import (
InferenceClient,
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputFunctionDefinition,
)
from huggingface_hub import InferenceClient
from transformers import (
AutoModelForCausalLM,
@ -58,6 +54,27 @@ if _is_package_available("litellm"):
import litellm
@dataclass
class ChatMessageToolCallDefinition:
arguments: Any
name: str
description: Optional[str] = None
@dataclass
class ChatMessageToolCall:
function: ChatMessageToolCallDefinition
id: str
type: str
@dataclass
class ChatMessage:
role: str
content: Optional[str] = None
tool_calls: Optional[List[ChatMessageToolCall]] = None
class MessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
@ -140,6 +157,17 @@ def get_clean_message_list(
return final_message_list
def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]:
try:
start, end = (
possible_dictionary.find("{"),
possible_dictionary.rfind("}") + 1,
)
return json.loads(possible_dictionary[start:end])
except Exception:
return possible_dictionary
class Model:
def __init__(self):
self.last_input_token_count = None
@ -157,7 +185,7 @@ class Model:
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
) -> ChatCompletionOutputMessage:
) -> ChatMessage:
"""Process the input messages and return the model's response.
Parameters:
@ -228,7 +256,7 @@ class HfApiModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatCompletionOutputMessage:
) -> ChatMessage:
"""
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
@ -329,7 +357,7 @@ class TransformersModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatCompletionOutputMessage:
) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
@ -365,21 +393,21 @@ class TransformersModel(Model):
if stop_sequences is not None:
output = remove_stop_sequences(output, stop_sequences)
if tools_to_call_from is None:
return ChatCompletionOutputMessage(role="assistant", content=output)
return ChatMessage(role="assistant", content=output)
else:
if "Action:" in output:
output = output.split("Action:", 1)[1].strip()
parsed_output = json.loads(output)
tool_name = parsed_output.get("tool_name")
tool_arguments = parsed_output.get("tool_arguments")
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
ChatMessageToolCall(
id="".join(random.choices("0123456789", k=5)),
type="function",
function=ChatCompletionOutputFunctionDefinition(
function=ChatMessageToolCallDefinition(
name=tool_name, arguments=tool_arguments
),
)
@ -414,7 +442,7 @@ class LiteLLMModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatCompletionOutputMessage:
) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
@ -485,7 +513,7 @@ class OpenAIServerModel(Model):
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatCompletionOutputMessage:
) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)

View File

@ -221,6 +221,16 @@ class Tool:
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
if not self.is_initialized:
self.setup()
# Handle the arguments might be passed as a single dictionary
if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
potential_kwargs = args[0]
# If the dictionary keys match our input parameters, convert it to kwargs
if all(key in self.inputs for key in potential_kwargs):
args = ()
kwargs = potential_kwargs
if sanitize_inputs_outputs:
args, kwargs = handle_agent_input_types(*args, **kwargs)
outputs = self.forward(*args, **kwargs)

View File

@ -30,10 +30,10 @@ from smolagents.agents import (
from smolagents.default_tools import PythonInterpreterTool
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
from huggingface_hub import (
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputFunctionDefinition,
from smolagents.models import (
ChatMessage,
ChatMessageToolCall,
ChatMessageToolCallDefinition,
)
@ -47,28 +47,28 @@ class FakeToolCallModel:
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
function=ChatMessageToolCallDefinition(
name="python_interpreter", arguments={"code": "2*3.6452"}
),
)
],
)
else:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
ChatMessageToolCall(
id="call_1",
type="function",
function=ChatCompletionOutputFunctionDefinition(
function=ChatMessageToolCallDefinition(
name="final_answer", arguments={"answer": "7.2904"}
),
)
@ -81,14 +81,14 @@ class FakeToolCallModelImage:
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
function=ChatMessageToolCallDefinition(
name="fake_image_generation_tool",
arguments={"prompt": "An image of a cat"},
),
@ -96,14 +96,14 @@ class FakeToolCallModelImage:
],
)
else:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
ChatMessageToolCall(
id="call_1",
type="function",
function=ChatCompletionOutputFunctionDefinition(
function=ChatMessageToolCallDefinition(
name="final_answer", arguments="image.png"
),
)
@ -114,7 +114,7 @@ class FakeToolCallModelImage:
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@ -125,7 +125,7 @@ result = 2**3.6452
""",
)
else: # We're at step 2
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@ -140,7 +140,7 @@ final_answer(7.2904)
def fake_code_model_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@ -154,7 +154,7 @@ print("Ok, calculation done!")
""",
)
else: # We're at step 2
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@ -169,7 +169,7 @@ final_answer("got an error")
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@ -183,7 +183,7 @@ print("Ok, calculation done!")
""",
)
else: # We're at step 2
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@ -196,7 +196,7 @@ final_answer("got an error")
def fake_code_model_import(messages, stop_sequences=None) -> str:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I can answer the question
@ -212,7 +212,7 @@ final_answer("got an error")
def fake_code_functiondef(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: Let's define the function. special_marker
@ -226,7 +226,7 @@ def moving_average(x, w):
""",
)
else: # We're at step 2
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@ -241,7 +241,7 @@ final_answer(res)
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@ -255,7 +255,7 @@ final_answer(result)
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@ -454,14 +454,14 @@ class AgentTests(unittest.TestCase):
):
if tools_to_call_from is not None:
if len(messages) < 3:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
function=ChatMessageToolCallDefinition(
name="search_agent",
arguments="Who is the current US president?",
),
@ -470,14 +470,14 @@ class AgentTests(unittest.TestCase):
)
else:
assert "Report on the current US president" in str(messages)
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
function=ChatMessageToolCallDefinition(
name="final_answer", arguments="Final report."
),
)
@ -485,7 +485,7 @@ class AgentTests(unittest.TestCase):
)
else:
if len(messages) < 3:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: Let's call our search agent.
@ -497,7 +497,7 @@ result = search_agent("Who is the current US president?")
)
else:
assert "Report on the current US president" in str(messages)
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Thought: Let's return the report.
@ -518,14 +518,14 @@ final_answer("Final report.")
stop_sequences=None,
grammar=None,
):
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
ChatMessageToolCall(
id="call_0",
type="function",
function=ChatCompletionOutputFunctionDefinition(
function=ChatMessageToolCallDefinition(
name="final_answer",
arguments="Report on the current US president",
),
@ -568,7 +568,7 @@ final_answer("Final report.")
def test_code_nontrivial_final_answer_works(self):
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""Code:
```py

View File

@ -0,0 +1,83 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
from smolagents.types import AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin
class DefaultToolTests(unittest.TestCase):
def test_visit_webpage(self):
arguments = {
"url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"
}
result = VisitWebpageTool()(arguments)
assert isinstance(result, str)
assert (
"* [About Wikipedia](/wiki/Wikipedia:About)" in result
) # Proper wikipedia pages have an About
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = PythonInterpreterTool(authorized_imports=["numpy"])
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool("(2 / 2) * 4")
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
def test_exact_match_kwarg(self):
result = self.tool(code="(2 / 2) * 4")
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
def test_agent_type_output(self):
inputs = ["2 * 2"]
output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type))
def test_agent_types_inputs(self):
inputs = ["2 * 2"]
_inputs = []
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"]
if isinstance(input_type, list):
_inputs.append(
[
AGENT_TYPE_MAPPING[_input_type](_input)
for _input_type in input_type
]
)
else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
# Should not raise an error
output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type))
def test_imports_work(self):
result = self.tool("import numpy as np")
assert "import from numpy is not allowed" not in result.lower()
def test_unauthorized_imports_fail(self):
with pytest.raises(Exception) as e:
self.tool("import sympy as sp")
assert "sympy" in str(e).lower()

View File

@ -23,9 +23,9 @@ from smolagents import (
stream_to_gradio,
)
from huggingface_hub import (
ChatCompletionOutputMessage,
ChatCompletionOutputToolCall,
ChatCompletionOutputFunctionDefinition,
ChatMessage,
ChatMessageToolCall,
ChatMessageToolCallDefinition,
)
@ -36,21 +36,21 @@ class FakeLLMModel:
def __call__(self, prompt, tools_to_call_from=None, **kwargs):
if tools_to_call_from is not None:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionOutputToolCall(
ChatMessageToolCall(
id="fake_id",
type="function",
function=ChatCompletionOutputFunctionDefinition(
function=ChatMessageToolCallDefinition(
name="final_answer", arguments={"answer": "image"}
),
)
],
)
else:
return ChatCompletionOutputMessage(
return ChatMessage(
role="assistant",
content="""
Code:
@ -91,9 +91,7 @@ class MonitoringTester(unittest.TestCase):
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
return ChatCompletionOutputMessage(
role="assistant", content="Malformed answer"
)
return ChatMessage(role="assistant", content="Malformed answer")
agent = CodeAgent(
tools=[],

View File

@ -18,15 +18,12 @@ import unittest
import numpy as np
import pytest
from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool
from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import (
InterpreterError,
evaluate_python_code,
fix_final_answer_code,
)
from smolagents.types import AGENT_TYPE_MAPPING
from .test_tools import ToolTesterMixin
# Fake function we will use as tool
@ -34,47 +31,6 @@ def add_two(x):
return x + 2
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"])
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool("(2 / 2) * 4")
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
def test_exact_match_kwarg(self):
result = self.tool(code="(2 / 2) * 4")
self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
def test_agent_type_output(self):
inputs = ["2 * 2"]
output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type))
def test_agent_types_inputs(self):
inputs = ["2 * 2"]
_inputs = []
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"]
if isinstance(input_type, list):
_inputs.append(
[
AGENT_TYPE_MAPPING[_input_type](_input)
for _input_type in input_type
]
)
else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
# Should not raise an error
output = self.tool(*inputs, sanitize_inputs_outputs=True)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type))
class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_assign(self):
code = "x = 3"