File size: 6,150 Bytes
962c41a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import re
import streamlit as st
import pandas as pd
from transformers import PreTrainedTokenizerFast, DistilBertForSequenceClassification, BartForConditionalGeneration

from tokenization_kobert import KoBertTokenizer
from tokenizers import SentencePieceBPETokenizer


@st.cache(allow_output_mutation=True)
def get_topic():
    model = DistilBertForSequenceClassification.from_pretrained(
        'alex6095/SanctiMolyTopic', problem_type="multi_label_classification", num_labels=9)
    model.eval()
    tokenizer = KoBertTokenizer.from_pretrained('monologg/distilkobert')
    return model, tokenizer

@st.cache(allow_output_mutation=True)
def get_date():
    model = BartForConditionalGeneration.from_pretrained('alex6095/SanctiMoly-Bart')
    model.eval()
    tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization')
    return model, tokenizer

class RegexSubstitution(object):
    """Regex substitution class for transform"""
    def __init__(self, regex, sub=''):
        if isinstance(regex, re.Pattern):
            self.regex = regex
        else:
            self.regex = re.compile(regex)
        self.sub = sub
    def __call__(self, target):
        if isinstance(target, list):
            return [self.regex.sub(self.sub, self.regex.sub(self.sub, string)) for string in target]
        else:
            return self.regex.sub(self.sub, self.regex.sub(self.sub, target))
            
default_text = '''์งˆ๋ณ‘๊ด€๋ฆฌ์ฒญ์€ 23์ผ ์ง€๋ฐฉ์ž์น˜๋‹จ์ฒด๊ฐ€ ๋ณด๊ฑด๋‹น๊ตญ๊ณผ ํ˜‘์˜ ์—†์ด ๋‹จ๋…์œผ๋กœ ์ธํ”Œ๋ฃจ์—”์ž(๋…๊ฐ) ๋ฐฑ์‹  ์ ‘์ข… ์ค‘๋‹จ์„ ๊ฒฐ์ •ํ•ด์„œ๋Š” ์•ˆ ๋œ๋‹ค๋Š” ์ž…์žฅ์„ ๋ฐํ˜”๋‹ค.
    ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  ์ฐธ๊ณ ์ž๋ฃŒ๋ฅผ ๋ฐฐํฌํ•˜๊ณ  โ€œํ–ฅํ›„ ์ „์ฒด ๊ตญ๊ฐ€ ์˜ˆ๋ฐฉ์ ‘์ข…์‚ฌ์—…์ด ์ฐจ์งˆ ์—†์ด ์ง„ํ–‰๋˜๋„๋ก ์ง€์ž์ฒด๊ฐ€ ์ž์ฒด์ ์œผ๋กœ ์ ‘์ข… ์œ ๋ณด ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•˜์ง€ ์•Š๋„๋ก ์•ˆ๋‚ด๋ฅผ ํ–ˆ๋‹คโ€๊ณ  ์„ค๋ช…ํ–ˆ๋‹ค.
    ๋…๊ฐ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•œ ํ›„ ๊ณ ๋ น์ธต์„ ์ค‘์‹ฌ์œผ๋กœ ์ „๊ตญ์—์„œ ์‚ฌ๋ง์ž๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์„œ์šธ ์˜๋“ฑํฌ๊ตฌ๋ณด๊ฑด์†Œ๋Š” ์ „๋‚ , ๊ฒฝ๋ถ ํฌํ•ญ์‹œ๋Š” ์ด๋‚  ๊ด€๋‚ด ์˜๋ฃŒ๊ธฐ๊ด€์— ์ ‘์ข…์„ ๋ณด๋ฅ˜ํ•ด๋‹ฌ๋ผ๋Š” ๊ณต๋ฌธ์„ ๋‚ด๋ ค๋ณด๋ƒˆ๋‹ค. ์ด๋Š” ์˜ˆ๋ฐฉ์ ‘์ข…๊ณผ ์‚ฌ๋ง ๊ฐ„ ์ง์ ‘์  ์—ฐ๊ด€์„ฑ์ด ๋‚ฎ์•„ ์ ‘์ข…์„ ์ค‘๋‹จํ•  ์ƒํ™ฉ์€ ์•„๋‹ˆ๋ผ๋Š” ์งˆ๋ณ‘์ฒญ์˜ ํŒ๋‹จ๊ณผ๋Š” ๋‹ค๋ฅธ ๊ฒƒ์ด๋‹ค.
    ์งˆ๋ณ‘์ฒญ์€ ์ง€๋‚œ 21์ผ ์ „๋ฌธ๊ฐ€ ๋“ฑ์ด ์ฐธ์—ฌํ•œ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜โ€™์˜ ๋ถ„์„ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๋…๊ฐ ์˜ˆ๋ฐฉ์ ‘์ข… ์‚ฌ์—…์„ ์ผ์ •๋Œ€๋กœ ์ง„ํ–‰ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค. ํŠนํžˆ ๊ณ ๋ น ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด, ์ž„์‹ ๋ถ€ ๋“ฑ ๋…๊ฐ ๊ณ ์œ„ํ—˜๊ตฐ์€ ๋ฐฑ์‹ ์„ ์ ‘์ข…ํ•˜์ง€ ์•Š์•˜์„ ๋•Œ ํ•ฉ๋ณ‘์ฆ ํ”ผํ•ด๊ฐ€ ํด ์ˆ˜ ์žˆ๋‹ค๋ฉด์„œ ์ ‘์ข…์„ ๋…๋ คํ–ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ๋ฐœํ‘œ ์ดํ›„์—๋„ ์‚ฌ๋ง ๋ณด๊ณ ๊ฐ€ ์ž‡๋”ฐ๋ฅด์ž ์งˆ๋ณ‘์ฒญ์€ ์ด๋‚  โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ํ”ผํ•ด์กฐ์‚ฌ๋ฐ˜ ํšŒ์˜โ€™์™€ โ€˜์˜ˆ๋ฐฉ์ ‘์ข… ์ „๋ฌธ์œ„์›ํšŒโ€™๋ฅผ ๊ฐœ์ตœํ•ด ๋…๊ฐ๋ฐฑ์‹ ๊ณผ ์‚ฌ๋ง ๊ฐ„ ๊ด€๋ จ์„ฑ, ์ ‘์ข…์‚ฌ์—… ์œ ์ง€ ์—ฌ๋ถ€ ๋“ฑ์— ๋Œ€ํ•ด ๋‹ค์‹œ ๊ฒฐ๋ก  ๋‚ด๋ฆฌ๊ธฐ๋กœ ํ–ˆ๋‹ค. ํšŒ์˜ ๊ฒฐ๊ณผ๋Š” ์ด๋‚  ์˜คํ›„ 7์‹œ ๋„˜์–ด ๋ฐœํ‘œ๋  ์˜ˆ์ •์ด๋‹ค.
'''
topics_raw = ['IT/๊ณผํ•™', '๊ฒฝ์ œ', '๋ฌธํ™”', '๋ฏธ์šฉ/๊ฑด๊ฐ•', '์‚ฌํšŒ', '์ƒํ™œ', '์Šคํฌ์ธ ', '์—ฐ์˜ˆ', '์ •์น˜']


