iSpr commited on
Commit
cadd18b
·
1 Parent(s): 085c39e

Create new file

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import sentencepiece
4
+
5
+ # 모델 준비하기
6
+ from transformers import XLMRobertaForSequenceClassification, XLMRobertaTokenizer
7
+ from torch.utils.data import DataLoader, Dataset
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import os
12
+ from tqdm import tqdm
13
+
14
+ # [theme]
15
+ # base="dark"
16
+ # primaryColor="purple"
17
+
18
+ # 제목 입력
19
+ st.header('한국표준산업분류 자동코딩 서비스')
20
+
21
+ # 재로드 안하도록
22
+ @st.experimental_memo(max_entries=20)
23
+ def md_loading():
24
+ ## cpu
25
+ device = torch.device("cpu")
26
+
27
+ tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
28
+ model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-base', num_labels=493)
29
+
30
+ model_checkpoint = 'base3_44_en.bin'
31
+ project_path = './'
32
+ output_model_file = os.path.join(project_path, model_checkpoint)
33
+
34
+ # model.load_state_dict(torch.load(output_model_file))
35
+ model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu')))
36
+ # ckpt = torch.load(output_model_file, map_location=torch.device('cpu'))
37
+ # model.load_state_dict(ckpt['model_state_dict'])
38
+
39
+ # device = torch.device("cuda" if torch.cuda.is_available() and not False else "cpu")
40
+ # device = torch.device("cpu")
41
+
42
+ model.to(device)
43
+
44
+ label_tbl = np.load('./label_table.npy')
45
+ loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8')
46
+
47
+ print('ready')
48
+
49
+ return tokenizer, model, label_tbl, loc_tbl, device
50
+
51
+ # 모델 로드
52
+ tokenizer, model, label_tbl, loc_tbl, device = md_loading()
53
+
54
+
55
+ # 데이터 셋 준비용
56
+ max_len = 64 # 64
57
+
58
+ class TVT_Dataset(Dataset):
59
+
60
+ def __init__(self, df):
61
+ self.df_data = df
62
+
63
+ def __getitem__(self, index):
64
+
65
+ # 데이터프레임 칼럼 들고오기
66
+ # sentence = self.df_data.loc[index, 'text']
67
+ sentence = self.df_data.loc[index, ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']]
68
+
69
+ encoded_dict = tokenizer(
70
+ ' <s> '.join(sentence.to_list()),
71
+ add_special_tokens = True,
72
+ max_length = max_len,
73
+ padding='max_length',
74
+ truncation=True,
75
+ return_attention_mask = True,
76
+ return_tensors = 'pt')
77
+
78
+
79
+ padded_token_list = encoded_dict['input_ids'][0]
80
+ att_mask = encoded_dict['attention_mask'][0]
81
+
82
+ # 숫자로 변환된 label을 텐서로 변환
83
+ # target = torch.tensor(self.df_data.loc[index, 'NEW_CD'])
84
+ # input_ids, attention_mask, label을 하나의 인풋으로 묶음
85
+ # sample = (padded_token_list, att_mask, target)
86
+ sample = (padded_token_list, att_mask)
87
+
88
+ return sample
89
+
90
+ def __len__(self):
91
+ return len(self.df_data)
92
+
93
+
94
+
95
+ # 텍스트 input 박스
96
+ business = st.text_input('사업체명')
97
+ business_work = st.text_input('사업체 하는일')
98
+ work_department = st.text_input('근무부서')
99
+ work_position = st.text_input('직책')
100
+ what_do_i = st.text_input('내가 하는 일')
101
+
102
+
103
+ # data 준비
104
+
105
+ # test dataset을 만들어줍니다.
106
+ input_col_type = ['CMPNY_NM', 'MAJ_ACT', 'WORK_TYPE', 'POSITION', 'DEPT_NM']
107
+
108
+ def preprocess_dataset(dataset):
109
+ dataset.reset_index(drop=True, inplace=True)
110
+ dataset.fillna('')
111
+ return dataset[input_col_type]
112
+
113
+
114
+ ## 임시 확인
115
+ # st.write(md_input)
116
+
117
+ # 버튼
118
+ if st.button('확인'):
119
+ ## 버튼 클릭 시 수행사항
120
+
121
+ ### 데이터 준비
122
+
123
+ # md_input: 모델에 입력할 input 값 정의
124
+ # md_input = '|'.join([business, business_work, what_do_i, work_position, work_department])
125
+ md_input = [str(business), str(business_work), str(what_do_i), str(work_position), str(work_department)]
126
+
127
+ test_dataset = pd.DataFrame({
128
+ input_col_type[0]: md_input[0],
129
+ input_col_type[1]: md_input[1],
130
+ input_col_type[2]: md_input[2],
131
+ input_col_type[3]: md_input[3],
132
+ input_col_type[4]: md_input[4]
133
+ }, index=[0])
134
+
135
+ # test_dataset = pd.read_csv(DATA_IN_PATH + test_set_name, sep='|', na_filter=False)
136
+
137
+ test_dataset.reset_index(inplace=True)
138
+
139
+ test_dataset = preprocess_dataset(test_dataset)
140
+
141
+ print(len(test_dataset))
142
+ print(test_dataset)
143
+
144
+ print('base_data_loader 사용 시점점')
145
+ test_data = TVT_Dataset(test_dataset)
146
+
147
+ train_batch_size = 48
148
+
149
+ # batch_size 만큼 데이터 분할
150
+ test_dataloader = DataLoader(test_data,
151
+ batch_size=train_batch_size,
152
+ shuffle=False)
153
+
154
+
155
+ ### 모델 실행
156
+
157
+
158
+ # Put model in evaluation mode
159
+ model.eval()
160
+ model.zero_grad()
161
+
162
+ # Tracking variables
163
+ predictions , true_labels = [], []
164
+
165
+ # Predict
166
+ for batch in tqdm(test_dataloader):
167
+ # Add batch to GPU
168
+ batch = tuple(t.to(device) for t in batch)
169
+
170
+ # Unpack the inputs from our dataloader
171
+ test_input_ids, test_attention_mask = batch
172
+
173
+ # Telling the model not to compute or store gradients, saving memory and
174
+ # speeding up prediction
175
+ with torch.no_grad():
176
+ # Forward pass, calculate logit predictions
177
+ outputs = model(test_input_ids, token_type_ids=None, attention_mask=test_attention_mask)
178
+
179
+ logits = outputs.logits
180
+
181
+ # Move logits and labels to CPU
182
+ # logits = logits.detach().cpu().numpy()
183
+
184
+
185
+ # # 단독 예측 시
186
+ # arg_idx = torch.argmax(logits, dim=1)
187
+ # print('arg_idx:', arg_idx)
188
+
189
+ # num_ans = label_tbl[arg_idx]
190
+ # str_ans = loc_tbl['항목명'][loc_tbl['코드'] == num_ans].values
191
+
192
+ # 상위 k번째까지 예측 시
193
+ k = 10
194
+ topk_idx = torch.topk(logits.flatten(), k).indices
195
+
196
+ num_ans_topk = label_tbl[topk_idx]
197
+ str_ans_topk = [loc_tbl['항목명'][loc_tbl['코드'] == k] for k in num_ans_topk]
198
+
199
+ # print(num_ans, str_ans)
200
+ # print(num_ans_topk)
201
+
202
+ # print('사업체명:', query_tokens[0])
203
+ # print('사업체 하는일:', query_tokens[1])
204
+ # print('근무부서:', query_tokens[2])
205
+ # print('직책:', query_tokens[3])
206
+ # print('내가 하는일:', query_tokens[4])
207
+ # print('산업코드 및 분류:', num_ans, str_ans)
208
+
209
+ # ans = ''
210
+ # ans1, ans2, ans3 = '', '', ''
211
+
212
+ ## 모델 결과값 출력
213
+ # st.write("산업코드 및 분류:", num_ans, str_ans[0])
214
+ # st.write("세분류 코드")
215
+ # for i in range(k):
216
+ # st.write(str(i+1) + '순위:', num_ans_topk[i], str_ans_topk[i].iloc[0])
217
+
218
+ # print(num_ans)
219
+ # print(str_ans, type(str_ans))
220
+
221
+ str_ans_topk_list = []
222
+ for i in range(k):
223
+ str_ans_topk_list.append(str_ans_topk[i].iloc[0])
224
+
225
+ # print(str_ans_topk_list)
226
+
227
+ ans_topk_df = pd.DataFrame({
228
+ 'NO': range(1, k+1),
229
+ '세분류 코드': num_ans_topk,
230
+ '세분류 명칭': str_ans_topk_list
231
+ })
232
+ ans_topk_df = ans_topk_df.set_index('NO')
233
+
234
+ st.dataframe(ans_topk_df)