akera's picture
Update app.py
22fe498 verified
raw
history blame
1.96 kB
import gradio as gr
from transformers import Wav2Vec2ForCTC, AutoProcessor, Wav2Vec2Processor
import torch
import librosa
import json
import os
import huggingface_hub
from transformers import pipeline
# with open('ISO_codes.json', 'r') as file:
# iso_codes = json.load(file)
# languages = ["lug", "ach", "nyn", "teo"]
auth_token = os.environ.get("HF_TOKEN")
target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
languages = list(target_lang_options.keys())
# Transcribe audio using custom model
def transcribe_audio(input_file, language,chunk_length_s=10,
stride_length_s=(4, 2), return_timestamps="word"):
device = "cuda" if torch.cuda.is_available() else "cpu"
target_lang_code = target_lang_options[language]
# Determine the model_id based on the language
if target_lang_code == "eng":
model_id = "facebook/mms-1b-all"
else:
model_id = "Sunbird/sunbird-mms"
pipe = pipeline(model=model_id, device=device, token=hf_auth_token)
pipe.tokenizer.set_target_lang(target_lang_code)
pipe.model.load_adapter(target_lang_code)
# Read audio file
audio_data = input_file.read()
output = pipe(audio_data, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=return_timestamps)
return output
description = '''ASR with salt-mms'''
iface = gr.Interface(fn=transcribe_audio,
inputs=[
gr.Audio(source="microphone", type="filepath", label="Record Audio"),
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
gr.Dropdown(choices=languages, label="Language", value="English")
],
outputs=gr.Textbox(label="Transcription"),
description=description
)
iface.launch()