Update benchmark with Hub datasets (#412)
* Use smolagents-benchmark Hub datasets * Push results to Hub * Fix style * Add Constants section at the top * Set DATE as constant
This commit is contained in:
parent
aa55f137e5
commit
6d0e4e49fc
|
@ -19,187 +19,40 @@
|
||||||
"!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
|
"!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<div>\n",
|
|
||||||
"<style scoped>\n",
|
|
||||||
" .dataframe tbody tr th:only-of-type {\n",
|
|
||||||
" vertical-align: middle;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe tbody tr th {\n",
|
|
||||||
" vertical-align: top;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe thead th {\n",
|
|
||||||
" text-align: right;\n",
|
|
||||||
" }\n",
|
|
||||||
"</style>\n",
|
|
||||||
"<table border=\"1\" class=\"dataframe\">\n",
|
|
||||||
" <thead>\n",
|
|
||||||
" <tr style=\"text-align: right;\">\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th>question</th>\n",
|
|
||||||
" <th>source</th>\n",
|
|
||||||
" <th>true_answer</th>\n",
|
|
||||||
" <th>true_reasoning</th>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </thead>\n",
|
|
||||||
" <tbody>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>0</th>\n",
|
|
||||||
" <td>If Eliud Kipchoge could maintain his record-ma...</td>\n",
|
|
||||||
" <td>GAIA</td>\n",
|
|
||||||
" <td>17</td>\n",
|
|
||||||
" <td>None</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>1</th>\n",
|
|
||||||
" <td>How many studio albums were published by Merce...</td>\n",
|
|
||||||
" <td>GAIA</td>\n",
|
|
||||||
" <td>3</td>\n",
|
|
||||||
" <td>None</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2</th>\n",
|
|
||||||
" <td>Here's a fun riddle that I think you'll enjoy....</td>\n",
|
|
||||||
" <td>GAIA</td>\n",
|
|
||||||
" <td>3</td>\n",
|
|
||||||
" <td>None</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>3</th>\n",
|
|
||||||
" <td>My family reunion is this week, and I was assi...</td>\n",
|
|
||||||
" <td>GAIA</td>\n",
|
|
||||||
" <td>2</td>\n",
|
|
||||||
" <td>None</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>4</th>\n",
|
|
||||||
" <td>In Emily Midkiff's June 2014 article in a jour...</td>\n",
|
|
||||||
" <td>GAIA</td>\n",
|
|
||||||
" <td>fluffy</td>\n",
|
|
||||||
" <td>None</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>...</th>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>127</th>\n",
|
|
||||||
" <td>What year was the municipality of San Carlos, ...</td>\n",
|
|
||||||
" <td>SimpleQA</td>\n",
|
|
||||||
" <td>1786</td>\n",
|
|
||||||
" <td>['https://en.wikipedia.org/wiki/San_Carlos,_An...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>128</th>\n",
|
|
||||||
" <td>In which year was Maria Elena Walsh named Illu...</td>\n",
|
|
||||||
" <td>SimpleQA</td>\n",
|
|
||||||
" <td>1985</td>\n",
|
|
||||||
" <td>['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>129</th>\n",
|
|
||||||
" <td>What is the durability of the Istarelle spear ...</td>\n",
|
|
||||||
" <td>SimpleQA</td>\n",
|
|
||||||
" <td>800</td>\n",
|
|
||||||
" <td>['http://demonssouls.wikidot.com/spear', 'http...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>130</th>\n",
|
|
||||||
" <td>What is the number of the executive order that...</td>\n",
|
|
||||||
" <td>SimpleQA</td>\n",
|
|
||||||
" <td>7034</td>\n",
|
|
||||||
" <td>['https://www.loc.gov/collections/federal-thea...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>131</th>\n",
|
|
||||||
" <td>Within plus or minus one minute, when was Marq...</td>\n",
|
|
||||||
" <td>SimpleQA</td>\n",
|
|
||||||
" <td>77</td>\n",
|
|
||||||
" <td>['https://www.fifa.com/fifaplus/en/match-centr...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </tbody>\n",
|
|
||||||
"</table>\n",
|
|
||||||
"<p>132 rows × 4 columns</p>\n",
|
|
||||||
"</div>"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
" question source true_answer \\\n",
|
|
||||||
"0 If Eliud Kipchoge could maintain his record-ma... GAIA 17 \n",
|
|
||||||
"1 How many studio albums were published by Merce... GAIA 3 \n",
|
|
||||||
"2 Here's a fun riddle that I think you'll enjoy.... GAIA 3 \n",
|
|
||||||
"3 My family reunion is this week, and I was assi... GAIA 2 \n",
|
|
||||||
"4 In Emily Midkiff's June 2014 article in a jour... GAIA fluffy \n",
|
|
||||||
".. ... ... ... \n",
|
|
||||||
"127 What year was the municipality of San Carlos, ... SimpleQA 1786 \n",
|
|
||||||
"128 In which year was Maria Elena Walsh named Illu... SimpleQA 1985 \n",
|
|
||||||
"129 What is the durability of the Istarelle spear ... SimpleQA 800 \n",
|
|
||||||
"130 What is the number of the executive order that... SimpleQA 7034 \n",
|
|
||||||
"131 Within plus or minus one minute, when was Marq... SimpleQA 77 \n",
|
|
||||||
"\n",
|
|
||||||
" true_reasoning \n",
|
|
||||||
"0 None \n",
|
|
||||||
"1 None \n",
|
|
||||||
"2 None \n",
|
|
||||||
"3 None \n",
|
|
||||||
"4 None \n",
|
|
||||||
".. ... \n",
|
|
||||||
"127 ['https://en.wikipedia.org/wiki/San_Carlos,_An... \n",
|
|
||||||
"128 ['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele... \n",
|
|
||||||
"129 ['http://demonssouls.wikidot.com/spear', 'http... \n",
|
|
||||||
"130 ['https://www.loc.gov/collections/federal-thea... \n",
|
|
||||||
"131 ['https://www.fifa.com/fifaplus/en/match-centr... \n",
|
|
||||||
"\n",
|
|
||||||
"[132 rows x 4 columns]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import datasets\n",
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
|
|
||||||
"pd.DataFrame(eval_ds)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Define utilities and tools\n",
|
"## Constants and utilities/tools"
|
||||||
"To run the SERPAPI tool, you will need to have a [SerpAPI](https://serpapi.com/dashboard) API key: for this you need a paid account."
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# Benchmark date\n",
|
||||||
|
"# - set a concrete date:\n",
|
||||||
|
"DATE = \"2024-12-26\"\n",
|
||||||
|
"# - or use default: today\n",
|
||||||
|
"# DATE = None\n",
|
||||||
|
"\n",
|
||||||
|
"# Evaluation dataset\n",
|
||||||
|
"EVAL_DATASET = \"smolagents-benchmark/benchmark-v1\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Answers dataset: it must be a gated dataset; required to score the answers\n",
|
||||||
|
"ANSWERS_DATASET = \"smolagents-benchmark/answers\"\n",
|
||||||
|
"# Whether to push the answers dataset to the Hub\n",
|
||||||
|
"PUSH_ANSWERS_DATASET_TO_HUB = True\n",
|
||||||
|
"\n",
|
||||||
|
"# Results dataset\n",
|
||||||
|
"RESULTS_DATASET = \"smolagents-benchmark/results\"\n",
|
||||||
|
"# Whether to push the results dataset to the Hub\n",
|
||||||
|
"PUSH_RESULTS_DATASET_TO_HUB = True\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"import datetime\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import re\n",
|
"import re\n",
|
||||||
|
@ -208,6 +61,7 @@
|
||||||
"import warnings\n",
|
"import warnings\n",
|
||||||
"from typing import List\n",
|
"from typing import List\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"import datasets\n",
|
||||||
"from dotenv import load_dotenv\n",
|
"from dotenv import load_dotenv\n",
|
||||||
"from tqdm import tqdm\n",
|
"from tqdm import tqdm\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -234,14 +88,27 @@
|
||||||
" return str(obj)\n",
|
" return str(obj)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n",
|
"def answer_questions(\n",
|
||||||
|
" eval_ds,\n",
|
||||||
|
" agent,\n",
|
||||||
|
" model_id,\n",
|
||||||
|
" action_type,\n",
|
||||||
|
" is_vanilla_llm=False,\n",
|
||||||
|
" date=DATE,\n",
|
||||||
|
" output_dir=\"output\",\n",
|
||||||
|
" push_to_hub_dataset=ANSWERS_DATASET if PUSH_ANSWERS_DATASET_TO_HUB else None,\n",
|
||||||
|
"):\n",
|
||||||
|
" date = date or datetime.date.today().isoformat()\n",
|
||||||
|
"\n",
|
||||||
|
" for task in eval_ds:\n",
|
||||||
|
" file_name = f\"output/{model_id.replace('/', '__')}__{action_type}__{task}__{date}.jsonl\"\n",
|
||||||
" answered_questions = []\n",
|
" answered_questions = []\n",
|
||||||
" if os.path.exists(file_name):\n",
|
" if os.path.exists(file_name):\n",
|
||||||
" with open(file_name, \"r\") as f:\n",
|
" with open(file_name, \"r\") as f:\n",
|
||||||
" for line in f:\n",
|
" for line in f:\n",
|
||||||
" answered_questions.append(json.loads(line)[\"question\"])\n",
|
" answered_questions.append(json.loads(line)[\"question\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" for _, example in tqdm(enumerate(eval_ds), total=len(eval_ds)):\n",
|
" for _, example in tqdm(enumerate(eval_ds[task]), total=len(eval_ds[task])):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" question = example[\"question\"]\n",
|
" question = example[\"question\"]\n",
|
||||||
" if example[\"source\"] == \"SimpleQA\":\n",
|
" if example[\"source\"] == \"SimpleQA\":\n",
|
||||||
|
@ -289,6 +156,18 @@
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
" print(\"Failed:\", e)\n",
|
" print(\"Failed:\", e)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" if push_to_hub_dataset:\n",
|
||||||
|
" ds = datasets.Dataset.from_pandas(pd.read_json(file_name, lines=True), split=\"test\", preserve_index=False)\n",
|
||||||
|
" config = f\"{model_id.replace('/', '__')}__{action_type}__{task}\"\n",
|
||||||
|
" data_dir = f\"{model_id}/{action_type}/{task}/{date}\"\n",
|
||||||
|
" ds.push_to_hub(\n",
|
||||||
|
" push_to_hub_dataset,\n",
|
||||||
|
" config_name=config,\n",
|
||||||
|
" data_dir=data_dir,\n",
|
||||||
|
" split=\"test\",\n",
|
||||||
|
" commit_message=f\"Upload {config}\",\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def normalize_number_str(number_str: str) -> float:\n",
|
"def normalize_number_str(number_str: str) -> float:\n",
|
||||||
" # we replace these common units and commas to allow\n",
|
" # we replace these common units and commas to allow\n",
|
||||||
|
@ -382,7 +261,172 @@
|
||||||
" return all(comparisons)\n",
|
" return all(comparisons)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" else: # if gt is a str\n",
|
" else: # if gt is a str\n",
|
||||||
" return normalize_str(model_answer) == normalize_str(ground_truth)"
|
" return normalize_str(model_answer) == normalize_str(ground_truth)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def get_correct(row):\n",
|
||||||
|
" if row[\"source\"] == \"MATH\": # Checks the last number in answer\n",
|
||||||
|
" numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
|
||||||
|
" if len(numbers_answer) == 0:\n",
|
||||||
|
" return False\n",
|
||||||
|
" return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n",
|
||||||
|
" else:\n",
|
||||||
|
" return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def score_answers(\n",
|
||||||
|
" answers_subsets,\n",
|
||||||
|
" answers_dataset=ANSWERS_DATASET,\n",
|
||||||
|
" date=DATE,\n",
|
||||||
|
" push_to_hub_dataset=RESULTS_DATASET if PUSH_RESULTS_DATASET_TO_HUB else None,\n",
|
||||||
|
" set_default=True,\n",
|
||||||
|
"):\n",
|
||||||
|
" if not answers_dataset:\n",
|
||||||
|
" raise ValueError(\"Pass 'answers_dataset' to load the answers from it\")\n",
|
||||||
|
" date = date or datetime.date.today().isoformat()\n",
|
||||||
|
" results = []\n",
|
||||||
|
" for answers_subset in answers_subsets:\n",
|
||||||
|
" *model_id, action_type, task = answers_subset.split(\"__\")\n",
|
||||||
|
" model_id = \"/\".join(model_id)\n",
|
||||||
|
" ds = datasets.load_dataset(answers_dataset, answers_subset, split=\"test\")\n",
|
||||||
|
" df = ds.to_pandas()\n",
|
||||||
|
" df[\"correct\"] = df.apply(get_correct, axis=1)\n",
|
||||||
|
" acc = df[\"correct\"].mean().item()\n",
|
||||||
|
" result = df.loc[0, [\"model_id\", \"agent_action_type\", \"source\"]].to_dict()\n",
|
||||||
|
" result[\"acc\"] = acc\n",
|
||||||
|
" results.append(result)\n",
|
||||||
|
" df = pd.DataFrame(results)\n",
|
||||||
|
"\n",
|
||||||
|
" if push_to_hub_dataset:\n",
|
||||||
|
" ds = datasets.Dataset.from_pandas(df)\n",
|
||||||
|
" config = date\n",
|
||||||
|
" set_default = set_default\n",
|
||||||
|
" ds.push_to_hub(\n",
|
||||||
|
" push_to_hub_dataset, config_name=config, set_default=set_default, commit_message=f\"Upload {config} results\"\n",
|
||||||
|
" )\n",
|
||||||
|
" return df"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Evaluation dataset"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"['gaia', 'math', 'simpleqa']\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>question</th>\n",
|
||||||
|
" <th>source</th>\n",
|
||||||
|
" <th>true_answer</th>\n",
|
||||||
|
" <th>true_reasoning</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>What year was the municipality of Ramiriquí, B...</td>\n",
|
||||||
|
" <td>SimpleQA</td>\n",
|
||||||
|
" <td>1541</td>\n",
|
||||||
|
" <td>['https://en.wikipedia.org/wiki/Ramiriqu%C3%AD...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>1</th>\n",
|
||||||
|
" <td>In what year did Hjalmar Hvam invent a mechani...</td>\n",
|
||||||
|
" <td>SimpleQA</td>\n",
|
||||||
|
" <td>1937</td>\n",
|
||||||
|
" <td>['https://www.kgw.com/article/features/portlan...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>In which year did Fayaz A. Malik (an Indian ph...</td>\n",
|
||||||
|
" <td>SimpleQA</td>\n",
|
||||||
|
" <td>2009</td>\n",
|
||||||
|
" <td>['https://en.wikipedia.org/wiki/Fayaz_A._Malik...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>In which year was John B. Goodenough elected a...</td>\n",
|
||||||
|
" <td>SimpleQA</td>\n",
|
||||||
|
" <td>2010</td>\n",
|
||||||
|
" <td>['https://en.wikipedia.org/wiki/John_B._Gooden...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>In which year did Atul Gawande earn an M.A. in...</td>\n",
|
||||||
|
" <td>SimpleQA</td>\n",
|
||||||
|
" <td>1989</td>\n",
|
||||||
|
" <td>['https://en.wikipedia.org/wiki/Atul_Gawande',...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" question source true_answer \\\n",
|
||||||
|
"0 What year was the municipality of Ramiriquí, B... SimpleQA 1541 \n",
|
||||||
|
"1 In what year did Hjalmar Hvam invent a mechani... SimpleQA 1937 \n",
|
||||||
|
"2 In which year did Fayaz A. Malik (an Indian ph... SimpleQA 2009 \n",
|
||||||
|
"3 In which year was John B. Goodenough elected a... SimpleQA 2010 \n",
|
||||||
|
"4 In which year did Atul Gawande earn an M.A. in... SimpleQA 1989 \n",
|
||||||
|
"\n",
|
||||||
|
" true_reasoning \n",
|
||||||
|
"0 ['https://en.wikipedia.org/wiki/Ramiriqu%C3%AD... \n",
|
||||||
|
"1 ['https://www.kgw.com/article/features/portlan... \n",
|
||||||
|
"2 ['https://en.wikipedia.org/wiki/Fayaz_A._Malik... \n",
|
||||||
|
"3 ['https://en.wikipedia.org/wiki/John_B._Gooden... \n",
|
||||||
|
"4 ['https://en.wikipedia.org/wiki/Atul_Gawande',... "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# Choose the tasks to evaluate on:\n",
|
||||||
|
"# tasks = [\"gaia\"]\n",
|
||||||
|
"# or evaluate on all tasks: [\"gaia\", \"math\", \"simpleqa\"]\n",
|
||||||
|
"tasks = datasets.get_dataset_config_names(EVAL_DATASET)\n",
|
||||||
|
"print(tasks)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"eval_ds = {task: datasets.load_dataset(EVAL_DATASET, task, split=\"test\") for task in tasks}\n",
|
||||||
|
"pd.DataFrame(eval_ds[\"simpleqa\"]).head()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -412,16 +456,16 @@
|
||||||
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
|
" # \"meta-llama/Llama-3.1-70B-Instruct\",\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"for model_id in open_model_ids:\n",
|
"for model_id in open_model_ids:\n",
|
||||||
" print(f\"Evaluating '{model_id}'...\")\n",
|
" print(f\"Evaluating '{model_id}'...\")\n",
|
||||||
" # action_type = \"tool_calling\"\n",
|
" # action_type = \"tool-calling\"\n",
|
||||||
" # agent = ToolCallingAgent(\n",
|
" # agent = ToolCallingAgent(\n",
|
||||||
" # tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
|
" # tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n",
|
||||||
" # model=HfApiModel(model_id),\n",
|
" # model=HfApiModel(model_id),\n",
|
||||||
" # max_steps=10,\n",
|
" # max_steps=10,\n",
|
||||||
" # )\n",
|
" # )\n",
|
||||||
" # file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" # answer_questions(eval_ds, agent, model_id, action_type)\n",
|
||||||
" # answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" action_type = \"code\"\n",
|
" action_type = \"code\"\n",
|
||||||
" agent = CodeAgent(\n",
|
" agent = CodeAgent(\n",
|
||||||
|
@ -430,21 +474,19 @@
|
||||||
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
|
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
|
||||||
" max_steps=10,\n",
|
" max_steps=10,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" answer_questions(eval_ds, agent, model_id, action_type)\n",
|
||||||
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" # Also evaluate vanilla model\n",
|
" # Also evaluate vanilla model\n",
|
||||||
" action_type = \"vanilla\"\n",
|
" action_type = \"vanilla\"\n",
|
||||||
" llm = HfApiModel(model_id)\n",
|
" llm = HfApiModel(model_id)\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" answer_questions(eval_ds, llm, model_id, action_type, is_vanilla_llm=True)"
|
||||||
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Closed models"
|
"### Closed models"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -458,9 +500,10 @@
|
||||||
"\n",
|
"\n",
|
||||||
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
|
"litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"for model_id in litellm_model_ids:\n",
|
"for model_id in litellm_model_ids:\n",
|
||||||
" print(f\"Evaluating '{model_id}'...\")\n",
|
" print(f\"Evaluating '{model_id}'...\")\n",
|
||||||
" action_type = \"tool_calling\"\n",
|
" action_type = \"tool-calling\"\n",
|
||||||
" agent = ToolCallingAgent(\n",
|
" agent = ToolCallingAgent(\n",
|
||||||
" tools=[\n",
|
" tools=[\n",
|
||||||
" GoogleSearchTool(),\n",
|
" GoogleSearchTool(),\n",
|
||||||
|
@ -470,8 +513,7 @@
|
||||||
" model=LiteLLMModel(model_id),\n",
|
" model=LiteLLMModel(model_id),\n",
|
||||||
" max_steps=10,\n",
|
" max_steps=10,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" answer_questions(eval_ds, agent, model_id, action_type)\n",
|
||||||
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" action_type = \"code\"\n",
|
" action_type = \"code\"\n",
|
||||||
" agent = CodeAgent(\n",
|
" agent = CodeAgent(\n",
|
||||||
|
@ -480,14 +522,12 @@
|
||||||
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
|
" additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
|
||||||
" max_steps=10,\n",
|
" max_steps=10,\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" answer_questions(eval_ds, agent, model_id, action_type)\n",
|
||||||
" answer_questions(eval_ds, file_name, agent, model_id, action_type)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" # Also evaluate vanilla model\n",
|
" # Also evaluate vanilla model\n",
|
||||||
" action_type = \"vanilla\"\n",
|
" action_type = \"vanilla\"\n",
|
||||||
" llm = LiteLLMModel(model_id)\n",
|
" llm = LiteLLMModel(model_id)\n",
|
||||||
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
|
" answer_questions(eval_ds, llm, model_id, action_type, is_vanilla_llm=True)"
|
||||||
" answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -539,58 +579,153 @@
|
||||||
"execution_count": 9,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Number of answers_subsets 54\n",
|
||||||
|
"Example of answers_subset Qwen__Qwen2.5-72B-Instruct__code__gaia\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
" warnings.warn(\n"
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n",
|
||||||
|
"/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n",
|
||||||
|
" warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>model_id</th>\n",
|
||||||
|
" <th>agent_action_type</th>\n",
|
||||||
|
" <th>source</th>\n",
|
||||||
|
" <th>acc</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
||||||
|
" <td>code</td>\n",
|
||||||
|
" <td>GAIA</td>\n",
|
||||||
|
" <td>28.12</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>1</th>\n",
|
||||||
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
||||||
|
" <td>code</td>\n",
|
||||||
|
" <td>MATH</td>\n",
|
||||||
|
" <td>76.00</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
||||||
|
" <td>code</td>\n",
|
||||||
|
" <td>SimpleQA</td>\n",
|
||||||
|
" <td>88.00</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
||||||
|
" <td>vanilla</td>\n",
|
||||||
|
" <td>GAIA</td>\n",
|
||||||
|
" <td>6.25</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>Qwen/Qwen2.5-72B-Instruct</td>\n",
|
||||||
|
" <td>vanilla</td>\n",
|
||||||
|
" <td>MATH</td>\n",
|
||||||
|
" <td>30.00</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" model_id agent_action_type source acc\n",
|
||||||
|
"0 Qwen/Qwen2.5-72B-Instruct code GAIA 28.12\n",
|
||||||
|
"1 Qwen/Qwen2.5-72B-Instruct code MATH 76.00\n",
|
||||||
|
"2 Qwen/Qwen2.5-72B-Instruct code SimpleQA 88.00\n",
|
||||||
|
"3 Qwen/Qwen2.5-72B-Instruct vanilla GAIA 6.25\n",
|
||||||
|
"4 Qwen/Qwen2.5-72B-Instruct vanilla MATH 30.00"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import glob\n",
|
"import datasets\n",
|
||||||
"\n",
|
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"res = []\n",
|
"# Choose the answers subsets to score:\n",
|
||||||
"for file_path in glob.glob(\"output/*.jsonl\"):\n",
|
"# answers_subsets = [\"meta-llama__Llama-3.1-8B-Instruct__code__gaia\"]\n",
|
||||||
" data = []\n",
|
"# or get all the answers subsets present in the ANSWERS_DATASET\n",
|
||||||
" with open(file_path) as f:\n",
|
"answers_subsets = datasets.get_dataset_config_names(ANSWERS_DATASET)\n",
|
||||||
" for line in f:\n",
|
"print(\"Number of answers_subsets\", len(answers_subsets))\n",
|
||||||
" try:\n",
|
"print(\"Example of answers_subset\", answers_subsets[0])\n",
|
||||||
" # Use standard json module instead of pandas.json to handle large numbers better\n",
|
|
||||||
" record = json.loads(line)\n",
|
|
||||||
" data.append(record)\n",
|
|
||||||
" except json.JSONDecodeError as e:\n",
|
|
||||||
" print(f\"Error parsing line in {file_path}: {e}\")\n",
|
|
||||||
" continue\n",
|
|
||||||
"\n",
|
|
||||||
" try:\n",
|
|
||||||
" smoldf = pd.DataFrame(data)\n",
|
|
||||||
" smoldf[\"action_type\"] = \"vanilla\" if \"-vanilla-\" in file_path else \"code\"\n",
|
|
||||||
" res.append(smoldf)\n",
|
|
||||||
" except Exception as e:\n",
|
|
||||||
" print(f\"Error creating DataFrame from {file_path}: {e}\")\n",
|
|
||||||
" continue\n",
|
|
||||||
"\n",
|
|
||||||
"result_df = pd.concat(res)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def get_correct(row):\n",
|
"result_df = score_answers(answers_subsets)\n",
|
||||||
" if row[\"source\"] == \"MATH\": # Checks the last number in answer\n",
|
"result_df[\"acc\"] = (result_df[\"acc\"] * 100).round(2)\n",
|
||||||
" numbers_answer = extract_numbers(str(row[\"answer\"]))\n",
|
"result_df.head()"
|
||||||
" if len(numbers_answer) == 0:\n",
|
|
||||||
" return False\n",
|
|
||||||
" return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n",
|
|
||||||
" else:\n",
|
|
||||||
" return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n",
|
|
||||||
"\n",
|
|
||||||
"result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue