luganda-asr / app.py
cahya's picture
add KenLM
37c396e
raw
history blame
3.03 kB
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from pyctcdecode import build_ctcdecoder
import gradio as gr
import sox
import os
from multiprocessing import Pool
class KenLM:
def __init__(self, tokenizer, model_name, num_workers=8, beam_width=128):
self.num_workers = num_workers
self.beam_width = beam_width
vocab_dict = tokenizer.get_vocab()
self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
# Workaround for wrong number of vocabularies:
self.vocabulary = self.vocabulary[:-2]
self.decoder = build_ctcdecoder(self.vocabulary, model_name)
@staticmethod
def lm_postprocess(text):
return ' '.join([x if len(x) > 1 else "" for x in text.split()]).strip()
def decode(self, logits):
probs = logits.cpu().numpy()
# probs = logits.numpy()
with Pool(self.num_workers) as pool:
text = self.decoder.decode_batch(pool, probs)
text = [KenLM.lm_postprocess(x) for x in text]
return text
def convert(inputfile, outfile):
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
)
sox_tfm.build(inputfile, outfile)
api_token = os.getenv("API_TOKEN")
model_name = "indonesian-nlp/wav2vec2-luganda"
processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=api_token)
model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=api_token)
kenlm = KenLM(processor.tokenizer, "5gram.bin")
def parse_transcription(wav_file):
filename = wav_file.name.split('.')[0]
convert(wav_file.name, filename + "16k.wav")
speech, _ = sf.read(filename + "16k.wav")
input_values = processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits
transcription = kenlm.decode(logits)[0]
return transcription
output = gr.outputs.Textbox(label="The transcript")
input_ = gr.inputs.Audio(source="microphone", type="file")
gr.Interface(parse_transcription, inputs=input_, outputs=[output],
analytics_enabled=False,
show_tips=False,
theme='huggingface',
layout='vertical',
title="Automatic Speech Recognition for Luganda",
description="Speech Recognition Live Demo for Luganda",
article="This demo was built for the "
"<a href='https://zindi.africa/competitions/mozilla-luganda-automatic-speech-recognition' target='_blank'>Mozilla Luganda Automatic Speech Recognition Competition</a>. "
"It uses the <a href='https://huggingface.co/indonesian-nlp/wav2vec2-luganda' target='_blank'>indonesian-nlp/wav2vec2-luganda</a> model "
"which was fine-tuned on Luganda Common Voice speech datasets.",
enable_queue=True).launch( inline=False)