#topic_model, topic_tokenizer = get_topic()
#date_model, date_tokenizer = get_date()

name = st.side_bar.selectbox('Model', ['Topic Classification', 'Date Prediction'])
if name == 'Topic Classification':
    title = 'News Topic Classification'
    model, tokenizer = get_topic()
elif name == 'Date Prediction':
    title = 'News Date prediction'
    model, tokenizer = get_date()

st.title(title)

text = st.text_area("Input news :", value=default_text)
st.markdown("## Original News Data")
st.write(text)

if name == 'Topic Classification':
    st.markdown("## Predict Topic")
    col1, col2 = st.columns(2)
    if text:
        with st.spinner('processing..'):
            text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
            encoded_dict = tokenizer(
                text=text,
                add_special_tokens=True,
                max_length=512,
                truncation=True,
                return_tensors='pt',
                return_length=True
            )
            input_ids = encoded_dict['input_ids']
            input_ids_len = encoded_dict['length'].unsqueeze(0)
            attn_mask = torch.arange(input_ids.size(1))
            attn_mask = attn_mask[None, :] < input_ids_len[:, None]
            outputs = model(input_ids=input_ids, attention_mask=attn_mask)
            _, preds = torch.max(outputs.logits, 1)
        col1.write(topics_raw[preds.squeeze(0)])
        softmax = torch.nn.Softmax(dim=1)
        prob = softmax(outputs.logits).squeeze(0).detach()
        chart_data = pd.DataFrame({
            'Topic': topics_raw,
            'Probability': prob
        })
        chart_data = chart_data.set_index('Topic')
        col2.bar_chart(chart_data)
        
elif name == 'Date Prediction':
    st.markdown("## Predict Date")
    if text:
        with st.spinner('processing..'):
            text = RegexSubstitution(r'\([^()]+\)|[<>\'"โ–ณโ–ฒโ–กโ– ]')(text)
            raw_input_ids = tokenizer.encode(text)
            input_ids = [tokenizer.bos_token_id] + \
                raw_input_ids + [tokenizer.eos_token_id]
            outputs = model.generate(torch.tensor([input_ids]),
                                         early_stopping=True,
                                         repetition_penalty=2.0,
                                         do_sample=True, #์ƒ˜ํ”Œ๋ง ์ „๋žต ์‚ฌ์šฉ
                                         max_length=50, # ์ตœ๋Œ€ ๋””์ฝ”๋”ฉ ๊ธธ์ด๋Š” 50
                                         top_k=50, # ํ™•๋ฅ  ์ˆœ์œ„๊ฐ€ 50์œ„ ๋ฐ–์ธ ํ† ํฐ์€ ์ƒ˜ํ”Œ๋ง์—์„œ ์ œ์™ธ
                                         top_p=0.95, # ๋ˆ„์  ํ™•๋ฅ ์ด 95%์ธ ํ›„๋ณด์ง‘ํ•ฉ์—์„œ๋งŒ ์ƒ์„ฑ
                                         num_return_sequences=3 #3๊ฐœ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋””์ฝ”๋”ฉํ•ด๋‚ธ๋‹ค
                                         )
        for output in outputs:
            pred_print = tokenizer.decode(output.squeeze().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
            st.write(pred_print)