Remove dependency on _is_package_available from transformers (#247)
This commit is contained in:
parent
6db75183ff
commit
58b18f5655
|
@ -32,7 +32,6 @@ from transformers import (
|
|||
StoppingCriteriaList,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
from .tools import Tool
|
||||
|
||||
|
@ -48,9 +47,6 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
|||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
|
||||
}
|
||||
|
||||
if _is_package_available("litellm"):
|
||||
import litellm
|
||||
|
||||
|
||||
def get_dict_from_nested_dataclasses(obj):
|
||||
def convert(obj):
|
||||
|
@ -508,9 +504,11 @@ class LiteLLMModel(Model):
|
|||
api_key=None,
|
||||
**kwargs,
|
||||
):
|
||||
if not _is_package_available("litellm"):
|
||||
raise ImportError(
|
||||
"litellm not found. Install it with `pip install litellm`"
|
||||
try:
|
||||
import litellm
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
|
||||
)
|
||||
super().__init__()
|
||||
self.model_id = model_id
|
||||
|
@ -530,6 +528,8 @@ class LiteLLMModel(Model):
|
|||
messages = get_clean_message_list(
|
||||
messages, role_conversions=tool_role_conversions
|
||||
)
|
||||
import litellm
|
||||
|
||||
if tools_to_call_from:
|
||||
response = litellm.completion(
|
||||
model=self.model_id,
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
|
@ -25,7 +26,6 @@ from transformers.utils import (
|
|||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -41,9 +41,6 @@ if is_torch_available():
|
|||
else:
|
||||
Tensor = object
|
||||
|
||||
if _is_package_available("soundfile"):
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
class AgentType:
|
||||
"""
|
||||
|
@ -187,11 +184,12 @@ class AgentAudio(AgentType, str):
|
|||
"""
|
||||
|
||||
def __init__(self, value, samplerate=16_000):
|
||||
if importlib.util.find_spec("soundfile") is None:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
|
||||
)
|
||||
super().__init__(value)
|
||||
|
||||
if not _is_package_available("soundfile"):
|
||||
raise ImportError("soundfile must be installed in order to handle audio.")
|
||||
|
||||
self._path = None
|
||||
self._tensor = None
|
||||
|
||||
|
@ -221,6 +219,8 @@ class AgentAudio(AgentType, str):
|
|||
"""
|
||||
Returns the "raw" version of that object. It is a `torch.Tensor` object.
|
||||
"""
|
||||
import soundfile as sf
|
||||
|
||||
if self._tensor is not None:
|
||||
return self._tensor
|
||||
|
||||
|
@ -239,6 +239,8 @@ class AgentAudio(AgentType, str):
|
|||
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
|
||||
version of the audio.
|
||||
"""
|
||||
import soundfile as sf
|
||||
|
||||
if self._path is not None:
|
||||
return self._path
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ast
|
||||
import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
|
@ -22,13 +23,10 @@ import types
|
|||
from typing import Dict, Tuple, Union
|
||||
|
||||
from rich.console import Console
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
_pygments_available = _is_package_available("pygments")
|
||||
|
||||
|
||||
def is_pygments_available():
|
||||
return _pygments_available
|
||||
return importlib.util.find_spec("soundfile") is not None
|
||||
|
||||
|
||||
console = Console(width=200)
|
||||
|
|
|
@ -24,15 +24,9 @@ from transformers.testing_utils import (
|
|||
require_torch,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.utils.import_utils import (
|
||||
_is_package_available,
|
||||
)
|
||||
|
||||
from smolagents.types import AgentAudio, AgentImage, AgentText
|
||||
|
||||
if _is_package_available("soundfile"):
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
directory = tempfile.mkdtemp()
|
||||
|
@ -43,6 +37,7 @@ def get_new_path(suffix="") -> str:
|
|||
@require_torch
|
||||
class AgentAudioTests(unittest.TestCase):
|
||||
def test_from_tensor(self):
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||
|
@ -62,6 +57,7 @@ class AgentAudioTests(unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
|
||||
|
||||
def test_from_string(self):
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||
|
|
Loading…
Reference in New Issue