Remove dependency on _is_package_available from transformers (#247)

This commit is contained in:
Albert Villanova del Moral 2025-01-17 18:38:33 +01:00 committed by GitHub
parent 6db75183ff
commit 58b18f5655
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 24 deletions

View File

@ -32,7 +32,6 @@ from transformers import (
StoppingCriteriaList, StoppingCriteriaList,
is_torch_available, is_torch_available,
) )
from transformers.utils.import_utils import _is_package_available
from .tools import Tool from .tools import Tool
@ -48,9 +47,6 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>", "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
} }
if _is_package_available("litellm"):
import litellm
def get_dict_from_nested_dataclasses(obj): def get_dict_from_nested_dataclasses(obj):
def convert(obj): def convert(obj):
@ -508,9 +504,11 @@ class LiteLLMModel(Model):
api_key=None, api_key=None,
**kwargs, **kwargs,
): ):
if not _is_package_available("litellm"): try:
raise ImportError( import litellm
"litellm not found. Install it with `pip install litellm`" except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
) )
super().__init__() super().__init__()
self.model_id = model_id self.model_id = model_id
@ -530,6 +528,8 @@ class LiteLLMModel(Model):
messages = get_clean_message_list( messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions messages, role_conversions=tool_role_conversions
) )
import litellm
if tools_to_call_from: if tools_to_call_from:
response = litellm.completion( response = litellm.completion(
model=self.model_id, model=self.model_id,

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib.util
import logging import logging
import os import os
import pathlib import pathlib
@ -25,7 +26,6 @@ from transformers.utils import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.utils.import_utils import _is_package_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,9 +41,6 @@ if is_torch_available():
else: else:
Tensor = object Tensor = object
if _is_package_available("soundfile"):
import soundfile as sf
class AgentType: class AgentType:
""" """
@ -187,11 +184,12 @@ class AgentAudio(AgentType, str):
""" """
def __init__(self, value, samplerate=16_000): def __init__(self, value, samplerate=16_000):
if importlib.util.find_spec("soundfile") is None:
raise ModuleNotFoundError(
"Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
)
super().__init__(value) super().__init__(value)
if not _is_package_available("soundfile"):
raise ImportError("soundfile must be installed in order to handle audio.")
self._path = None self._path = None
self._tensor = None self._tensor = None
@ -221,6 +219,8 @@ class AgentAudio(AgentType, str):
""" """
Returns the "raw" version of that object. It is a `torch.Tensor` object. Returns the "raw" version of that object. It is a `torch.Tensor` object.
""" """
import soundfile as sf
if self._tensor is not None: if self._tensor is not None:
return self._tensor return self._tensor
@ -239,6 +239,8 @@ class AgentAudio(AgentType, str):
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
version of the audio. version of the audio.
""" """
import soundfile as sf
if self._path is not None: if self._path is not None:
return self._path return self._path

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import ast import ast
import importlib.util
import inspect import inspect
import json import json
import re import re
@ -22,13 +23,10 @@ import types
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
from rich.console import Console from rich.console import Console
from transformers.utils.import_utils import _is_package_available
_pygments_available = _is_package_available("pygments")
def is_pygments_available(): def is_pygments_available():
return _pygments_available return importlib.util.find_spec("soundfile") is not None
console = Console(width=200) console = Console(width=200)

View File

@ -24,15 +24,9 @@ from transformers.testing_utils import (
require_torch, require_torch,
require_vision, require_vision,
) )
from transformers.utils.import_utils import (
_is_package_available,
)
from smolagents.types import AgentAudio, AgentImage, AgentText from smolagents.types import AgentAudio, AgentImage, AgentText
if _is_package_available("soundfile"):
import soundfile as sf
def get_new_path(suffix="") -> str: def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp() directory = tempfile.mkdtemp()
@ -43,6 +37,7 @@ def get_new_path(suffix="") -> str:
@require_torch @require_torch
class AgentAudioTests(unittest.TestCase): class AgentAudioTests(unittest.TestCase):
def test_from_tensor(self): def test_from_tensor(self):
import soundfile as sf
import torch import torch
tensor = torch.rand(12, dtype=torch.float64) - 0.5 tensor = torch.rand(12, dtype=torch.float64) - 0.5
@ -62,6 +57,7 @@ class AgentAudioTests(unittest.TestCase):
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4)) self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
def test_from_string(self): def test_from_string(self):
import soundfile as sf
import torch import torch
tensor = torch.rand(12, dtype=torch.float64) - 0.5 tensor = torch.rand(12, dtype=torch.float64) - 0.5