Update benchmark with Hub datasets (#412)

* Use smolagents-benchmark Hub datasets

* Push results to Hub

* Fix style

* Add Constants section at the top

* Set DATE as constant
This commit is contained in:
Albert Villanova del Moral 2025-01-30 19:21:32 +01:00 committed by GitHub
parent aa55f137e5
commit 6d0e4e49fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 410 additions and 275 deletions

View File

@ -16,190 +16,43 @@
} }
], ],
"source": [ "source": [
"!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages" "!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>question</th>\n",
" <th>source</th>\n",
" <th>true_answer</th>\n",
" <th>true_reasoning</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>If Eliud Kipchoge could maintain his record-ma...</td>\n",
" <td>GAIA</td>\n",
" <td>17</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>How many studio albums were published by Merce...</td>\n",
" <td>GAIA</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Here's a fun riddle that I think you'll enjoy....</td>\n",
" <td>GAIA</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>My family reunion is this week, and I was assi...</td>\n",
" <td>GAIA</td>\n",
" <td>2</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>In Emily Midkiff's June 2014 article in a jour...</td>\n",
" <td>GAIA</td>\n",
" <td>fluffy</td>\n",
" <td>None</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>127</th>\n",
" <td>What year was the municipality of San Carlos, ...</td>\n",
" <td>SimpleQA</td>\n",
" <td>1786</td>\n",
" <td>['https://en.wikipedia.org/wiki/San_Carlos,_An...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>128</th>\n",
" <td>In which year was Maria Elena Walsh named Illu...</td>\n",
" <td>SimpleQA</td>\n",
" <td>1985</td>\n",
" <td>['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>129</th>\n",
" <td>What is the durability of the Istarelle spear ...</td>\n",
" <td>SimpleQA</td>\n",
" <td>800</td>\n",
" <td>['http://demonssouls.wikidot.com/spear', 'http...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>130</th>\n",
" <td>What is the number of the executive order that...</td>\n",
" <td>SimpleQA</td>\n",
" <td>7034</td>\n",
" <td>['https://www.loc.gov/collections/federal-thea...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>131</th>\n",
" <td>Within plus or minus one minute, when was Marq...</td>\n",
" <td>SimpleQA</td>\n",
" <td>77</td>\n",
" <td>['https://www.fifa.com/fifaplus/en/match-centr...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>132 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" question source true_answer \\\n",
"0 If Eliud Kipchoge could maintain his record-ma... GAIA 17 \n",
"1 How many studio albums were published by Merce... GAIA 3 \n",
"2 Here's a fun riddle that I think you'll enjoy.... GAIA 3 \n",
"3 My family reunion is this week, and I was assi... GAIA 2 \n",
"4 In Emily Midkiff's June 2014 article in a jour... GAIA fluffy \n",
".. ... ... ... \n",
"127 What year was the municipality of San Carlos, ... SimpleQA 1786 \n",
"128 In which year was Maria Elena Walsh named Illu... SimpleQA 1985 \n",
"129 What is the durability of the Istarelle spear ... SimpleQA 800 \n",
"130 What is the number of the executive order that... SimpleQA 7034 \n",
"131 Within plus or minus one minute, when was Marq... SimpleQA 77 \n",
"\n",
" true_reasoning \n",
"0 None \n",
"1 None \n",
"2 None \n",
"3 None \n",
"4 None \n",
".. ... \n",
"127 ['https://en.wikipedia.org/wiki/San_Carlos,_An... \n",
"128 ['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele... \n",
"129 ['http://demonssouls.wikidot.com/spear', 'http... \n",
"130 ['https://www.loc.gov/collections/federal-thea... \n",
"131 ['https://www.fifa.com/fifaplus/en/match-centr... \n",
"\n",
"[132 rows x 4 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import datasets\n",
"import pandas as pd\n",
"\n",
"\n",
"eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
"pd.DataFrame(eval_ds)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Define utilities and tools\n", "## Constants and utilities/tools"
"To run the SERPAPI tool, you will need to have a [SerpAPI](https://serpapi.com/dashboard) API key: for this you need a paid account."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Benchmark date\n",
"# - set a concrete date:\n",
"DATE = \"2024-12-26\"\n",
"# - or use default: today\n",
"# DATE = None\n",
"\n",
"# Evaluation dataset\n",
"EVAL_DATASET = \"smolagents-benchmark/benchmark-v1\"\n",
"\n",
"# Answers dataset: it must be a gated dataset; required to score the answers\n",
"ANSWERS_DATASET = \"smolagents-benchmark/answers\"\n",
"# Whether to push the answers dataset to the Hub\n",
"PUSH_ANSWERS_DATASET_TO_HUB = True\n",
"\n",
"# Results dataset\n",
"RESULTS_DATASET = \"smolagents-benchmark/results\"\n",
"# Whether to push the results dataset to the Hub\n",
"PUSH_RESULTS_DATASET_TO_HUB = True\n",
"\n",
"\n",
"import datetime\n",
"import json\n", "import json\n",
"import os\n", "import os\n",
"import re\n", "import re\n",
@ -208,6 +61,7 @@
"import warnings\n", "import warnings\n",
"from typing import List\n", "from typing import List\n",
"\n", "\n",
"import datasets\n",
"from dotenv import load_dotenv\n", "from dotenv import load_dotenv\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"\n", "\n",
@ -234,60 +88,85 @@
" return str(obj)\n", " return str(obj)\n",
"\n", "\n",
"\n", "\n",
"def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n", "def answer_questions(\n",
" answered_questions = []\n", " eval_ds,\n",
" if os.path.exists(file_name):\n", " agent,\n",
" with open(file_name, \"r\") as f:\n", " model_id,\n",
" for line in f:\n", " action_type,\n",
" answered_questions.append(json.loads(line)[\"question\"])\n", " is_vanilla_llm=False,\n",
" date=DATE,\n",
" output_dir=\"output\",\n",
" push_to_hub_dataset=ANSWERS_DATASET if PUSH_ANSWERS_DATASET_TO_HUB else None,\n",
"):\n",
" date = date or datetime.date.today().isoformat()\n",
"\n", "\n",
" for _, example in tqdm(enumerate(eval_ds), total=len(eval_ds)):\n", " for task in eval_ds:\n",
" try:\n", " file_name = f\"output/{model_id.replace('/', '__')}__{action_type}__{task}__{date}.jsonl\"\n",
" question = example[\"question\"]\n", " answered_questions = []\n",
" if example[\"source\"] == \"SimpleQA\":\n", " if os.path.exists(file_name):\n",
" question += \" Answer with only the final number.\"\n", " with open(file_name, \"r\") as f:\n",
" if example[\"source\"] == \"MATH\":\n", " for line in f:\n",
" question += \" Write code, not latex.\"\n", " answered_questions.append(json.loads(line)[\"question\"])\n",
" if question in answered_questions:\n",
" continue\n",
" start_time = time.time()\n",
"\n", "\n",
" if is_vanilla_llm:\n", " for _, example in tqdm(enumerate(eval_ds[task]), total=len(eval_ds[task])):\n",
" llm = agent\n", " try:\n",
" answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n", " question = example[\"question\"]\n",
" token_count = {\n", " if example[\"source\"] == \"SimpleQA\":\n",
" \"input\": llm.last_input_token_count,\n", " question += \" Answer with only the final number.\"\n",
" \"output\": llm.last_output_token_count,\n", " if example[\"source\"] == \"MATH\":\n",
" question += \" Write code, not latex.\"\n",
" if question in answered_questions:\n",
" continue\n",
" start_time = time.time()\n",
"\n",
" if is_vanilla_llm:\n",
" llm = agent\n",
" answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n",
" token_count = {\n",
" \"input\": llm.last_input_token_count,\n",
" \"output\": llm.last_output_token_count,\n",
" }\n",
" intermediate_steps = str([])\n",
" else:\n",
" answer = str(agent.run(question))\n",
" token_count = agent.monitor.get_total_token_counts()\n",
" intermediate_steps = str(agent.logs)\n",
" # Remove memory from logs to make them more compact.\n",
" for step in agent.logs:\n",
" if isinstance(step, ActionStep):\n",
" step.agent_memory = None\n",
"\n",
" end_time = time.time()\n",
" annotated_example = {\n",
" \"model_id\": model_id,\n",
" \"agent_action_type\": action_type,\n",
" \"question\": question,\n",
" \"answer\": answer,\n",
" \"true_answer\": example[\"true_answer\"],\n",
" \"source\": example[\"source\"],\n",
" \"intermediate_steps\": intermediate_steps,\n",
" \"start_time\": start_time,\n",
" \"end_time\": end_time,\n",
" \"token_counts\": token_count,\n",
" }\n", " }\n",
" intermediate_steps = str([])\n",
" else:\n",
" answer = str(agent.run(question))\n",
" token_count = agent.monitor.get_total_token_counts()\n",
" intermediate_steps = str(agent.logs)\n",
" # Remove memory from logs to make them more compact.\n",
" for step in agent.logs:\n",
" if isinstance(step, ActionStep):\n",
" step.agent_memory = None\n",
"\n", "\n",
" end_time = time.time()\n", " with open(file_name, \"a\") as f:\n",
" annotated_example = {\n", " json.dump(annotated_example, f, default=serialize_agent_error)\n",
" \"model_id\": model_id,\n", " f.write(\"\\n\") # add a newline for JSONL format\n",
" \"agent_action_type\": action_type,\n", " except Exception as e:\n",
" \"question\": question,\n", " print(\"Failed:\", e)\n",
" \"answer\": answer,\n",
" \"true_answer\": example[\"true_answer\"],\n",
" \"source\": example[\"source\"],\n",
" \"intermediate_steps\": intermediate_steps,\n",
" \"start_time\": start_time,\n",
" \"end_time\": end_time,\n",
" \"token_counts\": token_count,\n",
" }\n",
"\n", "\n",
" with open(file_name, \"a\") as f:\n", " if push_to_hub_dataset:\n",
" json.dump(annotated_example, f, default=serialize_agent_error)\n", " ds = datasets.Dataset.from_pandas(pd.read_json(file_name, lines=True), split=\"test\", preserve_index=False)\n",
" f.write(\"\\n\") # add a newline for JSONL format\n", " config = f\"{model_id.replace('/', '__')}__{action_type}__{task}\"\n",
" except Exception as e:\n", " data_dir = f\"{model_id}/{action_type}/{task}/{date}\"\n",
" print(\"Failed:\", e)\n", " ds.push_to_hub(\n",
" push_to_hub_dataset,\n",
" config_name=config,\n",
" data_dir=data_dir,\n",
" split=\"test\",\n",
" commit_message=f\"Upload {config}\",\n",
" )\n",
"\n", "\n",
"\n", "\n",
"def normalize_number_str(number_str: str) -> float:\n", "def normalize_number_str(number_str: str) -> float:\n",
@ -382,7 +261,172 @@
" return all(comparisons)\n", " return all(comparisons)\n",
"\n", "\n",
" else: # if gt is a str\n", " else: # if gt is a str\n",
" return normalize_str(model_answer) == normalize_str(ground_truth)" " return normalize_str(model_answer) == normalize_str(ground_truth)\n",
"\n",
"\n",
"def get_correct(row):\n",
" if row[\"source\"] == \"MATH\": # Checks the last number in answer\n",
" numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
" if len(numbers_answer) == 0:\n",
" return False\n",
" return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n",
" else:\n",
" return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
"\n",
"\n",
"def score_answers(\n",
" answers_subsets,\n",
" answers_dataset=ANSWERS_DATASET,\n",
" date=DATE,\n",
" push_to_hub_dataset=RESULTS_DATASET if PUSH_RESULTS_DATASET_TO_HUB else None,\n",
" set_default=True,\n",
"):\n",
" if not answers_dataset:\n",
" raise ValueError(\"Pass 'answers_dataset' to load the answers from it\")\n",
" date = date or datetime.date.today().isoformat()\n",
" results = []\n",
" for answers_subset in answers_subsets:\n",
" *model_id, action_type, task = answers_subset.split(\"__\")\n",
" model_id = \"/\".join(model_id)\n",
" ds = datasets.load_dataset(answers_dataset, answers_subset, split=\"test\")\n",
" df = ds.to_pandas()\n",
" df[\"correct\"] = df.apply(get_correct, axis=1)\n",
" acc = df[\"correct\"].mean().item()\n",
" result = df.loc[0, [\"model_id\", \"agent_action_type\", \"source\"]].to_dict()\n",
" result[\"acc\"] = acc\n",
" results.append(result)\n",
" df = pd.DataFrame(results)\n",
"\n",
" if push_to_hub_dataset:\n",
" ds = datasets.Dataset.from_pandas(df)\n",
" config = date\n",
" set_default = set_default\n",
" ds.push_to_hub(\n",
" push_to_hub_dataset, config_name=config, set_default=set_default, commit_message=f\"Upload {config} results\"\n",
" )\n",
" return df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['gaia', 'math', 'simpleqa']\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>question</th>\n",
" <th>source</th>\n",
" <th>true_answer</th>\n",
" <th>true_reasoning</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>What year was the municipality of Ramiriquí, B...</td>\n",
" <td>SimpleQA</td>\n",
" <td>1541</td>\n",
" <td>['https://en.wikipedia.org/wiki/Ramiriqu%C3%AD...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>In what year did Hjalmar Hvam invent a mechani...</td>\n",
" <td>SimpleQA</td>\n",
" <td>1937</td>\n",
" <td>['https://www.kgw.com/article/features/portlan...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>In which year did Fayaz A. Malik (an Indian ph...</td>\n",
" <td>SimpleQA</td>\n",
" <td>2009</td>\n",
" <td>['https://en.wikipedia.org/wiki/Fayaz_A._Malik...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>In which year was John B. Goodenough elected a...</td>\n",
" <td>SimpleQA</td>\n",
" <td>2010</td>\n",
" <td>['https://en.wikipedia.org/wiki/John_B._Gooden...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>In which year did Atul Gawande earn an M.A. in...</td>\n",
" <td>SimpleQA</td>\n",
" <td>1989</td>\n",
" <td>['https://en.wikipedia.org/wiki/Atul_Gawande',...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" question source true_answer \\\n",
"0 What year was the municipality of Ramiriquí, B... SimpleQA 1541 \n",
"1 In what year did Hjalmar Hvam invent a mechani... SimpleQA 1937 \n",
"2 In which year did Fayaz A. Malik (an Indian ph... SimpleQA 2009 \n",
"3 In which year was John B. Goodenough elected a... SimpleQA 2010 \n",
"4 In which year did Atul Gawande earn an M.A. in... SimpleQA 1989 \n",
"\n",
" true_reasoning \n",
"0 ['https://en.wikipedia.org/wiki/Ramiriqu%C3%AD... \n",
"1 ['https://www.kgw.com/article/features/portlan... \n",
"2 ['https://en.wikipedia.org/wiki/Fayaz_A._Malik... \n",
"3 ['https://en.wikipedia.org/wiki/John_B._Gooden... \n",
"4 ['https://en.wikipedia.org/wiki/Atul_Gawande',... "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"# Choose the tasks to evaluate on:\n",
"# tasks = [\"gaia\"]\n",
"# or evaluate on all tasks: [\"gaia\", \"math\", \"simpleqa\"]\n",
"tasks = datasets.get_dataset_config_names(EVAL_DATASET)\n",
"print(tasks)\n",
"\n",
"\n",
"eval_ds = {task: datasets.load_dataset(EVAL_DATASET, task, split=\"test\") for task in tasks}\n",
"pd.DataFrame(eval_ds[\"simpleqa\"]).head()"
] ]
}, },
{ {
@ -412,16 +456,16 @@
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n", " # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
"]\n", "]\n",
"\n", "\n",
"\n",
"for model_id in open_model_ids:\n", "for model_id in open_model_ids:\n",
" print(f\"Evaluating '{model_id}'...\")\n", " print(f\"Evaluating '{model_id}'...\")\n",
" # action_type = \"tool_calling\"\n", " # action_type = \"tool-calling\"\n",
" # agent = ToolCallingAgent(\n", " # agent = ToolCallingAgent(\n",
" # tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n", " # tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
" # model=HfApiModel(model_id),\n", " # model=HfApiModel(model_id),\n",
" # max_steps=10,\n", " # max_steps=10,\n",
" # )\n", " # )\n",
" # file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " # answer_questions(eval_ds, agent, model_id, action_type)\n",
" # answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n", "\n",
" action_type = \"code\"\n", " action_type = \"code\"\n",
" agent = CodeAgent(\n", " agent = CodeAgent(\n",
@ -430,21 +474,19 @@
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n", " additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n", " max_steps=10,\n",
" )\n", " )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " answer_questions(eval_ds, agent, model_id, action_type)\n",
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n", "\n",
" # Also evaluate vanilla model\n", " # Also evaluate vanilla model\n",
" action_type = \"vanilla\"\n", " action_type = \"vanilla\"\n",
" llm = HfApiModel(model_id)\n", " llm = HfApiModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " answer_questions(eval_ds, llm, model_id, action_type, is_vanilla_llm=True)"
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Closed models" "### Closed models"
] ]
}, },
{ {
@ -458,9 +500,10 @@
"\n", "\n",
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n", "litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
"\n", "\n",
"\n",
"for model_id in litellm_model_ids:\n", "for model_id in litellm_model_ids:\n",
" print(f\"Evaluating '{model_id}'...\")\n", " print(f\"Evaluating '{model_id}'...\")\n",
" action_type = \"tool_calling\"\n", " action_type = \"tool-calling\"\n",
" agent = ToolCallingAgent(\n", " agent = ToolCallingAgent(\n",
" tools=[\n", " tools=[\n",
" GoogleSearchTool(),\n", " GoogleSearchTool(),\n",
@ -470,8 +513,7 @@
" model=LiteLLMModel(model_id),\n", " model=LiteLLMModel(model_id),\n",
" max_steps=10,\n", " max_steps=10,\n",
" )\n", " )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " answer_questions(eval_ds, agent, model_id, action_type)\n",
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n", "\n",
" action_type = \"code\"\n", " action_type = \"code\"\n",
" agent = CodeAgent(\n", " agent = CodeAgent(\n",
@ -480,14 +522,12 @@
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n", " additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n", " max_steps=10,\n",
" )\n", " )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " answer_questions(eval_ds, agent, model_id, action_type)\n",
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
"\n", "\n",
" # Also evaluate vanilla model\n", " # Also evaluate vanilla model\n",
" action_type = \"vanilla\"\n", " action_type = \"vanilla\"\n",
" llm = LiteLLMModel(model_id)\n", " llm = LiteLLMModel(model_id)\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", " answer_questions(eval_ds, llm, model_id, action_type, is_vanilla_llm=True)"
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
] ]
}, },
{ {
@ -539,58 +579,153 @@
"execution_count": 9, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of answers_subsets 54\n",
"Example of answers_subset Qwen__Qwen2.5-72B-Instruct__code__gaia\n"
]
},
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n", "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\n" " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n"
] ]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>model_id</th>\n",
" <th>agent_action_type</th>\n",
" <th>source</th>\n",
" <th>acc</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>code</td>\n",
" <td>GAIA</td>\n",
" <td>28.12</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>code</td>\n",
" <td>MATH</td>\n",
" <td>76.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>code</td>\n",
" <td>SimpleQA</td>\n",
" <td>88.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>vanilla</td>\n",
" <td>GAIA</td>\n",
" <td>6.25</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
" <td>vanilla</td>\n",
" <td>MATH</td>\n",
" <td>30.00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" model_id agent_action_type source acc\n",
"0 Qwen/Qwen2.5-72B-Instruct code GAIA 28.12\n",
"1 Qwen/Qwen2.5-72B-Instruct code MATH 76.00\n",
"2 Qwen/Qwen2.5-72B-Instruct code SimpleQA 88.00\n",
"3 Qwen/Qwen2.5-72B-Instruct vanilla GAIA 6.25\n",
"4 Qwen/Qwen2.5-72B-Instruct vanilla MATH 30.00"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
"import glob\n", "import datasets\n",
"\n",
"import pandas as pd\n", "import pandas as pd\n",
"\n", "\n",
"\n", "\n",
"res = []\n", "# Choose the answers subsets to score:\n",
"for file_path in glob.glob(\"output/*.jsonl\"):\n", "# answers_subsets = [\"meta-llama__Llama-3.1-8B-Instruct__code__gaia\"]\n",
" data = []\n", "# or get all the answers subsets present in the ANSWERS_DATASET\n",
" with open(file_path) as f:\n", "answers_subsets = datasets.get_dataset_config_names(ANSWERS_DATASET)\n",
" for line in f:\n", "print(\"Number of answers_subsets\", len(answers_subsets))\n",
" try:\n", "print(\"Example of answers_subset\", answers_subsets[0])\n",
" # Use standard json module instead of pandas.json to handle large numbers better\n",
" record = json.loads(line)\n",
" data.append(record)\n",
" except json.JSONDecodeError as e:\n",
" print(f\"Error parsing line in {file_path}: {e}\")\n",
" continue\n",
"\n",
" try:\n",
" smoldf = pd.DataFrame(data)\n",
" smoldf[\"action_type\"] = \"vanilla\" if \"-vanilla-\" in file_path else \"code\"\n",
" res.append(smoldf)\n",
" except Exception as e:\n",
" print(f\"Error creating DataFrame from {file_path}: {e}\")\n",
" continue\n",
"\n",
"result_df = pd.concat(res)\n",
"\n", "\n",
"\n", "\n",
"def get_correct(row):\n", "result_df = score_answers(answers_subsets)\n",
" if row[\"source\"] == \"MATH\": # Checks the last number in answer\n", "result_df[\"acc\"] = (result_df[\"acc\"] * 100).round(2)\n",
" numbers_answer = extract_numbers(str(row[\"answer\"]))\n", "result_df.head()"
" if len(numbers_answer) == 0:\n",
" return False\n",
" return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n",
" else:\n",
" return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
"\n",
"\n",
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
"\n",
"result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()"
] ]
}, },
{ {