import re import os import gradio as gr import librosa import numpy as np from transformers import AutoTokenizer,ViTImageProcessor from unidecode import unidecode from models import * tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base") processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') def preprocess(x): """Preprocess input string x""" s = unidecode(x) s = str.lower(s) s = re.sub(r"\[[a-z]+\]","", s) s = re.sub(r"\*","", s) s = re.sub(r"[^a-zA-Z0-9]+"," ",s) s = re.sub(r" +"," ",s) s = re.sub(r"(.)\1+",r"\1",s) return s label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"] audio_label_names = ["Laughter", "Sigh", "Cough", "Throat clearing", "Sneeze", "Sniff"] def ssl_predict(in_text, model_type): """main predict function""" preprocessed = preprocess(in_text) toks = tok( preprocessed, padding="max_length", max_length=96, truncation=True, return_tensors="tf" ) preds = None if model_type == "fixmatch": model = FixMatchTune(encoder_name="readerbench/RoBERT-base") model.load_weights("./checkpoints/fixmatch_tune") preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) elif model_type == "freematch": model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch") model.cls_head.load_weights("./checkpoints/freematch_tune") preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) elif model_type == "mixmatch": model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch") model.cls_head.load_weights("./checkpoints/mixmatch") preds = model([toks["input_ids"],toks["attention_mask"]], training=False) elif model_type == "contrastive_reg": model = FixMatchTune(encoder_name="readerbench/RoBERT-base") model.load_weights("./checkpoints/contrastive") preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) elif model_type == "label_propagation": model = LPModel() model.load_weights("./checkpoints/label_prop") preds = model([toks["input_ids"],toks["attention_mask"]], training=False) probs = list(preds[0].numpy()) d = {} for k, v in zip(label_names, probs): d[k] = float(v) return d def ssl_predict2(audio_file, model_type): """main predict function""" signal, sr = librosa.load(audio_file.name, sr=16000) length = 5 * 16000 if len(signal) < length: signal = np.pad(signal,(0,length-len(signal)),'constant') else: signal = signal[:length] spectrogram = librosa.feature.melspectrogram(y=signal, sr=sr, n_mels=128) spectrogram = librosa.power_to_db(S=spectrogram, ref=np.max) spectrogram_min, spectrogram_max = spectrogram.min(), spectrogram.max() spectrogram = (spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min) spectrogram = spectrogram.astype("float32") inputs = processor.preprocess( np.repeat(spectrogram[np.newaxis,:,:,np.newaxis],3,-1), image_mean=(-3.05,-3.05,-3.05), image_std=(2.33,2.33,2.33), return_tensors="tf" ) preds = None if model_type == "fixmatch": model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-fixmatch") model.cls_head.load_weights("./checkpoints/audio_fixmatch") preds, _ = model(inputs["pixel_values"], training=False) elif model_type == "freematch": model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-freematch") model.cls_head.load_weights("./checkpoints/audio_freematch") preds, _ = model(inputs["pixel_values"], training=False) elif model_type == "mixmatch": model = AudioMixMatch(encoder_name="andrei-saceleanu/vit-base-mixmatch") model.cls_head.load_weights("./checkpoints/audio_mixmatch") preds = model(inputs["pixel_values"], training=False) probs = list(preds[0].numpy()) d = {} for k, v in zip(audio_label_names, probs): d[k] = float(v) return d text_types = ["text", "password"] with open(file="examples.txt", mode="r", encoding="UTF-8") as fin: lines = [elem[:-1] for elem in fin.readlines()] DATA_DIR = os.path.abspath("./audio_data") with open(file="audio_examples.txt", mode="r", encoding="UTF-8") as fin: lines2 = [os.path.join(DATA_DIR, elem.strip()) for elem in fin.readlines()] with gr.Blocks() as ssl_interface: with gr.Tab("Text (RO-Offense)"): with gr.Row(): with gr.Column(): in_text = gr.Textbox(label="Input text",type="password") safe_view = gr.Checkbox(value=True,label="Safe view") model_list = gr.Dropdown( choices=["fixmatch", "freematch", "mixmatch", "contrastive_reg", "label_propagation"], max_choices=1, label="Training method", allow_custom_value=False, info="Select trained model according to different SSL techniques from paper", ) with gr.Row(): clear_btn = gr.Button(value="Clear") submit_btn = gr.Button(value="Submit") ds = gr.Dataset( components=[gr.Textbox(visible=False),gr.Textbox(visible=False)], headers=["Id","Expected class"], samples=[["1","ABUSE"],["2","INSULT"],["3","PROFANITY"],["4","OTHER"]], type="index" ) with gr.Column(): out_field = gr.Label(num_top_classes=4, label="Prediction") safe_view.change( fn= lambda checked: gr.update(type=text_types[int(checked)]), inputs=safe_view, outputs=in_text ) ds.click( fn=lambda idx: gr.update(value=lines[idx].split("##")[0]), inputs=ds, outputs=in_text ) submit_btn.click( fn=ssl_predict, inputs=[in_text, model_list], outputs=[out_field] ) clear_btn.click( fn=lambda: [None for _ in range(2)], inputs=None, outputs=[in_text, out_field], queue=False ) with gr.Tab("Audio (VocalSound)"): with gr.Row(): with gr.Column(): audio_file = gr.File( label="Input audio", file_count="single", file_types=["audio"] ) model_list2 = gr.Dropdown( choices=["fixmatch", "freematch", "mixmatch"], max_choices=1, label="Training method", allow_custom_value=False, info="Select trained model according to different SSL techniques from paper", ) with gr.Row(): clear_btn2 = gr.Button(value="Clear") submit_btn2 = gr.Button(value="Submit") ds2 = gr.Dataset( components=[gr.Textbox(visible=False),gr.Textbox(visible=False)], headers=["Id","Expected class"], samples=[["1","Laughter"],["2","Cough"],["3","Sneeze"],["4","Throatclearing"]], type="index" ) with gr.Column(): out_field2 = gr.Label(num_top_classes=6, label="Prediction") submit_btn2.click( fn=ssl_predict2, inputs=[audio_file, model_list2], outputs=[out_field2] ) clear_btn2.click( fn=lambda: [None for _ in range(2)], inputs=None, outputs=[audio_file, out_field2], queue=False ) ds2.click( fn=lambda idx: gr.update(value=lines2[idx]), inputs=ds2, outputs=audio_file ) ssl_interface.launch(server_name="0.0.0.0", server_port=7860)