import re import gradio as gr import librosa import numpy as np from transformers import AutoTokenizer,ViTImageProcessor from unidecode import unidecode from models import * tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base") processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') def preprocess(x): """Preprocess input string x""" s = unidecode(x) s = str.lower(s) s = re.sub(r"\[[a-z]+\]","", s) s = re.sub(r"\*","", s) s = re.sub(r"[^a-zA-Z0-9]+"," ",s) s = re.sub(r" +"," ",s) s = re.sub(r"(.)\1+",r"\1",s) return s label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"] audio_label_names = ["Laughter", "Sigh", "Cough", "Throat clearing", "Sneeze", "Sniff"] def ssl_predict(in_text, model_type): """main predict function""" preprocessed = preprocess(in_text) toks = tok( preprocessed, padding="max_length", max_length=96, truncation=True, return_tensors="tf" ) preds = None if model_type == "fixmatch": model = FixMatchTune(encoder_name="readerbench/RoBERT-base") model.load_weights("./checkpoints/fixmatch_tune") preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) elif model_type == "freematch": model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch") model.cls_head.load_weights("./checkpoints/freematch_tune") preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) elif model_type == "mixmatch": model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch") model.cls_head.load_weights("./checkpoints/mixmatch") preds = model([toks["input_ids"],toks["attention_mask"]], training=False) elif model_type == "contrastive_reg": model = FixMatchTune(encoder_name="readerbench/RoBERT-base") model.cls_head.load_weights("./checkpoints/contrastive") preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) elif model_type == "label_propagation": model = LPModel() model.cls_head.load_weights("./checkpoints/label_prop") preds = model([toks["input_ids"],toks["attention_mask"]], training=False) probs = list(preds[0].numpy()) d = {} for k, v in zip(label_names, probs): d[k] = float(v) return d def ssl_predict2(audio_file, model_type): """main predict function""" signal, sr = librosa.load(audio_file.name, sr=16000) length = 5 * 16000 if len(signal) < length: signal = np.pad(signal,(0,length-len(signal)),'constant') else: signal = signal[:length] spectrogram = librosa.feature.melspectrogram(y=signal, sr=sr, n_mels=128) spectrogram = librosa.power_to_db(S=spectrogram, ref=np.max) spectrogram_min, spectrogram_max = spectrogram.min(), spectrogram.max() spectrogram = (spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min) spectrogram = spectrogram.astype("float32") inputs = processor.preprocess( np.repeat(spectrogram[np.newaxis,:,:,np.newaxis],3,-1), image_mean=(-3.05,-3.05,-3.05), image_std=(2.33,2.33,2.33), return_tensors="tf" ) preds = None if model_type == "fixmatch": model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-fixmatch") model.cls_head.load_weights("./checkpoints/audio_fixmatch") preds, _ = model(inputs["pixel_values"], training=False) elif model_type == "freematch": model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-freematch") model.cls_head.load_weights("./checkpoints/audio_freematch") preds, _ = model(inputs["pixel_values"], training=False) elif model_type == "mixmatch": model = AudioMixMatch(bert_model="andrei-saceleanu/vit-base-mixmatch") model.cls_head.load_weights("./checkpoints/audio_mixmatch") preds = model(inputs["pixel_values"], training=False) probs = list(preds[0].numpy()) d = {} for k, v in zip(audio_label_names, probs): d[k] = float(v) return d with gr.Blocks() as ssl_interface: with gr.Tab("Text (RO-Offense)"): with gr.Row(): with gr.Column(): in_text = gr.Textbox(label="Input text") model_list = gr.Dropdown( choices=["fixmatch", "freematch", "mixmatch", "contrastive_reg", "label_propagation"], max_choices=1, label="Training method", allow_custom_value=False, info="Select trained model according to different SSL techniques from paper", ) with gr.Row(): clear_btn = gr.Button(value="Clear") submit_btn = gr.Button(value="Submit") with gr.Column(): out_field = gr.Label(num_top_classes=4, label="Prediction") submit_btn.click( fn=ssl_predict, inputs=[in_text, model_list], outputs=[out_field] ) clear_btn.click( fn=lambda: [None for _ in range(2)], inputs=None, outputs=[in_text, out_field] ) with gr.Tab("Audio (VocalSound)"): with gr.Row(): with gr.Column(): audio_file = gr.File( label="Input audio", file_count="single", file_types=["audio"] ) model_list2 = gr.Dropdown( choices=["fixmatch", "freematch", "mixmatch"], max_choices=1, label="Training method", allow_custom_value=False, info="Select trained model according to different SSL techniques from paper", ) with gr.Row(): clear_btn2 = gr.Button(value="Clear") submit_btn2 = gr.Button(value="Submit") with gr.Column(): out_field2 = gr.Label(num_top_classes=6, label="Prediction") submit_btn2.click( fn=ssl_predict2, inputs=[audio_file, model_list2], outputs=[out_field2] ) clear_btn2.click( fn=lambda: [None for _ in range(2)], inputs=None, outputs=[audio_file, out_field2] ) ssl_interface.launch(server_name="0.0.0.0", server_port=7860)