|
from smolagents import Tool |
|
import torch |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging |
|
import warnings |
|
|
|
|
|
class SpeechRecognitionTool(Tool): |
|
name = "speech_to_text" |
|
description = """Transcribes speech from audio.""" |
|
|
|
inputs = { |
|
"audio": { |
|
"type": "string", |
|
"description": "Path to the audio file to transcribe.", |
|
}, |
|
"with_time_markers": { |
|
"type": "boolean", |
|
"description": "Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float, float], indicating the number of seconds elapsed from the start of the audio.", |
|
"nullable": True, |
|
"default": False, |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
chunk_length_s = 30 |
|
|
|
def __new__(cls, *args, **kwargs): |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
model_id = "openai/whisper-large-v3-turbo" |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
) |
|
model.to(device) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
logging.set_verbosity_error() |
|
warnings.filterwarnings( |
|
"ignore", |
|
category=FutureWarning, |
|
message=r".*The input name `inputs` is deprecated.*", |
|
) |
|
cls.pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
chunk_length_s=cls.chunk_length_s, |
|
return_timestamps=True, |
|
) |
|
|
|
return super().__new__(cls, *args, **kwargs) |
|
|
|
def forward(self, audio: str, with_time_markers: bool = False) -> str: |
|
""" |
|
Transcribes speech from audio. |
|
|
|
Args: |
|
audio (str): Path to the audio file to transcribe. |
|
with_time_markers (bool): Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float], indicating the number of seconds elapsed from the start of the audio. |
|
|
|
Returns: |
|
str: The transcribed text. |
|
""" |
|
result = self.pipe(audio) |
|
if not with_time_markers: |
|
return result["text"].strip() |
|
|
|
txt = "" |
|
for chunk in self._normalize_chunks(result["chunks"]): |
|
txt += f"[{chunk['start']:.2f}]\n{chunk['text']}\n[{chunk['end']:.2f}]\n" |
|
return txt.strip() |
|
|
|
def transcribe(self, audio, **kwargs): |
|
result = self.pipe(audio, **kwargs) |
|
return self._normalize_chunks(result["chunks"]) |
|
|
|
def _normalize_chunks(self, chunks): |
|
chunk_length_s = self.chunk_length_s |
|
absolute_offset = 0.0 |
|
chunk_offset = 0.0 |
|
normalized = [] |
|
|
|
for chunk in chunks: |
|
timestamp_start = chunk["timestamp"][0] |
|
timestamp_end = chunk["timestamp"][1] |
|
if timestamp_start < chunk_offset: |
|
absolute_offset += chunk_length_s |
|
chunk_offset = timestamp_start |
|
absolute_start = absolute_offset + timestamp_start |
|
|
|
if timestamp_end < timestamp_start: |
|
absolute_offset += chunk_length_s |
|
absolute_end = absolute_offset + timestamp_end |
|
chunk_offset = timestamp_end |
|
|
|
chunk_text = chunk["text"].strip() |
|
if chunk_text: |
|
normalized.append( |
|
{ |
|
"start": absolute_start, |
|
"end": absolute_end, |
|
"text": chunk_text, |
|
} |
|
) |
|
|
|
return normalized |
|
|