Spaces:
Runtime error
Runtime error
import torch | |
import torchaudio | |
import gradio as gr | |
from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load processor & model | |
model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model | |
processor = Wav2Vec2BertProcessor.from_pretrained(model_name) | |
# Loading the original model. | |
original_model=Wav2Vec2BertForCTC.from_pretrained(model_name) | |
# Explicitly allow Wav2Vec2BertForCTC during unpickling3+ | |
torch.serialization.add_safe_globals([Wav2Vec2BertForCTC]) | |
# Load the full quantized model | |
quantized_model = torch.load("cdactvm/w2v-bert-punjabi/wav2vec2_bert_qint8.pth", weights_only=False) | |
quantized_model.eval() | |
##################################################### | |
# recognize speech using original model | |
def transcribe_original_model(audio_path): | |
# Load audio file | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# Convert stereo to mono (if needed) | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
# Resample to 16kHz | |
if sample_rate != 16000: | |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
# Process audio | |
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt") | |
inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()} | |
# Get logits & transcribe | |
with torch.no_grad(): | |
logits = original_model(**inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids)[0] | |
return transcription | |
# recognize speech using quantized model. | |
def transcribe_quantized_model(audio_path): | |
# Load audio file | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# Convert stereo to mono (if needed) | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
# Resample to 16kHz | |
if sample_rate != 16000: | |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
# Process audio | |
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt") | |
inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()} | |
# Get logits & transcribe | |
with torch.no_grad(): | |
logits = quantized_model(**inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids)[0] | |
return transcription | |
def select_lng(lng, mic=None, file=None): | |
if mic is not None: | |
audio = mic | |
elif file is not None: | |
audio = file | |
else: | |
return "You must either provide a mic recording or a file" | |
if lng == "original_model": | |
return transcribe_original_model(audio) | |
elif lng == "quantized_model": | |
return transcribe_quantized_model(audio) | |
# Gradio Interface | |
demo=gr.Interface( | |
fn=select_lng, | |
inputs=[ | |
gr.Dropdown(["original_model","quantized_model"],label="Select Model"), | |
gr.Audio(sources=["microphone","upload"], type="filepath"), | |
], | |
outputs=["textbox"], | |
title="Automatic Speech Recognition", | |
description = "Upload an audio file and get the transcription in Punjabi.", | |
) | |
if __name__ == "__main__": | |
app.launch() | |