lucytuan commited on
Commit
80ba8d4
·
1 Parent(s): 183312f

:hammer: [Fix] the used functions in module.py

Browse files
Files changed (1) hide show
  1. model/module.py +26 -4
model/module.py CHANGED
@@ -11,10 +11,10 @@ class Conv(nn.Module):
11
  out_channels,
12
  kernel_size,
13
  stride=1,
14
- padding=0,
15
  dilation=1,
16
  groups=1,
17
- act=nn.ReLU(),
18
  bias=False,
19
  auto_padding=True,
20
  padding_mode="zeros",
@@ -48,10 +48,10 @@ class Conv(nn.Module):
48
  # RepVGG
49
  class RepConv(nn.Module):
50
  # https://github.com/DingXiaoH/RepVGG
51
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, act=nn.ReLU()):
52
 
53
  super().__init__()
54
-
55
  self.conv1 = Conv(in_channels, out_channels, kernel_size, stride, groups=groups, act=False)
56
  self.conv2 = Conv(in_channels, out_channels, 1, stride, groups=groups, act=False)
57
  self.act = act if isinstance(act, nn.Module) else nn.Identity()
@@ -64,6 +64,28 @@ class RepConv(nn.Module):
64
 
65
  # to be implement
66
  # def fuse_convs(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  # ResNet
 
11
  out_channels,
12
  kernel_size,
13
  stride=1,
14
+ padding=None,
15
  dilation=1,
16
  groups=1,
17
+ act=nn.SiLU(),
18
  bias=False,
19
  auto_padding=True,
20
  padding_mode="zeros",
 
48
  # RepVGG
49
  class RepConv(nn.Module):
50
  # https://github.com/DingXiaoH/RepVGG
51
+ def __init__(self, in_channels, out_channels, kernel_size=3, padding=None, stride=1, groups=1, act=nn.SiLU(), deploy=False):
52
 
53
  super().__init__()
54
+ self.deploy = deploy
55
  self.conv1 = Conv(in_channels, out_channels, kernel_size, stride, groups=groups, act=False)
56
  self.conv2 = Conv(in_channels, out_channels, 1, stride, groups=groups, act=False)
57
  self.act = act if isinstance(act, nn.Module) else nn.Identity()
 
64
 
65
  # to be implement
66
  # def fuse_convs(self):
67
+ def fuse_conv_bn(self, conv, bn):
68
+
69
+ std = (bn.running_var + bn.eps).sqrt()
70
+ bias = bn.bias - bn.running_mean * bn.weight / std
71
+
72
+ t = (bn.weight / std).reshape(-1, 1, 1, 1)
73
+ weights = conv.weight * t
74
+
75
+ bn = nn.Identity()
76
+ conv = nn.Conv2d(in_channels = conv.in_channels,
77
+ out_channels = conv.out_channels,
78
+ kernel_size = conv.kernel_size,
79
+ stride=conv.stride,
80
+ padding = conv.padding,
81
+ dilation = conv.dilation,
82
+ groups = conv.groups,
83
+ bias = True,
84
+ padding_mode = conv.padding_mode)
85
+
86
+ conv.weight = torch.nn.Parameter(weights)
87
+ conv.bias = torch.nn.Parameter(bias)
88
+ return conv
89
 
90
 
91
  # ResNet