Andrei-Iulian SĂCELEANU commited on
Commit
02768a2
1 Parent(s): de475ce

added freematch test

Browse files
app.py CHANGED
@@ -1,7 +1,80 @@
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
  import gradio as gr
3
+ from transformers import AutoTokenizer
4
+ from unidecode import unidecode
5
+ from models import *
6
 
 
 
7
 
8
+ tok = AutoTokenizer.from_pretrained("readerbench/RoBERT-base")
9
+
10
+ def preprocess(x):
11
+ """Preprocess input string x"""
12
+
13
+ s = unidecode(x)
14
+ s = str.lower(s)
15
+ s = re.sub(r"\[[a-z]+\]","", s)
16
+ s = re.sub(r"\*","", s)
17
+ s = re.sub(r"[^a-zA-Z0-9]+"," ",s)
18
+ s = re.sub(r" +"," ",s)
19
+ s = re.sub(r"(.)\1+",r"\1",s)
20
+
21
+ return s
22
+
23
+ label_names = ["ABUSE", "INSULT", "OTHER", "PROFANITY"]
24
+
25
+ def ssl_predict(in_text, model_type):
26
+ """main predict function"""
27
+
28
+ preprocessed = preprocess(in_text)
29
+ toks = tok(
30
+ preprocessed,
31
+ padding="max_length",
32
+ max_length=96,
33
+ truncation=True,
34
+ return_tensors="tf"
35
+ )
36
+ if model_type == "freematch":
37
+ model = FixMatchTune(encoder_name="andrei-saceleanu/ro-offense-freematch")
38
+ model.cls_head.load_weights("./checkpoints/freematch_tune")
39
+
40
+ preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
41
+
42
+ probs = list(preds[0].numpy())
43
+
44
+ return {k:v for k, v in zip(label_names, probs)}
45
+
46
+
47
+
48
+
49
+ with gr.Blocks() as ssl_interface:
50
+ with gr.Row():
51
+ with gr.Column():
52
+ in_text = gr.Textbox(label="Input text")
53
+ model_list = gr.Dropdown(
54
+ choices=["fixmatch", "freematch", "mixmatch"],
55
+ max_choices=1,
56
+ label="Training method",
57
+ allow_custom_value=False,
58
+ info="Select trained model according to different SSL techniques from paper",
59
+ )
60
+
61
+ with gr.Row():
62
+ clear_btn = gr.Button(value="Clear")
63
+ submit_btn = gr.Button(value="Submit")
64
+
65
+ with gr.Column():
66
+ out_field = gr.Label(num_top_classes=4,label="Prediction")
67
+
68
+ submit_btn.click(
69
+ fn=ssl_predict,
70
+ inputs=[in_text, model_list],
71
+ outputs=[out_field]
72
+ )
73
+
74
+ clear_btn.click(
75
+ fn=lambda: [None for _ in range(2)],
76
+ inputs=None,
77
+ outputs=[in_text, out_field]
78
+ )
79
+
80
+ ssl_interface.launch(server_name="0.0.0.0", server_port=7860)
checkpoints/freematch_tune.data-00000-of-00001 ADDED
Binary file (855 kB). View file
 
checkpoints/freematch_tune.index ADDED
Binary file (518 Bytes). View file
 
models.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from transformers import TFAutoModel
3
+
4
+
5
+ class FixMatchTune(tf.keras.Model):
6
+ def __init__(
7
+ self,
8
+ encoder_name="readerbench/RoBERT-base",
9
+ num_classes=4,
10
+ **kwargs
11
+ ):
12
+ super(FixMatchTune,self).__init__(**kwargs)
13
+
14
+ self.bert = TFAutoModel.from_pretrained(encoder_name)
15
+ self.num_classes = num_classes
16
+ self.weak_augment = tf.keras.layers.GaussianNoise(stddev=0.5)
17
+ self.strong_augment = tf.keras.layers.GaussianNoise(stddev=5)
18
+
19
+ self.cls_head = tf.keras.Sequential([
20
+ tf.keras.layers.Dense(256,activation="relu"),
21
+ tf.keras.layers.Dropout(0.2),
22
+ tf.keras.layers.Dense(64,activation="relu"),
23
+ tf.keras.layers.Dense(self.num_classes, activation="softmax")
24
+ ])
25
+
26
+ def call(self, inputs, training):
27
+ ids, mask = inputs
28
+
29
+ embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
30
+
31
+ strongs = self.strong_augment(embeds,training=training)
32
+ weaks = self.weak_augment(embeds,training=training)
33
+
34
+ strong_preds = self.cls_head(strongs,training=training)
35
+ weak_preds = self.cls_head(weaks,training=training)
36
+
37
+ return weak_preds, strong_preds