Move plan user prompt to YAML and test text of plan prompts (#591)
This commit is contained in:
parent
2797f2fb3b
commit
1516ce8d74
|
@ -88,7 +88,8 @@ class PlanningPromptTemplate(TypedDict):
|
||||||
Prompt templates for the planning step.
|
Prompt templates for the planning step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
initial_facts (`str`): Initial facts prompt.
|
initial_facts_pre_task (`str`): Initial facts pre-task prompt.
|
||||||
|
initial_facts_task (`str`): Initial facts task prompt.
|
||||||
initial_plan (`str`): Initial plan prompt.
|
initial_plan (`str`): Initial plan prompt.
|
||||||
update_facts_pre_messages (`str`): Update facts pre-messages prompt.
|
update_facts_pre_messages (`str`): Update facts pre-messages prompt.
|
||||||
update_facts_post_messages (`str`): Update facts post-messages prompt.
|
update_facts_post_messages (`str`): Update facts post-messages prompt.
|
||||||
|
@ -96,7 +97,8 @@ class PlanningPromptTemplate(TypedDict):
|
||||||
update_plan_post_messages (`str`): Update plan post-messages prompt.
|
update_plan_post_messages (`str`): Update plan post-messages prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
initial_facts: str
|
initial_facts_pre_task: str
|
||||||
|
initial_facts_task: str
|
||||||
initial_plan: str
|
initial_plan: str
|
||||||
update_facts_pre_messages: str
|
update_facts_pre_messages: str
|
||||||
update_facts_post_messages: str
|
update_facts_post_messages: str
|
||||||
|
@ -524,26 +526,19 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
step (`int`): The number of the current step, used as an indication for the LLM.
|
step (`int`): The number of the current step, used as an indication for the LLM.
|
||||||
"""
|
"""
|
||||||
if is_first_step:
|
if is_first_step:
|
||||||
message_prompt_facts = {
|
input_messages = [
|
||||||
"role": MessageRole.SYSTEM,
|
{
|
||||||
"content": [{"type": "text", "text": self.prompt_templates["planning"]["initial_facts"]}],
|
"role": MessageRole.USER,
|
||||||
}
|
"content": [
|
||||||
message_prompt_task = {
|
{
|
||||||
"role": MessageRole.USER,
|
"type": "text",
|
||||||
"content": [
|
"text": populate_template(
|
||||||
{
|
self.prompt_templates["planning"]["initial_facts"], variables={"task": task}
|
||||||
"type": "text",
|
),
|
||||||
"text": textwrap.dedent(
|
}
|
||||||
f"""Here is the task:
|
],
|
||||||
```
|
},
|
||||||
{task}
|
]
|
||||||
```
|
|
||||||
Now begin!"""
|
|
||||||
),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
input_messages = [message_prompt_facts, message_prompt_task]
|
|
||||||
|
|
||||||
chat_message_facts: ChatMessage = self.model(input_messages)
|
chat_message_facts: ChatMessage = self.model(input_messages)
|
||||||
answer_facts = chat_message_facts.content
|
answer_facts = chat_message_facts.content
|
||||||
|
|
|
@ -196,6 +196,12 @@ planning:
|
||||||
### 2. Facts to look up
|
### 2. Facts to look up
|
||||||
### 3. Facts to derive
|
### 3. Facts to derive
|
||||||
Do not add anything else.
|
Do not add anything else.
|
||||||
|
|
||||||
|
Here is the task:
|
||||||
|
```
|
||||||
|
{{task}}
|
||||||
|
```
|
||||||
|
Now begin!
|
||||||
initial_plan : |-
|
initial_plan : |-
|
||||||
You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||||
|
|
||||||
|
|
|
@ -139,6 +139,12 @@ planning:
|
||||||
### 2. Facts to look up
|
### 2. Facts to look up
|
||||||
### 3. Facts to derive
|
### 3. Facts to derive
|
||||||
Do not add anything else.
|
Do not add anything else.
|
||||||
|
|
||||||
|
Here is the task:
|
||||||
|
```
|
||||||
|
{{task}}
|
||||||
|
```
|
||||||
|
Now begin!
|
||||||
initial_plan : |-
|
initial_plan : |-
|
||||||
You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||||
|
|
||||||
|
|
|
@ -697,11 +697,8 @@ class TestMultiStepAgent:
|
||||||
(
|
(
|
||||||
1,
|
1,
|
||||||
[
|
[
|
||||||
[
|
[{"role": MessageRole.USER, "content": [{"type": "text", "text": "INITIAL_FACTS_USER_PROMPT"}]}],
|
||||||
{"role": MessageRole.SYSTEM, "content": [{"type": "text", "text": "FACTS_SYSTEM_PROMPT"}]},
|
[{"role": MessageRole.USER, "content": [{"type": "text", "text": "INITIAL_PLAN_USER_PROMPT"}]}],
|
||||||
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_USER_PROMPT"}]},
|
|
||||||
],
|
|
||||||
[{"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_USER_PROMPT"}]}],
|
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
|
@ -710,22 +707,22 @@ class TestMultiStepAgent:
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"role": MessageRole.SYSTEM,
|
"role": MessageRole.SYSTEM,
|
||||||
"content": [{"type": "text", "text": "FACTS_UPDATE_SYSTEM_PROMPT"}],
|
"content": [{"type": "text", "text": "UPDATE_FACTS_SYSTEM_PROMPT"}],
|
||||||
},
|
},
|
||||||
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_UPDATE_USER_PROMPT"}]},
|
{"role": MessageRole.USER, "content": [{"type": "text", "text": "UPDATE_FACTS_USER_PROMPT"}]},
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"role": MessageRole.SYSTEM,
|
"role": MessageRole.SYSTEM,
|
||||||
"content": [{"type": "text", "text": "PLAN_UPDATE_SYSTEM_PROMPT"}],
|
"content": [{"type": "text", "text": "UPDATE_PLAN_SYSTEM_PROMPT"}],
|
||||||
},
|
},
|
||||||
{"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_UPDATE_USER_PROMPT"}]},
|
{"role": MessageRole.USER, "content": [{"type": "text", "text": "UPDATE_PLAN_USER_PROMPT"}]},
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_planning_step_first_step(self, step, expected_messages_list):
|
def test_planning_step(self, step, expected_messages_list):
|
||||||
fake_model = MagicMock()
|
fake_model = MagicMock()
|
||||||
agent = CodeAgent(
|
agent = CodeAgent(
|
||||||
tools=[],
|
tools=[],
|
||||||
|
@ -733,6 +730,39 @@ class TestMultiStepAgent:
|
||||||
)
|
)
|
||||||
task = "Test task"
|
task = "Test task"
|
||||||
agent.planning_step(task, is_first_step=(step == 1), step=step)
|
agent.planning_step(task, is_first_step=(step == 1), step=step)
|
||||||
|
expected_message_texts = {
|
||||||
|
"INITIAL_FACTS_USER_PROMPT": populate_template(
|
||||||
|
agent.prompt_templates["planning"]["initial_facts"], variables=dict(task=task)
|
||||||
|
),
|
||||||
|
"INITIAL_PLAN_USER_PROMPT": populate_template(
|
||||||
|
agent.prompt_templates["planning"]["initial_plan"],
|
||||||
|
variables=dict(
|
||||||
|
task=task,
|
||||||
|
tools=agent.tools,
|
||||||
|
managed_agents=agent.managed_agents,
|
||||||
|
answer_facts=agent.memory.steps[0].model_output_message_facts.content,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"UPDATE_FACTS_SYSTEM_PROMPT": agent.prompt_templates["planning"]["update_facts_pre_messages"],
|
||||||
|
"UPDATE_FACTS_USER_PROMPT": agent.prompt_templates["planning"]["update_facts_post_messages"],
|
||||||
|
"UPDATE_PLAN_SYSTEM_PROMPT": populate_template(
|
||||||
|
agent.prompt_templates["planning"]["update_plan_pre_messages"], variables=dict(task=task)
|
||||||
|
),
|
||||||
|
"UPDATE_PLAN_USER_PROMPT": populate_template(
|
||||||
|
agent.prompt_templates["planning"]["update_plan_post_messages"],
|
||||||
|
variables=dict(
|
||||||
|
task=task,
|
||||||
|
tools=agent.tools,
|
||||||
|
managed_agents=agent.managed_agents,
|
||||||
|
facts_update=agent.memory.steps[0].model_output_message_facts.content,
|
||||||
|
remaining_steps=agent.max_steps - step,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for expected_messages in expected_messages_list:
|
||||||
|
for expected_message in expected_messages:
|
||||||
|
for expected_content in expected_message["content"]:
|
||||||
|
expected_content["text"] = expected_message_texts[expected_content["text"]]
|
||||||
assert len(agent.memory.steps) == 1
|
assert len(agent.memory.steps) == 1
|
||||||
planning_step = agent.memory.steps[0]
|
planning_step = agent.memory.steps[0]
|
||||||
assert isinstance(planning_step, PlanningStep)
|
assert isinstance(planning_step, PlanningStep)
|
||||||
|
@ -744,14 +774,12 @@ class TestMultiStepAgent:
|
||||||
assert isinstance(message, dict)
|
assert isinstance(message, dict)
|
||||||
assert "role" in message
|
assert "role" in message
|
||||||
assert "content" in message
|
assert "content" in message
|
||||||
assert isinstance(message["role"], MessageRole)
|
assert message["role"] in MessageRole.__members__.values()
|
||||||
assert message["role"] == expected_message["role"]
|
assert message["role"] == expected_message["role"]
|
||||||
assert isinstance(message["content"], list)
|
assert isinstance(message["content"], list)
|
||||||
assert len(message["content"]) == 1
|
assert len(message["content"]) == 1
|
||||||
for content in message["content"]:
|
for content, expected_content in zip(message["content"], expected_message["content"]):
|
||||||
assert isinstance(content, dict)
|
assert content == expected_content
|
||||||
assert "type" in content
|
|
||||||
assert "text" in content
|
|
||||||
# Test calls to model
|
# Test calls to model
|
||||||
assert len(fake_model.call_args_list) == 2
|
assert len(fake_model.call_args_list) == 2
|
||||||
for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list):
|
for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list):
|
||||||
|
@ -763,14 +791,12 @@ class TestMultiStepAgent:
|
||||||
assert isinstance(message, dict)
|
assert isinstance(message, dict)
|
||||||
assert "role" in message
|
assert "role" in message
|
||||||
assert "content" in message
|
assert "content" in message
|
||||||
assert isinstance(message["role"], MessageRole)
|
assert message["role"] in MessageRole.__members__.values()
|
||||||
assert message["role"] == expected_message["role"]
|
assert message["role"] == expected_message["role"]
|
||||||
assert isinstance(message["content"], list)
|
assert isinstance(message["content"], list)
|
||||||
assert len(message["content"]) == 1
|
assert len(message["content"]) == 1
|
||||||
for content in message["content"]:
|
for content, expected_content in zip(message["content"], expected_message["content"]):
|
||||||
assert isinstance(content, dict)
|
assert content == expected_content
|
||||||
assert "type" in content
|
|
||||||
assert "text" in content
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"images, expected_messages_list",
|
"images, expected_messages_list",
|
||||||
|
|
Loading…
Reference in New Issue