Spaces:
Runtime error
Runtime error
import transformers | |
import torch | |
import streamlit as st | |
from transformers import BertTokenizer | |
st.markdown("### Из какой области статья") | |
link = 'https://www.clipartmax.com/png/middle/87-873210_akinator-with-transparent-background.png' | |
st.markdown(f"<img width=200px src='{link}'>", unsafe_allow_html=True) | |
# st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True) | |
# from transformers import | |
# pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl") | |
num_classes = 8 | |
class BERTClass(torch.nn.Module): | |
def __init__(self, n_hid1 = 1024, n_out=num_classes, bert_path='bert-base-uncased'): | |
super(BERTClass, self).__init__() | |
self.l1 = transformers.BertModel.from_pretrained(bert_path) | |
self.l2 = torch.nn.Dropout(0.3) | |
self.l3 = torch.nn.Linear(768, n_hid1) | |
self.l4 = torch.nn.ReLU() | |
self.l5 = torch.nn.Dropout(0.2) | |
self.l6 = torch.nn.Linear(n_hid1, n_out) | |
def forward(self, ids, mask, token_type_ids): | |
# _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids) | |
out = self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids) | |
out = self.l2(out[1]) | |
out = self.l3(out) | |
out = self.l4(out) | |
out = self.l5(out) | |
out = self.l6(out) | |
return out | |
def load_bert(): | |
model = BERTClass(bert_path='bert_pretrained') | |
model.load_state_dict(torch.load('bert_pretrained.pt')) | |
model.eval() | |
tokenizer = BertTokenizer.from_pretrained('bert_tokenizer') | |
return model, tokenizer | |
def apply_bert(text, model, tokenizer): | |
"""returns probabilities""" | |
MAX_LEN = 200 | |
ins = tokenizer.encode_plus(text, None, add_special_tokens=True, | |
max_length=MAX_LEN, | |
pad_to_max_length=True, | |
return_token_type_ids=True | |
) | |
ids = torch.tensor(ins['input_ids']).unsqueeze(0) | |
mask = torch.tensor(ins['attention_mask']).unsqueeze(0) | |
token_type_ids = torch.tensor(ins["token_type_ids"]) | |
out = model(ids, mask, token_type_ids) | |
return torch.sigmoid(out).flatten().detach() | |
class TinyBERTClass(torch.nn.Module): | |
def __init__(self, n_hid1 = 1024, n_out=num_classes, path='distilbert-base-uncased'): | |
super(TinyBERTClass, self).__init__() | |
self.l1 = transformers.DistilBertModel.from_pretrained(path) | |
self.l2 = torch.nn.Dropout(0.3) | |
self.l3 = torch.nn.Linear(768, n_hid1) | |
self.l4 = torch.nn.ReLU() | |
self.l5 = torch.nn.Dropout(0.2) | |
self.l6 = torch.nn.Linear(n_hid1, n_out) | |
def forward(self, ids, mask): | |
# _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids) | |
out = self.l1(ids, attention_mask = mask) | |
out = self.l2(out.last_hidden_state[:,0,:]) | |
out = self.l3(out) | |
out = self.l4(out) | |
out = self.l5(out) | |
out = self.l6(out) | |
return out | |
def load_tiny_bert(): | |
model = TinyBERTClass(path = 'tiny_bert_pretrained') | |
model.load_state_dict(torch.load('tiny_bert.pt')) | |
model.eval() | |
tokenizer = transformers.DistilBertTokenizer.from_pretrained('tiny_bert_tokenizer') | |
return model, tokenizer | |
def apply_tiny_bert(text, model, tokenizer): | |
encoded_input = tokenizer(text, return_tensors='pt') | |
out = model(encoded_input['input_ids'], encoded_input['attention_mask']) | |
return torch.sigmoid(out).flatten().detach() | |
title = st.text_area("Название статьи") | |
if not title.endswith('.') and title: | |
title += '.' | |
summary = st.text_area("Аннотация статьи") | |
calc_button = st.button('Угадать тематику') | |
bert_model, bert_tokenizer = load_bert() | |
tiny_bert, tiny_bert_tokenizer = load_tiny_bert() | |
# calculate ================================================================ | |
if calc_button: | |
print('title') | |
print(title) | |
print('=' * 80) | |
# print(text) | |
if summary: | |
text = title + summary | |
out = apply_bert(text, bert_model, bert_tokenizer) | |
else: | |
out = apply_tiny_bert(title, tiny_bert, tiny_bert_tokenizer) | |
RU_NAMES = ['компьютерным наукам' | |
,'экономике' | |
,'электротехнике и системотехнике' | |
,'математике' | |
,'физике' | |
,'количественной биологии' | |
,'количественным финансам' | |
,'статистике' | |
] | |
def get_classes(out, bandwidth = 0.5): | |
res = [] | |
for i in range(out.size()[0]): | |
if out[i] >= bandwidth: | |
res.append(i) | |
ans = '' | |
total = 0 | |
for i in res: | |
total += out[i].item() | |
if not ans: | |
ans += f'\nэто статья по {RU_NAMES[i]} с вероятностью {out[i].item():.2f}' | |
else: | |
ans += f',\nтакже она по {RU_NAMES[i]} с вероятностью {out[i].item():.2f}' | |
ans = 'Э' + ans[2:] | |
if total >= 1.0: | |
ans += '.\n(Решалась задача мультиклассификации, поэтому сумма вероятностей получилась больше 1.)' | |
if ans == 'Э': | |
return 'Не похоже на что-то научное, Вы уверены что это взято из статьи?' | |
return ans | |
res = get_classes(out) | |
st.markdown(f"{res}") | |