Irpan
asr
3a18b3b
raw
history blame
3.6 kB
import torchaudio
import torch
from transformers import (
WhisperProcessor,
AutoProcessor,
AutoModelForSpeechSeq2Seq,
AutoModelForCTC,
Wav2Vec2Processor,
Wav2Vec2ForCTC
)
import numpy as np
# Load processor and model
models_info = {
"openai/whisper-small-uzbek": {
"processor": WhisperProcessor.from_pretrained("openai/whisper-small", language="uzbek", task="transcribe"),
"model": AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small"),
"ctc_model": False
},
"ixxan/whisper-small-thugy20": {
"processor": AutoProcessor.from_pretrained("ixxan/whisper-small-thugy20"),
"model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-thugy20"),
"ctc_model": False
},
"ixxan/whisper-small-uyghur-common-voice": {
"processor": AutoProcessor.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
"model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
"ctc_model": False
},
"facebook/mms-1b-all": {
"processor": AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic'),
"model": AutoModelForCTC.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic', ignore_mismatched_sizes=True),
"ctc_model": True
},
# "ixxan/wav2vec2-large-mms-1b-uyghur-latin": {
# "processor": Wav2Vec2Processor.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
# "model": Wav2Vec2ForCTC.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin"),
# "ctc_model": True
# },
}
def transcribe(audio_data, model_id) -> str:
"""
Transcribes audio to text using the Whisper model for Uyghur.
Args:
- audio_data: Gradio audio input
Returns:
- str: The transcription of the audio.
"""
# Load audio file
if not audio_data:
return "<<ERROR: Empty Audio Input>>"
if isinstance(audio_data, tuple):
# microphone
sampling_rate, audio_input = audio_data
audio_input = (audio_input / 32768.0).astype(np.float32)
elif isinstance(audio_data, str):
# file upload
audio_input, sampling_rate = torchaudio.load(audio_data)
else:
return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
model = models_info[model_id]["model"]
processor = models_info[model_id]["processor"]
target_sr = processor.feature_extractor.sampling_rate
ctc_model = models_info[model_id]["ctc_model"]
# Resample if needed
if sampling_rate != target_sr:
resampler = torchaudio.transforms.Resample(sampling_rate, target_sr)
audio_input = resampler(audio_input)
# Preprocess the audio input
inputs = processor(audio_input.squeeze(), sampling_rate=target_sr, return_tensors="pt", padding=True)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
inputs = {key: val.to(device) for key, val in inputs.items()}
# Generate transcription
with torch.no_grad():
if ctc_model:
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
else:
generated_ids = model.generate(inputs["input_features"], max_length=225)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription