Fix SpeechToTextTool new instance (#478)

* Test SpeechToTextTool new instance

* Fix SpeechToTextTool new method
This commit is contained in:
Albert Villanova del Moral 2025-02-03 11:42:03 +01:00 committed by GitHub
parent 44f94eaa2d
commit 183869de04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 7 deletions

View File

@ -257,17 +257,15 @@ class SpeechToTextTool(PipelineTool):
}
output_type = "string"
def __new__(cls):
def __new__(cls, *args, **kwargs):
from transformers.models.whisper import (
WhisperForConditionalGeneration,
WhisperProcessor,
)
if not hasattr(cls, "pre_processor_class"):
cls.pre_processor_class = WhisperProcessor
if not hasattr(cls, "model_class"):
cls.model_class = WhisperForConditionalGeneration
return super().__new__()
cls.pre_processor_class = WhisperProcessor
cls.model_class = WhisperForConditionalGeneration
return super().__new__(cls, *args, **kwargs)
def encode(self, audio):
audio = AgentAudio(audio).to_raw()

View File

@ -17,7 +17,7 @@ import unittest
import pytest
from smolagents.agent_types import _AGENT_TYPE_MAPPING
from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, VisitWebpageTool
from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, SpeechToTextTool, VisitWebpageTool
from .test_tools import ToolTesterMixin
@ -77,3 +77,13 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
with pytest.raises(Exception) as e:
self.tool("import sympy as sp")
assert "sympy" in str(e).lower()
class TestSpeechToTextTool:
def test_new_instance(self):
from transformers.models.whisper import WhisperForConditionalGeneration, WhisperProcessor
tool = SpeechToTextTool()
assert tool is not None
assert tool.pre_processor_class == WhisperProcessor
assert tool.model_class == WhisperForConditionalGeneration