Fix MultiStepAgent.planning_step message content (#437)

* Fix MultiStepAgent.planning_step message content
This commit is contained in:
Albert Villanova del Moral 2025-01-30 21:18:30 +01:00 committed by GitHub
parent 6d0e4e49fc
commit 42d97716fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 10 deletions

View File

@ -494,16 +494,21 @@ You have been provided with these additional arguments, that you can access usin
if is_first_step:
message_prompt_facts = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_FACTS,
"content": [{"type": "text", "text": SYSTEM_PROMPT_FACTS}],
}
message_prompt_task = {
"role": MessageRole.USER,
"content": f"""Here is the task:
"content": [
{
"type": "text",
"text": f"""Here is the task:
```
{task}
```
Now begin!""",
}
],
}
input_messages = [message_prompt_facts, message_prompt_task]
chat_message_facts: ChatMessage = self.model(input_messages)
@ -511,17 +516,22 @@ Now begin!""",
message_system_prompt_plan = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_PLAN,
"content": [{"type": "text", "text": SYSTEM_PROMPT_PLAN}],
}
message_user_prompt_plan = {
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format(
"content": [
{
"type": "text",
"text": USER_PROMPT_PLAN.format(
task=task,
tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template),
managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)),
answer_facts=answer_facts,
),
}
],
}
chat_message_plan: ChatMessage = self.model(
[message_system_prompt_plan, message_user_prompt_plan],
stop_sequences=["<end_plan>"],

View File

@ -17,6 +17,7 @@ import tempfile
import unittest
import uuid
from pathlib import Path
from unittest.mock import MagicMock
from transformers.testing_utils import get_tests_dir
@ -25,11 +26,19 @@ from smolagents.agents import (
AgentMaxStepsError,
CodeAgent,
ManagedAgent,
MultiStepAgent,
ToolCall,
ToolCallingAgent,
)
from smolagents.default_tools import PythonInterpreterTool
from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel
from smolagents.memory import PlanningStep
from smolagents.models import (
ChatMessage,
ChatMessageToolCall,
ChatMessageToolCallDefinition,
MessageRole,
TransformersModel,
)
from smolagents.tools import tool
from smolagents.utils import BASE_BUILTIN_MODULES
@ -644,3 +653,49 @@ nested_answer()
assert step_memory_dict["model_output_message"].tool_calls[0].function.name == "get_weather"
assert step_memory_dict["model_output_message"].raw["completion_kwargs"]["max_new_tokens"] == 100
assert "model_input_messages" in agent.memory.get_full_steps()[1]
class TestMultiStepAgent:
def test_planning_step_first_step(self):
fake_model = MagicMock()
agent = MultiStepAgent(
tools=[],
model=fake_model,
)
task = "Test task"
agent.planning_step(task, is_first_step=True, step=0)
assert len(agent.memory.steps) == 1
planning_step = agent.memory.steps[0]
assert isinstance(planning_step, PlanningStep)
messages = planning_step.model_input_messages
assert isinstance(messages, list)
assert len(messages) == 2
for message in messages:
assert isinstance(message, dict)
assert "role" in message
assert "content" in message
assert isinstance(message["role"], MessageRole)
assert isinstance(message["content"], list)
assert len(message["content"]) == 1
for content in message["content"]:
assert isinstance(content, dict)
assert "type" in content
assert "text" in content
# Test calls to model
assert len(fake_model.call_args_list) == 2
for call_args in fake_model.call_args_list:
assert len(call_args.args) == 1
messages = call_args.args[0]
assert isinstance(messages, list)
assert len(messages) == 2
for message in messages:
assert isinstance(message, dict)
assert "role" in message
assert "content" in message
assert isinstance(message["role"], MessageRole)
assert isinstance(message["content"], list)
assert len(message["content"]) == 1
for content in message["content"]:
assert isinstance(content, dict)
assert "type" in content
assert "text" in content