SixOpen commited on
Commit
a299f2e
·
verified ·
1 Parent(s): b2b5573

Update Scripts/model.py

Browse files
Files changed (1) hide show
  1. Scripts/model.py +1 -1
Scripts/model.py CHANGED
@@ -33,6 +33,6 @@ def create_cam(model):
33
  target_layers = [model.net._blocks[-1]]
34
  targets = [ClassifierOutputTarget(1)]
35
  cam_algorithm = GradCAMElementWise
36
- use_cuda = torch.cuda.is_available() and model.device.type == 'cuda'
37
  cam = cam_algorithm(model=model, target_layers=target_layers, use_cuda=use_cuda)
38
  return cam
 
33
  target_layers = [model.net._blocks[-1]]
34
  targets = [ClassifierOutputTarget(1)]
35
  cam_algorithm = GradCAMElementWise
36
+ use_cuda = torch.cuda.is_available() and next(model.parameters()).is_cuda
37
  cam = cam_algorithm(model=model, target_layers=target_layers, use_cuda=use_cuda)
38
  return cam