Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ def md_loading():
|
|
27 |
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
|
28 |
model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-base', num_labels=493)
|
29 |
|
30 |
-
model_checkpoint = '
|
31 |
project_path = './'
|
32 |
output_model_file = os.path.join(project_path, model_checkpoint)
|
33 |
|
@@ -93,17 +93,16 @@ class TVT_Dataset(Dataset):
|
|
93 |
|
94 |
|
95 |
# 텍스트 input 박스
|
96 |
-
business = st.text_input('')
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
business_work = ''
|
103 |
-
work_department = ''
|
104 |
-
work_position = ''
|
105 |
-
what_do_i = ''
|
106 |
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# data 준비
|
109 |
|
@@ -186,7 +185,9 @@ if st.button('확인'):
|
|
186 |
# Move logits and labels to CPU
|
187 |
# logits = logits.detach().cpu().numpy()
|
188 |
|
189 |
-
|
|
|
|
|
190 |
# # 단독 예측 시
|
191 |
# arg_idx = torch.argmax(logits, dim=1)
|
192 |
# print('arg_idx:', arg_idx)
|
@@ -196,11 +197,15 @@ if st.button('확인'):
|
|
196 |
|
197 |
# 상위 k번째까지 예측 시
|
198 |
k = 10
|
199 |
-
topk_idx = torch.topk(
|
|
|
|
|
200 |
|
201 |
num_ans_topk = label_tbl[topk_idx]
|
202 |
str_ans_topk = [loc_tbl['항목명'][loc_tbl['코드'] == k] for k in num_ans_topk]
|
203 |
-
|
|
|
|
|
204 |
# print(num_ans, str_ans)
|
205 |
# print(num_ans_topk)
|
206 |
|
@@ -224,16 +229,24 @@ if st.button('확인'):
|
|
224 |
# print(str_ans, type(str_ans))
|
225 |
|
226 |
str_ans_topk_list = []
|
|
|
227 |
for i in range(k):
|
228 |
str_ans_topk_list.append(str_ans_topk[i].iloc[0])
|
|
|
229 |
|
230 |
# print(str_ans_topk_list)
|
231 |
|
232 |
ans_topk_df = pd.DataFrame({
|
233 |
'NO': range(1, k+1),
|
234 |
'세분류 코드': num_ans_topk,
|
235 |
-
'세분류 명칭': str_ans_topk_list
|
|
|
236 |
})
|
237 |
ans_topk_df = ans_topk_df.set_index('NO')
|
238 |
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
27 |
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
|
28 |
model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-base', num_labels=493)
|
29 |
|
30 |
+
model_checkpoint = 'en_ko_4mix_proto.bin'
|
31 |
project_path = './'
|
32 |
output_model_file = os.path.join(project_path, model_checkpoint)
|
33 |
|
|
|
93 |
|
94 |
|
95 |
# 텍스트 input 박스
|
96 |
+
business = st.text_input('사업체명')
|
97 |
+
business_work = st.text_input('사업체 하는일')
|
98 |
+
work_department = st.text_input('근무부서')
|
99 |
+
work_position = st.text_input('직책')
|
100 |
+
what_do_i = st.text_input('내가 하는 일')
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
+
# business_work = ''
|
103 |
+
# work_department = ''
|
104 |
+
# work_position = ''
|
105 |
+
# what_do_i = ''
|
106 |
|
107 |
# data 준비
|
108 |
|
|
|
185 |
# Move logits and labels to CPU
|
186 |
# logits = logits.detach().cpu().numpy()
|
187 |
|
188 |
+
pred_m = torch.nn.Softmax(dim=1)
|
189 |
+
pred_ = pred_m(logits)
|
190 |
+
# st.write(logits.size())
|
191 |
# # 단독 예측 시
|
192 |
# arg_idx = torch.argmax(logits, dim=1)
|
193 |
# print('arg_idx:', arg_idx)
|
|
|
197 |
|
198 |
# 상위 k번째까지 예측 시
|
199 |
k = 10
|
200 |
+
topk_idx = torch.topk(pred_.flatten(), k).indices
|
201 |
+
topk_values = torch.topk(pred_.flatten(), k).values
|
202 |
+
|
203 |
|
204 |
num_ans_topk = label_tbl[topk_idx]
|
205 |
str_ans_topk = [loc_tbl['항목명'][loc_tbl['코드'] == k] for k in num_ans_topk]
|
206 |
+
percent_ans_topk = topk_values.numpy()
|
207 |
+
|
208 |
+
st.write(sum(torch.topk(pred_.flatten(), 493).values.numpy()))
|
209 |
# print(num_ans, str_ans)
|
210 |
# print(num_ans_topk)
|
211 |
|
|
|
229 |
# print(str_ans, type(str_ans))
|
230 |
|
231 |
str_ans_topk_list = []
|
232 |
+
percent_ans_topk_list = []
|
233 |
for i in range(k):
|
234 |
str_ans_topk_list.append(str_ans_topk[i].iloc[0])
|
235 |
+
percent_ans_topk_list.append(percent_ans_topk[i]*100)
|
236 |
|
237 |
# print(str_ans_topk_list)
|
238 |
|
239 |
ans_topk_df = pd.DataFrame({
|
240 |
'NO': range(1, k+1),
|
241 |
'세분류 코드': num_ans_topk,
|
242 |
+
'세분류 명칭': str_ans_topk_list,
|
243 |
+
'확률': percent_ans_topk_list
|
244 |
})
|
245 |
ans_topk_df = ans_topk_df.set_index('NO')
|
246 |
|
247 |
+
# ans_topk_df.style.bar(subset='확률', align='left', color='blue')
|
248 |
+
# ans_topk_df['확률'].style.applymap(color='black', font_color='blue')
|
249 |
+
|
250 |
+
# st.dataframe(ans_topk_df)
|
251 |
+
# st.dataframe(ans_topk_df.style.bar(subset='확률', align='left', color='blue'))
|
252 |
+
st.write(ans_topk_df.style.bar(subset='확률', align='left', color='blue'))
|