Commit
·
8fe299f
1
Parent(s):
c672bef
model fuse
Browse files- utils/torch_utils.py +1 -1
utils/torch_utils.py
CHANGED
@@ -90,7 +90,7 @@ def fuse_conv_and_bn(conv, bn):
|
|
90 |
if conv.bias is not None:
|
91 |
b_conv = conv.bias
|
92 |
else:
|
93 |
-
b_conv = torch.zeros(conv.weight.size(0))
|
94 |
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
95 |
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
96 |
|
|
|
90 |
if conv.bias is not None:
|
91 |
b_conv = conv.bias
|
92 |
else:
|
93 |
+
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
|
94 |
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
95 |
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
96 |
|