Alexa-NLU-Clone / app.py
qanastek's picture
Update
51be472
raw
history blame
4.87 kB
import gradio as gr
import os
import torch
import librosa
from glob import glob
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, AutoModelForTokenClassification, TokenClassificationPipeline, Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
SAMPLE_RATE = 16_000
models = {}
models_paths = {
"en-US": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
"fr-FR": "jonatasgrosman/wav2vec2-large-xlsr-53-french",
"nl-NL": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"pl-PL": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"it-IT": "jonatasgrosman/wav2vec2-large-xlsr-53-italian",
"ru-RU": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pt-PT": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"de-DE": "jonatasgrosman/wav2vec2-large-xlsr-53-german",
"es-ES": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish",
"ja-JP": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"ar-SA": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"fi-FI": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
"hu-HU": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
"zh-CN": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"el-GR": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
}
# Classifier Intent
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer_intent = AutoTokenizer.from_pretrained(model_name)
model_intent = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_intent = TextClassificationPipeline(model=model_intent, tokenizer=tokenizer_intent)
# Classifier Language
model_name = 'qanastek/51-languages-classifier'
tokenizer_langs = AutoTokenizer.from_pretrained(model_name)
model_langs = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_language = TextClassificationPipeline(model=model_langs, tokenizer=tokenizer_langs)
# NER Extractor
model_name = 'qanastek/XLMRoberta-Alexa-Intents-NER-NLU'
tokenizer_ner = AutoTokenizer.from_pretrained(model_name)
model_ner = AutoModelForTokenClassification.from_pretrained(model_name)
predict_ner = TokenClassificationPipeline(model=model_ner, tokenizer=tokenizer_ner)
EXAMPLE_DIR = './wavs/'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav')))
examples = [[e, e.split("=")[0].split("/")[-1]] for e in examples]
def transcribe(audio_path, lang_code):
speech_array, sampling_rate = librosa.load(audio_path, sr=16_000)
if lang_code not in models:
models[lang_code] = {}
models[lang_code]["processor"] = Wav2Vec2Processor.from_pretrained(models_paths[lang_code])
models[lang_code]["model"] = Wav2Vec2ForCTC.from_pretrained(models_paths[lang_code])
# Load model
processor_asr = models[lang_code]["processor"]
model_asr = models[lang_code]["model"]
inputs = processor_asr(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model_asr(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
return processor_asr.batch_decode(predicted_ids)[0]
def getUniform(text):
idx = 0
res = {}
for t in text:
raw = t["entity"].replace("B-","").replace("I-","")
word = t["word"].replace("▁","")
if "B-" in t["entity"]:
res[f"{raw}|{idx}"] = [word]
idx += 1
else:
res[f"{raw}|{idx}"].append(word)
res = [(r.split("|")[0], res[r]) for r in res]
return res
def predict(wav_file, lang_code):
if lang_code not in models_paths.keys():
return {
"The language code is unknown!"
}
text = transcribe(wav_file, lang_code).replace("apizza","a pizza") + " ."
intent_class = classifier_intent(text)[0]["label"]
language_class = classifier_language(text)[0]["label"]
named_entities = getUniform(predict_ner(text))
return {
"text": text,
"language": language_class,
"intent_class": intent_class,
"named_entities": named_entities,
}
iface = gr.Interface(
predict,
title='Alexa Clone πŸ‘©β€πŸ’Ό πŸ—ͺ πŸ€– Multilingual NLU',
description='Upload your wav file to test the models (<i>First execution take about 20s to 30s, then next run in less than 1s</i>)',
# thumbnail="",
inputs=[
gr.inputs.Audio(label='wav file', source='microphone', type='filepath'),
gr.inputs.Dropdown(choices=list(models_paths.keys())),
],
outputs=[
gr.outputs.JSON(label='ASR -> Slot Recognition + Intent Classification + Language Classification'),
],
examples=examples,
article='Made with ❀️ by <a href="https://www.linkedin.com/in/yanis-labrak-8a7412145/" target="_blank">Yanis Labrak</a> thanks to πŸ€—',
)
iface.launch()