Spaces:
Runtime error
Runtime error
File size: 3,374 Bytes
9568e5e 1500574 9568e5e 1500574 9568e5e 82814b2 153933b 82814b2 1500574 82814b2 9568e5e 1500574 9568e5e 966d76f 9568e5e ed97bcc 9568e5e ed97bcc 9568e5e 82814b2 9568e5e ed97bcc 9568e5e 1500574 82814b2 1500574 966d76f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import torch
import torchaudio
import gradio as gr
from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load processor & model
model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
# Loading the original model.
original_model=Wav2Vec2BertForCTC.from_pretrained(model_name)
# Explicitly allow Wav2Vec2BertForCTC during unpickling3+
torch.serialization.add_safe_globals([Wav2Vec2BertForCTC])
# Load the full quantized model
quantized_model = torch.load("cdactvm/w2v-bert-punjabi/wav2vec2_bert_qint8.pth", weights_only=False)
quantized_model.eval()
#####################################################
# recognize speech using original model
def transcribe_original_model(audio_path):
# Load audio file
waveform, sample_rate = torchaudio.load(audio_path)
# Convert stereo to mono (if needed)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 16kHz
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Process audio
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
# Get logits & transcribe
with torch.no_grad():
logits = original_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# recognize speech using quantized model.
def transcribe_quantized_model(audio_path):
# Load audio file
waveform, sample_rate = torchaudio.load(audio_path)
# Convert stereo to mono (if needed)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 16kHz
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Process audio
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
# Get logits & transcribe
with torch.no_grad():
logits = quantized_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
def select_lng(lng, mic=None, file=None):
if mic is not None:
audio = mic
elif file is not None:
audio = file
else:
return "You must either provide a mic recording or a file"
if lng == "original_model":
return transcribe_original_model(audio)
elif lng == "quantized_model":
return transcribe_quantized_model(audio)
# Gradio Interface
demo=gr.Interface(
fn=select_lng,
inputs=[
gr.Dropdown(["original_model","quantized_model"],label="Select Model"),
gr.Audio(sources=["microphone","upload"], type="filepath"),
],
outputs=["textbox"],
title="Automatic Speech Recognition",
description = "Upload an audio file and get the transcription in Punjabi.",
)
if __name__ == "__main__":
app.launch()
|