glenn-jocher commited on
Commit
866bc7d
·
unverified ·
1 Parent(s): 1e8ab3f

Speed profiling improvements (#2648)

Browse files

* Speed profiling improvements

* Update torch_utils.py

deepcopy() required to avoid adding elements to model.

* Update torch_utils.py

Files changed (2) hide show
  1. hubconf.py +4 -3
  2. utils/torch_utils.py +1 -1
hubconf.py CHANGED
@@ -38,9 +38,10 @@ def create(name, pretrained, channels, classes, autoshape):
38
  fname = f'{name}.pt' # checkpoint filename
39
  attempt_download(fname) # download if not found locally
40
  ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
41
- state_dict = ckpt['model'].float().state_dict() # to FP32
42
- state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
43
- model.load_state_dict(state_dict, strict=False) # load
 
44
  if len(ckpt['model'].names) == classes:
45
  model.names = ckpt['model'].names # set class names attribute
46
  if autoshape:
 
38
  fname = f'{name}.pt' # checkpoint filename
39
  attempt_download(fname) # download if not found locally
40
  ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
41
+ msd = model.state_dict() # model state_dict
42
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
43
+ csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
44
+ model.load_state_dict(csd, strict=False) # load
45
  if len(ckpt['model'].names) == classes:
46
  model.names = ckpt['model'].names # set class names attribute
47
  if autoshape:
utils/torch_utils.py CHANGED
@@ -191,7 +191,7 @@ def fuse_conv_and_bn(conv, bn):
191
  # prepare filters
192
  w_conv = conv.weight.clone().view(conv.out_channels, -1)
193
  w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
194
- fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
195
 
196
  # prepare spatial bias
197
  b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
 
191
  # prepare filters
192
  w_conv = conv.weight.clone().view(conv.out_channels, -1)
193
  w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
194
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
195
 
196
  # prepare spatial bias
197
  b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias