Fix MultiStepAgent.planning_step message content (#437)
* Fix MultiStepAgent.planning_step message content
This commit is contained in:
parent
6d0e4e49fc
commit
42d97716fe
|
@ -494,15 +494,20 @@ You have been provided with these additional arguments, that you can access usin
|
||||||
if is_first_step:
|
if is_first_step:
|
||||||
message_prompt_facts = {
|
message_prompt_facts = {
|
||||||
"role": MessageRole.SYSTEM,
|
"role": MessageRole.SYSTEM,
|
||||||
"content": SYSTEM_PROMPT_FACTS,
|
"content": [{"type": "text", "text": SYSTEM_PROMPT_FACTS}],
|
||||||
}
|
}
|
||||||
message_prompt_task = {
|
message_prompt_task = {
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": f"""Here is the task:
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"""Here is the task:
|
||||||
```
|
```
|
||||||
{task}
|
{task}
|
||||||
```
|
```
|
||||||
Now begin!""",
|
Now begin!""",
|
||||||
|
}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
input_messages = [message_prompt_facts, message_prompt_task]
|
input_messages = [message_prompt_facts, message_prompt_task]
|
||||||
|
|
||||||
|
@ -511,16 +516,21 @@ Now begin!""",
|
||||||
|
|
||||||
message_system_prompt_plan = {
|
message_system_prompt_plan = {
|
||||||
"role": MessageRole.SYSTEM,
|
"role": MessageRole.SYSTEM,
|
||||||
"content": SYSTEM_PROMPT_PLAN,
|
"content": [{"type": "text", "text": SYSTEM_PROMPT_PLAN}],
|
||||||
}
|
}
|
||||||
message_user_prompt_plan = {
|
message_user_prompt_plan = {
|
||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": USER_PROMPT_PLAN.format(
|
"content": [
|
||||||
task=task,
|
{
|
||||||
tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
|
"type": "text",
|
||||||
managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
|
"text": USER_PROMPT_PLAN.format(
|
||||||
answer_facts=answer_facts,
|
task=task,
|
||||||
),
|
tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
|
||||||
|
managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
|
||||||
|
answer_facts=answer_facts,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
chat_message_plan: ChatMessage = self.model(
|
chat_message_plan: ChatMessage = self.model(
|
||||||
[message_system_prompt_plan, message_user_prompt_plan],
|
[message_system_prompt_plan, message_user_prompt_plan],
|
||||||
|
|
|
@ -17,6 +17,7 @@ import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
|
@ -25,11 +26,19 @@ from smolagents.agents import (
|
||||||
AgentMaxStepsError,
|
AgentMaxStepsError,
|
||||||
CodeAgent,
|
CodeAgent,
|
||||||
ManagedAgent,
|
ManagedAgent,
|
||||||
|
MultiStepAgent,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallingAgent,
|
ToolCallingAgent,
|
||||||
)
|
)
|
||||||
from smolagents.default_tools import PythonInterpreterTool
|
from smolagents.default_tools import PythonInterpreterTool
|
||||||
from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel
|
from smolagents.memory import PlanningStep
|
||||||
|
from smolagents.models import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageToolCall,
|
||||||
|
ChatMessageToolCallDefinition,
|
||||||
|
MessageRole,
|
||||||
|
TransformersModel,
|
||||||
|
)
|
||||||
from smolagents.tools import tool
|
from smolagents.tools import tool
|
||||||
from smolagents.utils import BASE_BUILTIN_MODULES
|
from smolagents.utils import BASE_BUILTIN_MODULES
|
||||||
|
|
||||||
|
@ -644,3 +653,49 @@ nested_answer()
|
||||||
assert step_memory_dict["model_output_message"].tool_calls[0].function.name == "get_weather"
|
assert step_memory_dict["model_output_message"].tool_calls[0].function.name == "get_weather"
|
||||||
assert step_memory_dict["model_output_message"].raw["completion_kwargs"]["max_new_tokens"] == 100
|
assert step_memory_dict["model_output_message"].raw["completion_kwargs"]["max_new_tokens"] == 100
|
||||||
assert "model_input_messages" in agent.memory.get_full_steps()[1]
|
assert "model_input_messages" in agent.memory.get_full_steps()[1]
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiStepAgent:
|
||||||
|
def test_planning_step_first_step(self):
|
||||||
|
fake_model = MagicMock()
|
||||||
|
agent = MultiStepAgent(
|
||||||
|
tools=[],
|
||||||
|
model=fake_model,
|
||||||
|
)
|
||||||
|
task = "Test task"
|
||||||
|
agent.planning_step(task, is_first_step=True, step=0)
|
||||||
|
assert len(agent.memory.steps) == 1
|
||||||
|
planning_step = agent.memory.steps[0]
|
||||||
|
assert isinstance(planning_step, PlanningStep)
|
||||||
|
messages = planning_step.model_input_messages
|
||||||
|
assert isinstance(messages, list)
|
||||||
|
assert len(messages) == 2
|
||||||
|
for message in messages:
|
||||||
|
assert isinstance(message, dict)
|
||||||
|
assert "role" in message
|
||||||
|
assert "content" in message
|
||||||
|
assert isinstance(message["role"], MessageRole)
|
||||||
|
assert isinstance(message["content"], list)
|
||||||
|
assert len(message["content"]) == 1
|
||||||
|
for content in message["content"]:
|
||||||
|
assert isinstance(content, dict)
|
||||||
|
assert "type" in content
|
||||||
|
assert "text" in content
|
||||||
|
# Test calls to model
|
||||||
|
assert len(fake_model.call_args_list) == 2
|
||||||
|
for call_args in fake_model.call_args_list:
|
||||||
|
assert len(call_args.args) == 1
|
||||||
|
messages = call_args.args[0]
|
||||||
|
assert isinstance(messages, list)
|
||||||
|
assert len(messages) == 2
|
||||||
|
for message in messages:
|
||||||
|
assert isinstance(message, dict)
|
||||||
|
assert "role" in message
|
||||||
|
assert "content" in message
|
||||||
|
assert isinstance(message["role"], MessageRole)
|
||||||
|
assert isinstance(message["content"], list)
|
||||||
|
assert len(message["content"]) == 1
|
||||||
|
for content in message["content"]:
|
||||||
|
assert isinstance(content, dict)
|
||||||
|
assert "type" in content
|
||||||
|
assert "text" in content
|
||||||
|
|
Loading…
Reference in New Issue