Test get_clean_message_list (#448)

* Test get_clean_message_list

* Test get_clean_message_list
This commit is contained in:
Albert Villanova del Moral 2025-01-31 13:47:02 +01:00 committed by GitHub
parent b228ffa328
commit cedf63cde7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 7 deletions

View File

@ -209,16 +209,18 @@ def get_clean_message_list(
message["role"] = role_conversions[role]
# encode images if needed
if isinstance(message["content"], list):
for i, element in enumerate(message["content"]):
for element in message["content"]:
if element["type"] == "image":
assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}"
if convert_images_to_image_urls:
message["content"][i] = {
"type": "image_url",
"image_url": {"url": make_image_url(encode_image_base64(element["image"]))},
}
element.update(
{
"type": "image_url",
"image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))},
}
)
else:
message["content"][i]["image"] = encode_image_base64(element["image"])
element["image"] = encode_image_base64(element["image"])
if len(output_message_list) > 0 and message["role"] == output_message_list[-1]["role"]:
assert isinstance(message["content"], list), "Error: wrong content:" + str(message["content"])

View File

@ -17,12 +17,13 @@ import os
import unittest
from pathlib import Path
from typing import Optional
from unittest.mock import patch
import pytest
from transformers.testing_utils import get_tests_dir
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
from smolagents.models import parse_json_if_needed
from smolagents.models import get_clean_message_list, parse_json_if_needed
class ModelTests(unittest.TestCase):
@ -100,3 +101,81 @@ class ModelTests(unittest.TestCase):
args = 3
parsed_args = parse_json_if_needed(args)
assert parsed_args == 3
def test_get_clean_message_list_basic():
messages = [
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
{"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]},
]
result = get_clean_message_list(messages)
assert len(result) == 2
assert result[0]["role"] == "user"
assert result[0]["content"][0]["text"] == "Hello!"
assert result[1]["role"] == "assistant"
assert result[1]["content"][0]["text"] == "Hi there!"
def test_get_clean_message_list_role_conversions():
messages = [
{"role": "tool-call", "content": [{"type": "text", "text": "Calling tool..."}]},
{"role": "tool-response", "content": [{"type": "text", "text": "Tool response"}]},
]
result = get_clean_message_list(messages, role_conversions={"tool-call": "assistant", "tool-response": "user"})
assert len(result) == 2
assert result[0]["role"] == "assistant"
assert result[0]["content"][0]["text"] == "Calling tool..."
assert result[1]["role"] == "user"
assert result[1]["content"][0]["text"] == "Tool response"
@pytest.mark.parametrize(
"convert_images_to_image_urls, expected_clean_message",
[
(
False,
{
"role": "user",
"content": [
{"type": "image", "image": "encoded_image"},
{"type": "image", "image": "second_encoded_image"},
],
},
),
(
True,
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "_image"}},
{"type": "image_url", "image_url": {"url": "_encoded_image"}},
],
},
),
],
)
def test_get_clean_message_list_image_encoding(convert_images_to_image_urls, expected_clean_message):
messages = [
{
"role": "user",
"content": [{"type": "image", "image": b"image_data"}, {"type": "image", "image": b"second_image_data"}],
}
]
with patch("smolagents.models.encode_image_base64") as mock_encode:
mock_encode.side_effect = ["encoded_image", "second_encoded_image"]
result = get_clean_message_list(messages, convert_images_to_image_urls=convert_images_to_image_urls)
mock_encode.assert_any_call(b"image_data")
mock_encode.assert_any_call(b"second_image_data")
assert len(result) == 1
assert result[0] == expected_clean_message
def test_get_clean_message_list_flatten_messages_as_text():
messages = [
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
{"role": "user", "content": [{"type": "text", "text": "How are you?"}]},
]
result = get_clean_message_list(messages, flatten_messages_as_text=True)
assert len(result) == 1
assert result[0]["role"] == "user"
assert result[0]["content"] == "Hello!How are you?"