Spaces:
Paused
Paused
import re | |
import gradio as gr | |
from transformers import AutoTokenizer | |
from unidecode import unidecode | |
from models import * | |
tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base") | |
def preprocess(x): | |
"""Preprocess input string x""" | |
s = unidecode(x) | |
s = str.lower(s) | |
s = re.sub(r"\[[a-z]+\]","", s) | |
s = re.sub(r"\*","", s) | |
s = re.sub(r"[^a-zA-Z0-9]+"," ",s) | |
s = re.sub(r" +"," ",s) | |
s = re.sub(r"(.)\1+",r"\1",s) | |
return s | |
label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"] | |
def ssl_predict(in_text, model_type): | |
"""main predict function""" | |
preprocessed = preprocess(in_text) | |
toks = tok( | |
preprocessed, | |
padding="max_length", | |
max_length=96, | |
truncation=True, | |
return_tensors="tf" | |
) | |
if model_type == "freematch": | |
model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch") | |
model.cls_head.load_weights("./checkpoints/freematch_tune") | |
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False) | |
probs = list(preds[0].numpy()) | |
return {k:v for k, v in zip(label_names, probs)} | |
with gr.Blocks() as ssl_interface: | |
with gr.Row(): | |
with gr.Column(): | |
in_text = gr.Textbox(label="Input text") | |
model_list = gr.Dropdown( | |
choices=["fixmatch", "freematch", "mixmatch"], | |
max_choices=1, | |
label="Training method", | |
allow_custom_value=False, | |
info="Select trained model according to different SSL techniques from paper", | |
) | |
with gr.Row(): | |
clear_btn = gr.Button(value="Clear") | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(): | |
out_field = gr.Label(num_top_classes=4,label="Prediction") | |
submit_btn.click( | |
fn=ssl_predict, | |
inputs=[in_text, model_list], | |
outputs=[out_field] | |
) | |
clear_btn.click( | |
fn=lambda: [None for _ in range(2)], | |
inputs=None, | |
outputs=[in_text, out_field] | |
) | |
ssl_interface.launch(server_name="0.0.0.0", server_port=7860) | |