glenn-jocher commited on
Commit
38f5c1a
·
1 Parent(s): 997ba7b

pruning and sparsity initial commit

Browse files
Files changed (2) hide show
  1. models/yolo.py +1 -1
  2. utils/torch_utils.py +20 -0
models/yolo.py CHANGED
@@ -48,7 +48,7 @@ class Model(nn.Module):
48
  if type(model_cfg) is dict:
49
  self.md = model_cfg # model dict
50
  else: # is *.yaml
51
- import yaml
52
  with open(model_cfg) as f:
53
  self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
54
 
 
48
  if type(model_cfg) is dict:
49
  self.md = model_cfg # model dict
50
  else: # is *.yaml
51
+ import yaml # for torch hub
52
  with open(model_cfg) as f:
53
  self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
54
 
utils/torch_utils.py CHANGED
@@ -76,6 +76,26 @@ def find_modules(model, mclass=nn.Conv2d):
76
  return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def fuse_conv_and_bn(conv, bn):
80
  # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
81
  with torch.no_grad():
 
76
  return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
77
 
78
 
79
+ def sparsity(model):
80
+ # Return global model sparsity
81
+ a, b = 0., 0.
82
+ for p in model.parameters():
83
+ a += p.numel()
84
+ b += (p == 0).sum()
85
+ return b / a
86
+
87
+
88
+ def prune(model, amount=0.3):
89
+ # Prune model to requested global sparsity
90
+ import torch.nn.utils.prune as prune
91
+ print('Pruning model... ', end='')
92
+ for name, m in model.named_modules():
93
+ if isinstance(m, torch.nn.Conv2d):
94
+ prune.l1_unstructured(m, name='weight', amount=amount) # prune
95
+ prune.remove(m, 'weight') # make permanent
96
+ print(' %.3g global sparsity' % sparsity(model))
97
+
98
+
99
  def fuse_conv_and_bn(conv, bn):
100
  # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
101
  with torch.no_grad():