Simpler code for DWConvClass (#4310)
Browse files* more simpler code for DWConvClass
more simpler code for DWConvClass
* remove DWConv function
* Replace DWConvClass with DWConv
- models/common.py +2 -8
- models/yolo.py +1 -1
models/common.py
CHANGED
@@ -29,11 +29,6 @@ def autopad(k, p=None): # kernel, padding
|
|
29 |
return p
|
30 |
|
31 |
|
32 |
-
def DWConv(c1, c2, k=1, s=1, act=True):
|
33 |
-
# Depth-wise convolution function
|
34 |
-
return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
|
35 |
-
|
36 |
-
|
37 |
class Conv(nn.Module):
|
38 |
# Standard convolution
|
39 |
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
@@ -49,11 +44,10 @@ class Conv(nn.Module):
|
|
49 |
return self.act(self.conv(x))
|
50 |
|
51 |
|
52 |
-
class
|
53 |
# Depth-wise convolution class
|
54 |
def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
55 |
-
super().__init__(c1, c2, k, s, act)
|
56 |
-
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k), groups=math.gcd(c1, c2), bias=False)
|
57 |
|
58 |
|
59 |
class TransformerLayer(nn.Module):
|
|
|
29 |
return p
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
class Conv(nn.Module):
|
33 |
# Standard convolution
|
34 |
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
|
|
44 |
return self.act(self.conv(x))
|
45 |
|
46 |
|
47 |
+
class DWConv(Conv):
|
48 |
# Depth-wise convolution class
|
49 |
def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
50 |
+
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
|
|
|
51 |
|
52 |
|
53 |
class TransformerLayer(nn.Module):
|
models/yolo.py
CHANGED
@@ -202,7 +202,7 @@ class Model(nn.Module):
|
|
202 |
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
203 |
LOGGER.info('Fusing layers... ')
|
204 |
for m in self.model.modules():
|
205 |
-
if isinstance(m, (Conv,
|
206 |
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
207 |
delattr(m, 'bn') # remove batchnorm
|
208 |
m.forward = m.forward_fuse # update forward
|
|
|
202 |
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
203 |
LOGGER.info('Fusing layers... ')
|
204 |
for m in self.model.modules():
|
205 |
+
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
206 |
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
207 |
delattr(m, 'bn') # remove batchnorm
|
208 |
m.forward = m.forward_fuse # update forward
|