Test HfApiModel call with custom_role_conversions (#517)
This commit is contained in:
parent
ec8e830e7b
commit
c4bd41d39c
|
@ -17,13 +17,13 @@ import os
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers.testing_utils import get_tests_dir
|
from transformers.testing_utils import get_tests_dir
|
||||||
|
|
||||||
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
||||||
from smolagents.models import get_clean_message_list, parse_json_if_needed
|
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
|
||||||
|
|
||||||
|
|
||||||
class ModelTests(unittest.TestCase):
|
class ModelTests(unittest.TestCase):
|
||||||
|
@ -103,6 +103,19 @@ class ModelTests(unittest.TestCase):
|
||||||
assert parsed_args == 3
|
assert parsed_args == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestHfApiModel:
|
||||||
|
def test_call_with_custom_role_conversions(self):
|
||||||
|
custom_role_conversions = {MessageRole.USER: MessageRole.SYSTEM}
|
||||||
|
model = HfApiModel(model_id="test-model", custom_role_conversions=custom_role_conversions)
|
||||||
|
model.client = MagicMock()
|
||||||
|
messages = [{"role": "user", "content": "Test message"}]
|
||||||
|
_ = model(messages)
|
||||||
|
# Verify that the role conversion was applied
|
||||||
|
assert model.client.chat_completion.call_args.kwargs["messages"][0]["role"] == "system", (
|
||||||
|
"role conversion should be applied"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_clean_message_list_basic():
|
def test_get_clean_message_list_basic():
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
||||||
|
|
Loading…
Reference in New Issue