Spaces:
Runtime error
Runtime error
File size: 4,426 Bytes
355910d 3f5668b 355910d 1e9c888 355910d 3f5668b 620a618 3f5668b 1e9c888 3f5668b 21e3a78 3f5668b 355910d ee8c3f2 355910d 7de8e9d 21e3a78 355910d 38eb205 355910d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import torch
import torch.nn as nn
import re
import streamlit as st
from transformers import DistilBertModel
from tokenization_kobert import KoBertTokenizer
class SanctiMoly(nn.Module):
""" Holy Moly News BERT """
def __init__(self, bert_model, freeze_bert = True):
super(SanctiMoly, self).__init__()
self.encoder = bert_model
# FC-BN-Tanh
self.linear = nn.Sequential(nn.Linear(768, 1024),
nn.BatchNorm1d(1024),
nn.Tanh(),
nn.Dropout(),
nn.Linear(1024, 768),
nn.BatchNorm1d(768),
nn.Tanh(),
nn.Dropout(),
nn.Linear(768, 120)
)
# self.softmax = nn.LogSoftmax(dim=-1)
if freeze_bert == True:
for param in self.encoder.parameters():
param.requires_grad = False
else:
for param in self.encoder.parameters():
param.requires_grad = True
def forward(self, input_ids, input_length):
# calculate attention mask
attn_mask = torch.arange(input_ids.size(1))
attn_mask = attn_mask[None, :] < input_length[:, None]
enc_o = self.encoder(input_ids, attn_mask)
output = self.linear(enc_o.last_hidden_state[:, 0, :])
# print(output.shape)
return output
@st.cache(allow_output_mutation=True)
def get_model():
bert_model = DistilBertModel.from_pretrained('alex6095/SanctiMolyOH_Cpu')
tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
model = SanctiMoly(bert_model, freeze_bert=False)
device = torch.device('cpu')
checkpoint = torch.load("./model.pt", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model, tokenizer
model, tokenizer = get_model()
class RegexSubstitution(object):
"""Regex substitution class for transform"""
def __init__(self, regex, sub=''):
if isinstance(regex, re.Pattern):
self.regex = regex
else:
self.regex = re.compile(regex)
self.sub = sub
def __call__(self, target):
if isinstance(target, list):
return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
else:
return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
def i2ym(fl):
return (str(fl // 12 + 2009), str(fl % 12 + 1))
default_text = '''ํ๋ฒ์ฌํ์๊ฐ ๋ฐ๊ทผํ ๋ํต๋ น์ ํ๋ฉด์ ๋ง์ฅ์ผ์น๋ก ๊ฒฐ์ ํ๋ค. ํ์ง ๋ํต๋ น ํํต์ด ์ธ์ฉ๋ ๊ฒ์ ํ์ ์ฌ์ ์ต์ด๋ค. ๋ฐ ์ ๋ํต๋ น์ ๋ํ ํ๋ฉด์ด ๊ฒฐ์ ๋๋ฉด์ ํ๋ฒ๊ณผ ๊ณต์ง์ ๊ฑฐ๋ฒ์ ๋ฐ๋ผ ์์ผ๋ก 60์ผ ์ด๋ด์ ์ฐจ๊ธฐ ๋ํต๋ น ์ ๊ฑฐ๊ฐ ์น๋ฌ์ง๋ค.
์ด์ ๋ฏธ ํ์ฌ์์ฅ ๊ถํ๋ํ(์ฌํ๊ด)์ 10์ผ ์ค์ 11์ 23๋ถ ์์ธ ์ข
๋ก๊ตฌ ํ๋ฒ์ฌํ์ ๋์ฌํ์ ์์ โํผ์ฒญ๊ตฌ์ธ ๋ํต๋ น ๋ฐ๊ทผํ๋ฅผ ํ๋ฉดํ๋คโ๊ณ ์ฃผ๋ฌธ์ ์ ๊ณ ํ๋ค. ๊ทธ ์๊ฐ ๋์ฌํ์ ๊ณณ๊ณณ์์ ๋ฌด๊ฒ๊ณ ๋์งํ ํ์ฑ์ด ํฐ์ ธ ๋์๋ค. ์ด๋ ๋์ฌํ์ ์์ ๋ฐ๊ทผํ ์ ๋ํต๋ น ์ธก๊ณผ ๊ตญํ์์ถ์์ ์ธก ๊ด๊ณ์๋ค๊ณผ ์ทจ์ฌ์ง 80๋ช
, ์จ๋ผ์ธ ์ ์๋ฅผ ํตํด 795๋ 1์ ๊ฒฝ์๋ฅ ์ ๋ซ๊ณ ์ ์ ๋ ์ผ๋ฐ๋ฐฉ์ฒญ๊ฐ 24๋ช
์ด ์จ์ ์ฃฝ์ด๊ณ ์์๋ค.
'''
st.title("Date prediction")
text = st.text_area("Input news :", value=default_text)
st.markdown("## Original News Data")
st.write(text)
st.markdown("## Predict Top 3 Date")
if text:
with st.spinner('processing..'):
text = RegexSubstitution(r'\([^()]+\)|[<>\'"โณโฒโกโ ]')(text)
encoded_dict = tokenizer(
text=[text],
add_special_tokens=True,
max_length=512,
truncation=True,
return_tensors='pt',
return_length=True
)
input_ids = encoded_dict['input_ids']
input_ids_len = encoded_dict['length']
pred = model(input_ids, input_ids_len)
_, indices = torch.topk(pred, 3)
pred_print = []
for i in indices.squeeze(0):
year, month = i2ym(i.item())
pred_print.append(year+"-"+month)
st.write(", ".join(pred_print)) |