NithitEiEi commited on
Commit
d8f4336
·
verified ·
1 Parent(s): 19fca01

upload model and app

Browse files
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from bert.model import BERT_predict
3
+ from lstm.model import BiLSTM_predict
4
+ from deberta.model import deBERTa_predict
5
+
6
+ examples=[
7
+ "How do I train my dogs to kill raccoons?",
8
+ "Why Indian education totally bullshit?",
9
+ "What is it really like to be a nurse practitioner",
10
+ "What is your prefer race? white, black or asian."
11
+ ]
12
+
13
+ def alert(prob):
14
+ raise gr.Error(f"This may not suitable for posting insincere probability {prob:.2f}")
15
+
16
+ def clear_post():
17
+ return ""
18
+
19
+ def add_post(posts, new_post, mode):
20
+ if mode == "BiLSTM":
21
+ prob = BiLSTM_predict(new_post)
22
+ insincere = 1 if prob > 0.35 else 0
23
+ if insincere:
24
+ alert(prob)
25
+ return posts, ""
26
+ else:
27
+ return posts + [{"post": new_post, "model": "BiLSTM", "prob": prob}], ""
28
+ elif mode == "BERT":
29
+ insincere, prob = BERT_predict(new_post)
30
+ if insincere:
31
+ alert(prob)
32
+ return posts, ""
33
+ else:
34
+ return posts + [{"post": new_post, "model": "BERT", "prob": prob}], ""
35
+ elif mode == "DeBERTaV3":
36
+ insincere, prob = deBERTa_predict(new_post)
37
+ if insincere:
38
+ alert(prob)
39
+ return posts, ""
40
+ else:
41
+ return posts + [{"post": new_post, "model": "DeBERTaV3", "prob": prob}], ""
42
+
43
+
44
+ with gr.Blocks(theme=gr.themes.Soft(), title="Quara Question post") as demo:
45
+
46
+ posts = gr.State([])
47
+
48
+ new_post = gr.Textbox(label="Add post", autofocus=True)
49
+ mode = gr.Radio(["BiLSTM", "BERT", "DeBERTaV3"], value="BiLSTM", label="Model")
50
+
51
+ with gr.Row():
52
+ submit = gr.Button("submit", variant='primary')
53
+ clear = gr.Button("clear")
54
+
55
+ submit.click(add_post, inputs=[posts, new_post, mode], outputs=[posts, new_post])
56
+ clear.click(clear_post, inputs=None, outputs=new_post)
57
+ @gr.render(inputs=posts)
58
+ def render_posts(post_list):
59
+ output = [post for post in post_list]
60
+ gr.Markdown(f"### Question post ({len(output)})")
61
+ for index, post in enumerate(output):
62
+ with gr.Row():
63
+ gr.Textbox(
64
+ f"{post['post']} | {post['prob']:.8f}",
65
+ label=f"Post{index + 1} ({post['model']})",
66
+ show_label=True
67
+ )
68
+ delete_btn = gr.Button("Delete", scale=0, variant="stop")
69
+
70
+ def delete(post=post):
71
+ post_list.remove(post)
72
+ return post_list
73
+ delete_btn.click(delete, None, [posts])
74
+
75
+ with gr.Row():
76
+ examples = gr.Examples(
77
+ examples=examples,
78
+ inputs=[new_post],
79
+ )
80
+
81
+ demo.launch()
bert/bert_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d7423dcada747073f22d892f938150cfe737b5c4a46aabd334f20959d604db1
3
+ size 436484297
bert/model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ class BertClassifier(nn.Module):
7
+ def __init__(self, bert):
8
+ super(BertClassifier, self).__init__()
9
+ self.bert = bert
10
+
11
+ def forward(self, input_id, attention_mask):
12
+ output = self.bert(input_ids=input_id, attention_mask=attention_mask)
13
+ return output.logits
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
16
+
17
+ bert = AutoModelForSequenceClassification.from_pretrained('bert-base-cased').train()
18
+
19
+ classifier = nn.Sequential(
20
+ nn.Linear(768, 1024),
21
+ nn.ReLU(),
22
+ nn.Dropout(0.5),
23
+ nn.Linear(1024, 2)
24
+ )
25
+
26
+ bert.classifier = classifier
27
+
28
+ model = BertClassifier(bert)
29
+
30
+ model.load_state_dict(torch.load("./bert/bert_model.pth", map_location=torch.device('cpu'), weights_only=True))
31
+
32
+ model.eval()
33
+
34
+ def BERT_predict(text):
35
+ tokenized_input = tokenizer(text,
36
+ padding="max_length",
37
+ truncation=True,
38
+ max_length=30,
39
+ return_tensors="pt")
40
+
41
+ model.eval()
42
+ with torch.no_grad():
43
+ logits = model(tokenized_input['input_ids'], tokenized_input['attention_mask'])
44
+
45
+ probabilities = F.softmax(logits, dim=-1)
46
+
47
+ prediction = torch.argmax(probabilities, dim=-1).item()
48
+
49
+ return prediction, probabilities[0][1].item()
deberta/fastai_QIQC-deberta-v3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d2b64c078a7de8ffe57b1ff767e3bfac6bdb52bb3f5977b5e06f3ce9993b873
3
+ size 740942321
deberta/model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ class BertClassifier(nn.Module):
7
+ def __init__(self, bert):
8
+ super(BertClassifier, self).__init__()
9
+ self.bert = bert
10
+
11
+ def forward(self, input_id, attention_mask):
12
+ output = self.bert(input_ids=input_id, attention_mask=attention_mask)
13
+ return output.logits
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-base')
16
+
17
+ bert = AutoModelForSequenceClassification.from_pretrained('microsoft/deberta-v3-base').train()
18
+
19
+ classifier = nn.Sequential(
20
+ nn.Linear(768, 1024),
21
+ nn.ReLU(),
22
+ nn.Dropout(0.5),
23
+ nn.Linear(1024, 2)
24
+ )
25
+
26
+ bert.classifier = classifier
27
+
28
+ model = BertClassifier(bert)
29
+ state_dict = torch.load(
30
+ "./deberta/fastai_QIQC-deberta-v3.pth", map_location=torch.device('cpu'),
31
+ weights_only=True
32
+ )
33
+
34
+ model.load_state_dict(state_dict, strict=False)
35
+
36
+ model.eval()
37
+
38
+ def deBERTa_predict(text):
39
+ tokenized_input = tokenizer(text,
40
+ padding="max_length",
41
+ truncation=True,
42
+ max_length=30,
43
+ return_tensors="pt")
44
+
45
+ model.eval()
46
+ with torch.no_grad():
47
+ logits = model(tokenized_input['input_ids'], tokenized_input['attention_mask'])
48
+
49
+ probabilities = F.softmax(logits, dim=-1)
50
+
51
+ prediction = torch.argmax(probabilities, dim=-1).item()
52
+
53
+ return prediction, probabilities[0][1].item()
lstm/model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spacy
3
+ import pickle
4
+ import numpy as np
5
+ import tensorflow as tf
6
+
7
+ BATCH_SIZE = 512
8
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
9
+
10
+ def preprocess_text(text):
11
+ """Preprocess the input text using SpaCy and return word indices."""
12
+ docs = nlp.pipe([text], n_process=1)
13
+ word_seq = []
14
+ for doc in docs:
15
+ for token in doc:
16
+ if token.pos_ != "PUNCT":
17
+ if token.text not in word_dict:
18
+ word_dict[token.text] = 0 # OOV_INDEX
19
+ word_seq.append(word_dict[token.text])
20
+ return word_seq
21
+
22
+ def BiLSTM_predict(text):
23
+ seq = preprocess_text(text)
24
+ padded_seq = tf.keras.preprocessing.sequence.pad_sequences([seq], maxlen=55)
25
+ pred1 = 0.15 * np.squeeze(model_1.predict(padded_seq, batch_size=BATCH_SIZE, verbose=2))
26
+ pred2 = 0.35 * np.squeeze(model_2.predict(padded_seq, batch_size=BATCH_SIZE, verbose=2))
27
+ pred3 = 0.15 * np.squeeze(model_3.predict(padded_seq, batch_size=BATCH_SIZE, verbose=2))
28
+ pred4 = 0.35 * np.squeeze(model_4.predict(padded_seq, batch_size=BATCH_SIZE, verbose=2))
29
+ pred = pred1 + pred2 + pred3 + pred4
30
+
31
+ return pred
32
+
33
+
34
+ model_1 = tf.keras.models.load_model("./lstm/model_1.h5")
35
+ model_2 = tf.keras.models.load_model("./lstm/model_2.h5")
36
+ model_3 = tf.keras.models.load_model("./lstm/model_3.h5")
37
+ model_4 = tf.keras.models.load_model("./lstm/model_4.h5")
38
+
39
+ with open('./lstm/word_dict.pkl', 'rb') as f:
40
+ word_dict = pickle.load(f)
41
+
42
+ os.system("python -m spacy download en_core_web_lg")
43
+
44
+ nlp = spacy.load('en_core_web_lg', disable=['parser', 'ner', 'tagger'])
45
+ nlp.vocab.add_flag(lambda s: s.lower() in spacy.lang.en.stop_words.STOP_WORDS, spacy.attrs.IS_STOP)
lstm/model_1.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e36dfda896de06192843447fdf71b4bc5a72f46a4fc788dfb080a767af6b974c
3
+ size 749650112
lstm/model_2.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1d0c6ce351d7ba21b6ae5392768abf2ca44bfe22261d5d0a54109dedb6ed6c3
3
+ size 749650112
lstm/model_3.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a96ee5741ecf16149d3ca66e82634a5c46b42e42d939321bd1468f856c00d90
3
+ size 749650016
lstm/model_4.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b579d0565f632ae2a1fb53e02e0d0f452d85db7a7238bcff445644cde92b9c4
3
+ size 749650016
lstm/word_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1abb8d9104762746b16fa989592b247332fb563b1c8be89edc2829c4d2aec513
3
+ size 4555634