LiteLLMModel - detect message flatenning based on model information (#553)

This commit is contained in:
Alex 2025-02-13 11:25:02 +01:00 committed by GitHub
parent 41a388dac6
commit 392fc5ade5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 1 deletions

View File

@ -22,6 +22,7 @@ import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from huggingface_hub import InferenceClient
@ -799,6 +800,16 @@ class LiteLLMModel(Model):
self.api_key = api_key
self.custom_role_conversions = custom_role_conversions
@cached_property
def _flatten_messages_as_text(self):
import litellm
model_info: dict = litellm.get_model_info(self.model_id)
if model_info["litellm_provider"] == "ollama":
return model_info["key"] != "llava"
return False
def __call__(
self,
messages: List[Dict[str, str]],
@ -818,7 +829,7 @@ class LiteLLMModel(Model):
api_base=self.api_base,
api_key=self.api_key,
convert_images_to_image_urls=True,
flatten_messages_as_text=self.model_id.startswith("ollama"),
flatten_messages_as_text=self._flatten_messages_as_text,
custom_role_conversions=self.custom_role_conversions,
**kwargs,
)