File size: 2,019 Bytes
78f91a7 0d3225e 78f91a7 049c790 51f97f8 3de3a79 5ca6967 51f97f8 3de3a79 02af193 3de3a79 fc3e0b1 3de3a79 795b9e0 3de3a79 fc3e0b1 3de3a79 fc3e0b1 3de3a79 02af193 78f91a7 8cb1d65 78f91a7 3de3a79 2460eee 30c461d 78f91a7 7d57f72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import streamlit as st
import torch
@st.cache(allow_output_mutation=True)
def Model():
from transformers import DebertaTokenizer, DebertaForSequenceClassification
tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
bn_state_dict = torch.load('model_weights.pt', map_location=torch.device('cpu'))
model.load_state_dict(bn_state_dict)
return model, tokenizer
def Predict(model, tokenizer, text):
res = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
res = model(**res)
logits = res.logits.softmax(dim=1)
logits = logits.detach().numpy()[0]#.cpu().detach().numpy()[0]
return logits
def Print(logits, dictionary):
z = zip(logits, np.arange(0, 8))
z = sorted(z, key=lambda x: x[0], reverse=True)
summ, idx = 0, 0
while summ < 0.95:
st.markdown(f"{idx + 1}. ", dictionary[z[idx][1]])
summ += z[idx][0]
idx += 1
def filter(title, abstract):
return True
st.title('Классификация статьи по названию и описанию')
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
title = st.text_area("Введите название статьи:")
abstract = st.text_area("Введите описание статьи:")
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
text = title + '. ' + abstract
dictionary = ['computer science', 'economics', 'Electrical Engineering and Systems Science',
'math', 'physics', 'quantitative biology', 'quantitative finance',
'statistics']
if filter(title, abstract):
model, tokenizer = Model()
logits = Predict(model, tokenizer, text)
Print(logits, dictionary)
st.markdown(f"{model}")
|