Sajjo's picture
Add .ipynb_checkpoints to .gitignore
1500574
raw
history blame
1.29 kB
import torch
import gradio as gr
import torchaudio
from transformers import AutoModel, AutoProcessor
from quanto import qint8, quantize, freeze
# Load and quantize the model
model_name = "cdactvm/w2v-bert-punjabi"
model = AutoModel.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
# Quantization
quantize(model, weights=qint8, activations=None)
freeze(model)
# Audio transcription function
def transcribe(audio):
waveform, sample_rate = torchaudio.load(audio)
# Ensure 16kHz sample rate
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
# Process audio
inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
# Run inference
with torch.no_grad():
logits = model(**inputs).logits
# Decode transcription
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# Gradio UI
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(source="upload", type="filepath"),
outputs="text",
title="Punjabi Speech Recognition",
description="Upload an audio file and get a Punjabi transcription using a quantized model.",
)
iface.launch()