Spaces:
Sleeping
Sleeping
NithitEiEi
commited on
upload model and app
Browse files- app.py +81 -0
- bert/bert_model.pth +3 -0
- bert/model.py +49 -0
- deberta/fastai_QIQC-deberta-v3.pth +3 -0
- deberta/model.py +53 -0
- lstm/model.py +45 -0
- lstm/model_1.h5 +3 -0
- lstm/model_2.h5 +3 -0
- lstm/model_3.h5 +3 -0
- lstm/model_4.h5 +3 -0
- lstm/word_dict.pkl +3 -0
app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from bert.model import BERT_predict
|
3 |
+
from lstm.model import BiLSTM_predict
|
4 |
+
from deberta.model import deBERTa_predict
|
5 |
+
|
6 |
+
examples=[
|
7 |
+
"How do I train my dogs to kill raccoons?",
|
8 |
+
"Why Indian education totally bullshit?",
|
9 |
+
"What is it really like to be a nurse practitioner",
|
10 |
+
"What is your prefer race? white, black or asian."
|
11 |
+
]
|
12 |
+
|
13 |
+
def alert(prob):
|
14 |
+
raise gr.Error(f"This may not suitable for posting insincere probability {prob:.2f}")
|
15 |
+
|
16 |
+
def clear_post():
|
17 |
+
return ""
|
18 |
+
|
19 |
+
def add_post(posts, new_post, mode):
|
20 |
+
if mode == "BiLSTM":
|
21 |
+
prob = BiLSTM_predict(new_post)
|
22 |
+
insincere = 1 if prob > 0.35 else 0
|
23 |
+
if insincere:
|
24 |
+
alert(prob)
|
25 |
+
return posts, ""
|
26 |
+
else:
|
27 |
+
return posts + [{"post": new_post, "model": "BiLSTM", "prob": prob}], ""
|
28 |
+
elif mode == "BERT":
|
29 |
+
insincere, prob = BERT_predict(new_post)
|
30 |
+
if insincere:
|
31 |
+
alert(prob)
|
32 |
+
return posts, ""
|
33 |
+
else:
|
34 |
+
return posts + [{"post": new_post, "model": "BERT", "prob": prob}], ""
|
35 |
+
elif mode == "DeBERTaV3":
|
36 |
+
insincere, prob = deBERTa_predict(new_post)
|
37 |
+
if insincere:
|
38 |
+
alert(prob)
|
39 |
+
return posts, ""
|
40 |
+
else:
|
41 |
+
return posts + [{"post": new_post, "model": "DeBERTaV3", "prob": prob}], ""
|
42 |
+
|
43 |
+
|
44 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Quara Question post") as demo:
|
45 |
+
|
46 |
+
posts = gr.State([])
|
47 |
+
|
48 |
+
new_post = gr.Textbox(label="Add post", autofocus=True)
|
49 |
+
mode = gr.Radio(["BiLSTM", "BERT", "DeBERTaV3"], value="BiLSTM", label="Model")
|
50 |
+
|
51 |
+
with gr.Row():
|
52 |
+
submit = gr.Button("submit", variant='primary')
|
53 |
+
clear = gr.Button("clear")
|
54 |
+
|
55 |
+
submit.click(add_post, inputs=[posts, new_post, mode], outputs=[posts, new_post])
|
56 |
+
clear.click(clear_post, inputs=None, outputs=new_post)
|
57 |
+
@gr.render(inputs=posts)
|
58 |
+
def render_posts(post_list):
|
59 |
+
output = [post for post in post_list]
|
60 |
+
gr.Markdown(f"### Question post ({len(output)})")
|
61 |
+
for index, post in enumerate(output):
|
62 |
+
with gr.Row():
|
63 |
+
gr.Textbox(
|
64 |
+
f"{post['post']} | {post['prob']:.8f}",
|
65 |
+
label=f"Post{index + 1} ({post['model']})",
|
66 |
+
show_label=True
|
67 |
+
)
|
68 |
+
delete_btn = gr.Button("Delete", scale=0, variant="stop")
|
69 |
+
|
70 |
+
def delete(post=post):
|
71 |
+
post_list.remove(post)
|
72 |
+
return post_list
|
73 |
+
delete_btn.click(delete, None, [posts])
|
74 |
+
|
75 |
+
with gr.Row():
|
76 |
+
examples = gr.Examples(
|
77 |
+
examples=examples,
|
78 |
+
inputs=[new_post],
|
79 |
+
)
|
80 |
+
|
81 |
+
demo.launch()
|
bert/bert_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d7423dcada747073f22d892f938150cfe737b5c4a46aabd334f20959d604db1
|
3 |
+
size 436484297
|
bert/model.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
|
6 |
+
class BertClassifier(nn.Module):
|
7 |
+
def __init__(self, bert):
|
8 |
+
super(BertClassifier, self).__init__()
|
9 |
+
self.bert = bert
|
10 |
+
|
11 |
+
def forward(self, input_id, attention_mask):
|
12 |
+
output = self.bert(input_ids=input_id, attention_mask=attention_mask)
|
13 |
+
return output.logits
|
14 |
+
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
|
16 |
+
|
17 |
+
bert = AutoModelForSequenceClassification.from_pretrained('bert-base-cased').train()
|
18 |
+
|
19 |
+
classifier = nn.Sequential(
|
20 |
+
nn.Linear(768, 1024),
|
21 |
+
nn.ReLU(),
|
22 |
+
nn.Dropout(0.5),
|
23 |
+
nn.Linear(1024, 2)
|
24 |
+
)
|
25 |
+
|
26 |
+
bert.classifier = classifier
|
27 |
+
|
28 |
+
model = BertClassifier(bert)
|
29 |
+
|
30 |
+
model.load_state_dict(torch.load("./bert/bert_model.pth", map_location=torch.device('cpu'), weights_only=True))
|
31 |
+
|
32 |
+
model.eval()
|
33 |
+
|
34 |
+
def BERT_predict(text):
|
35 |
+
tokenized_input = tokenizer(text,
|
36 |
+
padding="max_length",
|
37 |
+
truncation=True,
|
38 |
+
max_length=30,
|
39 |
+
return_tensors="pt")
|
40 |
+
|
41 |
+
model.eval()
|
42 |
+
with torch.no_grad():
|
43 |
+
logits = model(tokenized_input['input_ids'], tokenized_input['attention_mask'])
|
44 |
+
|
45 |
+
probabilities = F.softmax(logits, dim=-1)
|
46 |
+
|
47 |
+
prediction = torch.argmax(probabilities, dim=-1).item()
|
48 |
+
|
49 |
+
return prediction, probabilities[0][1].item()
|
deberta/fastai_QIQC-deberta-v3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d2b64c078a7de8ffe57b1ff767e3bfac6bdb52bb3f5977b5e06f3ce9993b873
|
3 |
+
size 740942321
|
deberta/model.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
|
6 |
+
class BertClassifier(nn.Module):
|
7 |
+
def __init__(self, bert):
|
8 |
+
super(BertClassifier, self).__init__()
|
9 |
+
self.bert = bert
|
10 |
+
|
11 |
+
def forward(self, input_id, attention_mask):
|
12 |
+
output = self.bert(input_ids=input_id, attention_mask=attention_mask)
|
13 |
+
return output.logits
|
14 |
+
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-base')
|
16 |
+
|
17 |
+
bert = AutoModelForSequenceClassification.from_pretrained('microsoft/deberta-v3-base').train()
|
18 |
+
|
19 |
+
classifier = nn.Sequential(
|
20 |
+
nn.Linear(768, 1024),
|
21 |
+
nn.ReLU(),
|
22 |
+
nn.Dropout(0.5),
|
23 |
+
nn.Linear(1024, 2)
|
24 |
+
)
|
25 |
+
|
26 |
+
bert.classifier = classifier
|
27 |
+
|
28 |
+
model = BertClassifier(bert)
|
29 |
+
state_dict = torch.load(
|
30 |
+
"./deberta/fastai_QIQC-deberta-v3.pth", map_location=torch.device('cpu'),
|
31 |
+
weights_only=True
|
32 |
+
)
|
33 |
+
|
34 |
+
model.load_state_dict(state_dict, strict=False)
|
35 |
+
|
36 |
+
model.eval()
|
37 |
+
|
38 |
+
def deBERTa_predict(text):
|
39 |
+
tokenized_input = tokenizer(text,
|
40 |
+
padding="max_length",
|
41 |
+
truncation=True,
|
42 |
+
max_length=30,
|
43 |
+
return_tensors="pt")
|
44 |
+
|
45 |
+
model.eval()
|
46 |
+
with torch.no_grad():
|
47 |
+
logits = model(tokenized_input['input_ids'], tokenized_input['attention_mask'])
|
48 |
+
|
49 |
+
probabilities = F.softmax(logits, dim=-1)
|
50 |
+
|
51 |
+
prediction = torch.argmax(probabilities, dim=-1).item()
|
52 |
+
|
53 |
+
return prediction, probabilities[0][1].item()
|
lstm/model.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import spacy
|
3 |
+
import pickle
|
4 |
+
import numpy as np
|
5 |
+
import tensorflow as tf
|
6 |
+
|
7 |
+
BATCH_SIZE = 512
|
8 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
9 |
+
|
10 |
+
def preprocess_text(text):
|
11 |
+
"""Preprocess the input text using SpaCy and return word indices."""
|
12 |
+
docs = nlp.pipe([text], n_process=1)
|
13 |
+
word_seq = []
|
14 |
+
for doc in docs:
|
15 |
+
for token in doc:
|
16 |
+
if token.pos_ != "PUNCT":
|
17 |
+
if token.text not in word_dict:
|
18 |
+
word_dict[token.text] = 0 # OOV_INDEX
|
19 |
+
word_seq.append(word_dict[token.text])
|
20 |
+
return word_seq
|
21 |
+
|
22 |
+
def BiLSTM_predict(text):
|
23 |
+
seq = preprocess_text(text)
|
24 |
+
padded_seq = tf.keras.preprocessing.sequence.pad_sequences([seq], maxlen=55)
|
25 |
+
pred1 = 0.15 * np.squeeze(model_1.predict(padded_seq, batch_size=BATCH_SIZE, verbose=2))
|
26 |
+
pred2 = 0.35 * np.squeeze(model_2.predict(padded_seq, batch_size=BATCH_SIZE, verbose=2))
|
27 |
+
pred3 = 0.15 * np.squeeze(model_3.predict(padded_seq, batch_size=BATCH_SIZE, verbose=2))
|
28 |
+
pred4 = 0.35 * np.squeeze(model_4.predict(padded_seq, batch_size=BATCH_SIZE, verbose=2))
|
29 |
+
pred = pred1 + pred2 + pred3 + pred4
|
30 |
+
|
31 |
+
return pred
|
32 |
+
|
33 |
+
|
34 |
+
model_1 = tf.keras.models.load_model("./lstm/model_1.h5")
|
35 |
+
model_2 = tf.keras.models.load_model("./lstm/model_2.h5")
|
36 |
+
model_3 = tf.keras.models.load_model("./lstm/model_3.h5")
|
37 |
+
model_4 = tf.keras.models.load_model("./lstm/model_4.h5")
|
38 |
+
|
39 |
+
with open('./lstm/word_dict.pkl', 'rb') as f:
|
40 |
+
word_dict = pickle.load(f)
|
41 |
+
|
42 |
+
os.system("python -m spacy download en_core_web_lg")
|
43 |
+
|
44 |
+
nlp = spacy.load('en_core_web_lg', disable=['parser', 'ner', 'tagger'])
|
45 |
+
nlp.vocab.add_flag(lambda s: s.lower() in spacy.lang.en.stop_words.STOP_WORDS, spacy.attrs.IS_STOP)
|
lstm/model_1.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e36dfda896de06192843447fdf71b4bc5a72f46a4fc788dfb080a767af6b974c
|
3 |
+
size 749650112
|
lstm/model_2.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1d0c6ce351d7ba21b6ae5392768abf2ca44bfe22261d5d0a54109dedb6ed6c3
|
3 |
+
size 749650112
|
lstm/model_3.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a96ee5741ecf16149d3ca66e82634a5c46b42e42d939321bd1468f856c00d90
|
3 |
+
size 749650016
|
lstm/model_4.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b579d0565f632ae2a1fb53e02e0d0f452d85db7a7238bcff445644cde92b9c4
|
3 |
+
size 749650016
|
lstm/word_dict.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1abb8d9104762746b16fa989592b247332fb563b1c8be89edc2829c4d2aec513
|
3 |
+
size 4555634
|