import torch import torchaudio import gradio as gr from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Load processor & model model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model processor = Wav2Vec2BertProcessor.from_pretrained(model_name) # Loading the original model. original_model=Wav2Vec2BertForCTC.from_pretrained(model_name) # Explicitly allow Wav2Vec2BertForCTC during unpickling3+ torch.serialization.add_safe_globals([Wav2Vec2BertForCTC]) # Load the full quantized model quantized_model = torch.load("cdactvm/w2v-bert-punjabi/wav2vec2_bert_qint8.pth", weights_only=False) quantized_model.eval() ##################################################### # recognize speech using original model def transcribe_original_model(audio_path): # Load audio file waveform, sample_rate = torchaudio.load(audio_path) # Convert stereo to mono (if needed) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample to 16kHz if sample_rate != 16000: waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) # Process audio inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt") inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()} # Get logits & transcribe with torch.no_grad(): logits = original_model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] return transcription # recognize speech using quantized model. def transcribe_quantized_model(audio_path): # Load audio file waveform, sample_rate = torchaudio.load(audio_path) # Convert stereo to mono (if needed) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample to 16kHz if sample_rate != 16000: waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) # Process audio inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt") inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()} # Get logits & transcribe with torch.no_grad(): logits = quantized_model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] return transcription def select_lng(lng, mic=None, file=None): if mic is not None: audio = mic elif file is not None: audio = file else: return "You must either provide a mic recording or a file" if lng == "original_model": return transcribe_original_model(audio) elif lng == "quantized_model": return transcribe_quantized_model(audio) # Gradio Interface demo=gr.Interface( fn=select_lng, inputs=[ gr.Dropdown(["original_model","quantized_model"],label="Select Model"), gr.Audio(sources=["microphone","upload"], type="filepath"), ], outputs=["textbox"], title="Automatic Speech Recognition", description = "Upload an audio file and get the transcription in Punjabi.", ) if __name__ == "__main__": app.launch()