Refactor and test final answer prompts (#595)
This commit is contained in:
parent
02b2b7ebb9
commit
833aec9198
|
@ -65,3 +65,5 @@ _This class is deprecated since 1.8.0: now you simply need to pass attributes `n
|
||||||
[[autodoc]] smolagents.agents.PlanningPromptTemplate
|
[[autodoc]] smolagents.agents.PlanningPromptTemplate
|
||||||
|
|
||||||
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
|
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate
|
||||||
|
|
|
@ -163,3 +163,5 @@ model = OpenAIServerModel(
|
||||||
[[autodoc]] smolagents.agents.PlanningPromptTemplate
|
[[autodoc]] smolagents.agents.PlanningPromptTemplate
|
||||||
|
|
||||||
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
|
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate
|
||||||
|
|
|
@ -155,3 +155,5 @@ print(model(messages))
|
||||||
[[autodoc]] smolagents.agents.PlanningPromptTemplate
|
[[autodoc]] smolagents.agents.PlanningPromptTemplate
|
||||||
|
|
||||||
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
|
[[autodoc]] smolagents.agents.ManagedAgentPromptTemplate
|
||||||
|
|
||||||
|
[[autodoc]] smolagents.agents.FinalAnswerPromptTemplate
|
||||||
|
|
|
@ -117,19 +117,34 @@ class ManagedAgentPromptTemplate(TypedDict):
|
||||||
report: str
|
report: str
|
||||||
|
|
||||||
|
|
||||||
|
class FinalAnswerPromptTemplate(TypedDict):
|
||||||
|
"""
|
||||||
|
Prompt templates for the final answer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pre_messages (`str`): Pre-messages prompt.
|
||||||
|
post_messages (`str`): Post-messages prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pre_messages: str
|
||||||
|
post_messages: str
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplates(TypedDict):
|
class PromptTemplates(TypedDict):
|
||||||
"""
|
"""
|
||||||
Prompt templates for the agent.
|
Prompt templates for the agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
system_prompt (`str`): System prompt.
|
system_prompt (`str`): System prompt.
|
||||||
planning ([`~agents.PlanningPromptTemplate`]): Planning prompt template.
|
planning ([`~agents.PlanningPromptTemplate`]): Planning prompt templates.
|
||||||
managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt template.
|
managed_agent ([`~agents.ManagedAgentPromptTemplate`]): Managed agent prompt templates.
|
||||||
|
final_answer ([`~agents.FinalAnswerPromptTemplate`]): Final answer prompt templates.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
system_prompt: str
|
system_prompt: str
|
||||||
planning: PlanningPromptTemplate
|
planning: PlanningPromptTemplate
|
||||||
managed_agent: ManagedAgentPromptTemplate
|
managed_agent: ManagedAgentPromptTemplate
|
||||||
|
final_answer: FinalAnswerPromptTemplate
|
||||||
|
|
||||||
|
|
||||||
EMPTY_PROMPT_TEMPLATES = PromptTemplates(
|
EMPTY_PROMPT_TEMPLATES = PromptTemplates(
|
||||||
|
@ -143,6 +158,7 @@ EMPTY_PROMPT_TEMPLATES = PromptTemplates(
|
||||||
update_plan_post_messages="",
|
update_plan_post_messages="",
|
||||||
),
|
),
|
||||||
managed_agent=ManagedAgentPromptTemplate(task="", report=""),
|
managed_agent=ManagedAgentPromptTemplate(task="", report=""),
|
||||||
|
final_answer=FinalAnswerPromptTemplate(pre_messages="", post_messages=""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -290,46 +306,33 @@ class MultiStepAgent:
|
||||||
Returns:
|
Returns:
|
||||||
`str`: Final answer to the task.
|
`str`: Final answer to the task.
|
||||||
"""
|
"""
|
||||||
messages = [{"role": MessageRole.SYSTEM, "content": []}]
|
messages = [
|
||||||
|
{
|
||||||
|
"role": MessageRole.SYSTEM,
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": self.prompt_templates["final_answer"]["pre_messages"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
if images:
|
if images:
|
||||||
messages[0]["content"] = [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
messages[0]["content"].append({"type": "image"})
|
messages[0]["content"].append({"type": "image"})
|
||||||
messages += self.write_memory_to_messages()[1:]
|
messages += self.write_memory_to_messages()[1:]
|
||||||
messages += [
|
messages += [
|
||||||
{
|
{
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
"text": populate_template(
|
||||||
}
|
self.prompt_templates["final_answer"]["post_messages"], variables={"task": task}
|
||||||
],
|
),
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
else:
|
}
|
||||||
messages[0]["content"] = [
|
]
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
messages += self.write_memory_to_messages()[1:]
|
|
||||||
messages += [
|
|
||||||
{
|
|
||||||
"role": MessageRole.USER,
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
try:
|
try:
|
||||||
chat_message: ChatMessage = self.model(messages)
|
chat_message: ChatMessage = self.model(messages)
|
||||||
return chat_message.content
|
return chat_message.content
|
||||||
|
|
|
@ -318,4 +318,10 @@ managed_agent:
|
||||||
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
|
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
|
||||||
report: |-
|
report: |-
|
||||||
Here is the final answer from your managed agent '{{name}}':
|
Here is the final answer from your managed agent '{{name}}':
|
||||||
{{final_answer}}
|
{{final_answer}}
|
||||||
|
final_answer:
|
||||||
|
pre_messages: |-
|
||||||
|
An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:
|
||||||
|
post_messages: |-
|
||||||
|
Based on the above, please provide an answer to the following user request:
|
||||||
|
{{task}}
|
||||||
|
|
|
@ -261,4 +261,10 @@ managed_agent:
|
||||||
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
|
And even if your task resolution is not successful, please return as much context as possible, so that your manager can act upon this feedback.
|
||||||
report: |-
|
report: |-
|
||||||
Here is the final answer from your managed agent '{{name}}':
|
Here is the final answer from your managed agent '{{name}}':
|
||||||
{{final_answer}}
|
{{final_answer}}
|
||||||
|
final_answer:
|
||||||
|
pre_messages: |-
|
||||||
|
An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:
|
||||||
|
post_messages: |-
|
||||||
|
Based on the above, please provide an answer to the following user request:
|
||||||
|
{{task}}
|
||||||
|
|
|
@ -29,6 +29,7 @@ from smolagents.agents import (
|
||||||
MultiStepAgent,
|
MultiStepAgent,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallingAgent,
|
ToolCallingAgent,
|
||||||
|
populate_template,
|
||||||
)
|
)
|
||||||
from smolagents.default_tools import PythonInterpreterTool
|
from smolagents.default_tools import PythonInterpreterTool
|
||||||
from smolagents.memory import PlanningStep
|
from smolagents.memory import PlanningStep
|
||||||
|
@ -771,6 +772,74 @@ class TestMultiStepAgent:
|
||||||
assert "type" in content
|
assert "type" in content
|
||||||
assert "text" in content
|
assert "text" in content
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"images, expected_messages_list",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": MessageRole.SYSTEM,
|
||||||
|
"content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}],
|
||||||
|
},
|
||||||
|
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
["image1.png"],
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": MessageRole.SYSTEM,
|
||||||
|
"content": [{"type": "text", "text": "FINAL_ANSWER_SYSTEM_PROMPT"}, {"type": "image"}],
|
||||||
|
},
|
||||||
|
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FINAL_ANSWER_USER_PROMPT"}]},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_provide_final_answer(self, images, expected_messages_list):
|
||||||
|
fake_model = MagicMock()
|
||||||
|
fake_model.return_value.content = "Final answer."
|
||||||
|
agent = CodeAgent(
|
||||||
|
tools=[],
|
||||||
|
model=fake_model,
|
||||||
|
)
|
||||||
|
task = "Test task"
|
||||||
|
final_answer = agent.provide_final_answer(task, images=images)
|
||||||
|
expected_message_texts = {
|
||||||
|
"FINAL_ANSWER_SYSTEM_PROMPT": agent.prompt_templates["final_answer"]["pre_messages"],
|
||||||
|
"FINAL_ANSWER_USER_PROMPT": populate_template(
|
||||||
|
agent.prompt_templates["final_answer"]["post_messages"], variables=dict(task=task)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for expected_messages in expected_messages_list:
|
||||||
|
for expected_message in expected_messages:
|
||||||
|
for expected_content in expected_message["content"]:
|
||||||
|
if "text" in expected_content:
|
||||||
|
expected_content["text"] = expected_message_texts[expected_content["text"]]
|
||||||
|
assert final_answer == "Final answer."
|
||||||
|
# Test calls to model
|
||||||
|
assert len(fake_model.call_args_list) == 1
|
||||||
|
for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list):
|
||||||
|
assert len(call_args.args) == 1
|
||||||
|
messages = call_args.args[0]
|
||||||
|
assert isinstance(messages, list)
|
||||||
|
assert len(messages) == len(expected_messages)
|
||||||
|
for message, expected_message in zip(messages, expected_messages):
|
||||||
|
assert isinstance(message, dict)
|
||||||
|
assert "role" in message
|
||||||
|
assert "content" in message
|
||||||
|
assert message["role"] in MessageRole.__members__.values()
|
||||||
|
assert message["role"] == expected_message["role"]
|
||||||
|
assert isinstance(message["content"], list)
|
||||||
|
assert len(message["content"]) == len(expected_message["content"])
|
||||||
|
for content, expected_content in zip(message["content"], expected_message["content"]):
|
||||||
|
assert content == expected_content
|
||||||
|
|
||||||
|
|
||||||
class TestCodeAgent:
|
class TestCodeAgent:
|
||||||
@pytest.mark.parametrize("provide_run_summary", [False, True])
|
@pytest.mark.parametrize("provide_run_summary", [False, True])
|
||||||
|
|
Loading…
Reference in New Issue