699 lines
24 KiB
Plaintext
699 lines
24 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import datasets\n",
|
|
"\n",
|
|
"eval_ds = datasets.load_dataset(\"m-ric/agents_medium_benchmark_2\")[\"train\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Define utilities and tools\n",
|
|
"To run the SERPAPI tool, you will need to have a [SerpAPI](https://serpapi.com/dashboard) API key: for this you need a paid account."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import time\n",
|
|
"import json\n",
|
|
"import os\n",
|
|
"import re\n",
|
|
"import string\n",
|
|
"import warnings\n",
|
|
"from tqdm import tqdm\n",
|
|
"from typing import List\n",
|
|
"\n",
|
|
"from smolagents import (\n",
|
|
" GoogleSearchTool,\n",
|
|
" CodeAgent,\n",
|
|
" ToolCallingAgent,\n",
|
|
" HfApiModel,\n",
|
|
" AgentError,\n",
|
|
" VisitWebpageTool,\n",
|
|
" PythonInterpreterTool,\n",
|
|
")\n",
|
|
"from smolagents.agents import ActionStep\n",
|
|
"from dotenv import load_dotenv\n",
|
|
"\n",
|
|
"load_dotenv()\n",
|
|
"os.makedirs(\"output\", exist_ok=True)\n",
|
|
"\n",
|
|
"\n",
|
|
"def serialize_agent_error(obj):\n",
|
|
" if isinstance(obj, AgentError):\n",
|
|
" return {\"error_type\": obj.__class__.__name__, \"message\": obj.message}\n",
|
|
" else:\n",
|
|
" return str(obj)\n",
|
|
"\n",
|
|
"\n",
|
|
"def answer_questions(eval_ds, file_name, agent, model_id, action_type):\n",
|
|
" answered_questions = []\n",
|
|
" if os.path.exists(file_name):\n",
|
|
" with open(file_name, \"r\") as f:\n",
|
|
" for line in f:\n",
|
|
" answered_questions.append(json.loads(line)[\"question\"])\n",
|
|
"\n",
|
|
" for _, example in tqdm(enumerate(eval_ds), total=len(eval_ds)):\n",
|
|
" try:\n",
|
|
" question = example[\"question\"]\n",
|
|
" if example[\"source\"] == \"SimpleQA\":\n",
|
|
" question += \" Answer with only the final number.\"\n",
|
|
" if question in answered_questions:\n",
|
|
" continue\n",
|
|
" start_time = time.time()\n",
|
|
" answer = agent.run(question)\n",
|
|
" end_time = time.time()\n",
|
|
" for step_log in agent.logs:\n",
|
|
" if hasattr(step_log, \"memory\"):\n",
|
|
" step_log.memory = None\n",
|
|
"\n",
|
|
" # Remove memory from logs to make them more compact.\n",
|
|
" for step in agent.logs:\n",
|
|
" if isinstance(step, ActionStep):\n",
|
|
" step.agent_memory = None\n",
|
|
"\n",
|
|
" annotated_example = {\n",
|
|
" \"model_id\": model_id,\n",
|
|
" \"agent_action_type\": action_type,\n",
|
|
" \"question\": question,\n",
|
|
" \"answer\": answer,\n",
|
|
" \"true_answer\": example[\"true_answer\"],\n",
|
|
" \"source\": example[\"source\"],\n",
|
|
" \"intermediate_steps\": str(agent.logs),\n",
|
|
" \"start_time\": start_time,\n",
|
|
" \"end_time\": end_time,\n",
|
|
" \"token_counts\": agent.monitor.get_total_token_counts(),\n",
|
|
" }\n",
|
|
"\n",
|
|
" with open(file_name, \"a\") as f:\n",
|
|
" json.dump(annotated_example, f, default=serialize_agent_error)\n",
|
|
" f.write(\"\\n\") # add a newline for JSONL format\n",
|
|
" except Exception as e:\n",
|
|
" print(\"Failed:\", e)\n",
|
|
"\n",
|
|
"\n",
|
|
"def normalize_number_str(number_str: str) -> float:\n",
|
|
" # we replace these common units and commas to allow\n",
|
|
" # conversion to float\n",
|
|
" for char in [\"$\", \"%\", \",\"]:\n",
|
|
" number_str = number_str.replace(char, \"\")\n",
|
|
" try:\n",
|
|
" return float(number_str)\n",
|
|
" except ValueError:\n",
|
|
" return float(\"inf\")\n",
|
|
"\n",
|
|
"\n",
|
|
"def split_string(\n",
|
|
" s: str,\n",
|
|
" char_list: list[str] = [\",\", \";\"],\n",
|
|
") -> list[str]:\n",
|
|
" pattern = f\"[{''.join(char_list)}]\"\n",
|
|
" return re.split(pattern, s)\n",
|
|
"\n",
|
|
"\n",
|
|
"def is_float(element: any) -> bool:\n",
|
|
" try:\n",
|
|
" float(element)\n",
|
|
" return True\n",
|
|
" except ValueError:\n",
|
|
" return False\n",
|
|
"\n",
|
|
"\n",
|
|
"def normalize_str(input_str, remove_punct=True) -> str:\n",
|
|
" \"\"\"\n",
|
|
" Normalize a string by:\n",
|
|
" - Removing all white spaces\n",
|
|
" - Optionally removing punctuation (if remove_punct is True)\n",
|
|
" - Converting to lowercase\n",
|
|
" Parameters:\n",
|
|
" - input_str: str, the string to normalize\n",
|
|
" - remove_punct: bool, whether to remove punctuation (default: True)\n",
|
|
" Returns:\n",
|
|
" - str, the normalized string\n",
|
|
" \"\"\"\n",
|
|
" # Remove all white spaces. Required e.g for seagull vs. sea gull\n",
|
|
" no_spaces = re.sub(r\"\\s\", \"\", input_str)\n",
|
|
"\n",
|
|
" # Remove punctuation, if specified.\n",
|
|
" if remove_punct:\n",
|
|
" translator = str.maketrans(\"\", \"\", string.punctuation)\n",
|
|
" return no_spaces.lower().translate(translator)\n",
|
|
" else:\n",
|
|
" return no_spaces.lower()\n",
|
|
"\n",
|
|
"\n",
|
|
"def extract_numbers(text: str) -> List[str]:\n",
|
|
" \"\"\"This pattern matches:\n",
|
|
" - Optional negative sign\n",
|
|
" - Numbers with optional comma thousand separators\n",
|
|
" - Optional decimal points with decimal numbers\n",
|
|
" \"\"\"\n",
|
|
" pattern = r\"-?(?:\\d{1,3}(?:,\\d{3})+|\\d+)(?:\\.\\d+)?\"\n",
|
|
"\n",
|
|
" return [el.replace(\",\", \"\") for el in re.findall(pattern, text)]\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_question_score_gaia(\n",
|
|
" model_answer: str,\n",
|
|
" ground_truth: str,\n",
|
|
") -> bool:\n",
|
|
" if is_float(ground_truth):\n",
|
|
" normalized_answer = normalize_number_str(str(model_answer))\n",
|
|
" return normalized_answer == float(ground_truth)\n",
|
|
"\n",
|
|
" elif any(char in ground_truth for char in [\",\", \";\"]): # if gt is a list\n",
|
|
" # question with the fish: normalization removes punct\n",
|
|
" gt_elems = split_string(ground_truth)\n",
|
|
" ma_elems = split_string(model_answer)\n",
|
|
"\n",
|
|
" if len(gt_elems) != len(ma_elems): # check length is the same\n",
|
|
" warnings.warn(\n",
|
|
" \"Answer lists have different lengths, returning False.\", UserWarning\n",
|
|
" )\n",
|
|
" return False\n",
|
|
"\n",
|
|
" comparisons = []\n",
|
|
" for ma_elem, gt_elem in zip(\n",
|
|
" ma_elems, gt_elems\n",
|
|
" ): # compare each element as float or str\n",
|
|
" if is_float(gt_elem):\n",
|
|
" normalized_ma_elem = normalize_number_str(ma_elem)\n",
|
|
" comparisons.append(normalized_ma_elem == float(gt_elem))\n",
|
|
" else:\n",
|
|
" # we do not remove punct since comparisons can include punct\n",
|
|
" comparisons.append(\n",
|
|
" normalize_str(ma_elem, remove_punct=False)\n",
|
|
" == normalize_str(gt_elem, remove_punct=False)\n",
|
|
" )\n",
|
|
" return all(comparisons)\n",
|
|
"\n",
|
|
" else: # if gt is a str\n",
|
|
" return normalize_str(model_answer) == normalize_str(ground_truth)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Run benchmark\n",
|
|
"\n",
|
|
"### Open models"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"open_model_ids = [\n",
|
|
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
|
|
" # \"Qwen/QwQ-32B-Preview\",\n",
|
|
" \"Qwen/Qwen2.5-72B-Instruct\",\n",
|
|
" \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n",
|
|
" \"meta-llama/Llama-3.2-3B-Instruct\",\n",
|
|
" # \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n",
|
|
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
|
|
"]\n",
|
|
"\n",
|
|
"for model_id in open_model_ids:\n",
|
|
" print(f\"Evaluating '{model_id}'...\")\n",
|
|
" action_type = \"tool_calling\"\n",
|
|
" agent = ToolCallingAgent(\n",
|
|
" tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
|
|
" model=HfApiModel(model_id),\n",
|
|
" max_iterations=10,\n",
|
|
" )\n",
|
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
|
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
|
"\n",
|
|
" action_type = \"code\"\n",
|
|
" agent = CodeAgent(\n",
|
|
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
|
|
" model=HfApiModel(model_id),\n",
|
|
" additional_authorized_imports=[\"numpy\"],\n",
|
|
" max_iterations=10,\n",
|
|
" )\n",
|
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
|
" answer_questions(eval_ds, file_name, agent, model_id, action_type)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Closed models"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from smolagents import LiteLLMModel\n",
|
|
"\n",
|
|
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
|
|
"\n",
|
|
"for model_id in litellm_model_ids:\n",
|
|
" print(f\"Evaluating '{model_id}'...\")\n",
|
|
" action_type = \"tool_calling\"\n",
|
|
" agent = ToolCallingAgent(\n",
|
|
" tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
|
|
" model=LiteLLMModel(model_id),\n",
|
|
" max_iterations=10,\n",
|
|
" )\n",
|
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
|
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
|
"\n",
|
|
" action_type = \"code\"\n",
|
|
" agent = CodeAgent(\n",
|
|
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
|
|
" model=LiteLLMModel(model_id),\n",
|
|
" additional_authorized_imports=[\"numpy\"],\n",
|
|
" max_iterations=10,\n",
|
|
" )\n",
|
|
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
|
" answer_questions(eval_ds, file_name, agent, model_id, action_type)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# import glob\n",
|
|
"# import json\n",
|
|
"# jsonl_files = glob.glob(f\"output/*.jsonl\")\n",
|
|
"\n",
|
|
"# for file_path in jsonl_files:\n",
|
|
"# print(file_path)\n",
|
|
"# # Read all lines and filter out SimpleQA sources\n",
|
|
"# filtered_lines = []\n",
|
|
"# removed = 0\n",
|
|
"# with open(file_path, 'r', encoding='utf-8') as f:\n",
|
|
"# for line in f:\n",
|
|
"# try:\n",
|
|
"# data = json.loads(line.strip())\n",
|
|
"# if data[\"source\"] == \"SimpleQA\" and \"Answer with only the final number.\" not in data[\"question\"]:\n",
|
|
"# removed +=1\n",
|
|
"# else:\n",
|
|
"# filtered_lines.append(line)\n",
|
|
"# except json.JSONDecodeError:\n",
|
|
"# print(\"Invalid line:\", line)\n",
|
|
"# continue # Skip invalid JSON lines\n",
|
|
"# print(f\"Removed {removed} lines.\")\n",
|
|
"# # Write filtered content back to the same file\n",
|
|
"# with open(file_path, 'w', encoding='utf-8') as f:\n",
|
|
"# f.writelines(filtered_lines)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Score answers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_37227/1724525657.py:154: UserWarning: Answer lists have different lengths, returning False.\n",
|
|
" warnings.warn(\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import glob\n",
|
|
"\n",
|
|
"res = []\n",
|
|
"for f in glob.glob(f\"output/*.jsonl\"):\n",
|
|
" res.append(pd.read_json(f, lines=True))\n",
|
|
"result_df = pd.concat(res)\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_correct(row):\n",
|
|
" if row[\"source\"] == \"GSM8K\":\n",
|
|
" numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
|
|
" if len(numbers_answer) == 0:\n",
|
|
" return False\n",
|
|
" return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n",
|
|
" else:\n",
|
|
" return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
|
|
"\n",
|
|
"\n",
|
|
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
|
|
"\n",
|
|
"result_df = result_df.loc[\n",
|
|
" (result_df[\"agent_action_type\"] == \"code\")\n",
|
|
" & (\n",
|
|
" ~result_df[\"model_id\"].isin(\n",
|
|
" [\n",
|
|
" \"meta-llama/Llama-3.2-3B-Instruct\",\n",
|
|
" \"meta-llama/Llama-3.1-70B-Instruct\",\n",
|
|
" \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n",
|
|
" ]\n",
|
|
" )\n",
|
|
" )\n",
|
|
"]\n",
|
|
"result_df = (\n",
|
|
" (result_df.groupby([\"model_id\", \"source\"])[[\"correct\"]].mean() * 100)\n",
|
|
" .round(1)\n",
|
|
" .reset_index()\n",
|
|
")\n",
|
|
"result_df[\"type\"] = \"agent\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"vanilla_data = [\n",
|
|
" [\"gpt-4o\", \"SimpleQA\", 38.2],\n",
|
|
" [\"gpt-4o\", \"GAIA\", 9.3],\n",
|
|
" [\"Qwen/Qwen2.5-72B-Instruct\", \"SimpleQA\", 9.1],\n",
|
|
" [\"anthropic/claude-3-5-sonnet-latest\", \"SimpleQA\", 28.4],\n",
|
|
" [\"gpt-4o\", \"GSM8K\", 94.3],\n",
|
|
" [\"anthropic/claude-3-5-sonnet-latest\", \"GSM8K\", 96.4],\n",
|
|
" [\"meta-llama/Llama-3.3-70B-Instruct\", \"GSM8K\", 95.1],\n",
|
|
"]\n",
|
|
"\n",
|
|
"df2 = pd.DataFrame(vanilla_data, columns=[\"model_id\", \"source\", \"correct\"])\n",
|
|
"df2[\"type\"] = \"vanilla\"\n",
|
|
"\n",
|
|
"combined_df = pd.concat([result_df, df2], ignore_index=True)\n",
|
|
"\n",
|
|
"pivot_df = combined_df.pivot_table(\n",
|
|
" index=[\"model_id\", \"source\"],\n",
|
|
" columns=[\"type\"],\n",
|
|
" values=\"correct\",\n",
|
|
" fill_value=float(\"nan\"),\n",
|
|
").reset_index()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Display results"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th>type</th>\n",
|
|
" <th>model_id</th>\n",
|
|
" <th>source</th>\n",
|
|
" <th>agent</th>\n",
|
|
" <th>vanilla</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
|
" <td>GAIA</td>\n",
|
|
" <td>12.5</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
|
" <td>GSM8K</td>\n",
|
|
" <td>82.9</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
|
" <td>SimpleQA</td>\n",
|
|
" <td>42.5</td>\n",
|
|
" <td>9.1</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
|
|
" <td>GAIA</td>\n",
|
|
" <td>28.1</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
|
|
" <td>GSM8K</td>\n",
|
|
" <td>92.9</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>5</th>\n",
|
|
" <td>Qwen/Qwen2.5-Coder-32B-Instruct</td>\n",
|
|
" <td>SimpleQA</td>\n",
|
|
" <td>42.5</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>6</th>\n",
|
|
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
|
" <td>GAIA</td>\n",
|
|
" <td>43.8</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>7</th>\n",
|
|
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
|
" <td>GSM8K</td>\n",
|
|
" <td>91.4</td>\n",
|
|
" <td>96.4</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>8</th>\n",
|
|
" <td>anthropic/claude-3-5-sonnet-latest</td>\n",
|
|
" <td>SimpleQA</td>\n",
|
|
" <td>47.5</td>\n",
|
|
" <td>28.4</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>9</th>\n",
|
|
" <td>gpt-4o</td>\n",
|
|
" <td>GAIA</td>\n",
|
|
" <td>25.0</td>\n",
|
|
" <td>9.3</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>10</th>\n",
|
|
" <td>gpt-4o</td>\n",
|
|
" <td>GSM8K</td>\n",
|
|
" <td>91.4</td>\n",
|
|
" <td>94.3</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>11</th>\n",
|
|
" <td>gpt-4o</td>\n",
|
|
" <td>SimpleQA</td>\n",
|
|
" <td>60.0</td>\n",
|
|
" <td>38.2</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>12</th>\n",
|
|
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
|
" <td>GAIA</td>\n",
|
|
" <td>21.9</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>13</th>\n",
|
|
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
|
" <td>GSM8K</td>\n",
|
|
" <td>95.7</td>\n",
|
|
" <td>95.1</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>14</th>\n",
|
|
" <td>meta-llama/Llama-3.3-70B-Instruct</td>\n",
|
|
" <td>SimpleQA</td>\n",
|
|
" <td>30.0</td>\n",
|
|
" <td>NaN</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
"type model_id source agent vanilla\n",
|
|
"0 Qwen/Qwen2.5-72B-Instruct GAIA 12.5 NaN\n",
|
|
"1 Qwen/Qwen2.5-72B-Instruct GSM8K 82.9 NaN\n",
|
|
"2 Qwen/Qwen2.5-72B-Instruct SimpleQA 42.5 9.1\n",
|
|
"3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 28.1 NaN\n",
|
|
"4 Qwen/Qwen2.5-Coder-32B-Instruct GSM8K 92.9 NaN\n",
|
|
"5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 42.5 NaN\n",
|
|
"6 anthropic/claude-3-5-sonnet-latest GAIA 43.8 NaN\n",
|
|
"7 anthropic/claude-3-5-sonnet-latest GSM8K 91.4 96.4\n",
|
|
"8 anthropic/claude-3-5-sonnet-latest SimpleQA 47.5 28.4\n",
|
|
"9 gpt-4o GAIA 25.0 9.3\n",
|
|
"10 gpt-4o GSM8K 91.4 94.3\n",
|
|
"11 gpt-4o SimpleQA 60.0 38.2\n",
|
|
"12 meta-llama/Llama-3.3-70B-Instruct GAIA 21.9 NaN\n",
|
|
"13 meta-llama/Llama-3.3-70B-Instruct GSM8K 95.7 95.1\n",
|
|
"14 meta-llama/Llama-3.3-70B-Instruct SimpleQA 30.0 NaN"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"display(pivot_df)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\\begin{array}{llcc}\n",
|
|
"\\text{Model} & \\text{Task} & \\text{Agent} & \\text{Vanilla} \\\\\n",
|
|
"\\hline\n",
|
|
"\\textit{Qwen/Qwen2.5-72B-Instruct} & GAIA & 12.500 & - \\\\\n",
|
|
"\\; & GSM8K & 82.900 & - \\\\\n",
|
|
"\\; & SimpleQA & \\textbf{42.500} & 9.100 \\\\\n",
|
|
"\\hline\n",
|
|
"\\textit{Qwen/Qwen2.5-Coder-32B-Instruct} & GAIA & 28.100 & - \\\\\n",
|
|
"\\; & GSM8K & 92.900 & - \\\\\n",
|
|
"\\; & SimpleQA & 42.500 & - \\\\\n",
|
|
"\\hline\n",
|
|
"\\textit{anthropic/claude-3-5-sonnet-latest} & GAIA & 43.800 & - \\\\\n",
|
|
"\\; & GSM8K & 91.400 & \\textbf{96.400} \\\\\n",
|
|
"\\; & SimpleQA & \\textbf{47.500} & 28.400 \\\\\n",
|
|
"\\hline\n",
|
|
"gpt-4o & GAIA & \\textbf{25.000} & 9.300 \\\\\n",
|
|
"\\; & GSM8K & 91.400 & \\textbf{94.300} \\\\\n",
|
|
"\\; & SimpleQA & \\textbf{60.000} & 38.200 \\\\\n",
|
|
"\\hline\n",
|
|
"meta-llama/Llama-3.3-70B-Instruct & GAIA & 21.900 & - \\\\\n",
|
|
"\\; & GSM8K & \\textbf{95.700} & 95.100 \\\\\n",
|
|
"\\; & SimpleQA & 30.000 & - \\\\\n",
|
|
"\\hline\n",
|
|
"\\end{array}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def create_mathjax_table(pivot_df, formatted_df):\n",
|
|
" # Start the matrix environment with 4 columns\n",
|
|
" # l for left-aligned model and task, c for centered numbers\n",
|
|
" mathjax_table = \"\\\\begin{array}{llcc}\\n\"\n",
|
|
" mathjax_table += (\n",
|
|
" \"\\\\text{Model} & \\\\text{Task} & \\\\text{Agent} & \\\\text{Vanilla} \\\\\\\\\\n\"\n",
|
|
" )\n",
|
|
" mathjax_table += \"\\\\hline\\n\"\n",
|
|
"\n",
|
|
" # Sort the DataFrame by model_id and source\n",
|
|
" formatted_df = formatted_df.sort_values([\"model_id\", \"source\"])\n",
|
|
"\n",
|
|
" current_model = None\n",
|
|
" for _, row in formatted_df.iterrows():\n",
|
|
" model = row[\"model_id\"]\n",
|
|
" source = row[\"source\"]\n",
|
|
"\n",
|
|
" # Add a horizontal line between different models\n",
|
|
" if current_model is not None and current_model != model:\n",
|
|
" mathjax_table += \"\\\\hline\\n\"\n",
|
|
"\n",
|
|
" # Format model name\n",
|
|
" model_display = model.replace(\"_\", \"\\\\_\")\n",
|
|
" if \"Qwen\" in model or \"anthropic\" in model:\n",
|
|
" model_display = f\"\\\\textit{{{model_display}}}\"\n",
|
|
"\n",
|
|
" # If it's the same model as previous row, use empty space\n",
|
|
" if current_model == model:\n",
|
|
" model_display = \"\\\\;\"\n",
|
|
"\n",
|
|
" # Add the data row\n",
|
|
" mathjax_table += (\n",
|
|
" f\"{model_display} & {source} & {row['agent']} & {row['vanilla']} \\\\\\\\\\n\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" current_model = model\n",
|
|
"\n",
|
|
" mathjax_table += \"\\\\hline\\n\"\n",
|
|
" mathjax_table += \"\\\\end{array}\"\n",
|
|
"\n",
|
|
" return mathjax_table\n",
|
|
"\n",
|
|
"\n",
|
|
"# Usage (after running your previous data processing code):\n",
|
|
"mathjax_table = create_mathjax_table(pivot_df, formatted_df)\n",
|
|
"print(mathjax_table)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "compare-agents",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|