diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index d3cf1e6..d290e6f 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -257,17 +257,15 @@ class SpeechToTextTool(PipelineTool): } output_type = "string" - def __new__(cls): + def __new__(cls, *args, **kwargs): from transformers.models.whisper import ( WhisperForConditionalGeneration, WhisperProcessor, ) - if not hasattr(cls, "pre_processor_class"): - cls.pre_processor_class = WhisperProcessor - if not hasattr(cls, "model_class"): - cls.model_class = WhisperForConditionalGeneration - return super().__new__() + cls.pre_processor_class = WhisperProcessor + cls.model_class = WhisperForConditionalGeneration + return super().__new__(cls, *args, **kwargs) def encode(self, audio): audio = AgentAudio(audio).to_raw() diff --git a/tests/test_default_tools.py b/tests/test_default_tools.py index e3cb4a6..5ff436e 100644 --- a/tests/test_default_tools.py +++ b/tests/test_default_tools.py @@ -17,7 +17,7 @@ import unittest import pytest from smolagents.agent_types import _AGENT_TYPE_MAPPING -from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool +from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, SpeechToTextTool, VisitWebpageTool from .test_tools import ToolTesterMixin @@ -77,3 +77,13 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): with pytest.raises(Exception) as e: self.tool("import sympy as sp") assert "sympy" in str(e).lower() + + +class TestSpeechToTextTool: + def test_new_instance(self): + from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor + + tool = SpeechToTextTool() + assert tool is not None + assert tool.pre_processor_class == WhisperProcessor + assert tool.model_class == WhisperForConditionalGeneration