Refactor and test final answer prompts (#595)

This commit is contained in:
Albert Villanova del Moral 2025-02-12 09:30:26 +01:00 committed by GitHub
parent 02b2b7ebb9
commit 833aec9198
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 132 additions and 42 deletions

View File

@ -65,3 +65,5 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n
[[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate

View File

@ -163,3 +163,5 @@ model = OpenAIServerModel(
[[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate

View File

@ -155,3 +155,5 @@ print(model(messages))
[[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate

View File

@ -117,19 +117,34 @@ class ManagedAgentPromptTemplate(TypedDict):
report: str
class FinalAnswerPromptTemplate(TypedDict):
"""
Prompt templates for the final answer.
Args:
pre_messages (`str`): Pre-messages prompt.
post_messages (`str`): Post-messages prompt.
"""
pre_messages: str
post_messages: str
class PromptTemplates(TypedDict):
"""
Prompt templates for the agent.
Args:
system_prompt (`str`): System prompt.
planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template.
managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template.
planning ([`~agents.PlanningPromptTemplate`]): Planning prompt templates.
managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt templates.
final_answer ([`~agents.FinalAnswerPromptTemplate`]): Final answer prompt templates.
"""
system_prompt: str
planning: PlanningPromptTemplate
managed_agent: ManagedAgentPromptTemplate
final_answer: FinalAnswerPromptTemplate
EMPTY_PROMPT_TEMPLATES = PromptTemplates(
@ -143,6 +158,7 @@ EMPTY_PROMPT_TEMPLATES = PromptTemplates(
update_plan_post_messages="",
),
managed_agent=ManagedAgentPromptTemplate(task="", report=""),
final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""),
)
@ -290,14 +306,18 @@ class MultiStepAgent:
Returns:
`str`: Final answer to the task.
"""
messages = [{"role": MessageRole.SYSTEM, "content": []}]
if images:
messages[0]["content"] = [
messages = [
{
"role": MessageRole.SYSTEM,
"content": [
{
"type": "text",
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
"text": self.prompt_templates["final_answer"]["pre_messages"],
}
],
}
]
if images:
messages[0]["content"].append({"type": "image"})
messages += self.write_memory_to_messages()[1:]
messages += [
@ -306,26 +326,9 @@ class MultiStepAgent:
"content": [
{
"type": "text",
"text": f"Based on the above, please provide an answer to the following user request:\n{task}",
}
],
}
]
else:
messages[0]["content"] = [
{
"type": "text",
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
}
]
messages += self.write_memory_to_messages()[1:]
messages += [
{
"role": MessageRole.USER,
"content": [
{
"type": "text",
"text": f"Based on the above, please provide an answer to the following user request:\n{task}",
"text": populate_template(
self.prompt_templates["final_answer"]["post_messages"], variables={"task": task}
),
}
],
}

View File

@ -319,3 +319,9 @@ managed_agent:
report: |-
Here is the final answer from your managed agent '{{name}}':
{{final_answer}}
final_answer:
pre_messages: |-
An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:
post_messages: |-
Based on the above, please provide an answer to the following user request:
{{task}}

View File

@ -262,3 +262,9 @@ managed_agent:
report: |-
Here is the final answer from your managed agent '{{name}}':
{{final_answer}}
final_answer:
pre_messages: |-
An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:
post_messages: |-
Based on the above, please provide an answer to the following user request:
{{task}}

View File

@ -29,6 +29,7 @@ from smolagents.agents import (
MultiStepAgent,
ToolCall,
ToolCallingAgent,
populate_template,
)
from smolagents.default_tools import PythonInterpreterTool
from smolagents.memory import PlanningStep
@ -771,6 +772,74 @@ class TestMultiStepAgent:
assert "type" in content
assert "text" in content
@pytest.mark.parametrize(
"images, expected_messages_list",
[
(
None,
[
[
{
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}],
},
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]},
]
],
),
(
["image1.png"],
[
[
{
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}, {"type": "image"}],
},
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]},
]
],
),
],
)
def test_provide_final_answer(self, images, expected_messages_list):
fake_model = MagicMock()
fake_model.return_value.content = "Final answer."
agent = CodeAgent(
tools=[],
model=fake_model,
)
task = "Test task"
final_answer = agent.provide_final_answer(task, images=images)
expected_message_texts = {
"FINAL_ANSWER_SYSTEM_PROMPT": agent.prompt_templates["final_answer"]["pre_messages"],
"FINAL_ANSWER_USER_PROMPT": populate_template(
agent.prompt_templates["final_answer"]["post_messages"], variables=dict(task=task)
),
}
for expected_messages in expected_messages_list:
for expected_message in expected_messages:
for expected_content in expected_message["content"]:
if "text" in expected_content:
expected_content["text"] = expected_message_texts[expected_content["text"]]
assert final_answer == "Final answer."
# Test calls to model
assert len(fake_model.call_args_list) == 1
for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list):
assert len(call_args.args) == 1
messages = call_args.args[0]
assert isinstance(messages, list)
assert len(messages) == len(expected_messages)
for message, expected_message in zip(messages, expected_messages):
assert isinstance(message, dict)
assert "role" in message
assert "content" in message
assert message["role"] in MessageRole.__members__.values()
assert message["role"] == expected_message["role"]
assert isinstance(message["content"], list)
assert len(message["content"]) == len(expected_message["content"])
for content, expected_content in zip(message["content"], expected_message["content"]):
assert content == expected_content
class TestCodeAgent:
@pytest.mark.parametrize("provide_run_summary", [False, True])