Final_Assignment_Template / tools /speech_recognition_tool.py
onkar127's picture
base upload
b5d547f verified
raw
history blame
3.96 kB
from smolagents import Tool
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging
import warnings
class SpeechRecognitionTool(Tool):
name = "speech_to_text"
description = """Transcribes speech from audio."""
inputs = {
"audio": {
"type": "string",
"description": "Path to the audio file to transcribe.",
},
"with_time_markers": {
"type": "boolean",
"description": "Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float, float], indicating the number of seconds elapsed from the start of the audio.",
"nullable": True,
"default": False,
},
}
output_type = "string"
chunk_length_s = 30
def __new__(cls, *args, **kwargs):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
logging.set_verbosity_error()
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message=r".*The input name `inputs` is deprecated.*",
)
cls.pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
chunk_length_s=cls.chunk_length_s,
return_timestamps=True,
)
return super().__new__(cls, *args, **kwargs)
def forward(self, audio: str, with_time_markers: bool = False) -> str:
"""
Transcribes speech from audio.
Args:
audio (str): Path to the audio file to transcribe.
with_time_markers (bool): Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float], indicating the number of seconds elapsed from the start of the audio.
Returns:
str: The transcribed text.
"""
result = self.pipe(audio)
if not with_time_markers:
return result["text"].strip()
txt = ""
for chunk in self._normalize_chunks(result["chunks"]):
txt += f"[{chunk['start']:.2f}]\n{chunk['text']}\n[{chunk['end']:.2f}]\n"
return txt.strip()
def transcribe(self, audio, **kwargs):
result = self.pipe(audio, **kwargs)
return self._normalize_chunks(result["chunks"])
def _normalize_chunks(self, chunks):
chunk_length_s = self.chunk_length_s
absolute_offset = 0.0
chunk_offset = 0.0
normalized = []
for chunk in chunks:
timestamp_start = chunk["timestamp"][0]
timestamp_end = chunk["timestamp"][1]
if timestamp_start < chunk_offset:
absolute_offset += chunk_length_s
chunk_offset = timestamp_start
absolute_start = absolute_offset + timestamp_start
if timestamp_end < timestamp_start:
absolute_offset += chunk_length_s
absolute_end = absolute_offset + timestamp_end
chunk_offset = timestamp_end
chunk_text = chunk["text"].strip()
if chunk_text:
normalized.append(
{
"start": absolute_start,
"end": absolute_end,
"text": chunk_text,
}
)
return normalized