From db2df942414d9cf7db67aff873057816fc9c72e7 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Sat, 15 Feb 2025 11:58:39 +0100 Subject: [PATCH] Fix litellm flatten_messages_as_text detection (#659) --- src/smolagents/models.py | 16 ++++++---------- tests/test_models.py | 4 ++-- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index cb825b4..3159cb4 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -842,17 +842,8 @@ class LiteLLMModel(Model): custom_role_conversions: Optional[Dict[str, str]] = None, **kwargs, ): - try: - import litellm - except ModuleNotFoundError: - raise ModuleNotFoundError( - "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" - ) - super().__init__(**kwargs) self.model_id = model_id - # IMPORTANT - Set this to TRUE to add the function to the prompt for Non OpenAI LLMs - litellm.add_function_to_prompt = True self.api_base = api_base self.api_key = api_key self.custom_role_conversions = custom_role_conversions @@ -870,7 +861,12 @@ class LiteLLMModel(Model): tools_to_call_from: Optional[List[Tool]] = None, **kwargs, ) -> ChatMessage: - import litellm + try: + import litellm + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" + ) completion_kwargs = self._prepare_completion_kwargs( messages=messages, diff --git a/tests/test_models.py b/tests/test_models.py index a77c956..10f36ae 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -148,8 +148,8 @@ class TestLiteLLMModel: "model_id, error_flag", [ ("groq/llama-3.3-70b", "Missing API Key"), - ("cerebras/llama-3.3-70b", "Wrong API Key"), - ("ollama/llama2", "not found"), + ("cerebras/llama-3.3-70b", "The api_key client option must be set"), + ("mistral/mistral-tiny", "The api_key client option must be set"), ], ) def test_call_different_providers_without_key(self, model_id, error_flag):