Refactor and test final answer prompts (#595)

This commit is contained in:
Albert Villanova del Moral 2025-02-12 09:30:26 +01:00 committed by GitHub
parent 02b2b7ebb9
commit 833aec9198
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 132 additions and 42 deletions

View File

@ -65,3 +65,5 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n
[[autodoc]] smolagents.agents.PlanningPromptTemplate [[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate

View File

@ -163,3 +163,5 @@ model = OpenAIServerModel(
[[autodoc]] smolagents.agents.PlanningPromptTemplate [[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate

View File

@ -155,3 +155,5 @@ print(model(messages))
[[autodoc]] smolagents.agents.PlanningPromptTemplate [[autodoc]] smolagents.agents.PlanningPromptTemplate
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate [[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate

View File

@ -117,19 +117,34 @@ class ManagedAgentPromptTemplate(TypedDict):
report: str report: str
class FinalAnswerPromptTemplate(TypedDict):
"""
Prompt templates for the final answer.
Args:
pre_messages (`str`): Pre-messages prompt.
post_messages (`str`): Post-messages prompt.
"""
pre_messages: str
post_messages: str
class PromptTemplates(TypedDict): class PromptTemplates(TypedDict):
""" """
Prompt templates for the agent. Prompt templates for the agent.
Args: Args:
system_prompt (`str`): System prompt. system_prompt (`str`): System prompt.
planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template. planning ([`~agents.PlanningPromptTemplate`]): Planning prompt templates.
managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template. managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt templates.
final_answer ([`~agents.FinalAnswerPromptTemplate`]): Final answer prompt templates.
""" """
system_prompt: str system_prompt: str
planning: PlanningPromptTemplate planning: PlanningPromptTemplate
managed_agent: ManagedAgentPromptTemplate managed_agent: ManagedAgentPromptTemplate
final_answer: FinalAnswerPromptTemplate
EMPTY_PROMPT_TEMPLATES = PromptTemplates( EMPTY_PROMPT_TEMPLATES = PromptTemplates(
@ -143,6 +158,7 @@ EMPTY_PROMPT_TEMPLATES = PromptTemplates(
update_plan_post_messages="", update_plan_post_messages="",
), ),
managed_agent=ManagedAgentPromptTemplate(task="", report=""), managed_agent=ManagedAgentPromptTemplate(task="", report=""),
final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""),
) )
@ -290,46 +306,33 @@ class MultiStepAgent:
Returns: Returns:
`str`: Final answer to the task. `str`: Final answer to the task.
""" """
messages = [{"role": MessageRole.SYSTEM, "content": []}] messages = [
{
"role": MessageRole.SYSTEM,
"content": [
{
"type": "text",
"text": self.prompt_templates["final_answer"]["pre_messages"],
}
],
}
]
if images: if images:
messages[0]["content"] = [
{
"type": "text",
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
}
]
messages[0]["content"].append({"type": "image"}) messages[0]["content"].append({"type": "image"})
messages += self.write_memory_to_messages()[1:] messages += self.write_memory_to_messages()[1:]
messages += [ messages += [
{ {
"role": MessageRole.USER, "role": MessageRole.USER,
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": f"Based on the above, please provide an answer to the following user request:\n{task}", "text": populate_template(
} self.prompt_templates["final_answer"]["post_messages"], variables={"task": task}
], ),
} }
] ],
else: }
messages[0]["content"] = [ ]
{
"type": "text",
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
}
]
messages += self.write_memory_to_messages()[1:]
messages += [
{
"role": MessageRole.USER,
"content": [
{
"type": "text",
"text": f"Based on the above, please provide an answer to the following user request:\n{task}",
}
],
}
]
try: try:
chat_message: ChatMessage = self.model(messages) chat_message: ChatMessage = self.model(messages)
return chat_message.content return chat_message.content

View File

@ -318,4 +318,10 @@ managed_agent:
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
report: |- report: |-
Here is the final answer from your managed agent '{{name}}': Here is the final answer from your managed agent '{{name}}':
{{final_answer}} {{final_answer}}
final_answer:
pre_messages: |-
An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:
post_messages: |-
Based on the above, please provide an answer to the following user request:
{{task}}

View File

@ -261,4 +261,10 @@ managed_agent:
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback. And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
report: |- report: |-
Here is the final answer from your managed agent '{{name}}': Here is the final answer from your managed agent '{{name}}':
{{final_answer}} {{final_answer}}
final_answer:
pre_messages: |-
An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:
post_messages: |-
Based on the above, please provide an answer to the following user request:
{{task}}

View File

@ -29,6 +29,7 @@ from smolagents.agents import (
MultiStepAgent, MultiStepAgent,
ToolCall, ToolCall,
ToolCallingAgent, ToolCallingAgent,
populate_template,
) )
from smolagents.default_tools import PythonInterpreterTool from smolagents.default_tools import PythonInterpreterTool
from smolagents.memory import PlanningStep from smolagents.memory import PlanningStep
@ -771,6 +772,74 @@ class TestMultiStepAgent:
assert "type" in content assert "type" in content
assert "text" in content assert "text" in content
@pytest.mark.parametrize(
"images, expected_messages_list",
[
(
None,
[
[
{
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}],
},
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]},
]
],
),
(
["image1.png"],
[
[
{
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}, {"type": "image"}],
},
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]},
]
],
),
],
)
def test_provide_final_answer(self, images, expected_messages_list):
fake_model = MagicMock()
fake_model.return_value.content = "Final answer."
agent = CodeAgent(
tools=[],
model=fake_model,
)
task = "Test task"
final_answer = agent.provide_final_answer(task, images=images)
expected_message_texts = {
"FINAL_ANSWER_SYSTEM_PROMPT": agent.prompt_templates["final_answer"]["pre_messages"],
"FINAL_ANSWER_USER_PROMPT": populate_template(
agent.prompt_templates["final_answer"]["post_messages"], variables=dict(task=task)
),
}
for expected_messages in expected_messages_list:
for expected_message in expected_messages:
for expected_content in expected_message["content"]:
if "text" in expected_content:
expected_content["text"] = expected_message_texts[expected_content["text"]]
assert final_answer == "Final answer."
# Test calls to model
assert len(fake_model.call_args_list) == 1
for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list):
assert len(call_args.args) == 1
messages = call_args.args[0]
assert isinstance(messages, list)
assert len(messages) == len(expected_messages)
for message, expected_message in zip(messages, expected_messages):
assert isinstance(message, dict)
assert "role" in message
assert "content" in message
assert message["role"] in MessageRole.__members__.values()
assert message["role"] == expected_message["role"]
assert isinstance(message["content"], list)
assert len(message["content"]) == len(expected_message["content"])
for content, expected_content in zip(message["content"], expected_message["content"]):
assert content == expected_content
class TestCodeAgent: class TestCodeAgent:
@pytest.mark.parametrize("provide_run_summary", [False, True]) @pytest.mark.parametrize("provide_run_summary", [False, True])