import torch import torch.nn as nn import re import streamlit as st from transformers import DistilBertModel from tokenization_kobert import KoBertTokenizer class SanctiMoly(nn.Module): """ Holy Moly News BERT """ def __init__(self, bert_model, freeze_bert = True): super(SanctiMoly, self).__init__() self.encoder = bert_model # FC-BN-Tanh self.linear = nn.Sequential(nn.Linear(768, 1024), nn.BatchNorm1d(1024), nn.Tanh(), nn.Dropout(), nn.Linear(1024, 768), nn.BatchNorm1d(768), nn.Tanh(), nn.Dropout(), nn.Linear(768, 120) ) # self.softmax = nn.LogSoftmax(dim=-1) if freeze_bert == True: for param in self.encoder.parameters(): param.requires_grad = False else: for param in self.encoder.parameters(): param.requires_grad = True def forward(self, input_ids, input_length): # calculate attention mask attn_mask = torch.arange(input_ids.size(1)) attn_mask = attn_mask[None, :] < input_length[:, None] enc_o = self.encoder(input_ids, attn_mask) output = self.linear(enc_o.last_hidden_state[:, 0, :]) # print(output.shape) return output @st.cache(allow_output_mutation=True) def get_model(): bert_model = DistilBertModel.from_pretrained('alex6095/SanctiMolyOH_Cpu') tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert') model = SanctiMoly(bert_model, freeze_bert=False) device = torch.device('cpu') checkpoint = torch.load("./model.pt", map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model, tokenizer model, tokenizer = get_model() class RegexSubstitution(object): """Regex substitution class for transform""" def __init__(self, regex, sub=''): if isinstance(regex, re.Pattern): self.regex = regex else: self.regex = re.compile(regex) self.sub = sub def __call__(self, target): if isinstance(target, list): return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target] else: return self.regex.sub(self.sub, self.regex.sub(self.sub, target)) def i2ym(fl): return (str(fl // 12 + 2009), str(fl % 12 + 1)) default_text = '''헌법재판소가 박근혜 대통령의 파면을 만장일치로 결정했다. 현직 대통령 탄핵이 인용된 것은 헌정 사상 최초다. 박 전 대통령에 대한 파면이 결정되면서 헌법과 공직선거법에 따라 앞으로 60일 이내에 차기 대통령 선거가 치러진다. 이정미 헌재소장 권한대행(재판관)은 10일 오전 11시 23분 서울 종로구 헌법재판소 대심판정에서 “피청구인 대통령 박근혜를 파면한다”고 주문을 선고했다. 그 순간 대심판정 곳곳에서 무겁고 나직한 탄성이 터져 나왔다. 이날 대심판정에선 박근혜 전 대통령 측과 국회소추위원 측 관계자들과 취재진 80명, 온라인 접수를 통해 795대 1의 경쟁률을 뚫고 선정된 일반방청객 24명이 숨을 죽이고 있었다. ''' st.title("Date prediction") text = st.text_area("Input news :", value=default_text) st.markdown("## Original News Data") st.write(text) st.markdown("## Predict Top 3 Date") if text: with st.spinner('processing..'): text = RegexSubstitution(r'\([^()]+\)|[<>\'"△▲□■]')(text) encoded_dict = tokenizer( text=[text], add_special_tokens=True, max_length=512, truncation=True, return_tensors='pt', return_length=True ) input_ids = encoded_dict['input_ids'] input_ids_len = encoded_dict['length'] pred = model(input_ids, input_ids_len) _, indices = torch.topk(pred, 3) pred_print = [] for i in indices.squeeze(0): year, month = i2ym(i.item()) pred_print.append(year+"-"+month) st.write(", ".join(pred_print))