Remove dependency on _is_package_available from transformers (#247)
This commit is contained in:
parent
6db75183ff
commit
58b18f5655
|
@ -32,7 +32,6 @@ from transformers import (
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
from transformers.utils.import_utils import _is_package_available
|
|
||||||
|
|
||||||
from .tools import Tool
|
from .tools import Tool
|
||||||
|
|
||||||
|
@ -48,9 +47,6 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = {
|
||||||
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
|
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>",
|
||||||
}
|
}
|
||||||
|
|
||||||
if _is_package_available("litellm"):
|
|
||||||
import litellm
|
|
||||||
|
|
||||||
|
|
||||||
def get_dict_from_nested_dataclasses(obj):
|
def get_dict_from_nested_dataclasses(obj):
|
||||||
def convert(obj):
|
def convert(obj):
|
||||||
|
@ -508,9 +504,11 @@ class LiteLLMModel(Model):
|
||||||
api_key=None,
|
api_key=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not _is_package_available("litellm"):
|
try:
|
||||||
raise ImportError(
|
import litellm
|
||||||
"litellm not found. Install it with `pip install litellm`"
|
except ModuleNotFoundError:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
|
||||||
)
|
)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
@ -530,6 +528,8 @@ class LiteLLMModel(Model):
|
||||||
messages = get_clean_message_list(
|
messages = get_clean_message_list(
|
||||||
messages, role_conversions=tool_role_conversions
|
messages, role_conversions=tool_role_conversions
|
||||||
)
|
)
|
||||||
|
import litellm
|
||||||
|
|
||||||
if tools_to_call_from:
|
if tools_to_call_from:
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
model=self.model_id,
|
model=self.model_id,
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
@ -25,7 +26,6 @@ from transformers.utils import (
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.utils.import_utils import _is_package_available
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -41,9 +41,6 @@ if is_torch_available():
|
||||||
else:
|
else:
|
||||||
Tensor = object
|
Tensor = object
|
||||||
|
|
||||||
if _is_package_available("soundfile"):
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
|
|
||||||
class AgentType:
|
class AgentType:
|
||||||
"""
|
"""
|
||||||
|
@ -187,11 +184,12 @@ class AgentAudio(AgentType, str):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, value, samplerate=16_000):
|
def __init__(self, value, samplerate=16_000):
|
||||||
|
if importlib.util.find_spec("soundfile") is None:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
|
||||||
|
)
|
||||||
super().__init__(value)
|
super().__init__(value)
|
||||||
|
|
||||||
if not _is_package_available("soundfile"):
|
|
||||||
raise ImportError("soundfile must be installed in order to handle audio.")
|
|
||||||
|
|
||||||
self._path = None
|
self._path = None
|
||||||
self._tensor = None
|
self._tensor = None
|
||||||
|
|
||||||
|
@ -221,6 +219,8 @@ class AgentAudio(AgentType, str):
|
||||||
"""
|
"""
|
||||||
Returns the "raw" version of that object. It is a `torch.Tensor` object.
|
Returns the "raw" version of that object. It is a `torch.Tensor` object.
|
||||||
"""
|
"""
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
if self._tensor is not None:
|
if self._tensor is not None:
|
||||||
return self._tensor
|
return self._tensor
|
||||||
|
|
||||||
|
@ -239,6 +239,8 @@ class AgentAudio(AgentType, str):
|
||||||
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
|
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
|
||||||
version of the audio.
|
version of the audio.
|
||||||
"""
|
"""
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
if self._path is not None:
|
if self._path is not None:
|
||||||
return self._path
|
return self._path
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import ast
|
import ast
|
||||||
|
import importlib.util
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
@ -22,13 +23,10 @@ import types
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from transformers.utils.import_utils import _is_package_available
|
|
||||||
|
|
||||||
_pygments_available = _is_package_available("pygments")
|
|
||||||
|
|
||||||
|
|
||||||
def is_pygments_available():
|
def is_pygments_available():
|
||||||
return _pygments_available
|
return importlib.util.find_spec("soundfile") is not None
|
||||||
|
|
||||||
|
|
||||||
console = Console(width=200)
|
console = Console(width=200)
|
||||||
|
|
|
@ -24,15 +24,9 @@ from transformers.testing_utils import (
|
||||||
require_torch,
|
require_torch,
|
||||||
require_vision,
|
require_vision,
|
||||||
)
|
)
|
||||||
from transformers.utils.import_utils import (
|
|
||||||
_is_package_available,
|
|
||||||
)
|
|
||||||
|
|
||||||
from smolagents.types import AgentAudio, AgentImage, AgentText
|
from smolagents.types import AgentAudio, AgentImage, AgentText
|
||||||
|
|
||||||
if _is_package_available("soundfile"):
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
|
|
||||||
def get_new_path(suffix="") -> str:
|
def get_new_path(suffix="") -> str:
|
||||||
directory = tempfile.mkdtemp()
|
directory = tempfile.mkdtemp()
|
||||||
|
@ -43,6 +37,7 @@ def get_new_path(suffix="") -> str:
|
||||||
@require_torch
|
@require_torch
|
||||||
class AgentAudioTests(unittest.TestCase):
|
class AgentAudioTests(unittest.TestCase):
|
||||||
def test_from_tensor(self):
|
def test_from_tensor(self):
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||||
|
@ -62,6 +57,7 @@ class AgentAudioTests(unittest.TestCase):
|
||||||
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
|
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
|
||||||
|
|
||||||
def test_from_string(self):
|
def test_from_string(self):
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||||
|
|
Loading…
Reference in New Issue