Zengyf-CVer pre-commit-ci[bot] glenn-jocher commited on
Commit
f17c86b
·
unverified ·
1 Parent(s): d6ae1c8

Save *.npy features on detect.py `--visualize` (#5701)

Browse files

* Add feature map to save npy files

Add feature map to save npy files,export npy files with 32 feature maps per layer.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update plots.py

* Update plots.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update plots.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>

Files changed (2) hide show
  1. tutorial.ipynb +1 -1
  2. utils/plots.py +4 -3
tutorial.ipynb CHANGED
@@ -1104,4 +1104,4 @@
1104
  "outputs": []
1105
  }
1106
  ]
1107
- }
 
1104
  "outputs": []
1105
  }
1106
  ]
1107
+ }
utils/plots.py CHANGED
@@ -132,7 +132,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
132
  if 'Detect' not in module_type:
133
  batch, channels, height, width = x.shape # batch, channels, height, width
134
  if height > 1 and width > 1:
135
- f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
136
 
137
  blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
138
  n = min(n, channels) # number of plots
@@ -143,9 +143,10 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
143
  ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
144
  ax[i].axis('off')
145
 
146
- print(f'Saving {save_dir / f}... ({n}/{channels})')
147
- plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')
148
  plt.close()
 
149
 
150
 
151
  def hist2d(x, y, n=100):
 
132
  if 'Detect' not in module_type:
133
  batch, channels, height, width = x.shape # batch, channels, height, width
134
  if height > 1 and width > 1:
135
+ f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
136
 
137
  blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
138
  n = min(n, channels) # number of plots
 
143
  ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
144
  ax[i].axis('off')
145
 
146
+ print(f'Saving {f}... ({n}/{channels})')
147
+ plt.savefig(f, dpi=300, bbox_inches='tight')
148
  plt.close()
149
+ np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
150
 
151
 
152
  def hist2d(x, y, n=100):