LiteLLMModel - detect message flatenning based on model information (#553)

This commit is contained in:
Alex 2025-02-13 11:25:02 +01:00 committed by GitHub
parent 41a388dac6
commit 392fc5ade5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 1 deletions

View File

@ -22,6 +22,7 @@ import uuid
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
@ -799,6 +800,16 @@ class LiteLLMModel(Model):
self.api_key = api_key self.api_key = api_key
self.custom_role_conversions = custom_role_conversions self.custom_role_conversions = custom_role_conversions
@cached_property
def _flatten_messages_as_text(self):
import litellm
model_info: dict = litellm.get_model_info(self.model_id)
if model_info["litellm_provider"] == "ollama":
return model_info["key"] != "llava"
return False
def __call__( def __call__(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
@ -818,7 +829,7 @@ class LiteLLMModel(Model):
api_base=self.api_base, api_base=self.api_base,
api_key=self.api_key, api_key=self.api_key,
convert_images_to_image_urls=True, convert_images_to_image_urls=True,
flatten_messages_as_text=self.model_id.startswith("ollama"), flatten_messages_as_text=self._flatten_messages_as_text,
custom_role_conversions=self.custom_role_conversions, custom_role_conversions=self.custom_role_conversions,
**kwargs, **kwargs,
) )