cdactvm's picture
Update app.py
8472c6f verified
raw
history blame
1.57 kB
import torch
import torchaudio
import gradio as gr
from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load processor & model
model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
model = Wav2Vec2BertForCTC.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
def transcribe(audio_path):
# Load audio file
waveform, sample_rate = torchaudio.load(audio_path)
# Convert stereo to mono (if needed)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Resample to 16kHz
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# Process audio
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
# Get logits & transcribe
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# Gradio Interface
app = gr.Interface(
fn=transcribe,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs="text",
title="Punjabi Speech-to-Text",
description="Upload an audio file and get the transcription in Punjabi."
)
if __name__ == "__main__":
app.launch()