Update gradcam.py
Browse files- gradcam.py +1 -1
gradcam.py
CHANGED
@@ -71,7 +71,7 @@ class GradCam():
|
|
71 |
return np.hstack(results)
|
72 |
|
73 |
|
74 |
-
def
|
75 |
logits = model(img_tensor.unsqueeze(0)).logits
|
76 |
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
77 |
topIdx = logits.cpu()[0, :].detach().numpy().argsort()[-1]
|
|
|
71 |
return np.hstack(results)
|
72 |
|
73 |
|
74 |
+
def get_top_category(self, model, img_tensor, top_k=5):
|
75 |
logits = model(img_tensor.unsqueeze(0)).logits
|
76 |
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
77 |
topIdx = logits.cpu()[0, :].detach().numpy().argsort()[-1]
|