Commit
·
4fce009
1
Parent(s):
2f77cf3
model.add_nms() method
Browse files- hubconf.py +1 -4
- models/yolo.py +10 -1
hubconf.py
CHANGED
@@ -37,10 +37,7 @@ def create(name, pretrained, channels, classes):
|
|
37 |
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
38 |
model.load_state_dict(state_dict, strict=False) # load
|
39 |
|
40 |
-
|
41 |
-
m.f = -1 # from
|
42 |
-
m.i = model.model[-1].i + 1 # index
|
43 |
-
model.model.add_module(name='%s' % m.i, module=m) # add NMS
|
44 |
model.eval()
|
45 |
return model
|
46 |
|
|
|
37 |
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
38 |
model.load_state_dict(state_dict, strict=False) # load
|
39 |
|
40 |
+
model.add_nms() # add NMS module
|
|
|
|
|
|
|
41 |
model.eval()
|
42 |
return model
|
43 |
|
models/yolo.py
CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
|
10 |
-
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat
|
11 |
from models.experimental import MixConv2d, CrossConv, C3
|
12 |
from utils.general import check_anchor_order, make_divisible, check_file, set_logging
|
13 |
from utils.torch_utils import (
|
@@ -168,6 +168,15 @@ class Model(nn.Module):
|
|
168 |
self.info()
|
169 |
return self
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
def info(self, verbose=False): # print model information
|
172 |
model_info(self, verbose)
|
173 |
|
|
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
|
10 |
+
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS
|
11 |
from models.experimental import MixConv2d, CrossConv, C3
|
12 |
from utils.general import check_anchor_order, make_divisible, check_file, set_logging
|
13 |
from utils.torch_utils import (
|
|
|
168 |
self.info()
|
169 |
return self
|
170 |
|
171 |
+
def add_nms(self): # fuse model Conv2d() + BatchNorm2d() layers
|
172 |
+
if type(self.model[-1]) is not NMS: # if missing NMS
|
173 |
+
print('Adding NMS module... ')
|
174 |
+
m = NMS() # module
|
175 |
+
m.f = -1 # from
|
176 |
+
m.i = self.model[-1].i + 1 # index
|
177 |
+
self.model.add_module(name='%s' % m.i, module=m) # add
|
178 |
+
return self
|
179 |
+
|
180 |
def info(self, verbose=False): # print model information
|
181 |
model_info(self, verbose)
|
182 |
|