Spaces:
Running
Running
File size: 4,870 Bytes
fb4e25f d66e935 59da3de f9e3936 59da3de 73bf18c 535f2ec 73bf18c 0ed6760 bcf29d2 0ed6760 bcf29d2 0ed6760 bcf29d2 0ed6760 73bf18c cb25b1b cef12c2 cb25b1b 73bf18c 63deeee 020e681 73bf18c 63deeee 1bcd824 73bf18c 2bc0b29 def88bd 2bc0b29 1bcd824 def88bd 1bcd824 138be94 b4ecc66 51be472 9a1f9cf 138be94 73bf18c 535f2ec 138be94 ae1d4cd 138be94 cf263be 138be94 fb4e25f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import os
import torch
import librosa
from glob import glob
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, AutoModelForTokenClassification, TokenClassificationPipeline, Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
SAMPLE_RATE = 16_000
models = {}
models_paths = {
"en-US": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
"fr-FR": "jonatasgrosman/wav2vec2-large-xlsr-53-french",
"nl-NL": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"pl-PL": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"it-IT": "jonatasgrosman/wav2vec2-large-xlsr-53-italian",
"ru-RU": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pt-PT": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"de-DE": "jonatasgrosman/wav2vec2-large-xlsr-53-german",
"es-ES": "jonatasgrosman/wav2vec2-large-xlsr-53-spanish",
"ja-JP": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"ar-SA": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"fi-FI": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
"hu-HU": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
"zh-CN": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"el-GR": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
}
# Classifier Intent
model_name = 'qanastek/XLMRoberta-Alexa-Intents-Classification'
tokenizer_intent = AutoTokenizer.from_pretrained(model_name)
model_intent = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_intent = TextClassificationPipeline(model=model_intent, tokenizer=tokenizer_intent)
# Classifier Language
model_name = 'qanastek/51-languages-classifier'
tokenizer_langs = AutoTokenizer.from_pretrained(model_name)
model_langs = AutoModelForSequenceClassification.from_pretrained(model_name)
classifier_language = TextClassificationPipeline(model=model_langs, tokenizer=tokenizer_langs)
# NER Extractor
model_name = 'qanastek/XLMRoberta-Alexa-Intents-NER-NLU'
tokenizer_ner = AutoTokenizer.from_pretrained(model_name)
model_ner = AutoModelForTokenClassification.from_pretrained(model_name)
predict_ner = TokenClassificationPipeline(model=model_ner, tokenizer=tokenizer_ner)
EXAMPLE_DIR = './wavs/'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav')))
examples = [[e, e.split("=")[0].split("/")[-1]] for e in examples]
def transcribe(audio_path, lang_code):
speech_array, sampling_rate = librosa.load(audio_path, sr=16_000)
if lang_code not in models:
models[lang_code] = {}
models[lang_code]["processor"] = Wav2Vec2Processor.from_pretrained(models_paths[lang_code])
models[lang_code]["model"] = Wav2Vec2ForCTC.from_pretrained(models_paths[lang_code])
# Load model
processor_asr = models[lang_code]["processor"]
model_asr = models[lang_code]["model"]
inputs = processor_asr(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model_asr(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
return processor_asr.batch_decode(predicted_ids)[0]
def getUniform(text):
idx = 0
res = {}
for t in text:
raw = t["entity"].replace("B-","").replace("I-","")
word = t["word"].replace("β","")
if "B-" in t["entity"]:
res[f"{raw}|{idx}"] = [word]
idx += 1
else:
res[f"{raw}|{idx}"].append(word)
res = [(r.split("|")[0], res[r]) for r in res]
return res
def predict(wav_file, lang_code):
if lang_code not in models_paths.keys():
return {
"The language code is unknown!"
}
text = transcribe(wav_file, lang_code).replace("apizza","a pizza") + " ."
intent_class = classifier_intent(text)[0]["label"]
language_class = classifier_language(text)[0]["label"]
named_entities = getUniform(predict_ner(text))
return {
"text": text,
"language": language_class,
"intent_class": intent_class,
"named_entities": named_entities,
}
iface = gr.Interface(
predict,
title='Alexa Clone π©βπΌ πͺ π€ Multilingual NLU',
description='Upload your wav file to test the models (<i>First execution take about 20s to 30s, then next run in less than 1s</i>)',
# thumbnail="",
inputs=[
gr.inputs.Audio(label='wav file', source='microphone', type='filepath'),
gr.inputs.Dropdown(choices=list(models_paths.keys())),
],
outputs=[
gr.outputs.JSON(label='ASR -> Slot Recognition + Intent Classification + Language Classification'),
],
examples=examples,
article='Made with β€οΈ by <a href="https://www.linkedin.com/in/yanis-labrak-8a7412145/" target="_blank">Yanis Labrak</a> thanks to π€',
)
iface.launch() |