From 6d0e4e49fc329a17e6ff33bd1b7933154361442d Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Thu, 30 Jan 2025 19:21:32 +0100
Subject: [PATCH] Update benchmark with Hub datasets (#412)
* Use smolagents-benchmark Hub datasets
* Push results to Hub
* Fix style
* Add Constants section at the top
* Set DATE as constant
---
examples/benchmark.ipynb | 685 +++++++++++++++++++++++----------------
1 file changed, 410 insertions(+), 275 deletions(-)
diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb
index 065adce..bd3e11a 100644
--- a/examples/benchmark.ipynb
+++ b/examples/benchmark.ipynb
@@ -16,190 +16,43 @@
}
],
"source": [
- "!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " question | \n",
- " source | \n",
- " true_answer | \n",
- " true_reasoning | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " If Eliud Kipchoge could maintain his record-ma... | \n",
- " GAIA | \n",
- " 17 | \n",
- " None | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " How many studio albums were published by Merce... | \n",
- " GAIA | \n",
- " 3 | \n",
- " None | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " Here's a fun riddle that I think you'll enjoy.... | \n",
- " GAIA | \n",
- " 3 | \n",
- " None | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " My family reunion is this week, and I was assi... | \n",
- " GAIA | \n",
- " 2 | \n",
- " None | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " In Emily Midkiff's June 2014 article in a jour... | \n",
- " GAIA | \n",
- " fluffy | \n",
- " None | \n",
- "
\n",
- " \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- " ... | \n",
- "
\n",
- " \n",
- " 127 | \n",
- " What year was the municipality of San Carlos, ... | \n",
- " SimpleQA | \n",
- " 1786 | \n",
- " ['https://en.wikipedia.org/wiki/San_Carlos,_An... | \n",
- "
\n",
- " \n",
- " 128 | \n",
- " In which year was Maria Elena Walsh named Illu... | \n",
- " SimpleQA | \n",
- " 1985 | \n",
- " ['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele... | \n",
- "
\n",
- " \n",
- " 129 | \n",
- " What is the durability of the Istarelle spear ... | \n",
- " SimpleQA | \n",
- " 800 | \n",
- " ['http://demonssouls.wikidot.com/spear', 'http... | \n",
- "
\n",
- " \n",
- " 130 | \n",
- " What is the number of the executive order that... | \n",
- " SimpleQA | \n",
- " 7034 | \n",
- " ['https://www.loc.gov/collections/federal-thea... | \n",
- "
\n",
- " \n",
- " 131 | \n",
- " Within plus or minus one minute, when was Marq... | \n",
- " SimpleQA | \n",
- " 77 | \n",
- " ['https://www.fifa.com/fifaplus/en/match-centr... | \n",
- "
\n",
- " \n",
- "
\n",
- "
132 rows × 4 columns
\n",
- "
"
- ],
- "text/plain": [
- " question source true_answer \\\n",
- "0 If Eliud Kipchoge could maintain his record-ma... GAIA 17 \n",
- "1 How many studio albums were published by Merce... GAIA 3 \n",
- "2 Here's a fun riddle that I think you'll enjoy.... GAIA 3 \n",
- "3 My family reunion is this week, and I was assi... GAIA 2 \n",
- "4 In Emily Midkiff's June 2014 article in a jour... GAIA fluffy \n",
- ".. ... ... ... \n",
- "127 What year was the municipality of San Carlos, ... SimpleQA 1786 \n",
- "128 In which year was Maria Elena Walsh named Illu... SimpleQA 1985 \n",
- "129 What is the durability of the Istarelle spear ... SimpleQA 800 \n",
- "130 What is the number of the executive order that... SimpleQA 7034 \n",
- "131 Within plus or minus one minute, when was Marq... SimpleQA 77 \n",
- "\n",
- " true_reasoning \n",
- "0 None \n",
- "1 None \n",
- "2 None \n",
- "3 None \n",
- "4 None \n",
- ".. ... \n",
- "127 ['https://en.wikipedia.org/wiki/San_Carlos,_An... \n",
- "128 ['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele... \n",
- "129 ['http://demonssouls.wikidot.com/spear', 'http... \n",
- "130 ['https://www.loc.gov/collections/federal-thea... \n",
- "131 ['https://www.fifa.com/fifaplus/en/match-centr... \n",
- "\n",
- "[132 rows x 4 columns]"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import datasets\n",
- "import pandas as pd\n",
- "\n",
- "\n",
- "eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
- "pd.DataFrame(eval_ds)"
+ "!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Define utilities and tools\n",
- "To run the SERPAPI tool, you will need to have a [SerpAPI](https://serpapi.com/dashboard) API key: for this you need a paid account."
+ "## Constants and utilities/tools"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
+ "# Benchmark date\n",
+ "# - set a concrete date:\n",
+ "DATE = \"2024-12-26\"\n",
+ "# - or use default: today\n",
+ "# DATE = None\n",
+ "\n",
+ "# Evaluation dataset\n",
+ "EVAL_DATASET = \"smolagents-benchmark/benchmark-v1\"\n",
+ "\n",
+ "# Answers dataset: it must be a gated dataset; required to score the answers\n",
+ "ANSWERS_DATASET = \"smolagents-benchmark/answers\"\n",
+ "# Whether to push the answers dataset to the Hub\n",
+ "PUSH_ANSWERS_DATASET_TO_HUB = True\n",
+ "\n",
+ "# Results dataset\n",
+ "RESULTS_DATASET = \"smolagents-benchmark/results\"\n",
+ "# Whether to push the results dataset to the Hub\n",
+ "PUSH_RESULTS_DATASET_TO_HUB = True\n",
+ "\n",
+ "\n",
+ "import datetime\n",
"import json\n",
"import os\n",
"import re\n",
@@ -208,6 +61,7 @@
"import warnings\n",
"from typing import List\n",
"\n",
+ "import datasets\n",
"from dotenv import load_dotenv\n",
"from tqdm import tqdm\n",
"\n",
@@ -234,60 +88,85 @@
" return str(obj)\n",
"\n",
"\n",
- "def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n",
- " answered_questions = []\n",
- " if os.path.exists(file_name):\n",
- " with open(file_name, \"r\") as f:\n",
- " for line in f:\n",
- " answered_questions.append(json.loads(line)[\"question\"])\n",
+ "def answer_questions(\n",
+ " eval_ds,\n",
+ " agent,\n",
+ " model_id,\n",
+ " action_type,\n",
+ " is_vanilla_llm=False,\n",
+ " date=DATE,\n",
+ " output_dir=\"output\",\n",
+ " push_to_hub_dataset=ANSWERS_DATASET if PUSH_ANSWERS_DATASET_TO_HUB else None,\n",
+ "):\n",
+ " date = date or datetime.date.today().isoformat()\n",
"\n",
- " for _, example in tqdm(enumerate(eval_ds), total=len(eval_ds)):\n",
- " try:\n",
- " question = example[\"question\"]\n",
- " if example[\"source\"] == \"SimpleQA\":\n",
- " question += \" Answer with only the final number.\"\n",
- " if example[\"source\"] == \"MATH\":\n",
- " question += \" Write code, not latex.\"\n",
- " if question in answered_questions:\n",
- " continue\n",
- " start_time = time.time()\n",
+ " for task in eval_ds:\n",
+ " file_name = f\"output/{model_id.replace('/', '__')}__{action_type}__{task}__{date}.jsonl\"\n",
+ " answered_questions = []\n",
+ " if os.path.exists(file_name):\n",
+ " with open(file_name, \"r\") as f:\n",
+ " for line in f:\n",
+ " answered_questions.append(json.loads(line)[\"question\"])\n",
"\n",
- " if is_vanilla_llm:\n",
- " llm = agent\n",
- " answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n",
- " token_count = {\n",
- " \"input\": llm.last_input_token_count,\n",
- " \"output\": llm.last_output_token_count,\n",
+ " for _, example in tqdm(enumerate(eval_ds[task]), total=len(eval_ds[task])):\n",
+ " try:\n",
+ " question = example[\"question\"]\n",
+ " if example[\"source\"] == \"SimpleQA\":\n",
+ " question += \" Answer with only the final number.\"\n",
+ " if example[\"source\"] == \"MATH\":\n",
+ " question += \" Write code, not latex.\"\n",
+ " if question in answered_questions:\n",
+ " continue\n",
+ " start_time = time.time()\n",
+ "\n",
+ " if is_vanilla_llm:\n",
+ " llm = agent\n",
+ " answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n",
+ " token_count = {\n",
+ " \"input\": llm.last_input_token_count,\n",
+ " \"output\": llm.last_output_token_count,\n",
+ " }\n",
+ " intermediate_steps = str([])\n",
+ " else:\n",
+ " answer = str(agent.run(question))\n",
+ " token_count = agent.monitor.get_total_token_counts()\n",
+ " intermediate_steps = str(agent.logs)\n",
+ " # Remove memory from logs to make them more compact.\n",
+ " for step in agent.logs:\n",
+ " if isinstance(step, ActionStep):\n",
+ " step.agent_memory = None\n",
+ "\n",
+ " end_time = time.time()\n",
+ " annotated_example = {\n",
+ " \"model_id\": model_id,\n",
+ " \"agent_action_type\": action_type,\n",
+ " \"question\": question,\n",
+ " \"answer\": answer,\n",
+ " \"true_answer\": example[\"true_answer\"],\n",
+ " \"source\": example[\"source\"],\n",
+ " \"intermediate_steps\": intermediate_steps,\n",
+ " \"start_time\": start_time,\n",
+ " \"end_time\": end_time,\n",
+ " \"token_counts\": token_count,\n",
" }\n",
- " intermediate_steps = str([])\n",
- " else:\n",
- " answer = str(agent.run(question))\n",
- " token_count = agent.monitor.get_total_token_counts()\n",
- " intermediate_steps = str(agent.logs)\n",
- " # Remove memory from logs to make them more compact.\n",
- " for step in agent.logs:\n",
- " if isinstance(step, ActionStep):\n",
- " step.agent_memory = None\n",
"\n",
- " end_time = time.time()\n",
- " annotated_example = {\n",
- " \"model_id\": model_id,\n",
- " \"agent_action_type\": action_type,\n",
- " \"question\": question,\n",
- " \"answer\": answer,\n",
- " \"true_answer\": example[\"true_answer\"],\n",
- " \"source\": example[\"source\"],\n",
- " \"intermediate_steps\": intermediate_steps,\n",
- " \"start_time\": start_time,\n",
- " \"end_time\": end_time,\n",
- " \"token_counts\": token_count,\n",
- " }\n",
+ " with open(file_name, \"a\") as f:\n",
+ " json.dump(annotated_example, f, default=serialize_agent_error)\n",
+ " f.write(\"\\n\") # add a newline for JSONL format\n",
+ " except Exception as e:\n",
+ " print(\"Failed:\", e)\n",
"\n",
- " with open(file_name, \"a\") as f:\n",
- " json.dump(annotated_example, f, default=serialize_agent_error)\n",
- " f.write(\"\\n\") # add a newline for JSONL format\n",
- " except Exception as e:\n",
- " print(\"Failed:\", e)\n",
+ " if push_to_hub_dataset:\n",
+ " ds = datasets.Dataset.from_pandas(pd.read_json(file_name, lines=True), split=\"test\", preserve_index=False)\n",
+ " config = f\"{model_id.replace('/', '__')}__{action_type}__{task}\"\n",
+ " data_dir = f\"{model_id}/{action_type}/{task}/{date}\"\n",
+ " ds.push_to_hub(\n",
+ " push_to_hub_dataset,\n",
+ " config_name=config,\n",
+ " data_dir=data_dir,\n",
+ " split=\"test\",\n",
+ " commit_message=f\"Upload {config}\",\n",
+ " )\n",
"\n",
"\n",
"def normalize_number_str(number_str: str) -> float:\n",
@@ -382,7 +261,172 @@
" return all(comparisons)\n",
"\n",
" else: # if gt is a str\n",
- " return normalize_str(model_answer) == normalize_str(ground_truth)"
+ " return normalize_str(model_answer) == normalize_str(ground_truth)\n",
+ "\n",
+ "\n",
+ "def get_correct(row):\n",
+ " if row[\"source\"] == \"MATH\": # Checks the last number in answer\n",
+ " numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
+ " if len(numbers_answer) == 0:\n",
+ " return False\n",
+ " return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n",
+ " else:\n",
+ " return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
+ "\n",
+ "\n",
+ "def score_answers(\n",
+ " answers_subsets,\n",
+ " answers_dataset=ANSWERS_DATASET,\n",
+ " date=DATE,\n",
+ " push_to_hub_dataset=RESULTS_DATASET if PUSH_RESULTS_DATASET_TO_HUB else None,\n",
+ " set_default=True,\n",
+ "):\n",
+ " if not answers_dataset:\n",
+ " raise ValueError(\"Pass 'answers_dataset' to load the answers from it\")\n",
+ " date = date or datetime.date.today().isoformat()\n",
+ " results = []\n",
+ " for answers_subset in answers_subsets:\n",
+ " *model_id, action_type, task = answers_subset.split(\"__\")\n",
+ " model_id = \"/\".join(model_id)\n",
+ " ds = datasets.load_dataset(answers_dataset, answers_subset, split=\"test\")\n",
+ " df = ds.to_pandas()\n",
+ " df[\"correct\"] = df.apply(get_correct, axis=1)\n",
+ " acc = df[\"correct\"].mean().item()\n",
+ " result = df.loc[0, [\"model_id\", \"agent_action_type\", \"source\"]].to_dict()\n",
+ " result[\"acc\"] = acc\n",
+ " results.append(result)\n",
+ " df = pd.DataFrame(results)\n",
+ "\n",
+ " if push_to_hub_dataset:\n",
+ " ds = datasets.Dataset.from_pandas(df)\n",
+ " config = date\n",
+ " set_default = set_default\n",
+ " ds.push_to_hub(\n",
+ " push_to_hub_dataset, config_name=config, set_default=set_default, commit_message=f\"Upload {config} results\"\n",
+ " )\n",
+ " return df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Evaluation dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['gaia', 'math', 'simpleqa']\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " question | \n",
+ " source | \n",
+ " true_answer | \n",
+ " true_reasoning | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " What year was the municipality of Ramiriquí, B... | \n",
+ " SimpleQA | \n",
+ " 1541 | \n",
+ " ['https://en.wikipedia.org/wiki/Ramiriqu%C3%AD... | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " In what year did Hjalmar Hvam invent a mechani... | \n",
+ " SimpleQA | \n",
+ " 1937 | \n",
+ " ['https://www.kgw.com/article/features/portlan... | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " In which year did Fayaz A. Malik (an Indian ph... | \n",
+ " SimpleQA | \n",
+ " 2009 | \n",
+ " ['https://en.wikipedia.org/wiki/Fayaz_A._Malik... | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " In which year was John B. Goodenough elected a... | \n",
+ " SimpleQA | \n",
+ " 2010 | \n",
+ " ['https://en.wikipedia.org/wiki/John_B._Gooden... | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " In which year did Atul Gawande earn an M.A. in... | \n",
+ " SimpleQA | \n",
+ " 1989 | \n",
+ " ['https://en.wikipedia.org/wiki/Atul_Gawande',... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " question source true_answer \\\n",
+ "0 What year was the municipality of Ramiriquí, B... SimpleQA 1541 \n",
+ "1 In what year did Hjalmar Hvam invent a mechani... SimpleQA 1937 \n",
+ "2 In which year did Fayaz A. Malik (an Indian ph... SimpleQA 2009 \n",
+ "3 In which year was John B. Goodenough elected a... SimpleQA 2010 \n",
+ "4 In which year did Atul Gawande earn an M.A. in... SimpleQA 1989 \n",
+ "\n",
+ " true_reasoning \n",
+ "0 ['https://en.wikipedia.org/wiki/Ramiriqu%C3%AD... \n",
+ "1 ['https://www.kgw.com/article/features/portlan... \n",
+ "2 ['https://en.wikipedia.org/wiki/Fayaz_A._Malik... \n",
+ "3 ['https://en.wikipedia.org/wiki/John_B._Gooden... \n",
+ "4 ['https://en.wikipedia.org/wiki/Atul_Gawande',... "
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "\n",
+ "# Choose the tasks to evaluate on:\n",
+ "# tasks = [\"gaia\"]\n",
+ "# or evaluate on all tasks: [\"gaia\", \"math\", \"simpleqa\"]\n",
+ "tasks = datasets.get_dataset_config_names(EVAL_DATASET)\n",
+ "print(tasks)\n",
+ "\n",
+ "\n",
+ "eval_ds = {task: datasets.load_dataset(EVAL_DATASET, task, split=\"test\") for task in tasks}\n",
+ "pd.DataFrame(eval_ds[\"simpleqa\"]).head()"
]
},
{
@@ -412,16 +456,16 @@
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
"]\n",
"\n",
+ "\n",
"for model_id in open_model_ids:\n",
" print(f\"Evaluating '{model_id}'...\")\n",
- " # action_type = \"tool_calling\"\n",
+ " # action_type = \"tool-calling\"\n",
" # agent = ToolCallingAgent(\n",
" # tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
" # model=HfApiModel(model_id),\n",
" # max_steps=10,\n",
" # )\n",
- " # file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
- " # answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
+ " # answer_questions(eval_ds, agent, model_id, action_type)\n",
"\n",
" action_type = \"code\"\n",
" agent = CodeAgent(\n",
@@ -430,21 +474,19 @@
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n",
" )\n",
- " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
- " answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
+ " answer_questions(eval_ds, agent, model_id, action_type)\n",
"\n",
" # Also evaluate vanilla model\n",
" action_type = \"vanilla\"\n",
" llm = HfApiModel(model_id)\n",
- " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
- " answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
+ " answer_questions(eval_ds, llm, model_id, action_type, is_vanilla_llm=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Closed models"
+ "### Closed models"
]
},
{
@@ -458,9 +500,10 @@
"\n",
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
"\n",
+ "\n",
"for model_id in litellm_model_ids:\n",
" print(f\"Evaluating '{model_id}'...\")\n",
- " action_type = \"tool_calling\"\n",
+ " action_type = \"tool-calling\"\n",
" agent = ToolCallingAgent(\n",
" tools=[\n",
" GoogleSearchTool(),\n",
@@ -470,8 +513,7 @@
" model=LiteLLMModel(model_id),\n",
" max_steps=10,\n",
" )\n",
- " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
- " answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
+ " answer_questions(eval_ds, agent, model_id, action_type)\n",
"\n",
" action_type = \"code\"\n",
" agent = CodeAgent(\n",
@@ -480,14 +522,12 @@
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n",
" )\n",
- " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
- " answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
+ " answer_questions(eval_ds, agent, model_id, action_type)\n",
"\n",
" # Also evaluate vanilla model\n",
" action_type = \"vanilla\"\n",
" llm = LiteLLMModel(model_id)\n",
- " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
- " answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
+ " answer_questions(eval_ds, llm, model_id, action_type, is_vanilla_llm=True)"
]
},
{
@@ -539,58 +579,153 @@
"execution_count": 9,
"metadata": {},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of answers_subsets 54\n",
+ "Example of answers_subset Qwen__Qwen2.5-72B-Instruct__code__gaia\n"
+ ]
+ },
{
"name": "stderr",
"output_type": "stream",
"text": [
- "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
- " warnings.warn(\n"
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
+ "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
+ " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n"
]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " model_id | \n",
+ " agent_action_type | \n",
+ " source | \n",
+ " acc | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Qwen/Qwen2.5-72B-Instruct | \n",
+ " code | \n",
+ " GAIA | \n",
+ " 28.12 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Qwen/Qwen2.5-72B-Instruct | \n",
+ " code | \n",
+ " MATH | \n",
+ " 76.00 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Qwen/Qwen2.5-72B-Instruct | \n",
+ " code | \n",
+ " SimpleQA | \n",
+ " 88.00 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Qwen/Qwen2.5-72B-Instruct | \n",
+ " vanilla | \n",
+ " GAIA | \n",
+ " 6.25 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Qwen/Qwen2.5-72B-Instruct | \n",
+ " vanilla | \n",
+ " MATH | \n",
+ " 30.00 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " model_id agent_action_type source acc\n",
+ "0 Qwen/Qwen2.5-72B-Instruct code GAIA 28.12\n",
+ "1 Qwen/Qwen2.5-72B-Instruct code MATH 76.00\n",
+ "2 Qwen/Qwen2.5-72B-Instruct code SimpleQA 88.00\n",
+ "3 Qwen/Qwen2.5-72B-Instruct vanilla GAIA 6.25\n",
+ "4 Qwen/Qwen2.5-72B-Instruct vanilla MATH 30.00"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "import glob\n",
- "\n",
+ "import datasets\n",
"import pandas as pd\n",
"\n",
"\n",
- "res = []\n",
- "for file_path in glob.glob(\"output/*.jsonl\"):\n",
- " data = []\n",
- " with open(file_path) as f:\n",
- " for line in f:\n",
- " try:\n",
- " # Use standard json module instead of pandas.json to handle large numbers better\n",
- " record = json.loads(line)\n",
- " data.append(record)\n",
- " except json.JSONDecodeError as e:\n",
- " print(f\"Error parsing line in {file_path}: {e}\")\n",
- " continue\n",
- "\n",
- " try:\n",
- " smoldf = pd.DataFrame(data)\n",
- " smoldf[\"action_type\"] = \"vanilla\" if \"-vanilla-\" in file_path else \"code\"\n",
- " res.append(smoldf)\n",
- " except Exception as e:\n",
- " print(f\"Error creating DataFrame from {file_path}: {e}\")\n",
- " continue\n",
- "\n",
- "result_df = pd.concat(res)\n",
+ "# Choose the answers subsets to score:\n",
+ "# answers_subsets = [\"meta-llama__Llama-3.1-8B-Instruct__code__gaia\"]\n",
+ "# or get all the answers subsets present in the ANSWERS_DATASET\n",
+ "answers_subsets = datasets.get_dataset_config_names(ANSWERS_DATASET)\n",
+ "print(\"Number of answers_subsets\", len(answers_subsets))\n",
+ "print(\"Example of answers_subset\", answers_subsets[0])\n",
"\n",
"\n",
- "def get_correct(row):\n",
- " if row[\"source\"] == \"MATH\": # Checks the last number in answer\n",
- " numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
- " if len(numbers_answer) == 0:\n",
- " return False\n",
- " return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n",
- " else:\n",
- " return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
- "\n",
- "\n",
- "result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
- "\n",
- "result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()"
+ "result_df = score_answers(answers_subsets)\n",
+ "result_df[\"acc\"] = (result_df[\"acc\"] * 100).round(2)\n",
+ "result_df.head()"
]
},
{