Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer | |
from IndicTransToolkit import IndicProcessor | |
import speech_recognition as sr | |
# Constants | |
BATCH_SIZE = 4 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
quantization = None | |
# ---- IndicTrans2 Model Initialization ---- | |
def initialize_model_and_tokenizer(ckpt_dir, quantization): | |
if quantization == "4-bit": | |
qconfig = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
elif quantization == "8-bit": | |
qconfig = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_use_double_quant=True, | |
bnb_8bit_compute_dtype=torch.bfloat16, | |
) | |
else: | |
qconfig = None | |
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
ckpt_dir, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
quantization_config=qconfig, | |
) | |
if qconfig is None: | |
model = model.to(DEVICE) | |
if DEVICE == "cuda": | |
model.half() | |
model.eval() | |
return tokenizer, model | |
def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip): | |
translations = [] | |
for i in range(0, len(input_sentences), BATCH_SIZE): | |
batch = input_sentences[i : i + BATCH_SIZE] | |
batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang) | |
inputs = tokenizer( | |
batch, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(DEVICE) | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=256, | |
num_beams=5, | |
num_return_sequences=1, | |
) | |
with tokenizer.as_target_tokenizer(): | |
generated_tokens = tokenizer.batch_decode( | |
generated_tokens.detach().cpu().tolist(), | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
) | |
translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang) | |
del inputs | |
torch.cuda.empty_cache() | |
return translations | |
# Initialize IndicTrans2 | |
en_indic_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B" | |
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization) | |
ip = IndicProcessor(inference=True) | |
# ---- Gradio Function ---- | |
def transcribe_and_translate(audio): | |
recognizer = sr.Recognizer() | |
with sr.AudioFile(audio) as source: | |
audio_data = recognizer.record(source) | |
try: | |
# Malayalam transcription using Google API | |
malayalam_text = recognizer.recognize_google(audio_data, language="ml-IN") | |
except sr.UnknownValueError: | |
return "Could not understand audio", "" | |
except sr.RequestError as e: | |
return f"Google API Error: {e}", "" | |
# Translation | |
en_sents = [malayalam_text] | |
src_lang, tgt_lang = "mal_Mlym", "eng_Latn" | |
translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip) | |
return malayalam_text, translations[0] | |
# ---- Gradio Interface ---- | |
iface = gr.Interface( | |
fn=transcribe_and_translate, | |
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"), | |
outputs=[ | |
gr.Textbox(label="Malayalam Transcription"), | |
gr.Textbox(label="English Translation") | |
], | |
title="Malayalam Speech Recognition & Translation", | |
description="Speak in Malayalam β Transcribe using Google Speech Recognition β Translate to English using IndicTrans2." | |
) | |
iface.launch(debug=True, share=True) | |