diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 0a9ec6c..4f263b7 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -494,15 +494,20 @@ You have been provided with these additional arguments, that you can access usin if is_first_step: message_prompt_facts = { "role": MessageRole.SYSTEM, - "content": SYSTEM_PROMPT_FACTS, + "content": [{"type": "text", "text": SYSTEM_PROMPT_FACTS}], } message_prompt_task = { "role": MessageRole.USER, - "content": f"""Here is the task: + "content": [ + { + "type": "text", + "text": f"""Here is the task: ``` {task} ``` Now begin!""", + } + ], } input_messages = [message_prompt_facts, message_prompt_task] @@ -511,16 +516,21 @@ Now begin!""", message_system_prompt_plan = { "role": MessageRole.SYSTEM, - "content": SYSTEM_PROMPT_PLAN, + "content": [{"type": "text", "text": SYSTEM_PROMPT_PLAN}], } message_user_prompt_plan = { "role": MessageRole.USER, - "content": USER_PROMPT_PLAN.format( - task=task, - tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template), - managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)), - answer_facts=answer_facts, - ), + "content": [ + { + "type": "text", + "text": USER_PROMPT_PLAN.format( + task=task, + tool_descriptions=get_tool_descriptions(self.tools, self.tool_description_template), + managed_agents_descriptions=(show_agents_descriptions(self.managed_agents)), + answer_facts=answer_facts, + ), + } + ], } chat_message_plan: ChatMessage = self.model( [message_system_prompt_plan, message_user_prompt_plan], diff --git a/tests/test_agents.py b/tests/test_agents.py index 53f2cfd..f1d8694 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -17,6 +17,7 @@ import tempfile import unittest import uuid from pathlib import Path +from unittest.mock import MagicMock from transformers.testing_utils import get_tests_dir @@ -25,11 +26,19 @@ from smolagents.agents import ( AgentMaxStepsError, CodeAgent, ManagedAgent, + MultiStepAgent, ToolCall, ToolCallingAgent, ) from smolagents.default_tools import PythonInterpreterTool -from smolagents.models import ChatMessage, ChatMessageToolCall, ChatMessageToolCallDefinition, TransformersModel +from smolagents.memory import PlanningStep +from smolagents.models import ( + ChatMessage, + ChatMessageToolCall, + ChatMessageToolCallDefinition, + MessageRole, + TransformersModel, +) from smolagents.tools import tool from smolagents.utils import BASE_BUILTIN_MODULES @@ -644,3 +653,49 @@ nested_answer() assert step_memory_dict["model_output_message"].tool_calls[0].function.name == "get_weather" assert step_memory_dict["model_output_message"].raw["completion_kwargs"]["max_new_tokens"] == 100 assert "model_input_messages" in agent.memory.get_full_steps()[1] + + +class TestMultiStepAgent: + def test_planning_step_first_step(self): + fake_model = MagicMock() + agent = MultiStepAgent( + tools=[], + model=fake_model, + ) + task = "Test task" + agent.planning_step(task, is_first_step=True, step=0) + assert len(agent.memory.steps) == 1 + planning_step = agent.memory.steps[0] + assert isinstance(planning_step, PlanningStep) + messages = planning_step.model_input_messages + assert isinstance(messages, list) + assert len(messages) == 2 + for message in messages: + assert isinstance(message, dict) + assert "role" in message + assert "content" in message + assert isinstance(message["role"], MessageRole) + assert isinstance(message["content"], list) + assert len(message["content"]) == 1 + for content in message["content"]: + assert isinstance(content, dict) + assert "type" in content + assert "text" in content + # Test calls to model + assert len(fake_model.call_args_list) == 2 + for call_args in fake_model.call_args_list: + assert len(call_args.args) == 1 + messages = call_args.args[0] + assert isinstance(messages, list) + assert len(messages) == 2 + for message in messages: + assert isinstance(message, dict) + assert "role" in message + assert "content" in message + assert isinstance(message["role"], MessageRole) + assert isinstance(message["content"], list) + assert len(message["content"]) == 1 + for content in message["content"]: + assert isinstance(content, dict) + assert "type" in content + assert "text" in content