RayCappola commited on
Commit
28f8152
1 Parent(s): f3f7dff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -2
app.py CHANGED
@@ -1,4 +1,60 @@
1
  import streamlit as st
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ import torch.nn as nn
5
 
6
+ def get_hidden_states(encoded, model):
7
+ """Push input IDs through model. Stack and sum `layers` (last four by default).
8
+ Select only those subword token outputs that belong to our word of interest
9
+ and average them."""
10
+ with torch.no_grad():
11
+ output = model(decoder_input_ids=encoded['input_ids'], output_hidden_states=True, **encoded)
12
+
13
+ layers = [-4, -3, -2, -1]
14
+ states = output['decoder_hidden_states']
15
+ output = torch.stack([states[i] for i in layers]).sum(0).squeeze()
16
+
17
+ return output.mean(dim=0)
18
+
19
+ def get_word_vector(sent, tokenizer, model):
20
+ encoded = tokenizer.encode_plus(sent, return_tensors="pt")
21
+
22
+ return get_hidden_states(encoded, model)
23
+
24
+ model=Net()
25
+ model.load_state_dict(torch.load('dummy_model.txt', map_location=torch.device('cpu')))
26
+ model.eval()
27
+
28
+
29
+ labels_articles = {1: 'Computer Science',2: 'Economics',3: "Electrical Engineering And Systems Science",
30
+
31
+ 4: "Mathematics",5: "Physics",6: "Quantitative Biology",7: "Quantitative Finance", 8: "Statistics"}
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen")
34
+
35
+ model_emb = AutoModelForSeq2SeqLM.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen")
36
+
37
+ title = st.text_area("Write title of your article")
38
+ summary = st.text_area("Write summary of your article or dont write anything (but you should press Ctrl + Enter)")
39
+
40
+ text = title + '. ' + summary
41
+
42
+ embed = get_word_vector(text, tokenizer, model_emb)
43
+
44
+ logits = torch.nn.functional.softmax(model(embed), dim=0)
45
+
46
+ best_tags = torch.argsort(logits, descending=True)
47
+
48
+ sum = 0
49
+
50
+ res = ''
51
+
52
+ for tag in best_tags:
53
+ if sum > 0.95:
54
+ break
55
+ sum += logits[tag.item()]
56
+ # print(tag.item())
57
+ new_tag = labels_articles[tag.item() + 1]
58
+ res += new_tag + '\n'
59
+
60
+ st.write('best tags = \n', res)