File size: 2,182 Bytes
02768a2
de475ce
02768a2
 
 
de475ce
 
02768a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674a3ea
 
 
 
02768a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import re
import gradio as gr
from transformers import AutoTokenizer
from unidecode import unidecode
from models import *


tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base")

def preprocess(x):
    """Preprocess input string x"""

    s = unidecode(x)
    s = str.lower(s)
    s = re.sub(r"\[[a-z]+\]","", s)
    s = re.sub(r"\*","", s)
    s = re.sub(r"[^a-zA-Z0-9]+"," ",s)
    s = re.sub(r" +"," ",s)
    s = re.sub(r"(.)\1+",r"\1",s)

    return s

label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"]

def ssl_predict(in_text, model_type):
    """main predict function"""

    preprocessed = preprocess(in_text)
    toks = tok(
        preprocessed,
        padding="max_length",
        max_length=96,
        truncation=True,
        return_tensors="tf"
    )
    if model_type == "freematch":
        model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
        model.cls_head.load_weights("./checkpoints/freematch_tune")

    preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
    probs = list(preds[0].numpy())

    d = {}
    for k, v in zip(label_names, probs):
        d[k] = float(v)
    return d




with gr.Blocks() as ssl_interface:
    with gr.Row():
        with gr.Column():
            in_text = gr.Textbox(label="Input text")
            model_list = gr.Dropdown(
                choices=["fixmatch", "freematch", "mixmatch"],
                max_choices=1,
                label="Training method",
                allow_custom_value=False,
                info="Select trained model according to different SSL techniques from paper",
            )

            with gr.Row():
                clear_btn = gr.Button(value="Clear")
                submit_btn = gr.Button(value="Submit")

        with gr.Column():
            out_field = gr.Label(num_top_classes=4,label="Prediction")

    submit_btn.click(
        fn=ssl_predict,
        inputs=[in_text, model_list],
        outputs=[out_field]
    )

    clear_btn.click(
        fn=lambda: [None for _ in range(2)],
        inputs=None,
        outputs=[in_text, out_field]
    )

ssl_interface.launch(server_name="0.0.0.0", server_port=7860)