Spaces:
Paused
Paused
import re | |
import os | |
import gradio as gr | |
import librosa | |
import numpy as np | |
from transformers import AutoTokenizer,ViTImageProcessor | |
from unidecode import unidecode | |
from models import * | |
tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base") | |
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') | |
def preprocess(x): | |
"""Preprocess input string x""" | |
s = unidecode(x) | |
s = str.lower(s) | |
s = re.sub(r"\[[a-z]+\]","", s) | |
s = re.sub(r"\*","", s) | |
s = re.sub(r"[^a-zA-Z0-9]+"," ",s) | |
s = re.sub(r" +"," ",s) | |
s = re.sub(r"(.)\1+",r"\1",s) | |
return s | |
label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"] | |
audio_label_names = ["Laughter", "Sigh", "Cough", "Throat clearing", "Sneeze", "Sniff"] | |
def ssl_predict(in_text, model_type): | |
"""main predict function""" | |
preprocessed = preprocess(in_text) | |
toks = tok( | |
preprocessed, | |
padding="max_length", | |
max_length=96, | |
truncation=True, | |
return_tensors="tf" | |
) | |
preds = None | |
if model_type == "fixmatch": | |
model = FixMatchTune(encoder_name="readerbench/RoBERT-base") | |
model.load_weights("./checkpoints/fixmatch_tune") | |
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
elif model_type == "freematch": | |
model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch") | |
model.cls_head.load_weights("./checkpoints/freematch_tune") | |
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
elif model_type == "mixmatch": | |
model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch") | |
model.cls_head.load_weights("./checkpoints/mixmatch") | |
preds = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
elif model_type == "contrastive_reg": | |
model = FixMatchTune(encoder_name="readerbench/RoBERT-base") | |
model.load_weights("./checkpoints/contrastive") | |
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
elif model_type == "label_propagation": | |
model = LPModel() | |
model.load_weights("./checkpoints/label_prop") | |
preds = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
probs = list(preds[0].numpy()) | |
d = {} | |
for k, v in zip(label_names, probs): | |
d[k] = float(v) | |
return d | |
def ssl_predict2(audio_file, model_type): | |
"""main predict function""" | |
signal, sr = librosa.load(audio_file.name, sr=16000) | |
length = 5 * 16000 | |
if len(signal) < length: | |
signal = np.pad(signal,(0,length-len(signal)),'constant') | |
else: | |
signal = signal[:length] | |
spectrogram = librosa.feature.melspectrogram(y=signal, sr=sr, n_mels=128) | |
spectrogram = librosa.power_to_db(S=spectrogram, ref=np.max) | |
spectrogram_min, spectrogram_max = spectrogram.min(), spectrogram.max() | |
spectrogram = (spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min) | |
spectrogram = spectrogram.astype("float32") | |
inputs = processor.preprocess( | |
np.repeat(spectrogram[np.newaxis,:,:,np.newaxis],3,-1), | |
image_mean=(-3.05,-3.05,-3.05), | |
image_std=(2.33,2.33,2.33), | |
return_tensors="tf" | |
) | |
preds = None | |
if model_type == "fixmatch": | |
model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-fixmatch") | |
model.cls_head.load_weights("./checkpoints/audio_fixmatch") | |
preds, _ = model(inputs["pixel_values"], training=False) | |
elif model_type == "freematch": | |
model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-freematch") | |
model.cls_head.load_weights("./checkpoints/audio_freematch") | |
preds, _ = model(inputs["pixel_values"], training=False) | |
elif model_type == "mixmatch": | |
model = AudioMixMatch(encoder_name="andrei-saceleanu/vit-base-mixmatch") | |
model.cls_head.load_weights("./checkpoints/audio_mixmatch") | |
preds = model(inputs["pixel_values"], training=False) | |
probs = list(preds[0].numpy()) | |
d = {} | |
for k, v in zip(audio_label_names, probs): | |
d[k] = float(v) | |
return d | |
text_types = ["text", "password"] | |
with open(file="examples.txt", mode="r", encoding="UTF-8") as fin: | |
lines = [elem[:-1] for elem in fin.readlines()] | |
DATA_DIR = os.path.abspath("./audio_data") | |
with open(file="audio_examples.txt", mode="r", encoding="UTF-8") as fin: | |
lines2 = [os.path.join(DATA_DIR, elem.strip()) for elem in fin.readlines()] | |
with gr.Blocks() as ssl_interface: | |
with gr.Tab("Text (RO-Offense)"): | |
with gr.Row(): | |
with gr.Column(): | |
in_text = gr.Textbox(label="Input text",type="password") | |
safe_view = gr.Checkbox(value=True,label="Safe view") | |
model_list = gr.Dropdown( | |
choices=["fixmatch", "freematch", "mixmatch", "contrastive_reg", "label_propagation"], | |
max_choices=1, | |
label="Training method", | |
allow_custom_value=False, | |
info="Select trained model according to different SSL techniques from paper", | |
) | |
with gr.Row(): | |
clear_btn = gr.Button(value="Clear") | |
submit_btn = gr.Button(value="Submit") | |
ds = gr.Dataset( | |
components=[gr.Textbox(visible=False),gr.Textbox(visible=False)], | |
headers=["Id","Expected class"], | |
samples=[["1","ABUSE"],["2","INSULT"],["3","PROFANITY"],["4","OTHER"]], | |
type="index" | |
) | |
with gr.Column(): | |
out_field = gr.Label(num_top_classes=4, label="Prediction") | |
safe_view.change( | |
fn= lambda checked: gr.update(type=text_types[int(checked)]), | |
inputs=safe_view, | |
outputs=in_text | |
) | |
ds.click( | |
fn=lambda idx: gr.update(value=lines[idx].split("##")[0]), | |
inputs=ds, | |
outputs=in_text | |
) | |
submit_btn.click( | |
fn=ssl_predict, | |
inputs=[in_text, model_list], | |
outputs=[out_field] | |
) | |
clear_btn.click( | |
fn=lambda: [None for _ in range(2)], | |
inputs=None, | |
outputs=[in_text, out_field], | |
queue=False | |
) | |
with gr.Tab("Audio (VocalSound)"): | |
with gr.Row(): | |
with gr.Column(): | |
audio_file = gr.File( | |
label="Input audio", | |
file_count="single", | |
file_types=["audio"] | |
) | |
model_list2 = gr.Dropdown( | |
choices=["fixmatch", "freematch", "mixmatch"], | |
max_choices=1, | |
label="Training method", | |
allow_custom_value=False, | |
info="Select trained model according to different SSL techniques from paper", | |
) | |
with gr.Row(): | |
clear_btn2 = gr.Button(value="Clear") | |
submit_btn2 = gr.Button(value="Submit") | |
ds2 = gr.Dataset( | |
components=[gr.Textbox(visible=False),gr.Textbox(visible=False)], | |
headers=["Id","Expected class"], | |
samples=[["1","Laughter"],["2","Cough"],["3","Sneeze"],["4","Throatclearing"]], | |
type="index" | |
) | |
with gr.Column(): | |
out_field2 = gr.Label(num_top_classes=6, label="Prediction") | |
submit_btn2.click( | |
fn=ssl_predict2, | |
inputs=[audio_file, model_list2], | |
outputs=[out_field2] | |
) | |
clear_btn2.click( | |
fn=lambda: [None for _ in range(2)], | |
inputs=None, | |
outputs=[audio_file, out_field2], | |
queue=False | |
) | |
ds2.click( | |
fn=lambda idx: gr.update(value=lines2[idx]), | |
inputs=ds2, | |
outputs=audio_file | |
) | |
ssl_interface.launch(server_name="0.0.0.0", server_port=7860) | |