Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -184,8 +184,9 @@ if st.button('ํ์ธ'):
|
|
184 |
|
185 |
# Move logits and labels to CPU
|
186 |
# logits = logits.detach().cpu().numpy()
|
187 |
-
|
188 |
-
|
|
|
189 |
# # ๋จ๋
์์ธก ์
|
190 |
# arg_idx = torch.argmax(logits, dim=1)
|
191 |
# print('arg_idx:', arg_idx)
|
@@ -195,15 +196,15 @@ if st.button('ํ์ธ'):
|
|
195 |
|
196 |
# ์์ k๋ฒ์งธ๊น์ง ์์ธก ์
|
197 |
k = 10
|
198 |
-
topk_idx = torch.topk(
|
199 |
-
topk_values = torch.topk(
|
200 |
|
201 |
|
202 |
num_ans_topk = label_tbl[topk_idx]
|
203 |
str_ans_topk = [loc_tbl['ํญ๋ชฉ๋ช
'][loc_tbl['์ฝ๋'] == k] for k in num_ans_topk]
|
204 |
percent_ans_topk = topk_values.numpy()
|
205 |
|
206 |
-
st.write(sum(torch.topk(
|
207 |
# print(num_ans, str_ans)
|
208 |
# print(num_ans_topk)
|
209 |
|
|
|
184 |
|
185 |
# Move logits and labels to CPU
|
186 |
# logits = logits.detach().cpu().numpy()
|
187 |
+
pred_m = torch.nn.Softmax(dim=1)
|
188 |
+
pred_ = pred_m(logits)
|
189 |
+
# st.write(logits.size())
|
190 |
# # ๋จ๋
์์ธก ์
|
191 |
# arg_idx = torch.argmax(logits, dim=1)
|
192 |
# print('arg_idx:', arg_idx)
|
|
|
196 |
|
197 |
# ์์ k๋ฒ์งธ๊น์ง ์์ธก ์
|
198 |
k = 10
|
199 |
+
topk_idx = torch.topk(pred_.flatten(), k).indices
|
200 |
+
topk_values = torch.topk(pred_.flatten(), k).values
|
201 |
|
202 |
|
203 |
num_ans_topk = label_tbl[topk_idx]
|
204 |
str_ans_topk = [loc_tbl['ํญ๋ชฉ๋ช
'][loc_tbl['์ฝ๋'] == k] for k in num_ans_topk]
|
205 |
percent_ans_topk = topk_values.numpy()
|
206 |
|
207 |
+
st.write(sum(torch.topk(pred_.flatten(), 493).values.numpy()))
|
208 |
# print(num_ans, str_ans)
|
209 |
# print(num_ans_topk)
|
210 |
|