Test get_clean_message_list (#448)
* Test get_clean_message_list * Test get_clean_message_list
This commit is contained in:
parent
b228ffa328
commit
cedf63cde7
|
@ -209,16 +209,18 @@ def get_clean_message_list(
|
|||
message["role"] = role_conversions[role]
|
||||
# encode images if needed
|
||||
if isinstance(message["content"], list):
|
||||
for i, element in enumerate(message["content"]):
|
||||
for element in message["content"]:
|
||||
if element["type"] == "image":
|
||||
assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}"
|
||||
if convert_images_to_image_urls:
|
||||
message["content"][i] = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": make_image_url(encode_image_base64(element["image"]))},
|
||||
}
|
||||
element.update(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))},
|
||||
}
|
||||
)
|
||||
else:
|
||||
message["content"][i]["image"] = encode_image_base64(element["image"])
|
||||
element["image"] = encode_image_base64(element["image"])
|
||||
|
||||
if len(output_message_list) > 0 and message["role"] == output_message_list[-1]["role"]:
|
||||
assert isinstance(message["content"], list), "Error: wrong content:" + str(message["content"])
|
||||
|
|
|
@ -17,12 +17,13 @@ import os
|
|||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
||||
from smolagents.models import parse_json_if_needed
|
||||
from smolagents.models import get_clean_message_list, parse_json_if_needed
|
||||
|
||||
|
||||
class ModelTests(unittest.TestCase):
|
||||
|
@ -100,3 +101,81 @@ class ModelTests(unittest.TestCase):
|
|||
args = 3
|
||||
parsed_args = parse_json_if_needed(args)
|
||||
assert parsed_args == 3
|
||||
|
||||
|
||||
def test_get_clean_message_list_basic():
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]},
|
||||
]
|
||||
result = get_clean_message_list(messages)
|
||||
assert len(result) == 2
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[0]["content"][0]["text"] == "Hello!"
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[1]["content"][0]["text"] == "Hi there!"
|
||||
|
||||
|
||||
def test_get_clean_message_list_role_conversions():
|
||||
messages = [
|
||||
{"role": "tool-call", "content": [{"type": "text", "text": "Calling tool..."}]},
|
||||
{"role": "tool-response", "content": [{"type": "text", "text": "Tool response"}]},
|
||||
]
|
||||
result = get_clean_message_list(messages, role_conversions={"tool-call": "assistant", "tool-response": "user"})
|
||||
assert len(result) == 2
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert result[0]["content"][0]["text"] == "Calling tool..."
|
||||
assert result[1]["role"] == "user"
|
||||
assert result[1]["content"][0]["text"] == "Tool response"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"convert_images_to_image_urls, expected_clean_message",
|
||||
[
|
||||
(
|
||||
False,
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": "encoded_image"},
|
||||
{"type": "image", "image": "second_encoded_image"},
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
True,
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,encoded_image"}},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,second_encoded_image"}},
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_clean_message_list_image_encoding(convert_images_to_image_urls, expected_clean_message):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image", "image": b"image_data"}, {"type": "image", "image": b"second_image_data"}],
|
||||
}
|
||||
]
|
||||
with patch("smolagents.models.encode_image_base64") as mock_encode:
|
||||
mock_encode.side_effect = ["encoded_image", "second_encoded_image"]
|
||||
result = get_clean_message_list(messages, convert_images_to_image_urls=convert_images_to_image_urls)
|
||||
mock_encode.assert_any_call(b"image_data")
|
||||
mock_encode.assert_any_call(b"second_image_data")
|
||||
assert len(result) == 1
|
||||
assert result[0] == expected_clean_message
|
||||
|
||||
|
||||
def test_get_clean_message_list_flatten_messages_as_text():
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
||||
{"role": "user", "content": [{"type": "text", "text": "How are you?"}]},
|
||||
]
|
||||
result = get_clean_message_list(messages, flatten_messages_as_text=True)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[0]["content"] == "Hello!How are you?"
|
||||
|
|
Loading…
Reference in New Issue