Spaces:
Running
Running
from __future__ import annotations | |
import os | |
import gradio as gr | |
import torch | |
import torchaudio | |
import spaces | |
import nemo.collections.asr as nemo_asr | |
LANGUAGE_NAME_TO_CODE = { | |
"Assamese": "as", | |
"Bengali": "bn", | |
"Bodo": "br", | |
"Dogri": "doi", | |
"Gujarati": "gu", | |
"Hindi": "hi", | |
"Kannada": "kn", | |
"Kashmiri": "ks", | |
"Konkani": "kok", | |
"Maithili": "mai", | |
"Malayalam": "ml", | |
"Manipuri": "mni", | |
"Marathi": "mr", | |
"Nepali": "ne", | |
"Odia": "or", | |
"Punjabi": "pa", | |
"Sanskrit": "sa", | |
"Santali": "sat", | |
"Sindhi": "sd", | |
"Tamil": "ta", | |
"Telugu": "te", | |
"Urdu": "ur" | |
} | |
DESCRIPTION = """\ | |
### **IndicConformer: Speech Recognition for Indian Languages** 🎙️➡️📜 | |
This Gradio demo showcases **IndicConformer**, a speech recognition model for **22 Indian languages**. The model operates in two modes: **CTC (Connectionist Temporal Classification)** and **RNNT (Recurrent Neural Network Transducer)**, providing robust and accurate transcriptions across diverse linguistic and acoustic conditions. | |
#### **How to Use:** | |
1. **Upload or record** an audio clip in any supported Indian language. | |
2. Select the **mode** (CTC or RNNT) for transcription. | |
3. Click **"Transcribe"** to generate the corresponding text in the target language. | |
4. View or copy the output for further use. | |
🚀 Try it out and experience seamless speech recognition for Indian languages! | |
""" | |
hf_token = os.getenv("HF_TOKEN") | |
device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32 | |
model_name_or_path = "ai4bharat/IndicConformer" | |
model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name_or_path).to(device) | |
# model = nemo_asr.models.EncDecCTCModel.restore_from("indicconformer_stt_bn_hybrid_rnnt_large.nemo").to(device) | |
model.eval() | |
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available() | |
AUDIO_SAMPLE_RATE = 16000 | |
MAX_INPUT_AUDIO_LENGTH = 60 # in seconds | |
DEFAULT_TARGET_LANGUAGE = "Bengali" | |
def run_asr_ctc(input_audio: str, target_language: str) -> str: | |
lang_id = LANGUAGE_NAME_TO_CODE[target_language] | |
# Load and preprocess audio | |
audio_tensor, orig_freq = torchaudio.load(input_audio) | |
# Convert to mono if not already | |
if audio_tensor.shape[0] > 1: | |
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) | |
# Ensure shape [B x T] | |
if len(audio_tensor.shape) == 1: | |
audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing | |
if audio_tensor.ndim > 1: | |
audio_tensor = audio_tensor.squeeze(0) | |
# Resample to 16kHz | |
audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000) | |
model.cur_decoder = "ctc" | |
ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0] | |
return ctc_text[0] | |
# @spaces.GPU | |
# def run_asr_ctc(input_audio: str, target_language: str) -> str: | |
# # preprocess_audio(input_audio) | |
# # input_audio, orig_freq = torchaudio.load(input_audio) | |
# # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000) | |
# lang_id = LANGUAGE_NAME_TO_CODE[target_language] | |
# model.cur_decoder = "ctc" | |
# ctc_text = model.transcribe([input_audio], batch_size=1, logprobs=False, language_id=lang_id)[0] | |
# return ctc_text[0] | |
def run_asr_rnnt(input_audio: str, target_language: str) -> str: | |
lang_id = LANGUAGE_NAME_TO_CODE[target_language] | |
# Load and preprocess audio | |
audio_tensor, orig_freq = torchaudio.load(input_audio) | |
# Convert to mono if not already | |
if audio_tensor.shape[0] > 1: | |
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) | |
# Ensure shape [B x T] | |
if len(audio_tensor.shape) == 1: | |
audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing | |
if audio_tensor.ndim > 1: | |
audio_tensor = audio_tensor.squeeze(0) | |
# Resample to 16kHz | |
audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000) | |
model.cur_decoder = "rnnt" | |
ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0] | |
return ctc_text[0] | |
# @spaces.GPU | |
# def run_asr_rnnt(input_audio: str, target_language: str) -> str: | |
# # preprocess_audio(input_audio) | |
# # input_audio, orig_freq = torchaudio.load(input_audio) | |
# # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000) | |
# lang_id = LANGUAGE_NAME_TO_CODE[target_language] | |
# model.cur_decoder = "rnnt" | |
# ctc_text = model.transcribe([input_audio], batch_size=1,logprobs=False, language_id=lang_id)[0] | |
# return ctc_text[0] | |
with gr.Blocks() as demo_asr_ctc: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
input_audio = gr.Audio(label="Input speech", type="filepath") | |
target_language = gr.Dropdown( | |
label="Target language", | |
choices=LANGUAGE_NAME_TO_CODE.keys(), | |
value=DEFAULT_TARGET_LANGUAGE, | |
) | |
btn = gr.Button("Transcribe") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Transcribed text") | |
gr.Examples( | |
examples=[ | |
["assets/Bengali.wav", "Bengali", "English"], | |
["assets/Gujarati.wav", "Gujarati", "Hindi"], | |
["assets/Punjabi.wav", "Punjabi", "Hindi"], | |
], | |
inputs=[input_audio, target_language], | |
outputs=output_text, | |
fn=run_asr_ctc, | |
cache_examples=CACHE_EXAMPLES, | |
api_name=False, | |
) | |
btn.click( | |
fn=run_asr_ctc, | |
inputs=[input_audio, target_language], | |
outputs=output_text, | |
api_name="asr", | |
) | |
with gr.Blocks() as demo_asr_rnnt: | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Group(): | |
input_audio = gr.Audio(label="Input speech", type="filepath") | |
target_language = gr.Dropdown( | |
label="Target language", | |
choices=LANGUAGE_NAME_TO_CODE.keys(), | |
value=DEFAULT_TARGET_LANGUAGE, | |
) | |
btn = gr.Button("Transcribe") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Transcribed text") | |
gr.Examples( | |
examples=[ | |
["assets/Bengali.wav", "Bengali", "English"], | |
["assets/Gujarati.wav", "Gujarati", "Hindi"], | |
["assets/Punjabi.wav", "Punjabi", "Hindi"], | |
], | |
inputs=[input_audio, target_language], | |
outputs=output_text, | |
fn=run_asr_rnnt, | |
cache_examples=CACHE_EXAMPLES, | |
api_name=False, | |
) | |
btn.click( | |
fn=run_asr_rnnt, | |
inputs=[input_audio, target_language], | |
outputs=output_text, | |
api_name="asr", | |
) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton( | |
value="Duplicate Space for private use", | |
elem_id="duplicate-button", | |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", | |
) | |
with gr.Tabs(): | |
with gr.Tab(label="CTC"): | |
demo_asr_ctc.render() | |
with gr.Tab(label="RNNT"): | |
demo_asr_rnnt.render() | |
if __name__ == "__main__": | |
demo.queue(max_size=50).launch() | |