{ "cells": [ { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import datasets\n", "\n", "eval_ds = datasets.load_dataset(\"m-ric/agents_medium_benchmark_2\")[\"train\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define utilities and tools\n", "To run the SERPAPI tool, you will need to have a [SerpAPI](https://serpapi.com/dashboard) API key: for this you need a paid account." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import time\n", "import json\n", "import os\n", "import re\n", "import string\n", "import warnings\n", "from tqdm import tqdm\n", "from typing import List\n", "\n", "from smolagents import (\n", " GoogleSearchTool,\n", " CodeAgent,\n", " ToolCallingAgent,\n", " HfApiModel,\n", " AgentError,\n", " VisitWebpageTool,\n", " PythonInterpreterTool,\n", ")\n", "from smolagents.agents import ActionStep\n", "from dotenv import load_dotenv\n", "\n", "load_dotenv()\n", "os.makedirs(\"output\", exist_ok=True)\n", "\n", "\n", "def serialize_agent_error(obj):\n", " if isinstance(obj, AgentError):\n", " return {\"error_type\": obj.__class__.__name__, \"message\": obj.message}\n", " else:\n", " return str(obj)\n", "\n", "\n", "def answer_questions(eval_ds, file_name, agent, model_id, action_type):\n", " answered_questions = []\n", " if os.path.exists(file_name):\n", " with open(file_name, \"r\") as f:\n", " for line in f:\n", " answered_questions.append(json.loads(line)[\"question\"])\n", "\n", " for _, example in tqdm(enumerate(eval_ds), total=len(eval_ds)):\n", " try:\n", " question = example[\"question\"]\n", " if example[\"source\"] == \"SimpleQA\":\n", " question += \" Answer with only the final number.\"\n", " if question in answered_questions:\n", " continue\n", " start_time = time.time()\n", " answer = agent.run(question)\n", " end_time = time.time()\n", " for step_log in agent.logs:\n", " if hasattr(step_log, \"memory\"):\n", " step_log.memory = None\n", "\n", " # Remove memory from logs to make them more compact.\n", " for step in agent.logs:\n", " if isinstance(step, ActionStep):\n", " step.agent_memory = None\n", "\n", " annotated_example = {\n", " \"model_id\": model_id,\n", " \"agent_action_type\": action_type,\n", " \"question\": question,\n", " \"answer\": answer,\n", " \"true_answer\": example[\"true_answer\"],\n", " \"source\": example[\"source\"],\n", " \"intermediate_steps\": str(agent.logs),\n", " \"start_time\": start_time,\n", " \"end_time\": end_time,\n", " \"token_counts\": agent.monitor.get_total_token_counts(),\n", " }\n", "\n", " with open(file_name, \"a\") as f:\n", " json.dump(annotated_example, f, default=serialize_agent_error)\n", " f.write(\"\\n\") # add a newline for JSONL format\n", " except Exception as e:\n", " print(\"Failed:\", e)\n", "\n", "\n", "def normalize_number_str(number_str: str) -> float:\n", " # we replace these common units and commas to allow\n", " # conversion to float\n", " for char in [\"$\", \"%\", \",\"]:\n", " number_str = number_str.replace(char, \"\")\n", " try:\n", " return float(number_str)\n", " except ValueError:\n", " return float(\"inf\")\n", "\n", "\n", "def split_string(\n", " s: str,\n", " char_list: list[str] = [\",\", \";\"],\n", ") -> list[str]:\n", " pattern = f\"[{''.join(char_list)}]\"\n", " return re.split(pattern, s)\n", "\n", "\n", "def is_float(element: any) -> bool:\n", " try:\n", " float(element)\n", " return True\n", " except ValueError:\n", " return False\n", "\n", "\n", "def normalize_str(input_str, remove_punct=True) -> str:\n", " \"\"\"\n", " Normalize a string by:\n", " - Removing all white spaces\n", " - Optionally removing punctuation (if remove_punct is True)\n", " - Converting to lowercase\n", " Parameters:\n", " - input_str: str, the string to normalize\n", " - remove_punct: bool, whether to remove punctuation (default: True)\n", " Returns:\n", " - str, the normalized string\n", " \"\"\"\n", " # Remove all white spaces. Required e.g for seagull vs. sea gull\n", " no_spaces = re.sub(r\"\\s\", \"\", input_str)\n", "\n", " # Remove punctuation, if specified.\n", " if remove_punct:\n", " translator = str.maketrans(\"\", \"\", string.punctuation)\n", " return no_spaces.lower().translate(translator)\n", " else:\n", " return no_spaces.lower()\n", "\n", "\n", "def extract_numbers(text: str) -> List[str]:\n", " \"\"\"This pattern matches:\n", " - Optional negative sign\n", " - Numbers with optional comma thousand separators\n", " - Optional decimal points with decimal numbers\n", " \"\"\"\n", " pattern = r\"-?(?:\\d{1,3}(?:,\\d{3})+|\\d+)(?:\\.\\d+)?\"\n", "\n", " return [el.replace(\",\", \"\") for el in re.findall(pattern, text)]\n", "\n", "\n", "def get_question_score_gaia(\n", " model_answer: str,\n", " ground_truth: str,\n", ") -> bool:\n", " if is_float(ground_truth):\n", " normalized_answer = normalize_number_str(str(model_answer))\n", " return normalized_answer == float(ground_truth)\n", "\n", " elif any(char in ground_truth for char in [\",\", \";\"]): # if gt is a list\n", " # question with the fish: normalization removes punct\n", " gt_elems = split_string(ground_truth)\n", " ma_elems = split_string(model_answer)\n", "\n", " if len(gt_elems) != len(ma_elems): # check length is the same\n", " warnings.warn(\n", " \"Answer lists have different lengths, returning False.\", UserWarning\n", " )\n", " return False\n", "\n", " comparisons = []\n", " for ma_elem, gt_elem in zip(\n", " ma_elems, gt_elems\n", " ): # compare each element as float or str\n", " if is_float(gt_elem):\n", " normalized_ma_elem = normalize_number_str(ma_elem)\n", " comparisons.append(normalized_ma_elem == float(gt_elem))\n", " else:\n", " # we do not remove punct since comparisons can include punct\n", " comparisons.append(\n", " normalize_str(ma_elem, remove_punct=False)\n", " == normalize_str(gt_elem, remove_punct=False)\n", " )\n", " return all(comparisons)\n", "\n", " else: # if gt is a str\n", " return normalize_str(model_answer) == normalize_str(ground_truth)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run benchmark\n", "\n", "### Open models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "open_model_ids = [\n", " \"meta-llama/Llama-3.3-70B-Instruct\",\n", " # \"Qwen/QwQ-32B-Preview\",\n", " \"Qwen/Qwen2.5-72B-Instruct\",\n", " \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n", " \"meta-llama/Llama-3.2-3B-Instruct\",\n", " # \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n", " # \"meta-llama/Llama-3.1-70B-Instruct\",\n", "]\n", "\n", "for model_id in open_model_ids:\n", " print(f\"Evaluating '{model_id}'...\")\n", " action_type = \"tool_calling\"\n", " agent = ToolCallingAgent(\n", " tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n", " model=HfApiModel(model_id),\n", " max_iterations=10,\n", " )\n", " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", "\n", " action_type = \"code\"\n", " agent = CodeAgent(\n", " tools=[GoogleSearchTool(), VisitWebpageTool()],\n", " model=HfApiModel(model_id),\n", " additional_authorized_imports=[\"numpy\"],\n", " max_iterations=10,\n", " )\n", " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " answer_questions(eval_ds, file_name, agent, model_id, action_type)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Closed models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from smolagents import LiteLLMModel\n", "\n", "litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n", "\n", "for model_id in litellm_model_ids:\n", " print(f\"Evaluating '{model_id}'...\")\n", " action_type = \"tool_calling\"\n", " agent = ToolCallingAgent(\n", " tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n", " model=LiteLLMModel(model_id),\n", " max_iterations=10,\n", " )\n", " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", "\n", " action_type = \"code\"\n", " agent = CodeAgent(\n", " tools=[GoogleSearchTool(), VisitWebpageTool()],\n", " model=LiteLLMModel(model_id),\n", " additional_authorized_imports=[\"numpy\"],\n", " max_iterations=10,\n", " )\n", " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " answer_questions(eval_ds, file_name, agent, model_id, action_type)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# import glob\n", "# import json\n", "# jsonl_files = glob.glob(f\"output/*.jsonl\")\n", "\n", "# for file_path in jsonl_files:\n", "# print(file_path)\n", "# # Read all lines and filter out SimpleQA sources\n", "# filtered_lines = []\n", "# removed = 0\n", "# with open(file_path, 'r', encoding='utf-8') as f:\n", "# for line in f:\n", "# try:\n", "# data = json.loads(line.strip())\n", "# if data[\"source\"] == \"SimpleQA\" and \"Answer with only the final number.\" not in data[\"question\"]:\n", "# removed +=1\n", "# else:\n", "# filtered_lines.append(line)\n", "# except json.JSONDecodeError:\n", "# print(\"Invalid line:\", line)\n", "# continue # Skip invalid JSON lines\n", "# print(f\"Removed {removed} lines.\")\n", "# # Write filtered content back to the same file\n", "# with open(file_path, 'w', encoding='utf-8') as f:\n", "# f.writelines(filtered_lines)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Score answers" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_37227/1724525657.py:154: UserWarning: Answer lists have different lengths, returning False.\n", " warnings.warn(\n" ] } ], "source": [ "import pandas as pd\n", "import glob\n", "\n", "res = []\n", "for f in glob.glob(f\"output/*.jsonl\"):\n", " res.append(pd.read_json(f, lines=True))\n", "result_df = pd.concat(res)\n", "\n", "\n", "def get_correct(row):\n", " if row[\"source\"] == \"GSM8K\":\n", " numbers_answer = extract_numbers(str(row[\"answer\"]))\n", " if len(numbers_answer) == 0:\n", " return False\n", " return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n", " else:\n", " return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n", "\n", "\n", "result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n", "\n", "result_df = result_df.loc[\n", " (result_df[\"agent_action_type\"] == \"code\")\n", " & (\n", " ~result_df[\"model_id\"].isin(\n", " [\n", " \"meta-llama/Llama-3.2-3B-Instruct\",\n", " \"meta-llama/Llama-3.1-70B-Instruct\",\n", " \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n", " ]\n", " )\n", " )\n", "]\n", "result_df = (\n", " (result_df.groupby([\"model_id\", \"source\"])[[\"correct\"]].mean() * 100)\n", " .round(1)\n", " .reset_index()\n", ")\n", "result_df[\"type\"] = \"agent\"" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "vanilla_data = [\n", " [\"gpt-4o\", \"SimpleQA\", 38.2],\n", " [\"gpt-4o\", \"GAIA\", 9.3],\n", " [\"Qwen/Qwen2.5-72B-Instruct\", \"SimpleQA\", 9.1],\n", " [\"anthropic/claude-3-5-sonnet-latest\", \"SimpleQA\", 28.4],\n", " [\"gpt-4o\", \"GSM8K\", 94.3],\n", " [\"anthropic/claude-3-5-sonnet-latest\", \"GSM8K\", 96.4],\n", " [\"meta-llama/Llama-3.3-70B-Instruct\", \"GSM8K\", 95.1],\n", "]\n", "\n", "df2 = pd.DataFrame(vanilla_data, columns=[\"model_id\", \"source\", \"correct\"])\n", "df2[\"type\"] = \"vanilla\"\n", "\n", "combined_df = pd.concat([result_df, df2], ignore_index=True)\n", "\n", "pivot_df = combined_df.pivot_table(\n", " index=[\"model_id\", \"source\"],\n", " columns=[\"type\"],\n", " values=\"correct\",\n", " fill_value=float(\"nan\"),\n", ").reset_index()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Display results" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
typemodel_idsourceagentvanilla
0Qwen/Qwen2.5-72B-InstructGAIA12.5NaN
1Qwen/Qwen2.5-72B-InstructGSM8K82.9NaN
2Qwen/Qwen2.5-72B-InstructSimpleQA42.59.1
3Qwen/Qwen2.5-Coder-32B-InstructGAIA28.1NaN
4Qwen/Qwen2.5-Coder-32B-InstructGSM8K92.9NaN
5Qwen/Qwen2.5-Coder-32B-InstructSimpleQA42.5NaN
6anthropic/claude-3-5-sonnet-latestGAIA43.8NaN
7anthropic/claude-3-5-sonnet-latestGSM8K91.496.4
8anthropic/claude-3-5-sonnet-latestSimpleQA47.528.4
9gpt-4oGAIA25.09.3
10gpt-4oGSM8K91.494.3
11gpt-4oSimpleQA60.038.2
12meta-llama/Llama-3.3-70B-InstructGAIA21.9NaN
13meta-llama/Llama-3.3-70B-InstructGSM8K95.795.1
14meta-llama/Llama-3.3-70B-InstructSimpleQA30.0NaN
\n", "
" ], "text/plain": [ "type model_id source agent vanilla\n", "0 Qwen/Qwen2.5-72B-Instruct GAIA 12.5 NaN\n", "1 Qwen/Qwen2.5-72B-Instruct GSM8K 82.9 NaN\n", "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 42.5 9.1\n", "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 28.1 NaN\n", "4 Qwen/Qwen2.5-Coder-32B-Instruct GSM8K 92.9 NaN\n", "5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 42.5 NaN\n", "6 anthropic/claude-3-5-sonnet-latest GAIA 43.8 NaN\n", "7 anthropic/claude-3-5-sonnet-latest GSM8K 91.4 96.4\n", "8 anthropic/claude-3-5-sonnet-latest SimpleQA 47.5 28.4\n", "9 gpt-4o GAIA 25.0 9.3\n", "10 gpt-4o GSM8K 91.4 94.3\n", "11 gpt-4o SimpleQA 60.0 38.2\n", "12 meta-llama/Llama-3.3-70B-Instruct GAIA 21.9 NaN\n", "13 meta-llama/Llama-3.3-70B-Instruct GSM8K 95.7 95.1\n", "14 meta-llama/Llama-3.3-70B-Instruct SimpleQA 30.0 NaN" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(pivot_df)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\\begin{array}{llcc}\n", "\\text{Model} & \\text{Task} & \\text{Agent} & \\text{Vanilla} \\\\\n", "\\hline\n", "\\textit{Qwen/Qwen2.5-72B-Instruct} & GAIA & 12.500 & - \\\\\n", "\\; & GSM8K & 82.900 & - \\\\\n", "\\; & SimpleQA & \\textbf{42.500} & 9.100 \\\\\n", "\\hline\n", "\\textit{Qwen/Qwen2.5-Coder-32B-Instruct} & GAIA & 28.100 & - \\\\\n", "\\; & GSM8K & 92.900 & - \\\\\n", "\\; & SimpleQA & 42.500 & - \\\\\n", "\\hline\n", "\\textit{anthropic/claude-3-5-sonnet-latest} & GAIA & 43.800 & - \\\\\n", "\\; & GSM8K & 91.400 & \\textbf{96.400} \\\\\n", "\\; & SimpleQA & \\textbf{47.500} & 28.400 \\\\\n", "\\hline\n", "gpt-4o & GAIA & \\textbf{25.000} & 9.300 \\\\\n", "\\; & GSM8K & 91.400 & \\textbf{94.300} \\\\\n", "\\; & SimpleQA & \\textbf{60.000} & 38.200 \\\\\n", "\\hline\n", "meta-llama/Llama-3.3-70B-Instruct & GAIA & 21.900 & - \\\\\n", "\\; & GSM8K & \\textbf{95.700} & 95.100 \\\\\n", "\\; & SimpleQA & 30.000 & - \\\\\n", "\\hline\n", "\\end{array}\n" ] } ], "source": [ "def create_mathjax_table(pivot_df, formatted_df):\n", " # Start the matrix environment with 4 columns\n", " # l for left-aligned model and task, c for centered numbers\n", " mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n", " mathjax_table += (\n", " \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n", " )\n", " mathjax_table += \"\\\\hline\\n\"\n", "\n", " # Sort the DataFrame by model_id and source\n", " formatted_df = formatted_df.sort_values([\"model_id\", \"source\"])\n", "\n", " current_model = None\n", " for _, row in formatted_df.iterrows():\n", " model = row[\"model_id\"]\n", " source = row[\"source\"]\n", "\n", " # Add a horizontal line between different models\n", " if current_model is not None and current_model != model:\n", " mathjax_table += \"\\\\hline\\n\"\n", "\n", " # Format model name\n", " model_display = model.replace(\"_\", \"\\\\_\")\n", " if \"Qwen\" in model or \"anthropic\" in model:\n", " model_display = f\"\\\\textit{{{model_display}}}\"\n", "\n", " # If it's the same model as previous row, use empty space\n", " if current_model == model:\n", " model_display = \"\\\\;\"\n", "\n", " # Add the data row\n", " mathjax_table += (\n", " f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n", " )\n", "\n", " current_model = model\n", "\n", " mathjax_table += \"\\\\hline\\n\"\n", " mathjax_table += \"\\\\end{array}\"\n", "\n", " return mathjax_table\n", "\n", "\n", "# Usage (after running your previous data processing code):\n", "mathjax_table = create_mathjax_table(pivot_df, formatted_df)\n", "print(mathjax_table)" ] } ], "metadata": { "kernelspec": { "display_name": "compare-agents", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.0" } }, "nbformat": 4, "nbformat_minor": 2 }