import gradio as gr import torch from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModelForSeq2SeqLM, AutoTokenizer from IndicTransToolkit import IndicProcessor import librosa import numpy as np # Constants BATCH_SIZE = 4 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ---- Initialize Wav2Vec2 Model for Malayalam Speech Recognition ---- def initialize_wav2vec2_model(model_name): processor = Wav2Vec2Processor.from_pretrained(model_name) model = Wav2Vec2ForCTC.from_pretrained(model_name).to(DEVICE) model.eval() return processor, model wav2vec2_model_name = "addy88/wav2vec2-malayalam-stt" wav2vec2_processor, wav2vec2_model = initialize_wav2vec2_model(wav2vec2_model_name) # ---- IndicTrans2 Model Initialization ---- def initialize_translation_model_and_tokenizer(ckpt_dir): tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True) model = AutoModelForSeq2SeqLM.from_pretrained( ckpt_dir, trust_remote_code=True, low_cpu_mem_usage=True, ).to(DEVICE) model.eval() return tokenizer, model en_indic_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B" en_indic_tokenizer, en_indic_model = initialize_translation_model_and_tokenizer(en_indic_ckpt_dir) ip = IndicProcessor(inference=True) # ---- Batch Translation Function ---- def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip): translations = [] for i in range(0, len(input_sentences), BATCH_SIZE): batch = input_sentences[i : i + BATCH_SIZE] batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang) inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(DEVICE) with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) with tokenizer.as_target_tokenizer(): generated_tokens = tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True, ) translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang) del inputs torch.cuda.empty_cache() return translations # ---- Gradio Function ---- def transcribe_and_translate(audio): try: # Load audio using librosa and force sample rate to 16kHz audio_input, sample_rate = librosa.load(audio, sr=16000) # Normalize audio if np.max(np.abs(audio_input)) != 0: audio_input = audio_input / np.max(np.abs(audio_input)) except Exception as e: return f"Error reading audio: {e}", "" # Process audio input_values = wav2vec2_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values.to(DEVICE) # Perform inference with torch.no_grad(): logits = wav2vec2_model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) # Decode transcription malayalam_text = wav2vec2_processor.decode(predicted_ids[0].cpu(), skip_special_tokens=True) # Translation en_sents = [malayalam_text] src_lang, tgt_lang = "mal_Mlym", "eng_Latn" translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip) return malayalam_text, translations[0] # ---- Gradio Interface ---- iface = gr.Interface( fn=transcribe_and_translate, inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"), outputs=[ gr.Textbox(label="Malayalam Transcription"), gr.Textbox(label="English Translation") ], title="Malayalam Speech Recognition & Translation", description="Speak in Malayalam → Transcribe using Wav2Vec2 → Translate to English using IndicTrans2." ) iface.launch(debug=True, share=True)