Spaces:
Paused
Paused
File size: 2,182 Bytes
02768a2 de475ce 02768a2 de475ce 02768a2 674a3ea 02768a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import re
import gradio as gr
from transformers import AutoTokenizer
from unidecode import unidecode
from models import *
tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base")
def preprocess(x):
"""Preprocess input string x"""
s = unidecode(x)
s = str.lower(s)
s = re.sub(r"\[[a-z]+\]","", s)
s = re.sub(r"\*","", s)
s = re.sub(r"[^a-zA-Z0-9]+"," ",s)
s = re.sub(r" +"," ",s)
s = re.sub(r"(.)\1+",r"\1",s)
return s
label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"]
def ssl_predict(in_text, model_type):
"""main predict function"""
preprocessed = preprocess(in_text)
toks = tok(
preprocessed,
padding="max_length",
max_length=96,
truncation=True,
return_tensors="tf"
)
if model_type == "freematch":
model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
model.cls_head.load_weights("./checkpoints/freematch_tune")
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
probs = list(preds[0].numpy())
d = {}
for k, v in zip(label_names, probs):
d[k] = float(v)
return d
with gr.Blocks() as ssl_interface:
with gr.Row():
with gr.Column():
in_text = gr.Textbox(label="Input text")
model_list = gr.Dropdown(
choices=["fixmatch", "freematch", "mixmatch"],
max_choices=1,
label="Training method",
allow_custom_value=False,
info="Select trained model according to different SSL techniques from paper",
)
with gr.Row():
clear_btn = gr.Button(value="Clear")
submit_btn = gr.Button(value="Submit")
with gr.Column():
out_field = gr.Label(num_top_classes=4,label="Prediction")
submit_btn.click(
fn=ssl_predict,
inputs=[in_text, model_list],
outputs=[out_field]
)
clear_btn.click(
fn=lambda: [None for _ in range(2)],
inputs=None,
outputs=[in_text, out_field]
)
ssl_interface.launch(server_name="0.0.0.0", server_port=7860)
|