SSL_demo / app.py
Andrei-Iulian SĂCELEANU
ceva
674a3ea
raw
history blame
2.18 kB
import re
import gradio as gr
from transformers import AutoTokenizer
from unidecode import unidecode
from models import *
tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base")
def preprocess(x):
"""Preprocess input string x"""
s = unidecode(x)
s = str.lower(s)
s = re.sub(r"\[[a-z]+\]","", s)
s = re.sub(r"\*","", s)
s = re.sub(r"[^a-zA-Z0-9]+"," ",s)
s = re.sub(r" +"," ",s)
s = re.sub(r"(.)\1+",r"\1",s)
return s
label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"]
def ssl_predict(in_text, model_type):
"""main predict function"""
preprocessed = preprocess(in_text)
toks = tok(
preprocessed,
padding="max_length",
max_length=96,
truncation=True,
return_tensors="tf"
)
if model_type == "freematch":
model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
model.cls_head.load_weights("./checkpoints/freematch_tune")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
probs = list(preds[0].numpy())
d = {}
for k, v in zip(label_names, probs):
d[k] = float(v)
return d
with gr.Blocks() as ssl_interface:
with gr.Row():
with gr.Column():
in_text = gr.Textbox(label="Input text")
model_list = gr.Dropdown(
choices=["fixmatch", "freematch", "mixmatch"],
max_choices=1,
label="Training method",
allow_custom_value=False,
info="Select trained model according to different SSL techniques from paper",
)
with gr.Row():
clear_btn = gr.Button(value="Clear")
submit_btn = gr.Button(value="Submit")
with gr.Column():
out_field = gr.Label(num_top_classes=4,label="Prediction")
submit_btn.click(
fn=ssl_predict,
inputs=[in_text, model_list],
outputs=[out_field]
)
clear_btn.click(
fn=lambda: [None for _ in range(2)],
inputs=None,
outputs=[in_text, out_field]
)
ssl_interface.launch(server_name="0.0.0.0", server_port=7860)