STT_Model / app.py
viditk's picture
Create app.py
d0ebd00 verified
raw
history blame
4.13 kB
import gradio as gr
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import librosa
import numpy as np
# Constants
BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ---- Initialize Wav2Vec2 Model for Malayalam Speech Recognition ----
def initialize_wav2vec2_model(model_name):
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(DEVICE)
model.eval()
return processor, model
wav2vec2_model_name = "addy88/wav2vec2-malayalam-stt"
wav2vec2_processor, wav2vec2_model = initialize_wav2vec2_model(wav2vec2_model_name)
# ---- IndicTrans2 Model Initialization ----
def initialize_translation_model_and_tokenizer(ckpt_dir):
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
ckpt_dir,
trust_remote_code=True,
low_cpu_mem_usage=True,
).to(DEVICE)
model.eval()
return tokenizer, model
en_indic_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B"
en_indic_tokenizer, en_indic_model = initialize_translation_model_and_tokenizer(en_indic_ckpt_dir)
ip = IndicProcessor(inference=True)
# ---- Batch Translation Function ----
def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
translations = []
for i in range(0, len(input_sentences), BATCH_SIZE):
batch = input_sentences[i : i + BATCH_SIZE]
batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)
del inputs
torch.cuda.empty_cache()
return translations
# ---- Gradio Function ----
def transcribe_and_translate(audio):
try:
# Load audio using librosa and force sample rate to 16kHz
audio_input, sample_rate = librosa.load(audio, sr=16000)
# Normalize audio
if np.max(np.abs(audio_input)) != 0:
audio_input = audio_input / np.max(np.abs(audio_input))
except Exception as e:
return f"Error reading audio: {e}", ""
# Process audio
input_values = wav2vec2_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values.to(DEVICE)
# Perform inference
with torch.no_grad():
logits = wav2vec2_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
# Decode transcription
malayalam_text = wav2vec2_processor.decode(predicted_ids[0].cpu(), skip_special_tokens=True)
# Translation
en_sents = [malayalam_text]
src_lang, tgt_lang = "mal_Mlym", "eng_Latn"
translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)
return malayalam_text, translations[0]
# ---- Gradio Interface ----
iface = gr.Interface(
fn=transcribe_and_translate,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
outputs=[
gr.Textbox(label="Malayalam Transcription"),
gr.Textbox(label="English Translation")
],
title="Malayalam Speech Recognition & Translation",
description="Speak in Malayalam β†’ Transcribe using Wav2Vec2 β†’ Translate to English using IndicTrans2."
)
iface.launch(debug=True, share=True)