From 6d0e4e49fc329a17e6ff33bd1b7933154361442d Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 30 Jan 2025 19:21:32 +0100 Subject: [PATCH] Update benchmark with Hub datasets (#412) * Use smolagents-benchmark Hub datasets * Push results to Hub * Fix style * Add Constants section at the top * Set DATE as constant --- examples/benchmark.ipynb | 685 +++++++++++++++++++++++---------------- 1 file changed, 410 insertions(+), 275 deletions(-) diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb index 065adce..bd3e11a 100644 --- a/examples/benchmark.ipynb +++ b/examples/benchmark.ipynb @@ -16,190 +16,43 @@ } ], "source": [ - "!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
questionsourcetrue_answertrue_reasoning
0If Eliud Kipchoge could maintain his record-ma...GAIA17None
1How many studio albums were published by Merce...GAIA3None
2Here's a fun riddle that I think you'll enjoy....GAIA3None
3My family reunion is this week, and I was assi...GAIA2None
4In Emily Midkiff's June 2014 article in a jour...GAIAfluffyNone
...............
127What year was the municipality of San Carlos, ...SimpleQA1786['https://en.wikipedia.org/wiki/San_Carlos,_An...
128In which year was Maria Elena Walsh named Illu...SimpleQA1985['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele...
129What is the durability of the Istarelle spear ...SimpleQA800['http://demonssouls.wikidot.com/spear', 'http...
130What is the number of the executive order that...SimpleQA7034['https://www.loc.gov/collections/federal-thea...
131Within plus or minus one minute, when was Marq...SimpleQA77['https://www.fifa.com/fifaplus/en/match-centr...
\n", - "

132 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " question source true_answer \\\n", - "0 If Eliud Kipchoge could maintain his record-ma... GAIA 17 \n", - "1 How many studio albums were published by Merce... GAIA 3 \n", - "2 Here's a fun riddle that I think you'll enjoy.... GAIA 3 \n", - "3 My family reunion is this week, and I was assi... GAIA 2 \n", - "4 In Emily Midkiff's June 2014 article in a jour... GAIA fluffy \n", - ".. ... ... ... \n", - "127 What year was the municipality of San Carlos, ... SimpleQA 1786 \n", - "128 In which year was Maria Elena Walsh named Illu... SimpleQA 1985 \n", - "129 What is the durability of the Istarelle spear ... SimpleQA 800 \n", - "130 What is the number of the executive order that... SimpleQA 7034 \n", - "131 Within plus or minus one minute, when was Marq... SimpleQA 77 \n", - "\n", - " true_reasoning \n", - "0 None \n", - "1 None \n", - "2 None \n", - "3 None \n", - "4 None \n", - ".. ... \n", - "127 ['https://en.wikipedia.org/wiki/San_Carlos,_An... \n", - "128 ['https://en.wikipedia.org/wiki/Mar%C3%ADa_Ele... \n", - "129 ['http://demonssouls.wikidot.com/spear', 'http... \n", - "130 ['https://www.loc.gov/collections/federal-thea... \n", - "131 ['https://www.fifa.com/fifaplus/en/match-centr... \n", - "\n", - "[132 rows x 4 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import datasets\n", - "import pandas as pd\n", - "\n", - "\n", - "eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n", - "pd.DataFrame(eval_ds)" + "!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Define utilities and tools\n", - "To run the SERPAPI tool, you will need to have a [SerpAPI](https://serpapi.com/dashboard) API key: for this you need a paid account." + "## Constants and utilities/tools" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "# Benchmark date\n", + "# - set a concrete date:\n", + "DATE = \"2024-12-26\"\n", + "# - or use default: today\n", + "# DATE = None\n", + "\n", + "# Evaluation dataset\n", + "EVAL_DATASET = \"smolagents-benchmark/benchmark-v1\"\n", + "\n", + "# Answers dataset: it must be a gated dataset; required to score the answers\n", + "ANSWERS_DATASET = \"smolagents-benchmark/answers\"\n", + "# Whether to push the answers dataset to the Hub\n", + "PUSH_ANSWERS_DATASET_TO_HUB = True\n", + "\n", + "# Results dataset\n", + "RESULTS_DATASET = \"smolagents-benchmark/results\"\n", + "# Whether to push the results dataset to the Hub\n", + "PUSH_RESULTS_DATASET_TO_HUB = True\n", + "\n", + "\n", + "import datetime\n", "import json\n", "import os\n", "import re\n", @@ -208,6 +61,7 @@ "import warnings\n", "from typing import List\n", "\n", + "import datasets\n", "from dotenv import load_dotenv\n", "from tqdm import tqdm\n", "\n", @@ -234,60 +88,85 @@ " return str(obj)\n", "\n", "\n", - "def answer_questions(eval_ds, file_name, agent, model_id, action_type, is_vanilla_llm=False):\n", - " answered_questions = []\n", - " if os.path.exists(file_name):\n", - " with open(file_name, \"r\") as f:\n", - " for line in f:\n", - " answered_questions.append(json.loads(line)[\"question\"])\n", + "def answer_questions(\n", + " eval_ds,\n", + " agent,\n", + " model_id,\n", + " action_type,\n", + " is_vanilla_llm=False,\n", + " date=DATE,\n", + " output_dir=\"output\",\n", + " push_to_hub_dataset=ANSWERS_DATASET if PUSH_ANSWERS_DATASET_TO_HUB else None,\n", + "):\n", + " date = date or datetime.date.today().isoformat()\n", "\n", - " for _, example in tqdm(enumerate(eval_ds), total=len(eval_ds)):\n", - " try:\n", - " question = example[\"question\"]\n", - " if example[\"source\"] == \"SimpleQA\":\n", - " question += \" Answer with only the final number.\"\n", - " if example[\"source\"] == \"MATH\":\n", - " question += \" Write code, not latex.\"\n", - " if question in answered_questions:\n", - " continue\n", - " start_time = time.time()\n", + " for task in eval_ds:\n", + " file_name = f\"output/{model_id.replace('/', '__')}__{action_type}__{task}__{date}.jsonl\"\n", + " answered_questions = []\n", + " if os.path.exists(file_name):\n", + " with open(file_name, \"r\") as f:\n", + " for line in f:\n", + " answered_questions.append(json.loads(line)[\"question\"])\n", "\n", - " if is_vanilla_llm:\n", - " llm = agent\n", - " answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n", - " token_count = {\n", - " \"input\": llm.last_input_token_count,\n", - " \"output\": llm.last_output_token_count,\n", + " for _, example in tqdm(enumerate(eval_ds[task]), total=len(eval_ds[task])):\n", + " try:\n", + " question = example[\"question\"]\n", + " if example[\"source\"] == \"SimpleQA\":\n", + " question += \" Answer with only the final number.\"\n", + " if example[\"source\"] == \"MATH\":\n", + " question += \" Write code, not latex.\"\n", + " if question in answered_questions:\n", + " continue\n", + " start_time = time.time()\n", + "\n", + " if is_vanilla_llm:\n", + " llm = agent\n", + " answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n", + " token_count = {\n", + " \"input\": llm.last_input_token_count,\n", + " \"output\": llm.last_output_token_count,\n", + " }\n", + " intermediate_steps = str([])\n", + " else:\n", + " answer = str(agent.run(question))\n", + " token_count = agent.monitor.get_total_token_counts()\n", + " intermediate_steps = str(agent.logs)\n", + " # Remove memory from logs to make them more compact.\n", + " for step in agent.logs:\n", + " if isinstance(step, ActionStep):\n", + " step.agent_memory = None\n", + "\n", + " end_time = time.time()\n", + " annotated_example = {\n", + " \"model_id\": model_id,\n", + " \"agent_action_type\": action_type,\n", + " \"question\": question,\n", + " \"answer\": answer,\n", + " \"true_answer\": example[\"true_answer\"],\n", + " \"source\": example[\"source\"],\n", + " \"intermediate_steps\": intermediate_steps,\n", + " \"start_time\": start_time,\n", + " \"end_time\": end_time,\n", + " \"token_counts\": token_count,\n", " }\n", - " intermediate_steps = str([])\n", - " else:\n", - " answer = str(agent.run(question))\n", - " token_count = agent.monitor.get_total_token_counts()\n", - " intermediate_steps = str(agent.logs)\n", - " # Remove memory from logs to make them more compact.\n", - " for step in agent.logs:\n", - " if isinstance(step, ActionStep):\n", - " step.agent_memory = None\n", "\n", - " end_time = time.time()\n", - " annotated_example = {\n", - " \"model_id\": model_id,\n", - " \"agent_action_type\": action_type,\n", - " \"question\": question,\n", - " \"answer\": answer,\n", - " \"true_answer\": example[\"true_answer\"],\n", - " \"source\": example[\"source\"],\n", - " \"intermediate_steps\": intermediate_steps,\n", - " \"start_time\": start_time,\n", - " \"end_time\": end_time,\n", - " \"token_counts\": token_count,\n", - " }\n", + " with open(file_name, \"a\") as f:\n", + " json.dump(annotated_example, f, default=serialize_agent_error)\n", + " f.write(\"\\n\") # add a newline for JSONL format\n", + " except Exception as e:\n", + " print(\"Failed:\", e)\n", "\n", - " with open(file_name, \"a\") as f:\n", - " json.dump(annotated_example, f, default=serialize_agent_error)\n", - " f.write(\"\\n\") # add a newline for JSONL format\n", - " except Exception as e:\n", - " print(\"Failed:\", e)\n", + " if push_to_hub_dataset:\n", + " ds = datasets.Dataset.from_pandas(pd.read_json(file_name, lines=True), split=\"test\", preserve_index=False)\n", + " config = f\"{model_id.replace('/', '__')}__{action_type}__{task}\"\n", + " data_dir = f\"{model_id}/{action_type}/{task}/{date}\"\n", + " ds.push_to_hub(\n", + " push_to_hub_dataset,\n", + " config_name=config,\n", + " data_dir=data_dir,\n", + " split=\"test\",\n", + " commit_message=f\"Upload {config}\",\n", + " )\n", "\n", "\n", "def normalize_number_str(number_str: str) -> float:\n", @@ -382,7 +261,172 @@ " return all(comparisons)\n", "\n", " else: # if gt is a str\n", - " return normalize_str(model_answer) == normalize_str(ground_truth)" + " return normalize_str(model_answer) == normalize_str(ground_truth)\n", + "\n", + "\n", + "def get_correct(row):\n", + " if row[\"source\"] == \"MATH\": # Checks the last number in answer\n", + " numbers_answer = extract_numbers(str(row[\"answer\"]))\n", + " if len(numbers_answer) == 0:\n", + " return False\n", + " return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n", + " else:\n", + " return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n", + "\n", + "\n", + "def score_answers(\n", + " answers_subsets,\n", + " answers_dataset=ANSWERS_DATASET,\n", + " date=DATE,\n", + " push_to_hub_dataset=RESULTS_DATASET if PUSH_RESULTS_DATASET_TO_HUB else None,\n", + " set_default=True,\n", + "):\n", + " if not answers_dataset:\n", + " raise ValueError(\"Pass 'answers_dataset' to load the answers from it\")\n", + " date = date or datetime.date.today().isoformat()\n", + " results = []\n", + " for answers_subset in answers_subsets:\n", + " *model_id, action_type, task = answers_subset.split(\"__\")\n", + " model_id = \"/\".join(model_id)\n", + " ds = datasets.load_dataset(answers_dataset, answers_subset, split=\"test\")\n", + " df = ds.to_pandas()\n", + " df[\"correct\"] = df.apply(get_correct, axis=1)\n", + " acc = df[\"correct\"].mean().item()\n", + " result = df.loc[0, [\"model_id\", \"agent_action_type\", \"source\"]].to_dict()\n", + " result[\"acc\"] = acc\n", + " results.append(result)\n", + " df = pd.DataFrame(results)\n", + "\n", + " if push_to_hub_dataset:\n", + " ds = datasets.Dataset.from_pandas(df)\n", + " config = date\n", + " set_default = set_default\n", + " ds.push_to_hub(\n", + " push_to_hub_dataset, config_name=config, set_default=set_default, commit_message=f\"Upload {config} results\"\n", + " )\n", + " return df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['gaia', 'math', 'simpleqa']\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
questionsourcetrue_answertrue_reasoning
0What year was the municipality of Ramiriquí, B...SimpleQA1541['https://en.wikipedia.org/wiki/Ramiriqu%C3%AD...
1In what year did Hjalmar Hvam invent a mechani...SimpleQA1937['https://www.kgw.com/article/features/portlan...
2In which year did Fayaz A. Malik (an Indian ph...SimpleQA2009['https://en.wikipedia.org/wiki/Fayaz_A._Malik...
3In which year was John B. Goodenough elected a...SimpleQA2010['https://en.wikipedia.org/wiki/John_B._Gooden...
4In which year did Atul Gawande earn an M.A. in...SimpleQA1989['https://en.wikipedia.org/wiki/Atul_Gawande',...
\n", + "
" + ], + "text/plain": [ + " question source true_answer \\\n", + "0 What year was the municipality of Ramiriquí, B... SimpleQA 1541 \n", + "1 In what year did Hjalmar Hvam invent a mechani... SimpleQA 1937 \n", + "2 In which year did Fayaz A. Malik (an Indian ph... SimpleQA 2009 \n", + "3 In which year was John B. Goodenough elected a... SimpleQA 2010 \n", + "4 In which year did Atul Gawande earn an M.A. in... SimpleQA 1989 \n", + "\n", + " true_reasoning \n", + "0 ['https://en.wikipedia.org/wiki/Ramiriqu%C3%AD... \n", + "1 ['https://www.kgw.com/article/features/portlan... \n", + "2 ['https://en.wikipedia.org/wiki/Fayaz_A._Malik... \n", + "3 ['https://en.wikipedia.org/wiki/John_B._Gooden... \n", + "4 ['https://en.wikipedia.org/wiki/Atul_Gawande',... " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "\n", + "# Choose the tasks to evaluate on:\n", + "# tasks = [\"gaia\"]\n", + "# or evaluate on all tasks: [\"gaia\", \"math\", \"simpleqa\"]\n", + "tasks = datasets.get_dataset_config_names(EVAL_DATASET)\n", + "print(tasks)\n", + "\n", + "\n", + "eval_ds = {task: datasets.load_dataset(EVAL_DATASET, task, split=\"test\") for task in tasks}\n", + "pd.DataFrame(eval_ds[\"simpleqa\"]).head()" ] }, { @@ -412,16 +456,16 @@ " # \"meta-llama/Llama-3.1-70B-Instruct\",\n", "]\n", "\n", + "\n", "for model_id in open_model_ids:\n", " print(f\"Evaluating '{model_id}'...\")\n", - " # action_type = \"tool_calling\"\n", + " # action_type = \"tool-calling\"\n", " # agent = ToolCallingAgent(\n", " # tools=[GoogleSearchTool(), VisitWebpageTool(), PythonInterpreterTool()],\n", " # model=HfApiModel(model_id),\n", " # max_steps=10,\n", " # )\n", - " # file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", - " # answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", + " # answer_questions(eval_ds, agent, model_id, action_type)\n", "\n", " action_type = \"code\"\n", " agent = CodeAgent(\n", @@ -430,21 +474,19 @@ " additional_authorized_imports=[\"numpy\", \"sympy\"],\n", " max_steps=10,\n", " )\n", - " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", - " answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", + " answer_questions(eval_ds, agent, model_id, action_type)\n", "\n", " # Also evaluate vanilla model\n", " action_type = \"vanilla\"\n", " llm = HfApiModel(model_id)\n", - " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", - " answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)" + " answer_questions(eval_ds, llm, model_id, action_type, is_vanilla_llm=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Closed models" + "### Closed models" ] }, { @@ -458,9 +500,10 @@ "\n", "litellm_model_ids = [\"gpt-4o\", \"anthropic/claude-3-5-sonnet-latest\"]\n", "\n", + "\n", "for model_id in litellm_model_ids:\n", " print(f\"Evaluating '{model_id}'...\")\n", - " action_type = \"tool_calling\"\n", + " action_type = \"tool-calling\"\n", " agent = ToolCallingAgent(\n", " tools=[\n", " GoogleSearchTool(),\n", @@ -470,8 +513,7 @@ " model=LiteLLMModel(model_id),\n", " max_steps=10,\n", " )\n", - " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", - " answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", + " answer_questions(eval_ds, agent, model_id, action_type)\n", "\n", " action_type = \"code\"\n", " agent = CodeAgent(\n", @@ -480,14 +522,12 @@ " additional_authorized_imports=[\"numpy\", \"sympy\"],\n", " max_steps=10,\n", " )\n", - " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", - " answer_questions(eval_ds, file_name, agent, model_id, action_type)\n", + " answer_questions(eval_ds, agent, model_id, action_type)\n", "\n", " # Also evaluate vanilla model\n", " action_type = \"vanilla\"\n", " llm = LiteLLMModel(model_id)\n", - " file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n", - " answer_questions(eval_ds, file_name, llm, model_id, action_type, is_vanilla_llm=True)" + " answer_questions(eval_ds, llm, model_id, action_type, is_vanilla_llm=True)" ] }, { @@ -539,58 +579,153 @@ "execution_count": 9, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of answers_subsets 54\n", + "Example of answers_subset Qwen__Qwen2.5-72B-Instruct__code__gaia\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n", - " warnings.warn(\n" + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n", + "/tmp/ipykernel_640885/2542893079.py:194: UserWarning: Answer lists have different lengths, returning False.\n", + " warnings.warn(\"Answer lists have different lengths, returning False.\", UserWarning)\n" ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
model_idagent_action_typesourceacc
0Qwen/Qwen2.5-72B-InstructcodeGAIA28.12
1Qwen/Qwen2.5-72B-InstructcodeMATH76.00
2Qwen/Qwen2.5-72B-InstructcodeSimpleQA88.00
3Qwen/Qwen2.5-72B-InstructvanillaGAIA6.25
4Qwen/Qwen2.5-72B-InstructvanillaMATH30.00
\n", + "
" + ], + "text/plain": [ + " model_id agent_action_type source acc\n", + "0 Qwen/Qwen2.5-72B-Instruct code GAIA 28.12\n", + "1 Qwen/Qwen2.5-72B-Instruct code MATH 76.00\n", + "2 Qwen/Qwen2.5-72B-Instruct code SimpleQA 88.00\n", + "3 Qwen/Qwen2.5-72B-Instruct vanilla GAIA 6.25\n", + "4 Qwen/Qwen2.5-72B-Instruct vanilla MATH 30.00" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "import glob\n", - "\n", + "import datasets\n", "import pandas as pd\n", "\n", "\n", - "res = []\n", - "for file_path in glob.glob(\"output/*.jsonl\"):\n", - " data = []\n", - " with open(file_path) as f:\n", - " for line in f:\n", - " try:\n", - " # Use standard json module instead of pandas.json to handle large numbers better\n", - " record = json.loads(line)\n", - " data.append(record)\n", - " except json.JSONDecodeError as e:\n", - " print(f\"Error parsing line in {file_path}: {e}\")\n", - " continue\n", - "\n", - " try:\n", - " smoldf = pd.DataFrame(data)\n", - " smoldf[\"action_type\"] = \"vanilla\" if \"-vanilla-\" in file_path else \"code\"\n", - " res.append(smoldf)\n", - " except Exception as e:\n", - " print(f\"Error creating DataFrame from {file_path}: {e}\")\n", - " continue\n", - "\n", - "result_df = pd.concat(res)\n", + "# Choose the answers subsets to score:\n", + "# answers_subsets = [\"meta-llama__Llama-3.1-8B-Instruct__code__gaia\"]\n", + "# or get all the answers subsets present in the ANSWERS_DATASET\n", + "answers_subsets = datasets.get_dataset_config_names(ANSWERS_DATASET)\n", + "print(\"Number of answers_subsets\", len(answers_subsets))\n", + "print(\"Example of answers_subset\", answers_subsets[0])\n", "\n", "\n", - "def get_correct(row):\n", - " if row[\"source\"] == \"MATH\": # Checks the last number in answer\n", - " numbers_answer = extract_numbers(str(row[\"answer\"]))\n", - " if len(numbers_answer) == 0:\n", - " return False\n", - " return float(numbers_answer[-1]) == float(row[\"true_answer\"])\n", - " else:\n", - " return get_question_score_gaia(str(row[\"answer\"]), str(row[\"true_answer\"]))\n", - "\n", - "\n", - "result_df[\"correct\"] = result_df.apply(get_correct, axis=1)\n", - "\n", - "result_df = (result_df.groupby([\"model_id\", \"source\", \"action_type\"])[[\"correct\"]].mean() * 100).round(1).reset_index()" + "result_df = score_answers(answers_subsets)\n", + "result_df[\"acc\"] = (result_df[\"acc\"] * 100).round(2)\n", + "result_df.head()" ] }, {