File size: 4,125 Bytes
d0ebd00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import librosa
import numpy as np

# Constants
BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---- Initialize Wav2Vec2 Model for Malayalam Speech Recognition ----
def initialize_wav2vec2_model(model_name):
    processor = Wav2Vec2Processor.from_pretrained(model_name)
    model = Wav2Vec2ForCTC.from_pretrained(model_name).to(DEVICE)
    model.eval()
    return processor, model

wav2vec2_model_name = "addy88/wav2vec2-malayalam-stt"
wav2vec2_processor, wav2vec2_model = initialize_wav2vec2_model(wav2vec2_model_name)

# ---- IndicTrans2 Model Initialization ----
def initialize_translation_model_and_tokenizer(ckpt_dir):
    tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        ckpt_dir,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    ).to(DEVICE)
    model.eval()
    return tokenizer, model

en_indic_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B"
en_indic_tokenizer, en_indic_model = initialize_translation_model_and_tokenizer(en_indic_ckpt_dir)
ip = IndicProcessor(inference=True)

# ---- Batch Translation Function ----
def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
    translations = []
    for i in range(0, len(input_sentences), BATCH_SIZE):
        batch = input_sentences[i : i + BATCH_SIZE]
        batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
        inputs = tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)

        with torch.no_grad():
            generated_tokens = model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )

        with tokenizer.as_target_tokenizer():
            generated_tokens = tokenizer.batch_decode(
                generated_tokens.detach().cpu().tolist(),
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )

        translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)
        del inputs
        torch.cuda.empty_cache()

    return translations

# ---- Gradio Function ----
def transcribe_and_translate(audio):
    try:
        # Load audio using librosa and force sample rate to 16kHz
        audio_input, sample_rate = librosa.load(audio, sr=16000)
        
        # Normalize audio
        if np.max(np.abs(audio_input)) != 0:
            audio_input = audio_input / np.max(np.abs(audio_input))
        
    except Exception as e:
        return f"Error reading audio: {e}", ""

    # Process audio
    input_values = wav2vec2_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values.to(DEVICE)

    # Perform inference
    with torch.no_grad():
        logits = wav2vec2_model(input_values).logits
        predicted_ids = torch.argmax(logits, dim=-1)

    # Decode transcription
    malayalam_text = wav2vec2_processor.decode(predicted_ids[0].cpu(), skip_special_tokens=True)

    # Translation
    en_sents = [malayalam_text]
    src_lang, tgt_lang = "mal_Mlym", "eng_Latn"
    translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)

    return malayalam_text, translations[0]

# ---- Gradio Interface ----
iface = gr.Interface(
    fn=transcribe_and_translate,
    inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
    outputs=[
        gr.Textbox(label="Malayalam Transcription"),
        gr.Textbox(label="English Translation")
    ],
    title="Malayalam Speech Recognition & Translation",
    description="Speak in Malayalam β†’ Transcribe using Wav2Vec2 β†’ Translate to English using IndicTrans2."
)

iface.launch(debug=True, share=True)