Speed profiling improvements (#2648)
Browse files* Speed profiling improvements
* Update torch_utils.py
deepcopy() required to avoid adding elements to model.
* Update torch_utils.py
- hubconf.py +4 -3
- utils/torch_utils.py +1 -1
hubconf.py
CHANGED
@@ -38,9 +38,10 @@ def create(name, pretrained, channels, classes, autoshape):
|
|
38 |
fname = f'{name}.pt' # checkpoint filename
|
39 |
attempt_download(fname) # download if not found locally
|
40 |
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
44 |
if len(ckpt['model'].names) == classes:
|
45 |
model.names = ckpt['model'].names # set class names attribute
|
46 |
if autoshape:
|
|
|
38 |
fname = f'{name}.pt' # checkpoint filename
|
39 |
attempt_download(fname) # download if not found locally
|
40 |
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
41 |
+
msd = model.state_dict() # model state_dict
|
42 |
+
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
43 |
+
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
44 |
+
model.load_state_dict(csd, strict=False) # load
|
45 |
if len(ckpt['model'].names) == classes:
|
46 |
model.names = ckpt['model'].names # set class names attribute
|
47 |
if autoshape:
|
utils/torch_utils.py
CHANGED
@@ -191,7 +191,7 @@ def fuse_conv_and_bn(conv, bn):
|
|
191 |
# prepare filters
|
192 |
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
193 |
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
194 |
-
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.
|
195 |
|
196 |
# prepare spatial bias
|
197 |
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
|
|
191 |
# prepare filters
|
192 |
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
193 |
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
194 |
+
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
195 |
|
196 |
# prepare spatial bias
|
197 |
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|