Commit
·
89655a8
1
Parent(s):
c4cb785
.fuse() gradient introduction bug fix
Browse files- utils/torch_utils.py +22 -22
utils/torch_utils.py
CHANGED
@@ -104,28 +104,28 @@ def prune(model, amount=0.3):
|
|
104 |
|
105 |
|
106 |
def fuse_conv_and_bn(conv, bn):
|
107 |
-
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
|
130 |
|
131 |
def model_info(model, verbose=False):
|
|
|
104 |
|
105 |
|
106 |
def fuse_conv_and_bn(conv, bn):
|
107 |
+
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
108 |
+
|
109 |
+
# init
|
110 |
+
fusedconv = nn.Conv2d(conv.in_channels,
|
111 |
+
conv.out_channels,
|
112 |
+
kernel_size=conv.kernel_size,
|
113 |
+
stride=conv.stride,
|
114 |
+
padding=conv.padding,
|
115 |
+
groups=conv.groups,
|
116 |
+
bias=True).requires_grad_(False).to(conv.weight.device)
|
117 |
+
|
118 |
+
# prepare filters
|
119 |
+
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
120 |
+
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
121 |
+
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
|
122 |
+
|
123 |
+
# prepare spatial bias
|
124 |
+
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
125 |
+
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
126 |
+
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
127 |
+
|
128 |
+
return fusedconv
|
129 |
|
130 |
|
131 |
def model_info(model, verbose=False):
|