Spaces:
Runtime error
Runtime error
Create new file
Browse files
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import argparse, torch, gc, os, random, json
|
3 |
+
from data import device
|
4 |
+
import numpy as np
|
5 |
+
from data import MyDataset, load_data, my_collate_fn, device
|
6 |
+
import re
|
7 |
+
def clean_str(string,use=True):
|
8 |
+
"""
|
9 |
+
Tokenization/string cleaning for all datasets except for SST.
|
10 |
+
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
|
11 |
+
"""
|
12 |
+
if not use: return string
|
13 |
+
|
14 |
+
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
|
15 |
+
string = re.sub(r"\'s", " \'s", string)
|
16 |
+
string = re.sub(r"\'ve", " \'ve", string)
|
17 |
+
string = re.sub(r"n\'t", " n\'t", string)
|
18 |
+
string = re.sub(r"\'re", " \'re", string)
|
19 |
+
string = re.sub(r"\'d", " \'d", string)
|
20 |
+
string = re.sub(r"\'ll", " \'ll", string)
|
21 |
+
string = re.sub(r",", " , ", string)
|
22 |
+
string = re.sub(r"!", " ! ", string)
|
23 |
+
string = re.sub(r"\(", " \( ", string)
|
24 |
+
string = re.sub(r"\)", " \) ", string)
|
25 |
+
string = re.sub(r"\?", " \? ", string)
|
26 |
+
string = re.sub(r"\s{2,}", " ", string)
|
27 |
+
return string.strip().lower()
|
28 |
+
|
29 |
+
title_list = np.load("./title_list.npy", allow_pickle=True).tolist()
|
30 |
+
data_path = os.path.join('..', 'data')
|
31 |
+
device = 'cpu'
|
32 |
+
vec_inuse = json.load(open('/Users/sauron/Desktop/Code/Finding-NLP-Papers/finding_papers_gradio/data/papers_embedding_inuse.json'))
|
33 |
+
vocab = list(vec_inuse)
|
34 |
+
vocab_size = len(vocab) + 2
|
35 |
+
word2index = dict()
|
36 |
+
index2word = list()
|
37 |
+
word2index['<PAD>'] = 0
|
38 |
+
word2index['<OOV>'] = 1
|
39 |
+
index2word.extend(['<PAD>', '<OOV>'])
|
40 |
+
index2word.extend(vocab)
|
41 |
+
word2vec = np.zeros((vocab_size, len(list(vec_inuse.values())[0])), dtype=np.float32)
|
42 |
+
for wd in vocab:
|
43 |
+
index = len(word2index)
|
44 |
+
word2index[wd] = index
|
45 |
+
word2vec[index, :] = vec_inuse[wd]
|
46 |
+
|
47 |
+
def data2index(data_x, word2index):
|
48 |
+
data_x_idx = list()
|
49 |
+
for instance in data_x:
|
50 |
+
def_word_idx = list()
|
51 |
+
def_words = clean_str(instance['question']).strip().split()
|
52 |
+
|
53 |
+
for def_word in def_words:
|
54 |
+
if def_word in word2index and def_word!=instance['answer']:
|
55 |
+
def_word_idx.append(word2index[def_word])
|
56 |
+
else:
|
57 |
+
def_word_idx.append(word2index['<OOV>'])
|
58 |
+
data_x_idx.append({'answer': word2index[instance['answer']], 'question_words': def_word_idx})
|
59 |
+
|
60 |
+
return data_x_idx
|
61 |
+
|
62 |
+
|
63 |
+
def greet(paper_str):
|
64 |
+
pred_list = []
|
65 |
+
|
66 |
+
model = torch.load("saved.model",map_location = torch.device('cpu'))
|
67 |
+
model.eval()
|
68 |
+
|
69 |
+
|
70 |
+
test_dataset = MyDataset(data2index(
|
71 |
+
[
|
72 |
+
{
|
73 |
+
'answer':'p-7241',
|
74 |
+
'question': paper_str
|
75 |
+
}
|
76 |
+
], word2index))
|
77 |
+
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=my_collate_fn)
|
78 |
+
|
79 |
+
for words_t, definition_words_t in test_dataloader:
|
80 |
+
indices = model('test', x=definition_words_t, w=words_t, mode="b")
|
81 |
+
predicted = indices[:, :10].detach().cpu().numpy().tolist()
|
82 |
+
predicted = [index2word[paper] for paper in predicted[0]]
|
83 |
+
|
84 |
+
del pred_list
|
85 |
+
gc.collect()
|
86 |
+
|
87 |
+
papers_output = []
|
88 |
+
for paper_i in predicted:
|
89 |
+
paper_i = int(paper_i.split('-')[1])
|
90 |
+
papers_output.append(title_list[paper_i])
|
91 |
+
return papers_output
|
92 |
+
|
93 |
+
|
94 |
+
with gr.Blocks() as demo:
|
95 |
+
name = gr.Textbox(label="Question")
|
96 |
+
output = gr.Textbox(label="Papers")
|
97 |
+
greet_btn = gr.Button("Submit")
|
98 |
+
greet_btn.click(fn=greet, inputs=name, outputs=output)
|
99 |
+
|
100 |
+
demo.launch()
|