Fix SpeechToTextTool new instance (#478)
* Test SpeechToTextTool new instance * Fix SpeechToTextTool new method
This commit is contained in:
parent
44f94eaa2d
commit
183869de04
|
@ -257,17 +257,15 @@ class SpeechToTextTool(PipelineTool):
|
||||||
}
|
}
|
||||||
output_type = "string"
|
output_type = "string"
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls, *args, **kwargs):
|
||||||
from transformers.models.whisper import (
|
from transformers.models.whisper import (
|
||||||
WhisperForConditionalGeneration,
|
WhisperForConditionalGeneration,
|
||||||
WhisperProcessor,
|
WhisperProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not hasattr(cls, "pre_processor_class"):
|
|
||||||
cls.pre_processor_class = WhisperProcessor
|
cls.pre_processor_class = WhisperProcessor
|
||||||
if not hasattr(cls, "model_class"):
|
|
||||||
cls.model_class = WhisperForConditionalGeneration
|
cls.model_class = WhisperForConditionalGeneration
|
||||||
return super().__new__()
|
return super().__new__(cls, *args, **kwargs)
|
||||||
|
|
||||||
def encode(self, audio):
|
def encode(self, audio):
|
||||||
audio = AgentAudio(audio).to_raw()
|
audio = AgentAudio(audio).to_raw()
|
||||||
|
|
|
@ -17,7 +17,7 @@ import unittest
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from smolagents.agent_types import _AGENT_TYPE_MAPPING
|
from smolagents.agent_types import _AGENT_TYPE_MAPPING
|
||||||
from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool
|
from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, SpeechToTextTool, VisitWebpageTool
|
||||||
|
|
||||||
from .test_tools import ToolTesterMixin
|
from .test_tools import ToolTesterMixin
|
||||||
|
|
||||||
|
@ -77,3 +77,13 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
with pytest.raises(Exception) as e:
|
with pytest.raises(Exception) as e:
|
||||||
self.tool("import sympy as sp")
|
self.tool("import sympy as sp")
|
||||||
assert "sympy" in str(e).lower()
|
assert "sympy" in str(e).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpeechToTextTool:
|
||||||
|
def test_new_instance(self):
|
||||||
|
from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
||||||
|
|
||||||
|
tool = SpeechToTextTool()
|
||||||
|
assert tool is not None
|
||||||
|
assert tool.pre_processor_class == WhisperProcessor
|
||||||
|
assert tool.model_class == WhisperForConditionalGeneration
|
||||||
|
|
Loading…
Reference in New Issue