dsorokin commited on
Commit
2fc0e56
·
1 Parent(s): 75c83d9
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +45 -6
  3. imgs/akinator_ready.png +0 -0
  4. requirements.txt +3 -2
.gitattributes CHANGED
@@ -31,3 +31,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
34
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,13 +1,52 @@
1
  import streamlit as st
 
 
 
 
 
2
 
3
- st.markdown("### Hello, world!")
4
- st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
5
 
6
- from transformers import pipeline
 
7
 
8
- pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
 
 
9
 
10
- text = st.text_area("TEXT HERE")
 
11
 
12
- st.markdown(f"{pipe(text)}")
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import json
6
+ import streamlit.components.v1 as components
7
 
 
 
8
 
9
+ if __name__ == '__main__':
10
+ st.markdown("### Arxiv paper classifier (No guarantees provided)")
11
 
12
+ col1, col2 = st.columns([1, 1])
13
+ col1.image('imgs/akinator_ready.png', width=200)
14
+ btn = col2.button('Classify!')
15
 
16
+ model = AutoModelForSequenceClassification.from_pretrained('checkpoint-3000')
17
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
18
 
19
+ with open('checkpoint-3000/config.json', 'r') as f:
20
+ id2label = json.load(f)['id2label']
21
 
22
+ id2label = {int(key): value for key, value in id2label.items()}
23
+ title = st.text_area(label='Input title...', placeholder='Input title...', label_visibility='hidden', height=3)
24
+ abstract = st.text_area(label='Input title...', placeholder='Input abstract...', label_visibility='hidden', height=10)
25
+ text = '\n'.join([title, abstract])
26
+
27
+ if btn and len(text) == 1:
28
+ st.error('Title and abstract are empty!')
29
+
30
+ if btn and len(text) > 1:
31
+ tokenized = tokenizer(text)
32
+
33
+ with torch.no_grad():
34
+ out = model(torch.tensor(tokenized['input_ids']).unsqueeze(dim=0))
35
+ _, ids = torch.sort(-out['logits'])
36
+ probs = F.softmax(out['logits'][0, ids], dim=1)
37
+ ids, probs = ids[0], probs[0]
38
+
39
+ ptotal = 0
40
+ result = []
41
+ for i, prob in enumerate(probs):
42
+ ptotal += prob
43
+ result.append(f'{id2label[ids[i].item()]} (prob = {prob.item()})')
44
+ output = '<br>'.join(result)
45
+
46
+ components.html(f'<div>'
47
+ f'<div style="height:120px;width:680px;'
48
+ f'border:1px solid #ccc;border-color: red;'
49
+ f'font:16px/26px Georgia, Garamond, Serif;'
50
+ f'overflow:scroll;'
51
+ f'color:white;">'
52
+ f'{output}</div>')
imgs/akinator_ready.png ADDED
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- transformers
2
- torch
 
3
 
 
1
+ transformers==4.15.0
2
+ torch==1.12.1
3
+
4