Spaces:
Runtime error
Runtime error
import torch | |
import torchaudio | |
import gradio as gr | |
from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load processor & model | |
model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model | |
processor = Wav2Vec2BertProcessor.from_pretrained(model_name) | |
model = Wav2Vec2BertForCTC.from_pretrained(model_name, torch_dtype=torch.float16).to(device) | |
def transcribe(audio_path): | |
# Load audio file | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# Convert stereo to mono (if needed) | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
# Resample to 16kHz | |
if sample_rate != 16000: | |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
# Process audio | |
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt") | |
inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()} | |
# Get logits & transcribe | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids)[0] | |
return transcription | |
# Gradio Interface | |
app = gr.Interface( | |
fn=transcribe, | |
inputs=gr.Audio(sources="upload", type="filepath"), | |
outputs="text", | |
title="Punjabi Speech-to-Text", | |
description="Upload an audio file and get the transcription in Punjabi." | |
) | |
if __name__ == "__main__": | |
app.launch() | |
# import gradio as gr | |
# import torch | |
# from transformers import pipeline | |
# # Set device | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# # Load ASR pipeline | |
# asr_pipeline = pipeline( | |
# "automatic-speech-recognition", | |
# model="cdactvm/w2v-bert-punjabi", # Replace with a Punjabi ASR model if available | |
# torch_dtype=torch.bfloat16, | |
# device=0 if torch.cuda.is_available() else -1 # GPU (0) or CPU (-1) | |
# ) | |
# def transcribe(audio_path): | |
# # Run inference | |
# result = asr_pipeline(audio_path) | |
# return result["text"] | |
# # Gradio Interface | |
# app = gr.Interface( | |
# fn=transcribe, | |
# inputs=gr.Audio(sources="upload", type="filepath"), | |
# outputs="text", | |
# title="Punjabi Speech-to-Text", | |
# description="Upload an audio file and get the transcription in Punjabi." | |
# ) | |
# if __name__ == "__main__": | |
# app.launch() | |