From 1516ce8d74b3a97ef26b3f6a4d40aedf9f78ebb7 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 13 Feb 2025 10:13:25 +0100 Subject: [PATCH] Move plan user prompt to YAML and test text of plan prompts (#591) --- src/smolagents/agents.py | 39 +++++------ src/smolagents/prompts/code_agent.yaml | 6 ++ src/smolagents/prompts/toolcalling_agent.yaml | 6 ++ tests/test_agents.py | 66 +++++++++++++------ 4 files changed, 75 insertions(+), 42 deletions(-) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 72db9d2..753ae18 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -88,7 +88,8 @@ class PlanningPromptTemplate(TypedDict): Prompt templates for the planning step. Args: - initial_facts (`str`): Initial facts prompt. + initial_facts_pre_task (`str`): Initial facts pre-task prompt. + initial_facts_task (`str`): Initial facts task prompt. initial_plan (`str`): Initial plan prompt. update_facts_pre_messages (`str`): Update facts pre-messages prompt. update_facts_post_messages (`str`): Update facts post-messages prompt. @@ -96,7 +97,8 @@ class PlanningPromptTemplate(TypedDict): update_plan_post_messages (`str`): Update plan post-messages prompt. """ - initial_facts: str + initial_facts_pre_task: str + initial_facts_task: str initial_plan: str update_facts_pre_messages: str update_facts_post_messages: str @@ -524,26 +526,19 @@ You have been provided with these additional arguments, that you can access usin step (`int`): The number of the current step, used as an indication for the LLM. """ if is_first_step: - message_prompt_facts = { - "role": MessageRole.SYSTEM, - "content": [{"type": "text", "text": self.prompt_templates["planning"]["initial_facts"]}], - } - message_prompt_task = { - "role": MessageRole.USER, - "content": [ - { - "type": "text", - "text": textwrap.dedent( - f"""Here is the task: - ``` - {task} - ``` - Now begin!""" - ), - }, - ], - } - input_messages = [message_prompt_facts, message_prompt_task] + input_messages = [ + { + "role": MessageRole.USER, + "content": [ + { + "type": "text", + "text": populate_template( + self.prompt_templates["planning"]["initial_facts"], variables={"task": task} + ), + } + ], + }, + ] chat_message_facts: ChatMessage = self.model(input_messages) answer_facts = chat_message_facts.content diff --git a/src/smolagents/prompts/code_agent.yaml b/src/smolagents/prompts/code_agent.yaml index 852c4cf..dc9aa0b 100644 --- a/src/smolagents/prompts/code_agent.yaml +++ b/src/smolagents/prompts/code_agent.yaml @@ -196,6 +196,12 @@ planning: ### 2. Facts to look up ### 3. Facts to derive Do not add anything else. + + Here is the task: + ``` + {{task}} + ``` + Now begin! initial_plan : |- You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools. diff --git a/src/smolagents/prompts/toolcalling_agent.yaml b/src/smolagents/prompts/toolcalling_agent.yaml index 19d67da..7489b9f 100644 --- a/src/smolagents/prompts/toolcalling_agent.yaml +++ b/src/smolagents/prompts/toolcalling_agent.yaml @@ -139,6 +139,12 @@ planning: ### 2. Facts to look up ### 3. Facts to derive Do not add anything else. + + Here is the task: + ``` + {{task}} + ``` + Now begin! initial_plan : |- You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools. diff --git a/tests/test_agents.py b/tests/test_agents.py index 58c6315..d4ce8d7 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -697,11 +697,8 @@ class TestMultiStepAgent: ( 1, [ - [ - {"role": MessageRole.SYSTEM, "content": [{"type": "text", "text": "FACTS_SYSTEM_PROMPT"}]}, - {"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_USER_PROMPT"}]}, - ], - [{"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_USER_PROMPT"}]}], + [{"role": MessageRole.USER, "content": [{"type": "text", "text": "INITIAL_FACTS_USER_PROMPT"}]}], + [{"role": MessageRole.USER, "content": [{"type": "text", "text": "INITIAL_PLAN_USER_PROMPT"}]}], ], ), ( @@ -710,22 +707,22 @@ class TestMultiStepAgent: [ { "role": MessageRole.SYSTEM, - "content": [{"type": "text", "text": "FACTS_UPDATE_SYSTEM_PROMPT"}], + "content": [{"type": "text", "text": "UPDATE_FACTS_SYSTEM_PROMPT"}], }, - {"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_UPDATE_USER_PROMPT"}]}, + {"role": MessageRole.USER, "content": [{"type": "text", "text": "UPDATE_FACTS_USER_PROMPT"}]}, ], [ { "role": MessageRole.SYSTEM, - "content": [{"type": "text", "text": "PLAN_UPDATE_SYSTEM_PROMPT"}], + "content": [{"type": "text", "text": "UPDATE_PLAN_SYSTEM_PROMPT"}], }, - {"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_UPDATE_USER_PROMPT"}]}, + {"role": MessageRole.USER, "content": [{"type": "text", "text": "UPDATE_PLAN_USER_PROMPT"}]}, ], ], ), ], ) - def test_planning_step_first_step(self, step, expected_messages_list): + def test_planning_step(self, step, expected_messages_list): fake_model = MagicMock() agent = CodeAgent( tools=[], @@ -733,6 +730,39 @@ class TestMultiStepAgent: ) task = "Test task" agent.planning_step(task, is_first_step=(step == 1), step=step) + expected_message_texts = { + "INITIAL_FACTS_USER_PROMPT": populate_template( + agent.prompt_templates["planning"]["initial_facts"], variables=dict(task=task) + ), + "INITIAL_PLAN_USER_PROMPT": populate_template( + agent.prompt_templates["planning"]["initial_plan"], + variables=dict( + task=task, + tools=agent.tools, + managed_agents=agent.managed_agents, + answer_facts=agent.memory.steps[0].model_output_message_facts.content, + ), + ), + "UPDATE_FACTS_SYSTEM_PROMPT": agent.prompt_templates["planning"]["update_facts_pre_messages"], + "UPDATE_FACTS_USER_PROMPT": agent.prompt_templates["planning"]["update_facts_post_messages"], + "UPDATE_PLAN_SYSTEM_PROMPT": populate_template( + agent.prompt_templates["planning"]["update_plan_pre_messages"], variables=dict(task=task) + ), + "UPDATE_PLAN_USER_PROMPT": populate_template( + agent.prompt_templates["planning"]["update_plan_post_messages"], + variables=dict( + task=task, + tools=agent.tools, + managed_agents=agent.managed_agents, + facts_update=agent.memory.steps[0].model_output_message_facts.content, + remaining_steps=agent.max_steps - step, + ), + ), + } + for expected_messages in expected_messages_list: + for expected_message in expected_messages: + for expected_content in expected_message["content"]: + expected_content["text"] = expected_message_texts[expected_content["text"]] assert len(agent.memory.steps) == 1 planning_step = agent.memory.steps[0] assert isinstance(planning_step, PlanningStep) @@ -744,14 +774,12 @@ class TestMultiStepAgent: assert isinstance(message, dict) assert "role" in message assert "content" in message - assert isinstance(message["role"], MessageRole) + assert message["role"] in MessageRole.__members__.values() assert message["role"] == expected_message["role"] assert isinstance(message["content"], list) assert len(message["content"]) == 1 - for content in message["content"]: - assert isinstance(content, dict) - assert "type" in content - assert "text" in content + for content, expected_content in zip(message["content"], expected_message["content"]): + assert content == expected_content # Test calls to model assert len(fake_model.call_args_list) == 2 for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list): @@ -763,14 +791,12 @@ class TestMultiStepAgent: assert isinstance(message, dict) assert "role" in message assert "content" in message - assert isinstance(message["role"], MessageRole) + assert message["role"] in MessageRole.__members__.values() assert message["role"] == expected_message["role"] assert isinstance(message["content"], list) assert len(message["content"]) == 1 - for content in message["content"]: - assert isinstance(content, dict) - assert "type" in content - assert "text" in content + for content, expected_content in zip(message["content"], expected_message["content"]): + assert content == expected_content @pytest.mark.parametrize( "images, expected_messages_list",