Spaces:
Runtime error
Runtime error
Create new file
Browse files
app.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import sentencepiece
|
4 |
+
|
5 |
+
# 모델 준비하기
|
6 |
+
from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
import os
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
# [theme]
|
15 |
+
# base="dark"
|
16 |
+
# primaryColor="purple"
|
17 |
+
|
18 |
+
# 제목 입력
|
19 |
+
st.header('한국표준산업분류 자동코딩 서비스')
|
20 |
+
|
21 |
+
# 재로드 안하도록
|
22 |
+
@st.experimental_memo(max_entries=20)
|
23 |
+
def md_loading():
|
24 |
+
## cpu
|
25 |
+
device = torch.device("cpu")
|
26 |
+
|
27 |
+
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
|
28 |
+
model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-base', num_labels=493)
|
29 |
+
|
30 |
+
model_checkpoint = 'base3_44_en.bin'
|
31 |
+
project_path = './'
|
32 |
+
output_model_file = os.path.join(project_path, model_checkpoint)
|
33 |
+
|
34 |
+
# model.load_state_dict(torch.load(output_model_file))
|
35 |
+
model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))
|
36 |
+
# ckpt = torch.load(output_model_file, map_location=torch.device('cpu'))
|
37 |
+
# model.load_state_dict(ckpt['model_state_dict'])
|
38 |
+
|
39 |
+
# device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu")
|
40 |
+
# device = torch.device("cpu")
|
41 |
+
|
42 |
+
model.to(device)
|
43 |
+
|
44 |
+
label_tbl = np.load('./label_table.npy')
|
45 |
+
loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
|
46 |
+
|
47 |
+
print('ready')
|
48 |
+
|
49 |
+
return tokenizer, model, label_tbl, loc_tbl, device
|
50 |
+
|
51 |
+
# 모델 로드
|
52 |
+
tokenizer, model, label_tbl, loc_tbl, device = md_loading()
|
53 |
+
|
54 |
+
|
55 |
+
# 데이터 셋 준비용
|
56 |
+
max_len = 64 # 64
|
57 |
+
|
58 |
+
class TVT_Dataset(Dataset):
|
59 |
+
|
60 |
+
def __init__(self, df):
|
61 |
+
self.df_data = df
|
62 |
+
|
63 |
+
def __getitem__(self, index):
|
64 |
+
|
65 |
+
# 데이터프레임 칼럼 들고오기
|
66 |
+
# sentence = self.df_data.loc[index, 'text']
|
67 |
+
sentence = self.df_data.loc[index, ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']]
|
68 |
+
|
69 |
+
encoded_dict = tokenizer(
|
70 |
+
' <s> '.join(sentence.to_list()),
|
71 |
+
add_special_tokens = True,
|
72 |
+
max_length = max_len,
|
73 |
+
padding='max_length',
|
74 |
+
truncation=True,
|
75 |
+
return_attention_mask = True,
|
76 |
+
return_tensors = 'pt')
|
77 |
+
|
78 |
+
|
79 |
+
padded_token_list = encoded_dict['input_ids'][0]
|
80 |
+
att_mask = encoded_dict['attention_mask'][0]
|
81 |
+
|
82 |
+
# 숫자로 변환된 label을 텐서로 변환
|
83 |
+
# target = torch.tensor(self.df_data.loc[index, 'NEW_CD'])
|
84 |
+
# input_ids, attention_mask, label을 하나의 인풋으로 묶음
|
85 |
+
# sample = (padded_token_list, att_mask, target)
|
86 |
+
sample = (padded_token_list, att_mask)
|
87 |
+
|
88 |
+
return sample
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return len(self.df_data)
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
# 텍스트 input 박스
|
96 |
+
business = st.text_input('사업체명')
|
97 |
+
business_work = st.text_input('사업체 하는일')
|
98 |
+
work_department = st.text_input('근무부서')
|
99 |
+
work_position = st.text_input('직책')
|
100 |
+
what_do_i = st.text_input('내가 하는 일')
|
101 |
+
|
102 |
+
|
103 |
+
# data 준비
|
104 |
+
|
105 |
+
# test dataset을 만들어줍니다.
|
106 |
+
input_col_type = ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']
|
107 |
+
|
108 |
+
def preprocess_dataset(dataset):
|
109 |
+
dataset.reset_index(drop=True, inplace=True)
|
110 |
+
dataset.fillna('')
|
111 |
+
return dataset[input_col_type]
|
112 |
+
|
113 |
+
|
114 |
+
## 임시 확인
|
115 |
+
# st.write(md_input)
|
116 |
+
|
117 |
+
# 버튼
|
118 |
+
if st.button('확인'):
|
119 |
+
## 버튼 클릭 시 수행사항
|
120 |
+
|
121 |
+
### 데이터 준비
|
122 |
+
|
123 |
+
# md_input: 모델에 입력할 input 값 정의
|
124 |
+
# md_input = '|'.join([business, business_work, what_do_i, work_position, work_department])
|
125 |
+
md_input = [str(business), str(business_work), str(what_do_i), str(work_position), str(work_department)]
|
126 |
+
|
127 |
+
test_dataset = pd.DataFrame({
|
128 |
+
input_col_type[0]: md_input[0],
|
129 |
+
input_col_type[1]: md_input[1],
|
130 |
+
input_col_type[2]: md_input[2],
|
131 |
+
input_col_type[3]: md_input[3],
|
132 |
+
input_col_type[4]: md_input[4]
|
133 |
+
}, index=[0])
|
134 |
+
|
135 |
+
# test_dataset = pd.read_csv(DATA_IN_PATH + test_set_name, sep='|', na_filter=False)
|
136 |
+
|
137 |
+
test_dataset.reset_index(inplace=True)
|
138 |
+
|
139 |
+
test_dataset = preprocess_dataset(test_dataset)
|
140 |
+
|
141 |
+
print(len(test_dataset))
|
142 |
+
print(test_dataset)
|
143 |
+
|
144 |
+
print('base_data_loader 사용 시점점')
|
145 |
+
test_data = TVT_Dataset(test_dataset)
|
146 |
+
|
147 |
+
train_batch_size = 48
|
148 |
+
|
149 |
+
# batch_size 만큼 데이터 분할
|
150 |
+
test_dataloader = DataLoader(test_data,
|
151 |
+
batch_size=train_batch_size,
|
152 |
+
shuffle=False)
|
153 |
+
|
154 |
+
|
155 |
+
### 모델 실행
|
156 |
+
|
157 |
+
|
158 |
+
# Put model in evaluation mode
|
159 |
+
model.eval()
|
160 |
+
model.zero_grad()
|
161 |
+
|
162 |
+
# Tracking variables
|
163 |
+
predictions , true_labels = [], []
|
164 |
+
|
165 |
+
# Predict
|
166 |
+
for batch in tqdm(test_dataloader):
|
167 |
+
# Add batch to GPU
|
168 |
+
batch = tuple(t.to(device) for t in batch)
|
169 |
+
|
170 |
+
# Unpack the inputs from our dataloader
|
171 |
+
test_input_ids, test_attention_mask = batch
|
172 |
+
|
173 |
+
# Telling the model not to compute or store gradients, saving memory and
|
174 |
+
# speeding up prediction
|
175 |
+
with torch.no_grad():
|
176 |
+
# Forward pass, calculate logit predictions
|
177 |
+
outputs = model(test_input_ids, token_type_ids=None, attention_mask=test_attention_mask)
|
178 |
+
|
179 |
+
logits = outputs.logits
|
180 |
+
|
181 |
+
# Move logits and labels to CPU
|
182 |
+
# logits = logits.detach().cpu().numpy()
|
183 |
+
|
184 |
+
|
185 |
+
# # 단독 예측 시
|
186 |
+
# arg_idx = torch.argmax(logits, dim=1)
|
187 |
+
# print('arg_idx:', arg_idx)
|
188 |
+
|
189 |
+
# num_ans = label_tbl[arg_idx]
|
190 |
+
# str_ans = loc_tbl['항목명'][loc_tbl['코드'] == num_ans].values
|
191 |
+
|
192 |
+
# 상위 k번째까지 예측 시
|
193 |
+
k = 10
|
194 |
+
topk_idx = torch.topk(logits.flatten(), k).indices
|
195 |
+
|
196 |
+
num_ans_topk = label_tbl[topk_idx]
|
197 |
+
str_ans_topk = [loc_tbl['항목명'][loc_tbl['코드'] == k] for k in num_ans_topk]
|
198 |
+
|
199 |
+
# print(num_ans, str_ans)
|
200 |
+
# print(num_ans_topk)
|
201 |
+
|
202 |
+
# print('사업체명:', query_tokens[0])
|
203 |
+
# print('사업체 하는일:', query_tokens[1])
|
204 |
+
# print('근무부서:', query_tokens[2])
|
205 |
+
# print('직책:', query_tokens[3])
|
206 |
+
# print('내가 하는일:', query_tokens[4])
|
207 |
+
# print('산업코드 및 분류:', num_ans, str_ans)
|
208 |
+
|
209 |
+
# ans = ''
|
210 |
+
# ans1, ans2, ans3 = '', '', ''
|
211 |
+
|
212 |
+
## 모델 결과값 출력
|
213 |
+
# st.write("산업코드 및 분류:", num_ans, str_ans[0])
|
214 |
+
# st.write("세분류 코드")
|
215 |
+
# for i in range(k):
|
216 |
+
# st.write(str(i+1) + '순위:', num_ans_topk[i], str_ans_topk[i].iloc[0])
|
217 |
+
|
218 |
+
# print(num_ans)
|
219 |
+
# print(str_ans, type(str_ans))
|
220 |
+
|
221 |
+
str_ans_topk_list = []
|
222 |
+
for i in range(k):
|
223 |
+
str_ans_topk_list.append(str_ans_topk[i].iloc[0])
|
224 |
+
|
225 |
+
# print(str_ans_topk_list)
|
226 |
+
|
227 |
+
ans_topk_df = pd.DataFrame({
|
228 |
+
'NO': range(1, k+1),
|
229 |
+
'세분류 코드': num_ans_topk,
|
230 |
+
'세분류 명칭': str_ans_topk_list
|
231 |
+
})
|
232 |
+
ans_topk_df = ans_topk_df.set_index('NO')
|
233 |
+
|
234 |
+
st.dataframe(ans_topk_df)
|