Fix litellm flatten_messages_as_text detection (#659)

This commit is contained in:
Aymeric Roucher 2025-02-15 11:58:39 +01:00 committed by GitHub
parent 161ef452c3
commit db2df94241
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 12 deletions

View File

@ -842,17 +842,8 @@ class LiteLLMModel(Model):
custom_role_conversions: Optional[Dict[str, str]] = None, custom_role_conversions: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
): ):
try:
import litellm
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
)
super().__init__(**kwargs) super().__init__(**kwargs)
self.model_id = model_id self.model_id = model_id
# IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs
litellm.add_function_to_prompt = True
self.api_base = api_base self.api_base = api_base
self.api_key = api_key self.api_key = api_key
self.custom_role_conversions = custom_role_conversions self.custom_role_conversions = custom_role_conversions
@ -870,7 +861,12 @@ class LiteLLMModel(Model):
tools_to_call_from: Optional[List[Tool]] = None, tools_to_call_from: Optional[List[Tool]] = None,
**kwargs, **kwargs,
) -> ChatMessage: ) -> ChatMessage:
import litellm try:
import litellm
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
)
completion_kwargs = self._prepare_completion_kwargs( completion_kwargs = self._prepare_completion_kwargs(
messages=messages, messages=messages,

View File

@ -148,8 +148,8 @@ class TestLiteLLMModel:
"model_id, error_flag", "model_id, error_flag",
[ [
("groq/llama-3.3-70b", "Missing API Key"), ("groq/llama-3.3-70b", "Missing API Key"),
("cerebras/llama-3.3-70b", "Wrong API Key"), ("cerebras/llama-3.3-70b", "The api_key client option must be set"),
("ollama/llama2", "not found"), ("mistral/mistral-tiny", "The api_key client option must be set"),
], ],
) )
def test_call_different_providers_without_key(self, model_id, error_flag): def test_call_different_providers_without_key(self, model_id, error_flag):