smolagents/examples/benchmark.ipynb

1060 lines
113 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>question</th>\n",
" <th>source</th>\n",
" <th>true_answer</th>\n",
" <th>true_reasoning</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>If Eliud Kipchoge could maintain his record-ma...</td>\n",
" <td>GAIA</td>\n",
" <td>17</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>How many studio albums were published by Merce...</td>\n",
" <td>GAIA</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Here's a fun riddle that I think you'll enjoy....</td>\n",
" <td>GAIA</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>My family reunion is this week, and I was assi...</td>\n",
" <td>GAIA</td>\n",
" <td>2</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>In Emily Midkiff's June 2014 article in a jour...</td>\n",
" <td>GAIA</td>\n",
" <td>fluffy</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>127</th>\n",
" <td>What year was the municipality of San Carlos, ...</td>\n",
" <td>SimpleQA</td>\n",
" <td>1786</td>\n",
" <td>['https://en.wikipedia.org/wiki/San_Carlos,_An...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>128</th>\n",
" <td>In which year was Maria Elena Walsh named Illu...</td>\n",
" <td>SimpleQA</td>\n",
" <td>1985</td>\n",
" <td>['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>129</th>\n",
" <td>What is the durability of the Istarelle spear ...</td>\n",
" <td>SimpleQA</td>\n",
" <td>800</td>\n",
" <td>['http://demonssouls.wikidot.com/spear', 'http...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>130</th>\n",
" <td>What is the number of the executive order that...</td>\n",
" <td>SimpleQA</td>\n",
" <td>7034</td>\n",
" <td>['https://www.loc.gov/collections/federal-thea...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131</th>\n",
" <td>Within plus or minus one minute, when was Marq...</td>\n",
" <td>SimpleQA</td>\n",
" <td>77</td>\n",
" <td>['https://www.fifa.com/fifaplus/en/match-centr...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>132 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" question source true_answer \\\n",
"0 If Eliud Kipchoge could maintain his record-ma... GAIA 17 \n",
"1 How many studio albums were published by Merce... GAIA 3 \n",
"2 Here's a fun riddle that I think you'll enjoy.... GAIA 3 \n",
"3 My family reunion is this week, and I was assi... GAIA 2 \n",
"4 In Emily Midkiff's June 2014 article in a jour... GAIA fluffy \n",
".. ... ... ... \n",
"127 What year was the municipality of San Carlos, ... SimpleQA 1786 \n",
"128 In which year was Maria Elena Walsh named Illu... SimpleQA 1985 \n",
"129 What is the durability of the Istarelle spear ... SimpleQA 800 \n",
"130 What is the number of the executive order that... SimpleQA 7034 \n",
"131 Within plus or minus one minute, when was Marq... SimpleQA 77 \n",
"\n",
" true_reasoning \n",
"0 None \n",
"1 None \n",
"2 None \n",
"3 None \n",
"4 None \n",
".. ... \n",
"127 ['https://en.wikipedia.org/wiki/San_Carlos,_An... \n",
"128 ['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele... \n",
"129 ['http://demonssouls.wikidot.com/spear', 'http... \n",
"130 ['https://www.loc.gov/collections/federal-thea... \n",
"131 ['https://www.fifa.com/fifaplus/en/match-centr... \n",
"\n",
"[132 rows x 4 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import datasets\n",
"import pandas as pd\n",
"\n",
"\n",
"eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
"pd.DataFrame(eval_ds)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define utilities and tools\n",
"To run the SERPAPI tool, you will need to have a [SerpAPI](https://serpapi.com/dashboard) API key: for this you need a paid account."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"import re\n",
"import string\n",
"import time\n",
"import warnings\n",
"from typing import List\n",
"\n",
"from dotenv import load_dotenv\n",
"from tqdm import tqdm\n",
"\n",
"from smolagents import (\n",
" AgentError,\n",
" CodeAgent,\n",
" GoogleSearchTool,\n",
" HfApiModel,\n",
" PythonInterpreterTool,\n",
" ToolCallingAgent,\n",
" VisitWebpageTool,\n",
")\n",
"from smolagents.agents import ActionStep\n",
"\n",
"\n",
"load_dotenv()\n",
"os.makedirs(\"output\", exist_ok=True)\n",
"\n",
"\n",
"def serialize_agent_error(obj):\n",
" if isinstance(obj, AgentError):\n",
" return {\"error_type\": obj.__class__.__name__, \"message\": obj.message}\n",
" else:\n",
" return str(obj)\n",
"\n",
"\n",
"def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n",
" answered_questions = []\n",
" if os.path.exists(file_name):\n",
" with open(file_name, \"r\") as f:\n",
" for line in f:\n",
" answered_questions.append(json.loads(line)[\"question\"])\n",
"\n",
" for _, example in tqdm(enumerate(eval_ds), total=len(eval_ds)):\n",
" try:\n",
" question = example[\"question\"]\n",
" if example[\"source\"] == \"SimpleQA\":\n",
" question += \" Answer with only the final number.\"\n",
" if example[\"source\"] == \"MATH\":\n",
" question += \" Write code, not latex.\"\n",
" if question in answered_questions:\n",
" continue\n",
" start_time = time.time()\n",
"\n",
" if is_vanilla_llm:\n",
" llm = agent\n",
" answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n",
" token_count = {\n",
" \"input\": llm.last_input_token_count,\n",
" \"output\": llm.last_output_token_count,\n",
" }\n",
" intermediate_steps = str([])\n",
" else:\n",
" answer = str(agent.run(question))\n",
" token_count = agent.monitor.get_total_token_counts()\n",
" intermediate_steps = str(agent.logs)\n",
" # Remove memory from logs to make them more compact.\n",
" for step in agent.logs:\n",
" if isinstance(step, ActionStep):\n",
" step.agent_memory = None\n",
"\n",
" end_time = time.time()\n",
" annotated_example = {\n",
" \"model_id\": model_id,\n",
" \"agent_action_type\": action_type,\n",
" \"question\": question,\n",
" \"answer\": answer,\n",
" \"true_answer\": example[\"true_answer\"],\n",
" \"source\": example[\"source\"],\n",
" \"intermediate_steps\": intermediate_steps,\n",
" \"start_time\": start_time,\n",
" \"end_time\": end_time,\n",
" \"token_counts\": token_count,\n",
" }\n",
"\n",
" with open(file_name, \"a\") as f:\n",
" json.dump(annotated_example, f, default=serialize_agent_error)\n",
" f.write(\"\\n\") # add a newline for JSONL format\n",
" except Exception as e:\n",
" print(\"Failed:\", e)\n",
"\n",
"\n",
"def normalize_number_str(number_str: str) -> float:\n",
" # we replace these common units and commas to allow\n",
" # conversion to float\n",
" for char in [\"$\", \"%\", \",\"]:\n",
" number_str = number_str.replace(char, \"\")\n",
" try:\n",
" return float(number_str)\n",
" except ValueError:\n",
" return float(\"inf\")\n",
"\n",
"\n",
"def split_string(\n",
" s: str,\n",
" char_list: list[str] = [\",\", \";\"],\n",
") -> list[str]:\n",
" pattern = f\"[{''.join(char_list)}]\"\n",
" return re.split(pattern, s)\n",
"\n",
"\n",
"def is_float(element: any) -> bool:\n",
" try:\n",
" float(element)\n",
" return True\n",
" except ValueError:\n",
" return False\n",
"\n",
"\n",
"def normalize_str(input_str, remove_punct=True) -> str:\n",
" \"\"\"\n",
" Normalize a string by:\n",
" - Removing all white spaces\n",
" - Optionally removing punctuation (if remove_punct is True)\n",
" - Converting to lowercase\n",
" Parameters:\n",
" - input_str: str, the string to normalize\n",
" - remove_punct: bool, whether to remove punctuation (default: True)\n",
" Returns:\n",
" - str, the normalized string\n",
" \"\"\"\n",
" # Remove all white spaces. Required e.g for seagull vs. sea gull\n",
" no_spaces = re.sub(r\"\\s\", \"\", input_str)\n",
"\n",
" # Remove punctuation, if specified.\n",
" if remove_punct:\n",
" translator = str.maketrans(\"\", \"\", string.punctuation)\n",
" return no_spaces.lower().translate(translator)\n",
" else:\n",
" return no_spaces.lower()\n",
"\n",
"\n",
"def extract_numbers(text: str) -> List[str]:\n",
" \"\"\"This pattern matches:\n",
" - Optional negative sign\n",
" - Numbers with optional comma thousand separators\n",
" - Optional decimal points with decimal numbers\n",
" \"\"\"\n",
" pattern = r\"-?(?:\\d{1,3}(?:,\\d{3})+|\\d+)(?:\\.\\d+)?\"\n",
"\n",
" return [el.replace(\",\", \"\") for el in re.findall(pattern, text)]\n",
"\n",
"\n",
"def get_question_score_gaia(\n",
" model_answer: str,\n",
" ground_truth: str,\n",
") -> bool:\n",
" \"\"\"Scoring function used to score functions from the GAIA benchmark\"\"\"\n",
" if is_float(ground_truth):\n",
" normalized_answer = normalize_number_str(str(model_answer))\n",
" return normalized_answer == float(ground_truth)\n",
"\n",
" elif any(char in ground_truth for char in [\",\", \";\"]): # if gt is a list\n",
" # question with the fish: normalization removes punct\n",
" gt_elems = split_string(ground_truth)\n",
" ma_elems = split_string(model_answer)\n",
"\n",
" if len(gt_elems) != len(ma_elems): # check length is the same\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
" return False\n",
"\n",
" comparisons = []\n",
" for ma_elem, gt_elem in zip(ma_elems, gt_elems): # compare each element as float or str\n",
" if is_float(gt_elem):\n",
" normalized_ma_elem = normalize_number_str(ma_elem)\n",
" comparisons.append(normalized_ma_elem == float(gt_elem))\n",
" else:\n",
" # we do not remove punct since comparisons can include punct\n",
" comparisons.append(\n",
" normalize_str(ma_elem, remove_punct=False) == normalize_str(gt_elem, remove_punct=False)\n",
" )\n",
" return all(comparisons)\n",
"\n",
" else: # if gt is a str\n",
" return normalize_str(model_answer) == normalize_str(ground_truth)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Benchmark agents\n",
"\n",
"### Open models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"open_model_ids = [\n",
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
" # \"Qwen/QwQ-32B-Preview\",\n",
" \"Qwen/Qwen2.5-72B-Instruct\",\n",
" \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n",
" \"meta-llama/Llama-3.2-3B-Instruct\",\n",
" \"meta-llama/Llama-3.1-8B-Instruct\",\n",
" \"mistralai/Mistral-Nemo-Instruct-2407\",\n",
" # \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n",
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
"]\n",
"\n",
"for model_id in open_model_ids:\n",
" print(f\"Evaluating '{model_id}'...\")\n",
" # action_type = \"tool_calling\"\n",
" # agent = ToolCallingAgent(\n",
" # tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
" # model=HfApiModel(model_id),\n",
" # max_steps=10,\n",
" # )\n",
" # file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" # answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n",
" action_type = \"code\"\n",
" agent = CodeAgent(\n",
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
" model=HfApiModel(model_id),\n",
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n",
" # Also evaluate vanilla model\n",
" action_type = \"vanilla\"\n",
" llm = HfApiModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Closed models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from smolagents import LiteLLMModel\n",
"\n",
"\n",
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
"\n",
"for model_id in litellm_model_ids:\n",
" print(f\"Evaluating '{model_id}'...\")\n",
" action_type = \"tool_calling\"\n",
" agent = ToolCallingAgent(\n",
" tools=[\n",
" GoogleSearchTool(),\n",
" VisitWebpageTool(),\n",
" PythonInterpreterTool([\"numpy\", \"sympy\"]),\n",
" ],\n",
" model=LiteLLMModel(model_id),\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n",
" action_type = \"code\"\n",
" agent = CodeAgent(\n",
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
" model=LiteLLMModel(model_id),\n",
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n",
" # Also evaluate vanilla model\n",
" action_type = \"vanilla\"\n",
" llm = LiteLLMModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# import glob\n",
"# import json\n",
"\n",
"# jsonl_files = glob.glob(f\"output/*.jsonl\")\n",
"\n",
"# for file_path in jsonl_files:\n",
"# if \"-Nemo-\" in file_path and \"-vanilla-\" in file_path:\n",
"# print(file_path)\n",
"# # Read all lines and filter out SimpleQA sources\n",
"# filtered_lines = []\n",
"# removed = 0\n",
"# with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
"# for line in f:\n",
"# try:\n",
"# data = json.loads(line.strip())\n",
"# data[\"answer\"] = data[\"answer\"][\"content\"]\n",
"# # if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
"# # removed +=1\n",
"# # else:\n",
"# filtered_lines.append(json.dumps(data) + \"\\n\")\n",
"# except json.JSONDecodeError:\n",
"# print(\"Invalid line:\", line)\n",
"# continue # Skip invalid JSON lines\n",
"# print(f\"Removed {removed} lines.\")\n",
"# # Write filtered content back to the same file\n",
"# with open(\n",
"# str(file_path).replace(\"-vanilla-\", \"-vanilla2-\"), \"w\", encoding=\"utf-8\"\n",
"# ) as f:\n",
"# f.writelines(filtered_lines)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Score answers"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\n"
]
}
],
"source": [
"import glob\n",
"\n",
"import pandas as pd\n",
"\n",
"\n",
"res = []\n",
"for file_path in glob.glob(\"output/*.jsonl\"):\n",
" data = []\n",
" with open(file_path) as f:\n",
" for line in f:\n",
" try:\n",
" # Use standard json module instead of pandas.json to handle large numbers better\n",
" record = json.loads(line)\n",
" data.append(record)\n",
" except json.JSONDecodeError as e:\n",
" print(f\"Error parsing line in {file_path}: {e}\")\n",
" continue\n",
"\n",
" try:\n",
" smoldf = pd.DataFrame(data)\n",
" smoldf[\"action_type\"] = \"vanilla\" if \"-vanilla-\" in file_path else \"code\"\n",
" res.append(smoldf)\n",
" except Exception as e:\n",
" print(f\"Error creating DataFrame from {file_path}: {e}\")\n",
" continue\n",
"\n",
"result_df = pd.concat(res)\n",
"\n",
"\n",
"def get_correct(row):\n",
" if row[\"source\"] == \"MATH\": # Checks the last number in answer\n",
" numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
" if len(numbers_answer) == 0:\n",
" return False\n",
" return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n",
" else:\n",
" return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
"\n",
"\n",
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
"\n",
"result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"pivot_df = result_df.pivot_table(\n",
" index=[\"model_id\", \"source\"],\n",
" columns=[\"action_type\"],\n",
" values=\"correct\",\n",
" fill_value=float(\"nan\"),\n",
").reset_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Display results"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>action_type</th>\n",
" <th>model_id</th>\n",
" <th>source</th>\n",
" <th>code</th>\n",
" <th>vanilla</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>28.1</td>\n",
" <td>6.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>76.0</td>\n",
" <td>30.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>88.0</td>\n",
" <td>10.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>25.0</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>86.0</td>\n",
" <td>60.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>86.0</td>\n",
" <td>8.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
" <td>GAIA</td>\n",
" <td>NaN</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
" <td>MATH</td>\n",
" <td>NaN</td>\n",
" <td>50.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
" <td>SimpleQA</td>\n",
" <td>NaN</td>\n",
" <td>34.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>gpt-4o</td>\n",
" <td>GAIA</td>\n",
" <td>25.6</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>gpt-4o</td>\n",
" <td>MATH</td>\n",
" <td>58.0</td>\n",
" <td>40.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>gpt-4o</td>\n",
" <td>SimpleQA</td>\n",
" <td>86.0</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>3.1</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>14.0</td>\n",
" <td>18.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>meta-llama/Llama-3.1-8B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>2.0</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>3.1</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>40.0</td>\n",
" <td>12.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>meta-llama/Llama-3.2-3B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>20.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
" <td>GAIA</td>\n",
" <td>31.2</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
" <td>MATH</td>\n",
" <td>72.0</td>\n",
" <td>40.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
" <td>SimpleQA</td>\n",
" <td>78.0</td>\n",
" <td>12.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
" <td>GAIA</td>\n",
" <td>0.0</td>\n",
" <td>3.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
" <td>MATH</td>\n",
" <td>30.0</td>\n",
" <td>22.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>mistralai/Mistral-Nemo-Instruct-2407</td>\n",
" <td>SimpleQA</td>\n",
" <td>30.0</td>\n",
" <td>6.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"action_type model_id source code vanilla\n",
"0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
"1 Qwen/Qwen2.5-72B-Instruct MATH 76.0 30.0\n",
"2 Qwen/Qwen2.5-72B-Instruct SimpleQA 88.0 10.0\n",
"3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 25.0 3.1\n",
"4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 86.0 60.0\n",
"5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
"6 anthropic/claude-3-5-sonnet-latest GAIA NaN 3.1\n",
"7 anthropic/claude-3-5-sonnet-latest MATH NaN 50.0\n",
"8 anthropic/claude-3-5-sonnet-latest SimpleQA NaN 34.0\n",
"9 gpt-4o GAIA 25.6 3.1\n",
"10 gpt-4o MATH 58.0 40.0\n",
"11 gpt-4o SimpleQA 86.0 6.0\n",
"12 meta-llama/Llama-3.1-8B-Instruct GAIA 3.1 0.0\n",
"13 meta-llama/Llama-3.1-8B-Instruct MATH 14.0 18.0\n",
"14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 2.0 6.0\n",
"15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
"16 meta-llama/Llama-3.2-3B-Instruct MATH 40.0 12.0\n",
"17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 20.0 0.0\n",
"18 meta-llama/Llama-3.3-70B-Instruct GAIA 31.2 3.1\n",
"19 meta-llama/Llama-3.3-70B-Instruct MATH 72.0 40.0\n",
"20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 78.0 12.0\n",
"21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 0.0 3.1\n",
"22 mistralai/Mistral-Nemo-Instruct-2407 MATH 30.0 22.0\n",
"23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 6.0"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(pivot_df)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1500x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mnotebook controller is DISPOSED. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"from matplotlib.legend_handler import HandlerTuple # Added import\n",
"\n",
"\n",
"# Assuming pivot_df is your original dataframe\n",
"models = pivot_df[\"model_id\"].unique()\n",
"sources = pivot_df[\"source\"].unique()\n",
"\n",
"# Create figure and axis\n",
"plt.style.use(\"seaborn-v0_8-white\")\n",
"fig, ax = plt.subplots(figsize=(15, 6))\n",
"\n",
"# Set the width of each bar group and positions of the bars\n",
"width = 0.15 # width of each bar\n",
"spacing = 0.02 # space between bars within a group\n",
"group_spacing = 0.2 # space between model groups\n",
"\n",
"# Calculate positions for the bars\n",
"num_sources = len(sources)\n",
"total_width_per_group = (width + spacing) * num_sources * 2 # *2 for agent and vanilla\n",
"x = np.arange(len(models)) * (total_width_per_group + group_spacing)\n",
"\n",
"# Plot bars for each source\n",
"for i, source in enumerate(sources):\n",
" source_data = pivot_df[pivot_df[\"source\"] == source]\n",
" agent_scores = [\n",
" source_data[source_data[\"model_id\"] == model][\"code\"].values[0]\n",
" if len(source_data[source_data[\"model_id\"] == model]) > 0\n",
" else np.nan\n",
" for model in models\n",
" ]\n",
" vanilla_scores = [\n",
" source_data[source_data[\"model_id\"] == model][\"vanilla\"].values[0]\n",
" if len(source_data[source_data[\"model_id\"] == model]) > 0\n",
" else np.nan\n",
" for model in models\n",
" ]\n",
"\n",
" # Position calculation for each pair of bars\n",
" pos = x + i * (width * 2 + spacing)\n",
"\n",
" agent_bars = ax.bar(pos, agent_scores, width, label=f\"{source} (Agent)\", alpha=0.8)\n",
" vanilla_bars = ax.bar(\n",
" pos + width * 0.6,\n",
" vanilla_scores,\n",
" width,\n",
" hatch=\"////\",\n",
" alpha=0.5,\n",
" hatch_linewidth=2,\n",
" label=f\"{source} (Vanilla)\",\n",
" color=\"white\",\n",
" edgecolor=agent_bars[0].get_facecolor(),\n",
" )\n",
"\n",
"# Customize the plot\n",
"ax.set_ylabel(\"Score\")\n",
"ax.set_title(\"Model Performance Comparison\")\n",
"\n",
"# Set x-axis ticks in the middle of each group\n",
"group_centers = x + (total_width_per_group - spacing) / 2\n",
"ax.set_xticks(group_centers)\n",
"\n",
"# Wrap long model names to prevent overlap\n",
"wrapped_labels = [\"\\n\".join(model.split(\"/\")) for model in models]\n",
"ax.set_xticklabels(wrapped_labels, rotation=0, ha=\"center\")\n",
"\n",
"# Modify legend to combine agent and vanilla entries\n",
"handles, labels = ax.get_legend_handles_labels()\n",
"unique_sources = sources\n",
"legend_elements = [\n",
" (handles[i * 2], handles[i * 2 + 1], labels[i * 2].replace(\" (Agent)\", \"\")) for i in range(len(unique_sources))\n",
"]\n",
"custom_legend = ax.legend(\n",
" [(agent_handle, vanilla_handle) for agent_handle, vanilla_handle, _ in legend_elements],\n",
" [label for _, _, label in legend_elements],\n",
" handler_map={tuple: HandlerTuple(ndivide=None)},\n",
" bbox_to_anchor=(1.05, 1),\n",
" loc=\"upper left\",\n",
")\n",
"\n",
"ax.yaxis.grid(True, linestyle=\"--\", alpha=0.3)\n",
"ax.set_ylim(bottom=0)\n",
"plt.tight_layout()\n",
"ax.spines[\"top\"].set_visible(False)\n",
"ax.spines[\"right\"].set_visible(False)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'formatted_df' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[12], line 45\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m mathjax_table\n\u001b[1;32m 44\u001b[0m \u001b[38;5;66;03m# Usage (after running your previous data processing code):\u001b[39;00m\n\u001b[0;32m---> 45\u001b[0m mathjax_table \u001b[38;5;241m=\u001b[39m create_mathjax_table(pivot_df, \u001b[43mformatted_df\u001b[49m)\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28mprint\u001b[39m(mathjax_table)\n",
"\u001b[0;31mNameError\u001b[0m: name 'formatted_df' is not defined"
]
}
],
"source": [
"def create_mathjax_table(pivot_df, formatted_df):\n",
" # Start the matrix environment with 4 columns\n",
" # l for left-aligned model and task, c for centered numbers\n",
" mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n",
" mathjax_table += \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
" mathjax_table += \"\\\\hline\\n\"\n",
"\n",
" # Sort the DataFrame by model_id and source\n",
" formatted_df = formatted_df.sort_values([\"model_id\", \"source\"])\n",
"\n",
" current_model = None\n",
" for _, row in formatted_df.iterrows():\n",
" model = row[\"model_id\"]\n",
" source = row[\"source\"]\n",
"\n",
" # Add a horizontal line between different models\n",
" if current_model is not None and current_model != model:\n",
" mathjax_table += \"\\\\hline\\n\"\n",
"\n",
" # Format model name\n",
" model_display = model.replace(\"_\", \"\\\\_\")\n",
" if \"Qwen\" in model or \"anthropic\" in model:\n",
" model_display = f\"\\\\textit{{{model_display}}}\"\n",
"\n",
" # If it's the same model as previous row, use empty space\n",
" if current_model == model:\n",
" model_display = \"\\\\;\"\n",
"\n",
" # Add the data row\n",
" mathjax_table += f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
"\n",
" current_model = model\n",
"\n",
" mathjax_table += \"\\\\hline\\n\"\n",
" mathjax_table += \"\\\\end{array}\"\n",
"\n",
" return mathjax_table\n",
"\n",
"\n",
"# Usage (after running your previous data processing code):\n",
"# mathjax_table = create_mathjax_table(pivot_df, formatted_df)\n",
"# print(mathjax_table)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "test",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}