SauronLee commited on
Commit
586b853
·
1 Parent(s): 24882ff

Create new file

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse, torch, gc, os, random, json
3
+ from data import device
4
+ import numpy as np
5
+ from data import MyDataset, load_data, my_collate_fn, device
6
+ import re
7
+ def clean_str(string,use=True):
8
+ """
9
+ Tokenization/string cleaning for all datasets except for SST.
10
+ Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
11
+ """
12
+ if not use: return string
13
+
14
+ string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
15
+ string = re.sub(r"\'s", " \'s", string)
16
+ string = re.sub(r"\'ve", " \'ve", string)
17
+ string = re.sub(r"n\'t", " n\'t", string)
18
+ string = re.sub(r"\'re", " \'re", string)
19
+ string = re.sub(r"\'d", " \'d", string)
20
+ string = re.sub(r"\'ll", " \'ll", string)
21
+ string = re.sub(r",", " , ", string)
22
+ string = re.sub(r"!", " ! ", string)
23
+ string = re.sub(r"\(", " \( ", string)
24
+ string = re.sub(r"\)", " \) ", string)
25
+ string = re.sub(r"\?", " \? ", string)
26
+ string = re.sub(r"\s{2,}", " ", string)
27
+ return string.strip().lower()
28
+
29
+ title_list = np.load("./title_list.npy", allow_pickle=True).tolist()
30
+ data_path = os.path.join('..', 'data')
31
+ device = 'cpu'
32
+ vec_inuse = json.load(open('/Users/sauron/Desktop/Code/Finding-NLP-Papers/finding_papers_gradio/data/papers_embedding_inuse.json'))
33
+ vocab = list(vec_inuse)
34
+ vocab_size = len(vocab) + 2
35
+ word2index = dict()
36
+ index2word = list()
37
+ word2index['<PAD>'] = 0
38
+ word2index['<OOV>'] = 1
39
+ index2word.extend(['<PAD>', '<OOV>'])
40
+ index2word.extend(vocab)
41
+ word2vec = np.zeros((vocab_size, len(list(vec_inuse.values())[0])), dtype=np.float32)
42
+ for wd in vocab:
43
+ index = len(word2index)
44
+ word2index[wd] = index
45
+ word2vec[index, :] = vec_inuse[wd]
46
+
47
+ def data2index(data_x, word2index):
48
+ data_x_idx = list()
49
+ for instance in data_x:
50
+ def_word_idx = list()
51
+ def_words = clean_str(instance['question']).strip().split()
52
+
53
+ for def_word in def_words:
54
+ if def_word in word2index and def_word!=instance['answer']:
55
+ def_word_idx.append(word2index[def_word])
56
+ else:
57
+ def_word_idx.append(word2index['<OOV>'])
58
+ data_x_idx.append({'answer': word2index[instance['answer']], 'question_words': def_word_idx})
59
+
60
+ return data_x_idx
61
+
62
+
63
+ def greet(paper_str):
64
+ pred_list = []
65
+
66
+ model = torch.load("saved.model",map_location = torch.device('cpu'))
67
+ model.eval()
68
+
69
+
70
+ test_dataset = MyDataset(data2index(
71
+ [
72
+ {
73
+ 'answer':'p-7241',
74
+ 'question': paper_str
75
+ }
76
+ ], word2index))
77
+ test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=my_collate_fn)
78
+
79
+ for words_t, definition_words_t in test_dataloader:
80
+ indices = model('test', x=definition_words_t, w=words_t, mode="b")
81
+ predicted = indices[:, :10].detach().cpu().numpy().tolist()
82
+ predicted = [index2word[paper] for paper in predicted[0]]
83
+
84
+ del pred_list
85
+ gc.collect()
86
+
87
+ papers_output = []
88
+ for paper_i in predicted:
89
+ paper_i = int(paper_i.split('-')[1])
90
+ papers_output.append(title_list[paper_i])
91
+ return papers_output
92
+
93
+
94
+ with gr.Blocks() as demo:
95
+ name = gr.Textbox(label="Question")
96
+ output = gr.Textbox(label="Papers")
97
+ greet_btn = gr.Button("Submit")
98
+ greet_btn.click(fn=greet, inputs=name, outputs=output)
99
+
100
+ demo.launch()