File size: 4,426 Bytes
355910d
 
 
 
 
 
 
 
 
 
 
3f5668b
355910d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e9c888
355910d
 
 
 
 
 
 
 
3f5668b
 
620a618
3f5668b
 
 
1e9c888
3f5668b
 
21e3a78
3f5668b
 
 
 
 
 
 
355910d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee8c3f2
 
 
355910d
 
 
 
 
 
 
7de8e9d
21e3a78
 
355910d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38eb205
355910d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
import re
import streamlit as st

from transformers import DistilBertModel
from tokenization_kobert import KoBertTokenizer

class SanctiMoly(nn.Module):
    """ Holy Moly News BERT """

    def __init__(self, bert_model, freeze_bert = True):
        super(SanctiMoly, self).__init__()
        self.encoder = bert_model
        # FC-BN-Tanh
        self.linear = nn.Sequential(nn.Linear(768, 1024),
                                    nn.BatchNorm1d(1024),
                                    nn.Tanh(),
                                    nn.Dropout(),
                                    nn.Linear(1024, 768),
                                    nn.BatchNorm1d(768),
                                    nn.Tanh(),
                                    nn.Dropout(),
                                    nn.Linear(768, 120)
                                    )
        # self.softmax = nn.LogSoftmax(dim=-1)

        if freeze_bert == True:
            for param in self.encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.encoder.parameters():
                param.requires_grad = True

            
    def forward(self, input_ids, input_length):
        # calculate attention mask
        attn_mask = torch.arange(input_ids.size(1))
        attn_mask = attn_mask[None, :] < input_length[:, None]

        enc_o = self.encoder(input_ids, attn_mask)
        
        output = self.linear(enc_o.last_hidden_state[:, 0, :])
        # print(output.shape)
        return output

@st.cache(allow_output_mutation=True)
def get_model():
    bert_model = DistilBertModel.from_pretrained('alex6095/SanctiMolyOH_Cpu')
    tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
    
    model = SanctiMoly(bert_model, freeze_bert=False)
    device = torch.device('cpu')
    checkpoint = torch.load("./model.pt", map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model, tokenizer
    
model, tokenizer = get_model()




class RegexSubstitution(object):
    """Regex substitution class for transform"""
    def __init__(self, regex, sub=''):
        if isinstance(regex, re.Pattern):
            self.regex = regex
        else:
            self.regex = re.compile(regex)
        self.sub = sub
    def __call__(self, target):
        if isinstance(target, list):
            return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
        else:
            return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
def i2ym(fl):
    return (str(fl // 12 + 2009), str(fl % 12 + 1))

default_text = '''ํ—Œ๋ฒ•์žฌํŒ์†Œ๊ฐ€ ๋ฐ•๊ทผํ˜œ ๋Œ€ํ†ต๋ น์˜ ํŒŒ๋ฉด์„ ๋งŒ์žฅ์ผ์น˜๋กœ ๊ฒฐ์ •ํ–ˆ๋‹ค. ํ˜„์ง ๋Œ€ํ†ต๋ น ํƒ„ํ•ต์ด ์ธ์šฉ๋œ ๊ฒƒ์€ ํ—Œ์ • ์‚ฌ์ƒ ์ตœ์ดˆ๋‹ค. ๋ฐ• ์ „ ๋Œ€ํ†ต๋ น์— ๋Œ€ํ•œ ํŒŒ๋ฉด์ด ๊ฒฐ์ •๋˜๋ฉด์„œ ํ—Œ๋ฒ•๊ณผ ๊ณต์ง์„ ๊ฑฐ๋ฒ•์— ๋”ฐ๋ผ ์•ž์œผ๋กœ 60์ผ ์ด๋‚ด์— ์ฐจ๊ธฐ ๋Œ€ํ†ต๋ น ์„ ๊ฑฐ๊ฐ€ ์น˜๋Ÿฌ์ง„๋‹ค.

์ด์ •๋ฏธ ํ—Œ์žฌ์†Œ์žฅ ๊ถŒํ•œ๋Œ€ํ–‰(์žฌํŒ๊ด€)์€ 10์ผ ์˜ค์ „ 11์‹œ 23๋ถ„ ์„œ์šธ ์ข…๋กœ๊ตฌ ํ—Œ๋ฒ•์žฌํŒ์†Œ ๋Œ€์‹ฌํŒ์ •์—์„œ โ€œํ”ผ์ฒญ๊ตฌ์ธ ๋Œ€ํ†ต๋ น ๋ฐ•๊ทผํ˜œ๋ฅผ ํŒŒ๋ฉดํ•œ๋‹คโ€๊ณ  ์ฃผ๋ฌธ์„ ์„ ๊ณ ํ–ˆ๋‹ค. ๊ทธ ์ˆœ๊ฐ„ ๋Œ€์‹ฌํŒ์ • ๊ณณ๊ณณ์—์„œ ๋ฌด๊ฒ๊ณ  ๋‚˜์งํ•œ ํƒ„์„ฑ์ด ํ„ฐ์ ธ ๋‚˜์™”๋‹ค. ์ด๋‚  ๋Œ€์‹ฌํŒ์ •์—์„  ๋ฐ•๊ทผํ˜œ ์ „ ๋Œ€ํ†ต๋ น ์ธก๊ณผ ๊ตญํšŒ์†Œ์ถ”์œ„์› ์ธก ๊ด€๊ณ„์ž๋“ค๊ณผ ์ทจ์žฌ์ง„ 80๋ช…, ์˜จ๋ผ์ธ ์ ‘์ˆ˜๋ฅผ ํ†ตํ•ด 795๋Œ€ 1์˜ ๊ฒฝ์Ÿ๋ฅ ์„ ๋šซ๊ณ  ์„ ์ •๋œ ์ผ๋ฐ˜๋ฐฉ์ฒญ๊ฐ 24๋ช…์ด ์ˆจ์„ ์ฃฝ์ด๊ณ  ์žˆ์—ˆ๋‹ค.
'''


st.title("Date prediction")
text = st.text_area("Input news :", value=default_text)
st.markdown("## Original News Data")
st.write(text)
st.markdown("## Predict Top 3 Date")


if text:
    with st.spinner('processing..'):
        text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
        encoded_dict = tokenizer(
            text=[text],
            add_special_tokens=True,
            max_length=512,
            truncation=True,
            return_tensors='pt',
            return_length=True
        )
        input_ids = encoded_dict['input_ids']
        input_ids_len = encoded_dict['length']
        
        pred = model(input_ids, input_ids_len)
        
    _, indices = torch.topk(pred, 3)
    pred_print = []
    for i in indices.squeeze(0):
        year, month = i2ym(i.item())
        pred_print.append(year+"-"+month)
    st.write(", ".join(pred_print))