Test HfApiModel call with custom_role_conversions (#517)
This commit is contained in:
parent
ec8e830e7b
commit
c4bd41d39c
|
@ -17,13 +17,13 @@ import os
|
|||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
|
||||
from smolagents.models import get_clean_message_list, parse_json_if_needed
|
||||
from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed
|
||||
|
||||
|
||||
class ModelTests(unittest.TestCase):
|
||||
|
@ -103,6 +103,19 @@ class ModelTests(unittest.TestCase):
|
|||
assert parsed_args == 3
|
||||
|
||||
|
||||
class TestHfApiModel:
|
||||
def test_call_with_custom_role_conversions(self):
|
||||
custom_role_conversions = {MessageRole.USER: MessageRole.SYSTEM}
|
||||
model = HfApiModel(model_id="test-model", custom_role_conversions=custom_role_conversions)
|
||||
model.client = MagicMock()
|
||||
messages = [{"role": "user", "content": "Test message"}]
|
||||
_ = model(messages)
|
||||
# Verify that the role conversion was applied
|
||||
assert model.client.chat_completion.call_args.kwargs["messages"][0]["role"] == "system", (
|
||||
"role conversion should be applied"
|
||||
)
|
||||
|
||||
|
||||
def test_get_clean_message_list_basic():
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
||||
|
|
Loading…
Reference in New Issue