iSpr commited on
Commit
3543f19
ยท
1 Parent(s): 65ae22a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -184,8 +184,9 @@ if st.button('ํ™•์ธ'):
184
 
185
  # Move logits and labels to CPU
186
  # logits = logits.detach().cpu().numpy()
187
-
188
- st.write(logits.size())
 
189
  # # ๋‹จ๋… ์˜ˆ์ธก ์‹œ
190
  # arg_idx = torch.argmax(logits, dim=1)
191
  # print('arg_idx:', arg_idx)
@@ -195,15 +196,15 @@ if st.button('ํ™•์ธ'):
195
 
196
  # ์ƒ์œ„ k๋ฒˆ์งธ๊นŒ์ง€ ์˜ˆ์ธก ์‹œ
197
  k = 10
198
- topk_idx = torch.topk(logits.flatten(), k).indices
199
- topk_values = torch.topk(logits.flatten(), k).values
200
 
201
 
202
  num_ans_topk = label_tbl[topk_idx]
203
  str_ans_topk = [loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == k] for k in num_ans_topk]
204
  percent_ans_topk = topk_values.numpy()
205
 
206
- st.write(sum(torch.topk(logits.flatten(), 493).values.numpy()))
207
  # print(num_ans, str_ans)
208
  # print(num_ans_topk)
209
 
 
184
 
185
  # Move logits and labels to CPU
186
  # logits = logits.detach().cpu().numpy()
187
+ pred_m = torch.nn.Softmax(dim=1)
188
+ pred_ = pred_m(logits)
189
+ # st.write(logits.size())
190
  # # ๋‹จ๋… ์˜ˆ์ธก ์‹œ
191
  # arg_idx = torch.argmax(logits, dim=1)
192
  # print('arg_idx:', arg_idx)
 
196
 
197
  # ์ƒ์œ„ k๋ฒˆ์งธ๊นŒ์ง€ ์˜ˆ์ธก ์‹œ
198
  k = 10
199
+ topk_idx = torch.topk(pred_.flatten(), k).indices
200
+ topk_values = torch.topk(pred_.flatten(), k).values
201
 
202
 
203
  num_ans_topk = label_tbl[topk_idx]
204
  str_ans_topk = [loc_tbl['ํ•ญ๋ชฉ๋ช…'][loc_tbl['์ฝ”๋“œ'] == k] for k in num_ans_topk]
205
  percent_ans_topk = topk_values.numpy()
206
 
207
+ st.write(sum(torch.topk(pred_.flatten(), 493).values.numpy()))
208
  # print(num_ans, str_ans)
209
  # print(num_ans_topk)
210