From 58b18f56552160842a3783cc7f4560f2069c282b Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 17 Jan 2025 18:38:33 +0100 Subject: [PATCH] Remove dependency on _is_package_available from transformers (#247) --- src/smolagents/models.py | 14 +++++++------- src/smolagents/types.py | 16 +++++++++------- src/smolagents/utils.py | 6 ++---- tests/test_types.py | 8 ++------ 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 0910c98..ca234f2 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -32,7 +32,6 @@ from transformers import ( StoppingCriteriaList, is_torch_available, ) -from transformers.utils.import_utils import _is_package_available from .tools import Tool @@ -48,9 +47,6 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = { "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```", } -if _is_package_available("litellm"): - import litellm - def get_dict_from_nested_dataclasses(obj): def convert(obj): @@ -508,9 +504,11 @@ class LiteLLMModel(Model): api_key=None, **kwargs, ): - if not _is_package_available("litellm"): - raise ImportError( - "litellm not found. Install it with `pip install litellm`" + try: + import litellm + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`" ) super().__init__() self.model_id = model_id @@ -530,6 +528,8 @@ class LiteLLMModel(Model): messages = get_clean_message_list( messages, role_conversions=tool_role_conversions ) + import litellm + if tools_to_call_from: response = litellm.completion( model=self.model_id, diff --git a/src/smolagents/types.py b/src/smolagents/types.py index d88293f..038885f 100644 --- a/src/smolagents/types.py +++ b/src/smolagents/types.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util import logging import os import pathlib @@ -25,7 +26,6 @@ from transformers.utils import ( is_torch_available, is_vision_available, ) -from transformers.utils.import_utils import _is_package_available logger = logging.getLogger(__name__) @@ -41,9 +41,6 @@ if is_torch_available(): else: Tensor = object -if _is_package_available("soundfile"): - import soundfile as sf - class AgentType: """ @@ -187,11 +184,12 @@ class AgentAudio(AgentType, str): """ def __init__(self, value, samplerate=16_000): + if importlib.util.find_spec("soundfile") is None: + raise ModuleNotFoundError( + "Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`" + ) super().__init__(value) - if not _is_package_available("soundfile"): - raise ImportError("soundfile must be installed in order to handle audio.") - self._path = None self._tensor = None @@ -221,6 +219,8 @@ class AgentAudio(AgentType, str): """ Returns the "raw" version of that object. It is a `torch.Tensor` object. """ + import soundfile as sf + if self._tensor is not None: return self._tensor @@ -239,6 +239,8 @@ class AgentAudio(AgentType, str): Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized version of the audio. """ + import soundfile as sf + if self._path is not None: return self._path diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index e3ea23c..8e9ccc8 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import ast +import importlib.util import inspect import json import re @@ -22,13 +23,10 @@ import types from typing import Dict, Tuple, Union from rich.console import Console -from transformers.utils.import_utils import _is_package_available - -_pygments_available = _is_package_available("pygments") def is_pygments_available(): - return _pygments_available + return importlib.util.find_spec("soundfile") is not None console = Console(width=200) diff --git a/tests/test_types.py b/tests/test_types.py index aa58a8f..244875c 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -24,15 +24,9 @@ from transformers.testing_utils import ( require_torch, require_vision, ) -from transformers.utils.import_utils import ( - _is_package_available, -) from smolagents.types import AgentAudio, AgentImage, AgentText -if _is_package_available("soundfile"): - import soundfile as sf - def get_new_path(suffix="") -> str: directory = tempfile.mkdtemp() @@ -43,6 +37,7 @@ def get_new_path(suffix="") -> str: @require_torch class AgentAudioTests(unittest.TestCase): def test_from_tensor(self): + import soundfile as sf import torch tensor = torch.rand(12, dtype=torch.float64) - 0.5 @@ -62,6 +57,7 @@ class AgentAudioTests(unittest.TestCase): self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4)) def test_from_string(self): + import soundfile as sf import torch tensor = torch.rand(12, dtype=torch.float64) - 0.5