glenn-jocher commited on
Commit
c4cb785
·
1 Parent(s): 5a9c5c1

add NMS to pretrained pytorch hub models

Browse files
Files changed (2) hide show
  1. hubconf.py +7 -0
  2. models/common.py +14 -0
hubconf.py CHANGED
@@ -10,6 +10,7 @@ import os
10
 
11
  import torch
12
 
 
13
  from models.yolo import Model
14
  from utils.google_utils import attempt_download
15
 
@@ -35,6 +36,12 @@ def create(name, pretrained, channels, classes):
35
  state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
36
  state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
37
  model.load_state_dict(state_dict, strict=False) # load
 
 
 
 
 
 
38
  return model
39
 
40
  except Exception as e:
 
10
 
11
  import torch
12
 
13
+ from models.common import NMS
14
  from models.yolo import Model
15
  from utils.google_utils import attempt_download
16
 
 
36
  state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
37
  state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
38
  model.load_state_dict(state_dict, strict=False) # load
39
+
40
+ m = NMS()
41
+ m.f = -1 # from
42
+ m.i = model.model[-1].i + 1 # index
43
+ model.model.add_module(name='%s' % m.i, module=m) # add NMS
44
+ model.eval()
45
  return model
46
 
47
  except Exception as e:
models/common.py CHANGED
@@ -3,6 +3,7 @@ import math
3
 
4
  import torch
5
  import torch.nn as nn
 
6
 
7
 
8
  def autopad(k, p=None): # kernel, padding
@@ -98,6 +99,19 @@ class Concat(nn.Module):
98
  return torch.cat(x, self.d)
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  class Flatten(nn.Module):
102
  # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
103
  @staticmethod
 
3
 
4
  import torch
5
  import torch.nn as nn
6
+ from utils.general import non_max_suppression
7
 
8
 
9
  def autopad(k, p=None): # kernel, padding
 
99
  return torch.cat(x, self.d)
100
 
101
 
102
+ class NMS(nn.Module):
103
+ # Non-Maximum Suppression (NMS) module
104
+ conf = 0.3 # confidence threshold
105
+ iou = 0.6 # IoU threshold
106
+ classes = None # (optional list) filter by class
107
+
108
+ def __init__(self, dimension=1):
109
+ super(NMS, self).__init__()
110
+
111
+ def forward(self, x):
112
+ return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
113
+
114
+
115
  class Flatten(nn.Module):
116
  # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
117
  @staticmethod