from __future__ import annotations import os import gradio as gr import torch import torchaudio import spaces import nemo.collections.asr as nemo_asr LANGUAGE_NAME_TO_CODE = { "Assamese": "as", "Bengali": "bn", "Bodo": "br", "Dogri": "doi", "Gujarati": "gu", "Hindi": "hi", "Kannada": "kn", "Kashmiri": "ks", "Konkani": "kok", "Maithili": "mai", "Malayalam": "ml", "Manipuri": "mni", "Marathi": "mr", "Nepali": "ne", "Odia": "or", "Punjabi": "pa", "Sanskrit": "sa", "Santali": "sat", "Sindhi": "sd", "Tamil": "ta", "Telugu": "te", "Urdu": "ur" } DESCRIPTION = """\ ### **IndicConformer: Speech Recognition for Indian Languages** 🎙️➡️📜 This Gradio demo showcases **IndicConformer**, a speech recognition model for **22 Indian languages**. The model operates in two modes: **CTC (Connectionist Temporal Classification)** and **RNNT (Recurrent Neural Network Transducer)**, providing robust and accurate transcriptions across diverse linguistic and acoustic conditions. #### **How to Use:** 1. **Upload or record** an audio clip in any supported Indian language. 2. Select the **mode** (CTC or RNNT) for transcription. 3. Click **"Transcribe"** to generate the corresponding text in the target language. 4. View or copy the output for further use. 🚀 Try it out and experience seamless speech recognition for Indian languages! """ hf_token = os.getenv("HF_TOKEN") device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32 model_name_or_path = "ai4bharat/IndicConformer" model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name_or_path).to(device) # model = nemo_asr.models.EncDecCTCModel.restore_from("indicconformer_stt_bn_hybrid_rnnt_large.nemo").to(device) model.eval() CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available() AUDIO_SAMPLE_RATE = 16000 MAX_INPUT_AUDIO_LENGTH = 60 # in seconds DEFAULT_TARGET_LANGUAGE = "Bengali" @spaces.GPU def run_asr_ctc(input_audio: str, target_language: str) -> str: lang_id = LANGUAGE_NAME_TO_CODE[target_language] # Load and preprocess audio audio_tensor, orig_freq = torchaudio.load(input_audio) # Convert to mono if not already if audio_tensor.shape[0] > 1: audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) # Ensure shape [B x T] if len(audio_tensor.shape) == 1: audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing if audio_tensor.ndim > 1: audio_tensor = audio_tensor.squeeze(0) # Resample to 16kHz audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000) model.cur_decoder = "ctc" ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0] return ctc_text[0] # @spaces.GPU # def run_asr_ctc(input_audio: str, target_language: str) -> str: # # preprocess_audio(input_audio) # # input_audio, orig_freq = torchaudio.load(input_audio) # # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000) # lang_id = LANGUAGE_NAME_TO_CODE[target_language] # model.cur_decoder = "ctc" # ctc_text = model.transcribe([input_audio], batch_size=1, logprobs=False, language_id=lang_id)[0] # return ctc_text[0] @spaces.GPU def run_asr_rnnt(input_audio: str, target_language: str) -> str: lang_id = LANGUAGE_NAME_TO_CODE[target_language] # Load and preprocess audio audio_tensor, orig_freq = torchaudio.load(input_audio) # Convert to mono if not already if audio_tensor.shape[0] > 1: audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) # Ensure shape [B x T] if len(audio_tensor.shape) == 1: audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing if audio_tensor.ndim > 1: audio_tensor = audio_tensor.squeeze(0) # Resample to 16kHz audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000) model.cur_decoder = "rnnt" ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0] return ctc_text[0] # @spaces.GPU # def run_asr_rnnt(input_audio: str, target_language: str) -> str: # # preprocess_audio(input_audio) # # input_audio, orig_freq = torchaudio.load(input_audio) # # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000) # lang_id = LANGUAGE_NAME_TO_CODE[target_language] # model.cur_decoder = "rnnt" # ctc_text = model.transcribe([input_audio], batch_size=1,logprobs=False, language_id=lang_id)[0] # return ctc_text[0] with gr.Blocks() as demo_asr_ctc: with gr.Row(): with gr.Column(): with gr.Group(): input_audio = gr.Audio(label="Input speech", type="filepath") target_language = gr.Dropdown( label="Target language", choices=LANGUAGE_NAME_TO_CODE.keys(), value=DEFAULT_TARGET_LANGUAGE, ) btn = gr.Button("Transcribe") with gr.Column(): output_text = gr.Textbox(label="Transcribed text") gr.Examples( examples=[ ["assets/Bengali.wav", "Bengali", "English"], ["assets/Gujarati.wav", "Gujarati", "Hindi"], ["assets/Punjabi.wav", "Punjabi", "Hindi"], ], inputs=[input_audio, target_language], outputs=output_text, fn=run_asr_ctc, cache_examples=CACHE_EXAMPLES, api_name=False, ) btn.click( fn=run_asr_ctc, inputs=[input_audio, target_language], outputs=output_text, api_name="asr", ) with gr.Blocks() as demo_asr_rnnt: with gr.Row(): with gr.Column(): with gr.Group(): input_audio = gr.Audio(label="Input speech", type="filepath") target_language = gr.Dropdown( label="Target language", choices=LANGUAGE_NAME_TO_CODE.keys(), value=DEFAULT_TARGET_LANGUAGE, ) btn = gr.Button("Transcribe") with gr.Column(): output_text = gr.Textbox(label="Transcribed text") gr.Examples( examples=[ ["assets/Bengali.wav", "Bengali", "English"], ["assets/Gujarati.wav", "Gujarati", "Hindi"], ["assets/Punjabi.wav", "Punjabi", "Hindi"], ], inputs=[input_audio, target_language], outputs=output_text, fn=run_asr_rnnt, cache_examples=CACHE_EXAMPLES, api_name=False, ) btn.click( fn=run_asr_rnnt, inputs=[input_audio, target_language], outputs=output_text, api_name="asr", ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton( value="Duplicate Space for private use", elem_id="duplicate-button", visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", ) with gr.Tabs(): with gr.Tab(label="CTC"): demo_asr_ctc.render() with gr.Tab(label="RNNT"): demo_asr_rnnt.render() if __name__ == "__main__": demo.queue(max_size=50).launch()