File size: 3,990 Bytes
1a0448c
 
 
 
 
 
 
 
 
96f8331
1a0448c
 
 
7d1dc00
 
1a0448c
7f9da5e
 
1a0448c
3ddfa0c
b65b827
 
 
 
 
 
 
 
 
24a231f
96f8331
 
1a0448c
96f8331
 
 
24a231f
1409048
3ddfa0c
1409048
96f8331
3ddfa0c
 
 
 
 
1409048
96f8331
1409048
 
96f8331
3ddfa0c
1409048
3ddfa0c
1409048
 
 
96f8331
1409048
96f8331
1409048
96f8331
1409048
96f8331
1409048
96f8331
3ddfa0c
1409048
 
 
 
 
96f8331
1409048
3ddfa0c
1409048
3ddfa0c
1409048
 
96f8331
1409048
3ddfa0c
96f8331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
975e866
96f8331
 
1a0448c
 
96f8331
1a0448c
 
7d1dc00
1a0448c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import torch
import streamlit as st
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from sklearn.preprocessing import LabelEncoder
from keras.utils import pad_sequences
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

st.markdown("### Paper category classification")
st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter

title = st.text_area("INPUT TITLE HERE")
abstract = st.text_area("INPUT ABSTRACT HERE")
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
if len(title) == 0 and len(abstract):
    st.markdown(f"Could you input paper title/abstrac :)")

@st.cache
def load_model_and_tokenizer():    
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
        num_labels = 44,)
    model.load_state_dict(torch.load("model_last_version.pt", map_location=torch.device('cpu')))
    return model, tokenizer

model, tokenizer = load_model_and_tokenizer()
MAX_LEN = 64
# Преобразуем название статьи в токены
tokens = tokenizer(title, padding=True, truncation=True, return_tensors="pt")

# Получаем предсказание модели для названия статьи и абстракта (если есть)
input_ids = tokens['input_ids']
attention_mask = tokens['attention_mask']
logits = model(input_ids, attention_mask)[0]

tags_names = ['Accelerator Physics',
 'adap-org',
 "adap-org",
 'Algebra-Geometry',
 'Astro-physics',
 "Astro-physics",
 'Chao-dynamics',
 'Chemistry-physics',
 'cmp-lg',
 "cmp-lg",
 'comp-gas',
 'cond-mat',
 "cond-mat",
 'Computer Science',
 'dg-ga',
 'Economics',
 'eess',
 'funct-an',
 'gr-qc',
 "gr-qc",
 'hep-ex',
 "hep-ex",
 'hep-lat',
 "hep-lat",
 'hep-ph',
 "hep-ph",
 'hep-th',
 "hep-th",
 'Math',
 'math-ph',
 'mtrl-th',
 'nlin',
 'nucl-ex',
 'nucl-th',
 "nucl-th",
 'patt-sol',
 'Physics',
 'q-alg',
 'Quantitie-biology',
 'q-fin',
 'quant-ph',
 "quant-ph",
 'solv-int',
 'Statistics']

if abstract:
    abstract_tokens = tokenizer(abstract, padding=True, truncation=True, return_tensors="pt")
    abstract_input_ids = abstract_tokens['input_ids']
    abstract_attention_mask = abstract_tokens['attention_mask']
    abstract_logits = model(abstract_input_ids, abstract_attention_mask)[0]
    logits += abstract_logits

# Получаем вероятности и сортируем их в порядке убывания
probs = torch.softmax(logits, dim=-1).squeeze()
sorted_probs, sorted_indices = torch.sort(probs, descending=True)

# Считаем сумму вероятностей
sum_probs = 0.0
top_classes = []
for i in range(len(sorted_probs)):
    sum_probs += sorted_probs[i]
    if sum_probs > 0.95 or sorted_probs[i] < 0.001:
        break
    top_classes.append((tags_names[sorted_indices[i].item()], sorted_probs[i].item()))

# Выводим список тем с их вероятностями
# from transformers import pipeline
# pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
raw_predictions = top_classes#le.inverse_transform(prediction)#pipe(text)
# тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost

st.markdown(f"Possible categories with their probabilities for this paper : {raw_predictions}")
# выводим результаты модели в текстовое поле, на потеху пользователю