diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 753c7c1..2dab05a 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -209,16 +209,18 @@ def get_clean_message_list( message["role"] = role_conversions[role] # encode images if needed if isinstance(message["content"], list): - for i, element in enumerate(message["content"]): + for element in message["content"]: if element["type"] == "image": assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}" if convert_images_to_image_urls: - message["content"][i] = { - "type": "image_url", - "image_url": {"url": make_image_url(encode_image_base64(element["image"]))}, - } + element.update( + { + "type": "image_url", + "image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))}, + } + ) else: - message["content"][i]["image"] = encode_image_base64(element["image"]) + element["image"] = encode_image_base64(element["image"]) if len(output_message_list) > 0 and message["role"] == output_message_list[-1]["role"]: assert isinstance(message["content"], list), "Error: wrong content:" + str(message["content"]) diff --git a/tests/test_models.py b/tests/test_models.py index aa9024c..4369844 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,12 +17,13 @@ import os import unittest from pathlib import Path from typing import Optional +from unittest.mock import patch import pytest from transformers.testing_utils import get_tests_dir from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool -from smolagents.models import parse_json_if_needed +from smolagents.models import get_clean_message_list, parse_json_if_needed class ModelTests(unittest.TestCase): @@ -100,3 +101,81 @@ class ModelTests(unittest.TestCase): args = 3 parsed_args = parse_json_if_needed(args) assert parsed_args == 3 + + +def test_get_clean_message_list_basic(): + messages = [ + {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]}, + ] + result = get_clean_message_list(messages) + assert len(result) == 2 + assert result[0]["role"] == "user" + assert result[0]["content"][0]["text"] == "Hello!" + assert result[1]["role"] == "assistant" + assert result[1]["content"][0]["text"] == "Hi there!" + + +def test_get_clean_message_list_role_conversions(): + messages = [ + {"role": "tool-call", "content": [{"type": "text", "text": "Calling tool..."}]}, + {"role": "tool-response", "content": [{"type": "text", "text": "Tool response"}]}, + ] + result = get_clean_message_list(messages, role_conversions={"tool-call": "assistant", "tool-response": "user"}) + assert len(result) == 2 + assert result[0]["role"] == "assistant" + assert result[0]["content"][0]["text"] == "Calling tool..." + assert result[1]["role"] == "user" + assert result[1]["content"][0]["text"] == "Tool response" + + +@pytest.mark.parametrize( + "convert_images_to_image_urls, expected_clean_message", + [ + ( + False, + { + "role": "user", + "content": [ + {"type": "image", "image": "encoded_image"}, + {"type": "image", "image": "second_encoded_image"}, + ], + }, + ), + ( + True, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "_image"}}, + {"type": "image_url", "image_url": {"url": "_encoded_image"}}, + ], + }, + ), + ], +) +def test_get_clean_message_list_image_encoding(convert_images_to_image_urls, expected_clean_message): + messages = [ + { + "role": "user", + "content": [{"type": "image", "image": b"image_data"}, {"type": "image", "image": b"second_image_data"}], + } + ] + with patch("smolagents.models.encode_image_base64") as mock_encode: + mock_encode.side_effect = ["encoded_image", "second_encoded_image"] + result = get_clean_message_list(messages, convert_images_to_image_urls=convert_images_to_image_urls) + mock_encode.assert_any_call(b"image_data") + mock_encode.assert_any_call(b"second_image_data") + assert len(result) == 1 + assert result[0] == expected_clean_message + + +def test_get_clean_message_list_flatten_messages_as_text(): + messages = [ + {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, + {"role": "user", "content": [{"type": "text", "text": "How are you?"}]}, + ] + result = get_clean_message_list(messages, flatten_messages_as_text=True) + assert len(result) == 1 + assert result[0]["role"] == "user" + assert result[0]["content"] == "Hello!How are you?"