Test HfApiModel call with custom_role_conversions (#517)

This commit is contained in:
Albert Villanova del Moral 2025-02-07 10:56:42 +01:00 committed by GitHub
parent ec8e830e7b
commit c4bd41d39c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 15 additions and 2 deletions

View File

@ -17,13 +17,13 @@ import os
import unittest
from pathlib import Path
from typing import Optional
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from transformers.testing_utils import get_tests_dir
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
from smolagents.models import get_clean_message_list, parse_json_if_needed
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
class ModelTests(unittest.TestCase):
@ -103,6 +103,19 @@ class ModelTests(unittest.TestCase):
assert parsed_args == 3
class TestHfApiModel:
def test_call_with_custom_role_conversions(self):
custom_role_conversions = {MessageRole.USER: MessageRole.SYSTEM}
model = HfApiModel(model_id="test-model", custom_role_conversions=custom_role_conversions)
model.client = MagicMock()
messages = [{"role": "user", "content": "Test message"}]
_ = model(messages)
# Verify that the role conversion was applied
assert model.client.chat_completion.call_args.kwargs["messages"][0]["role"] == "system", (
"role conversion should be applied"
)
def test_get_clean_message_list_basic():
messages = [
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},