ThePixOne commited on
Commit
7d71416
·
1 Parent(s): 0571512

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import time
4
+ import hashlib
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering, pipeline
7
+ from tqdm import tqdm
8
+ import os
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+ import textract
11
+ from scipy.special import softmax
12
+ import pandas as pd
13
+ from datetime import datetime
14
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1")
15
+ model = AutoModel.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1").to(device).eval()
16
+ tokenizer_ans = AutoTokenizer.from_pretrained("deepset/roberta-large-squad2")
17
+ model_ans = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-large-squad2").to(device).eval()
18
+ if device == 'cuda:0':
19
+ pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans,device = 0)
20
+ else:
21
+ pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans)
22
+
23
+ def cls_pooling(model_output):
24
+ return model_output.last_hidden_state[:,0]
25
+
26
+ def encode_query(query):
27
+ encoded_input = tokenizer(query, truncation=True, return_tensors='pt').to(device)
28
+
29
+ with torch.no_grad():
30
+ model_output = model(**encoded_input, return_dict=True)
31
+
32
+ embeddings = cls_pooling(model_output)
33
+
34
+ return embeddings.cpu()
35
+
36
+
37
+ def encode_docs(docs,maxlen = 64, stride = 32):
38
+ encoded_input = []
39
+ embeddings = []
40
+ spans = []
41
+ file_names = []
42
+ name, text = docs
43
+
44
+ text = text.split(" ")
45
+ if len(text) < maxlen:
46
+ text = " ".join(text)
47
+
48
+ encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device))
49
+ spans.append(temp_text)
50
+ file_names.append(name)
51
+
52
+ else:
53
+ num_iters = int(len(text)/maxlen)+1
54
+ for i in range(num_iters):
55
+ if i == 0:
56
+ temp_text = " ".join(text[i*maxlen:(i+1)*maxlen+stride])
57
+ else:
58
+ temp_text = " ".join(text[(i-1)*maxlen:(i)*maxlen][-stride:] + text[i*maxlen:(i+1)*maxlen])
59
+
60
+ encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device))
61
+ spans.append(temp_text)
62
+ file_names.append(name)
63
+
64
+ with torch.no_grad():
65
+ for encoded in tqdm(encoded_input):
66
+ model_output = model(**encoded, return_dict=True)
67
+ embeddings.append(cls_pooling(model_output))
68
+
69
+ embeddings = np.float32(torch.stack(embeddings).transpose(0, 1).cpu())
70
+
71
+ np.save("encoded_gradio/emb_{}.npy".format(name),dict(zip(list(range(len(embeddings))),embeddings)))
72
+ np.save("encoded_gradio/spans_{}.npy".format(name),dict(zip(list(range(len(spans))),spans)))
73
+ np.save("encoded_gradio/file_{}.npy".format(name),dict(zip(list(range(len(file_names))),file_names)))
74
+
75
+ return embeddings, spans, file_names
76
+
77
+ def predict(query,data):
78
+ name_to_save = data.name.split("\\")[-1].split(".")[0][:-8]
79
+ st = str([query,name_to_save])
80
+ hist = st + " " + str(hashlib.sha256(st.encode()).hexdigest())
81
+ now = datetime.now()
82
+ current_time = now.strftime("%H:%M:%S")
83
+ try:
84
+ df = pd.read_csv("HISTORY/{}.csv".format(hash(st)))
85
+ return df
86
+ except Exception as e:
87
+ print(e)
88
+ print(st)
89
+
90
+ if name_to_save+".txt" in os.listdir("text_gradio"):
91
+ doc_emb = np.load('encoded_gradio/emb_{}.npy'.format(name_to_save),allow_pickle='TRUE').item()
92
+ doc_text = np.load('encoded_gradio/spans_{}.npy'.format(name_to_save),allow_pickle='TRUE').item()
93
+ file_names_dicto = np.load('encoded_gradio/file_{}.npy'.format(name_to_save),allow_pickle='TRUE').item()
94
+
95
+ doc_emb = np.array(list(doc_emb.values())).reshape(-1,768)
96
+ doc_text = list(doc_text.values())
97
+ file_names = list(file_names_dicto.values())
98
+
99
+ else:
100
+ text = textract.process("{}".format(data.name)).decode('utf8')
101
+ text = text.replace("\r", " ")
102
+ text = text.replace("\n", " ")
103
+ text = text.replace(" . "," ")
104
+
105
+ doc_emb, doc_text, file_names = encode_docs((name_to_save,text),maxlen = 64, stride = 32)
106
+
107
+ doc_emb = doc_emb.reshape(-1, 768)
108
+ with open("text_gradio/{}.txt".format(name_to_save),"w",encoding="utf-8") as f:
109
+ f.write(text)
110
+ start = time.time()
111
+ query_emb = encode_query(query)
112
+
113
+ scores = np.matmul(query_emb, doc_emb.transpose(1,0))[0].tolist()
114
+ doc_score_pairs = list(zip(doc_text, scores, file_names))
115
+ doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
116
+ k = 5
117
+ probs_sum = 0
118
+ probs = softmax(sorted(scores,reverse = True)[:k])
119
+ table = {"Passage":[],"Answer":[],"Probabilities":[],"Source":[]}
120
+
121
+ for i, (passage, _, names) in enumerate(doc_score_pairs[:k]):
122
+ passage = passage.replace("\n","")
123
+ passage = passage.replace(" . "," ")
124
+
125
+ if probs[i] > 0.1 or (i < 3 and probs[i] > 0.05): #generate answers for more likely passages but no less than 2
126
+ QA = {'question':query,'context':passage}
127
+ ans = pipe(QA)
128
+ probabilities = "P(a|p): {}, P(a|p,q): {}, P(p|q): {}".format(round(ans["score"],5),
129
+ round(ans["score"]*probs[i],5),
130
+ round(probs[i],5))
131
+ passage = passage.replace(str(ans["answer"]),str(ans["answer"]).upper())
132
+ table["Passage"].append(passage)
133
+ table["Passage"].append("---")
134
+ table["Answer"].append(str(ans["answer"]).upper())
135
+ table["Answer"].append("---")
136
+ table["Probabilities"].append(probabilities)
137
+ table["Probabilities"].append("---")
138
+ table["Source"].append(names)
139
+ table["Source"].append("---")
140
+ else:
141
+ table["Passage"].append(passage)
142
+ table["Passage"].append("---")
143
+ table["Answer"].append("no_answer_calculated")
144
+ table["Answer"].append("---")
145
+ table["Probabilities"].append("P(p|q): {}".format(round(probs[i],5)))
146
+ table["Probabilities"].append("---")
147
+ table["Source"].append(names)
148
+ table["Source"].append("---")
149
+ df = pd.DataFrame(table)
150
+ print("time: "+ str(time.time()-start))
151
+
152
+ with open("HISTORY.txt","a", encoding = "utf-8") as f:
153
+ f.write(hist)
154
+ f.write(" " + str(current_time))
155
+ f.write("\n")
156
+ f.close()
157
+ df.to_csv("HISTORY/{}.csv".format(hash(st)), index=False)
158
+
159
+ return df
160
+
161
+ iface = gr.Interface(
162
+
163
+ fn =predict,
164
+ inputs = [gr.inputs.Textbox(default="What is Open-domain question answering?"),
165
+ gr.inputs.Checkbox(default=True),
166
+ gr.inputs.File(),
167
+ ],
168
+ outputs = [
169
+ gr.outputs.Dataframe(),
170
+ ],
171
+
172
+ allow_flagging ="manual",flagging_options = ["correct","wrong"],
173
+ allow_screenshot=False)
174
+
175
+ iface.launch(share = True,enable_queue=True, show_error =True)