SSL_demo / app.py
Andrei-Iulian SĂCELEANU
fix preprocess
1de0133
raw
history blame
6.52 kB
import re
import gradio as gr
import librosa
import numpy as np
from transformers import AutoTokenizer,ViTImageProcessor
from unidecode import unidecode
from models import *
tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base")
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
def preprocess(x):
"""Preprocess input string x"""
s = unidecode(x)
s = str.lower(s)
s = re.sub(r"\[[a-z]+\]","", s)
s = re.sub(r"\*","", s)
s = re.sub(r"[^a-zA-Z0-9]+"," ",s)
s = re.sub(r" +"," ",s)
s = re.sub(r"(.)\1+",r"\1",s)
return s
label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"]
audio_label_names = ["Laughter", "Sigh", "Cough", "Throat clearing", "Sneeze", "Sniff"]
def ssl_predict(in_text, model_type):
"""main predict function"""
preprocessed = preprocess(in_text)
toks = tok(
preprocessed,
padding="max_length",
max_length=96,
truncation=True,
return_tensors="tf"
)
preds = None
if model_type == "fixmatch":
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
model.load_weights("./checkpoints/fixmatch_tune")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "freematch":
model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
model.cls_head.load_weights("./checkpoints/freematch_tune")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "mixmatch":
model = MixMatch(bert_model="andrei-saceleanu/ro-offense-mixmatch")
model.cls_head.load_weights("./checkpoints/mixmatch")
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "contrastive_reg":
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
model.cls_head.load_weights("./checkpoints/contrastive")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
elif model_type == "label_propagation":
model = LPModel()
model.cls_head.load_weights("./checkpoints/label_prop")
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
probs = list(preds[0].numpy())
d = {}
for k, v in zip(label_names, probs):
d[k] = float(v)
return d
def ssl_predict2(audio_file, model_type):
"""main predict function"""
signal, sr = librosa.load(audio_file.name, sr=16000)
length = 5 * 16000
if len(signal) < length:
signal = np.pad(signal,(0,length-len(signal)),'constant')
else:
signal = signal[:length]
spectrogram = librosa.feature.melspectrogram(y=signal, sr=sr, n_mels=128)
spectrogram = librosa.power_to_db(S=spectrogram, ref=np.max)
spectrogram_min, spectrogram_max = spectrogram.min(), spectrogram.max()
spectrogram = (spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min)
spectrogram = spectrogram.astype("float32")
inputs = processor.preprocess(
np.repeat(spectrogram[np.newaxis,:,:,np.newaxis],3,-1),
image_mean=(-3.05,-3.05,-3.05),
image_std=(2.33,2.33,2.33),
return_tensors="tf"
)
preds = None
if model_type == "fixmatch":
model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-fixmatch")
model.cls_head.load_weights("./checkpoints/audio_fixmatch")
preds, _ = model(inputs["pixel_values"], training=False)
elif model_type == "freematch":
model = AudioFixMatch(encoder_name="andrei-saceleanu/vit-base-freematch")
model.cls_head.load_weights("./checkpoints/audio_freematch")
preds, _ = model(inputs["pixel_values"], training=False)
elif model_type == "mixmatch":
model = AudioMixMatch(bert_model="andrei-saceleanu/vit-base-mixmatch")
model.cls_head.load_weights("./checkpoints/audio_mixmatch")
preds = model(inputs["pixel_values"], training=False)
probs = list(preds[0].numpy())
d = {}
for k, v in zip(audio_label_names, probs):
d[k] = float(v)
return d
with gr.Blocks() as ssl_interface:
with gr.Tab("Text (RO-Offense)"):
with gr.Row():
with gr.Column():
in_text = gr.Textbox(label="Input text")
model_list = gr.Dropdown(
choices=["fixmatch", "freematch", "mixmatch", "contrastive_reg", "label_propagation"],
max_choices=1,
label="Training method",
allow_custom_value=False,
info="Select trained model according to different SSL techniques from paper",
)
with gr.Row():
clear_btn = gr.Button(value="Clear")
submit_btn = gr.Button(value="Submit")
with gr.Column():
out_field = gr.Label(num_top_classes=4, label="Prediction")
submit_btn.click(
fn=ssl_predict,
inputs=[in_text, model_list],
outputs=[out_field]
)
clear_btn.click(
fn=lambda: [None for _ in range(2)],
inputs=None,
outputs=[in_text, out_field]
)
with gr.Tab("Audio (VocalSound)"):
with gr.Row():
with gr.Column():
audio_file = gr.File(
label="Input audio",
file_count="single",
file_types=["audio"]
)
model_list2 = gr.Dropdown(
choices=["fixmatch", "freematch", "mixmatch"],
max_choices=1,
label="Training method",
allow_custom_value=False,
info="Select trained model according to different SSL techniques from paper",
)
with gr.Row():
clear_btn2 = gr.Button(value="Clear")
submit_btn2 = gr.Button(value="Submit")
with gr.Column():
out_field2 = gr.Label(num_top_classes=6, label="Prediction")
submit_btn2.click(
fn=ssl_predict2,
inputs=[audio_file, model_list2],
outputs=[out_field2]
)
clear_btn2.click(
fn=lambda: [None for _ in range(2)],
inputs=None,
outputs=[audio_file, out_field2]
)
ssl_interface.launch(server_name="0.0.0.0", server_port=7860)