AshwinSankar's picture
Update app.py
13c84da verified
from __future__ import annotations
import os
import gradio as gr
import torch
import torchaudio
import spaces
import nemo.collections.asr as nemo_asr
LANGUAGE_NAME_TO_CODE = {
"Assamese": "as",
"Bengali": "bn",
"Bodo": "br",
"Dogri": "doi",
"Gujarati": "gu",
"Hindi": "hi",
"Kannada": "kn",
"Kashmiri": "ks",
"Konkani": "kok",
"Maithili": "mai",
"Malayalam": "ml",
"Manipuri": "mni",
"Marathi": "mr",
"Nepali": "ne",
"Odia": "or",
"Punjabi": "pa",
"Sanskrit": "sa",
"Santali": "sat",
"Sindhi": "sd",
"Tamil": "ta",
"Telugu": "te",
"Urdu": "ur"
}
DESCRIPTION = """\
### **IndicConformer: Speech Recognition for Indian Languages** 🎙️➡️📜
This Gradio demo showcases **IndicConformer**, a speech recognition model for **22 Indian languages**. The model operates in two modes: **CTC (Connectionist Temporal Classification)** and **RNNT (Recurrent Neural Network Transducer)**, providing robust and accurate transcriptions across diverse linguistic and acoustic conditions.
#### **How to Use:**
1. **Upload or record** an audio clip in any supported Indian language.
2. Select the **mode** (CTC or RNNT) for transcription.
3. Click **"Transcribe"** to generate the corresponding text in the target language.
4. View or copy the output for further use.
🚀 Try it out and experience seamless speech recognition for Indian languages!
"""
hf_token = os.getenv("HF_TOKEN")
device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
model_name_or_path = "ai4bharat/IndicConformer"
model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name_or_path).to(device)
# model = nemo_asr.models.EncDecCTCModel.restore_from("indicconformer_stt_bn_hybrid_rnnt_large.nemo").to(device)
model.eval()
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
AUDIO_SAMPLE_RATE = 16000
MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
DEFAULT_TARGET_LANGUAGE = "Bengali"
@spaces.GPU
def run_asr_ctc(input_audio: str, target_language: str) -> str:
lang_id = LANGUAGE_NAME_TO_CODE[target_language]
# Load and preprocess audio
audio_tensor, orig_freq = torchaudio.load(input_audio)
# Convert to mono if not already
if audio_tensor.shape[0] > 1:
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
# Ensure shape [B x T]
if len(audio_tensor.shape) == 1:
audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing
if audio_tensor.ndim > 1:
audio_tensor = audio_tensor.squeeze(0)
# Resample to 16kHz
audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000)
model.cur_decoder = "ctc"
ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0]
return ctc_text[0]
# @spaces.GPU
# def run_asr_ctc(input_audio: str, target_language: str) -> str:
# # preprocess_audio(input_audio)
# # input_audio, orig_freq = torchaudio.load(input_audio)
# # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
# lang_id = LANGUAGE_NAME_TO_CODE[target_language]
# model.cur_decoder = "ctc"
# ctc_text = model.transcribe([input_audio], batch_size=1, logprobs=False, language_id=lang_id)[0]
# return ctc_text[0]
@spaces.GPU
def run_asr_rnnt(input_audio: str, target_language: str) -> str:
lang_id = LANGUAGE_NAME_TO_CODE[target_language]
# Load and preprocess audio
audio_tensor, orig_freq = torchaudio.load(input_audio)
# Convert to mono if not already
if audio_tensor.shape[0] > 1:
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True)
# Ensure shape [B x T]
if len(audio_tensor.shape) == 1:
audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension if missing
if audio_tensor.ndim > 1:
audio_tensor = audio_tensor.squeeze(0)
# Resample to 16kHz
audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=orig_freq, new_freq=16000)
model.cur_decoder = "rnnt"
ctc_text = model.transcribe([audio_tensor.numpy()], batch_size=1, logprobs=False, language_id=lang_id)[0]
return ctc_text[0]
# @spaces.GPU
# def run_asr_rnnt(input_audio: str, target_language: str) -> str:
# # preprocess_audio(input_audio)
# # input_audio, orig_freq = torchaudio.load(input_audio)
# # input_audio = torchaudio.functional.resample(input_audio, orig_freq=orig_freq, new_freq=16000)
# lang_id = LANGUAGE_NAME_TO_CODE[target_language]
# model.cur_decoder = "rnnt"
# ctc_text = model.transcribe([input_audio], batch_size=1,logprobs=False, language_id=lang_id)[0]
# return ctc_text[0]
with gr.Blocks() as demo_asr_ctc:
with gr.Row():
with gr.Column():
with gr.Group():
input_audio = gr.Audio(label="Input speech", type="filepath")
target_language = gr.Dropdown(
label="Target language",
choices=LANGUAGE_NAME_TO_CODE.keys(),
value=DEFAULT_TARGET_LANGUAGE,
)
btn = gr.Button("Transcribe")
with gr.Column():
output_text = gr.Textbox(label="Transcribed text")
gr.Examples(
examples=[
["assets/Bengali.wav", "Bengali", "English"],
["assets/Gujarati.wav", "Gujarati", "Hindi"],
["assets/Punjabi.wav", "Punjabi", "Hindi"],
],
inputs=[input_audio, target_language],
outputs=output_text,
fn=run_asr_ctc,
cache_examples=CACHE_EXAMPLES,
api_name=False,
)
btn.click(
fn=run_asr_ctc,
inputs=[input_audio, target_language],
outputs=output_text,
api_name="asr",
)
with gr.Blocks() as demo_asr_rnnt:
with gr.Row():
with gr.Column():
with gr.Group():
input_audio = gr.Audio(label="Input speech", type="filepath")
target_language = gr.Dropdown(
label="Target language",
choices=LANGUAGE_NAME_TO_CODE.keys(),
value=DEFAULT_TARGET_LANGUAGE,
)
btn = gr.Button("Transcribe")
with gr.Column():
output_text = gr.Textbox(label="Transcribed text")
gr.Examples(
examples=[
["assets/Bengali.wav", "Bengali", "English"],
["assets/Gujarati.wav", "Gujarati", "Hindi"],
["assets/Punjabi.wav", "Punjabi", "Hindi"],
],
inputs=[input_audio, target_language],
outputs=output_text,
fn=run_asr_rnnt,
cache_examples=CACHE_EXAMPLES,
api_name=False,
)
btn.click(
fn=run_asr_rnnt,
inputs=[input_audio, target_language],
outputs=output_text,
api_name="asr",
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Tabs():
with gr.Tab(label="CTC"):
demo_asr_ctc.render()
with gr.Tab(label="RNNT"):
demo_asr_rnnt.render()
if __name__ == "__main__":
demo.queue(max_size=50).launch()