ssiidd's picture
Complete demo
1da587e
raw
history blame
20.2 kB
import gradio as gr
import soundfile
import time
import torch
import scipy.io.wavfile
from espnet2.utils.types import str_or_none
from espnet2.bin.asr_inference import Speech2Text
from subprocess import call
import os
from espnet_model_zoo.downloader import ModelDownloader
# print(a1)
# exit()
# exit()
# tagen = 'kan-bayashi/ljspeech_vits'
# vocoder_tagen = "none"
audio_class_str='0."dog", 1."rooster", 2."pig", 3."cow", 4."frog", 5."cat", 6."hen", 7."insects", 8."sheep", 9."crow", 10."rain", 11."sea waves", 12."crackling fire", 13."crickets", 14."chirping birds", 15."water drops", 16."wind", 17."pouring water", 18."toilet flush", 19."thunderstorm", 20."crying baby", 21."sneezing", 22."clapping", 23."breathing", 24."coughing", 25."footsteps", 26."laughing", 27."brushing teeth", 28."snoring", 29."drinking sipping", 30."door wood knock", 31."mouse click", 32."keyboard typing", 33."door wood creaks", 34."can opening", 35."washing machine", 36."vacuum cleaner", 37."clock alarm", 38."clock tick", 39."glass breaking", 40."helicopter", 41."chainsaw", 42."siren", 43."car horn", 44."engine", 45."train", 46."church bells", 47."airplane", 48."fireworks", 49."hand saw".'
audio_class_arr=audio_class_str.split(", ")
audio_class_arr=[k.split('"')[1] for k in audio_class_arr]
def inference(wav,data):
# import pdb;pdb.set_trace()
with torch.no_grad():
speech, rate = soundfile.read(wav)
if len(speech.shape)==2:
speech=speech[:,0]
if data == "english_slurp":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|ner|> <|SLURP|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
beam_size=20,
ctc_weight=0.0,
penalty=0.1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("in:","")
scenario=intent.split("_")[0]
action=intent.split("_")[1]
ner_text=text.split(" SEP ")[1:-1]
text="INTENT: {scenario: "+scenario+", action: "+action+"}\n"
text=text+"NAMED ENTITIES: {"
for k in ner_text:
slot_name=k.split(" FILL ")[0].replace("sl:","")
slot_val=k.split(" FILL ")[1]
text=text+" "+slot_name+" : "+slot_val+","
text=text+"}"
elif data == "english_fsc":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|ic|> <|fsc|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("in:","")
action=intent.split("_")[0]
objects=intent.split("_")[1]
location=intent.split("_")[2]
text="INTENT: {action: "+action+", object: "+objects+", location: "+location+"}"
elif data == "english_snips":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|ic|> <|SNIPS|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("in:","")
text="INTENT: "+intent
elif data == "dutch_scr":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|nl|> <|scr|> <|grabo_scr|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=20,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0]
text="SPEECH COMMAND: "+intent
elif data == "english_scr":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|scr|> <|google_scr|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("command:","")
text="SPEECH COMMAND: "+intent
elif data == "lithuanian_scr":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token= "<|lt|> <|scr|> <|lt_scr|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text
text="SPEECH COMMAND: "+intent
elif data == "arabic_scr":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token= "<|ar|> <|scr|> <|ar_scr|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("command:","")
text="SPEECH COMMAND: "+intent
elif data == "lid_voxforge":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lid_prompt=True,
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
# import pdb;pdb.set_trace()
lang=speech2text.converter.tokenizer.tokenizer.convert_ids_to_tokens(nbests[0][2][0]).replace("|>","").replace("<|","")
text="LANG: "+lang
elif data == "fake_speech_detection_asvspoof":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|fsd|> <|asvspoof|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("class:","")
text="SPEECH CLASS: "+intent
elif data == "emotion_rec_iemocap":
replace_dict={}
replace_dict["em:neu"]="Neutral"
replace_dict["em:ang"]="Angry"
replace_dict["em:sad"]="Sad"
replace_dict["em:hap"]="Happy"
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|er|> <|iemocap|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=replace_dict[text.split(" ")[0]]
text="EMOTION: "+intent
elif data == "accent_classify_accentdb":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|accent_rec|> <|accentdb|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("accent:","")
text="ACCENT: "+intent
elif data == "sarcasm_mustard":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|scd|> <|mustard|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("class:","")
text="SARCASM CLASS: "+intent
elif data == "sarcasm_mustard_plus":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|scd|> <|mustard_plus_plus|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("class:","")
text="SARCASM CLASS: "+intent
elif data == "gender_voxceleb1":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|gid|> <|voxceleb|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("gender:f","female").replace("gender:m","male")
text="GENDER: "+intent
elif data == "audio_classification_esc50":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|audio|> <|auc|> <|esc50|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1]
intent=text.split(" ")[0].replace("audio_class:","")
text="AUDIO EVENT CLASS: "+audio_class_arr[int(intent)]
elif data == "semantic_parsing_stop":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lang_prompt_token="<|en|> <|sp|> <|STOP|>",
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=20,
penalty=0.1,
nbest=1
)
nbests = speech2text(speech)
text, *_ = nbests[0]
text=text.split("|>")[-1].replace("_STOP","")
text="SEMANTIC PARSE SEQUENCE: "+text
elif data == "vad_freesound":
speech2text = Speech2Text.from_pretrained(
asr_train_config="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/config.yaml",
asr_model_file="UniverSLU-17-Task-Specifier/exp/asr_train_asr_whisper_full_correct_specaug2_copy_raw_en_whisper_multilingual/valid.acc.ave_10best.pth",
# Decoding parameters are not included in the model file
lid_prompt=True,
prompt_token_file="UniverSLU-17-Task-Specifier/add_tokens-Copy1.txt",
ctc_weight=0.0,
beam_size=1,
nbest=1
)
nbests = speech2text(speech)
lang=speech2text.converter.tokenizer.tokenizer.convert_ids_to_tokens(nbests[0][2][0])
if lang=="<|nospeech|>":
text="VAD: no speech"
else:
text="VAD: speech"
# if lang == "chinese":
# wav = text2speechch(text)["wav"]
# scipy.io.wavfile.write("out.wav",text2speechch.fs , wav.view(-1).cpu().numpy())
# if lang == "japanese":
# wav = text2speechjp(text)["wav"]
# scipy.io.wavfile.write("out.wav",text2speechjp.fs , wav.view(-1).cpu().numpy())
return text
title = "UniverSLU"
description = "Gradio demo for UniverSLU Task Specifier (https://huggingface.co/espnet/UniverSLU-17-Task-Specifier). UniverSLU-17 Task Specifier is a Multi-task Spoken Language Understanding model from CMU WAVLab. It adapts Whisper to additional tasks using single-token task specifiers. To use it, simply record your audio or click one of the examples to load them. More details about the SLU tasks that the model is trained on and it's performance on these tasks can be found in our paper: https://aclanthology.org/2024.naacl-long.151/"
article = "<p style='text-align: center'><a href='https://github.com/espnet/espnet' target='_blank'>Github Repo</a></p>"
examples=[['audio_slurp_ner.flac',"english_slurp"],['audio_fsc.wav',"english_fsc"],['audio_grabo.wav',"dutch_scr"],['audio_english_scr.wav',"english_scr"],['audio_lt_scr.wav',"lithuanian_scr"],['audio_ar_scr.wav',"arabic_scr"],['audio_snips.wav',"english_snips"],['audio_lid.wav',"lid_voxforge"],['audio_fsd.wav',"fake_speech_detection_asvspoof"],['audio_er.wav',"emotion_rec_iemocap"],['audio_acc.wav',"accent_classify_accentdb"],['audio_mustard.wav',"sarcasm_mustard"],['audio_mustard_plus.wav',"sarcasm_mustard_plus"],['audio_voxceleb1.wav',"gender_voxceleb1"],['audio_esc50.wav',"audio_classification_esc50"],['audio_stop.wav',"semantic_parsing_stop"],['audio_freesound.wav',"vad_freesound"]]
# gr.inputs.Textbox(label="input text",lines=10),gr.inputs.Radio(choices=["english"], type="value", default="english", label="language")
gr.Interface(
inference,
[gr.Audio(label="input audio",sources=["microphone"],type="filepath"),gr.Radio(choices=["english_slurp","english_fsc","dutch_scr","english_scr","lithuanian_scr","arabic_scr","english_snips","lid_voxforge","fake_speech_detection_asvspoof","emotion_rec_iemocap","accent_classify_accentdb","sarcasm_mustard","sarcasm_mustard_plus","gender_voxceleb1","audio_classification_esc50","semantic_parsing_stop","vad_freesound"], type="value", label="Task")],
gr.Textbox(type="text", label="Output"),
title=title,
description=description,
article=article,
examples=examples
).launch(debug=True)