SixOpen commited on
Commit
b2b5573
·
verified ·
1 Parent(s): 4576e40

Update Scripts/model.py

Browse files
Files changed (1) hide show
  1. Scripts/model.py +14 -10
Scripts/model.py CHANGED
@@ -4,24 +4,27 @@ from efficientnet_pytorch import EfficientNet
4
  from pytorch_grad_cam import GradCAMElementWise
5
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
6
 
7
-
8
  class Detector(nn.Module):
9
  def __init__(self):
10
  super(Detector, self).__init__()
11
- self.net=EfficientNet.from_pretrained("efficientnet-b4",advprop=True,num_classes=2)
12
 
13
- def forward(self,x):
14
- x=self.net(x)
15
  return x
16
-
17
 
18
  def create_model(path="Weights/94_0.9485_val.tar", device=torch.device('cpu')):
19
- model=Detector()
20
- model=model.to(device)
 
 
 
 
 
21
  if device == torch.device('cpu'):
22
- cnn_sd=torch.load(path, map_location=torch.device('cpu') )["model"]
23
  else:
24
- cnn_sd=torch.load(path)["model"]
25
  model.load_state_dict(cnn_sd)
26
  model.eval()
27
  return model
@@ -30,5 +33,6 @@ def create_cam(model):
30
  target_layers = [model.net._blocks[-1]]
31
  targets = [ClassifierOutputTarget(1)]
32
  cam_algorithm = GradCAMElementWise
33
- cam = cam_algorithm(model=model,target_layers=target_layers,use_cuda=False)
 
34
  return cam
 
4
  from pytorch_grad_cam import GradCAMElementWise
5
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
6
 
 
7
  class Detector(nn.Module):
8
  def __init__(self):
9
  super(Detector, self).__init__()
10
+ self.net = EfficientNet.from_pretrained("efficientnet-b4", advprop=True, num_classes=2)
11
 
12
+ def forward(self, x):
13
+ x = self.net(x)
14
  return x
 
15
 
16
  def create_model(path="Weights/94_0.9485_val.tar", device=torch.device('cpu')):
17
+ model = Detector()
18
+ try:
19
+ if device.type == 'cuda':
20
+ model = model.half()
21
+ except:
22
+ model = model.float()
23
+ model = model.to(device)
24
  if device == torch.device('cpu'):
25
+ cnn_sd = torch.load(path, map_location=torch.device('cpu'))["model"]
26
  else:
27
+ cnn_sd = torch.load(path)["model"]
28
  model.load_state_dict(cnn_sd)
29
  model.eval()
30
  return model
 
33
  target_layers = [model.net._blocks[-1]]
34
  targets = [ClassifierOutputTarget(1)]
35
  cam_algorithm = GradCAMElementWise
36
+ use_cuda = torch.cuda.is_available() and model.device.type == 'cuda'
37
+ cam = cam_algorithm(model=model, target_layers=target_layers, use_cuda=use_cuda)
38
  return cam