LiteLLMModel - detect message flatenning based on model information (#553)
This commit is contained in:
parent
41a388dac6
commit
392fc5ade5
|
@ -22,6 +22,7 @@ import uuid
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import InferenceClient
|
from huggingface_hub import InferenceClient
|
||||||
|
@ -799,6 +800,16 @@ class LiteLLMModel(Model):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.custom_role_conversions = custom_role_conversions
|
self.custom_role_conversions = custom_role_conversions
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _flatten_messages_as_text(self):
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
model_info: dict = litellm.get_model_info(self.model_id)
|
||||||
|
if model_info["litellm_provider"] == "ollama":
|
||||||
|
return model_info["key"] != "llava"
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
|
@ -818,7 +829,7 @@ class LiteLLMModel(Model):
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
convert_images_to_image_urls=True,
|
convert_images_to_image_urls=True,
|
||||||
flatten_messages_as_text=self.model_id.startswith("ollama"),
|
flatten_messages_as_text=self._flatten_messages_as_text,
|
||||||
custom_role_conversions=self.custom_role_conversions,
|
custom_role_conversions=self.custom_role_conversions,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue