Remove dependency on _is_package_available from transformers (#247)

This commit is contained in:
Albert Villanova del Moral 2025-01-17 18:38:33 +01:00 committed by GitHub
parent 6db75183ff
commit 58b18f5655
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 24 deletions

View File

@ -32,7 +32,6 @@ from transformers import (
StoppingCriteriaList,
is_torch_available,
)
from transformers.utils.import_utils import _is_package_available
from .tools import Tool
@ -48,9 +47,6 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
}
if _is_package_available("litellm"):
import litellm
def get_dict_from_nested_dataclasses(obj):
def convert(obj):
@ -508,9 +504,11 @@ class LiteLLMModel(Model):
api_key=None,
**kwargs,
):
if not _is_package_available("litellm"):
raise ImportError(
"litellm not found. Install it with `pip install litellm`"
try:
import litellm
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
)
super().__init__()
self.model_id = model_id
@ -530,6 +528,8 @@ class LiteLLMModel(Model):
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
import litellm
if tools_to_call_from:
response = litellm.completion(
model=self.model_id,

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import logging
import os
import pathlib
@ -25,7 +26,6 @@ from transformers.utils import (
is_torch_available,
is_vision_available,
)
from transformers.utils.import_utils import _is_package_available
logger = logging.getLogger(__name__)
@ -41,9 +41,6 @@ if is_torch_available():
else:
Tensor = object
if _is_package_available("soundfile"):
import soundfile as sf
class AgentType:
"""
@ -187,11 +184,12 @@ class AgentAudio(AgentType, str):
"""
def __init__(self, value, samplerate=16_000):
if importlib.util.find_spec("soundfile") is None:
raise ModuleNotFoundError(
"Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
)
super().__init__(value)
if not _is_package_available("soundfile"):
raise ImportError("soundfile must be installed in order to handle audio.")
self._path = None
self._tensor = None
@ -221,6 +219,8 @@ class AgentAudio(AgentType, str):
"""
Returns the "raw" version of that object. It is a `torch.Tensor` object.
"""
import soundfile as sf
if self._tensor is not None:
return self._tensor
@ -239,6 +239,8 @@ class AgentAudio(AgentType, str):
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
version of the audio.
"""
import soundfile as sf
if self._path is not None:
return self._path

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import importlib.util
import inspect
import json
import re
@ -22,13 +23,10 @@ import types
from typing import Dict, Tuple, Union
from rich.console import Console
from transformers.utils.import_utils import _is_package_available
_pygments_available = _is_package_available("pygments")
def is_pygments_available():
return _pygments_available
return importlib.util.find_spec("soundfile") is not None
console = Console(width=200)

View File

@ -24,15 +24,9 @@ from transformers.testing_utils import (
require_torch,
require_vision,
)
from transformers.utils.import_utils import (
_is_package_available,
)
from smolagents.types import AgentAudio, AgentImage, AgentText
if _is_package_available("soundfile"):
import soundfile as sf
def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp()
@ -43,6 +37,7 @@ def get_new_path(suffix="") -> str:
@require_torch
class AgentAudioTests(unittest.TestCase):
def test_from_tensor(self):
import soundfile as sf
import torch
tensor = torch.rand(12, dtype=torch.float64) - 0.5
@ -62,6 +57,7 @@ class AgentAudioTests(unittest.TestCase):
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
def test_from_string(self):
import soundfile as sf
import torch
tensor = torch.rand(12, dtype=torch.float64) - 0.5