Save *.npy features on detect.py `--visualize` (#5701)
Browse files* Add feature map to save npy files
Add feature map to save npy files,export npy files with 32 feature maps per layer.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update plots.py
* Update plots.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update plots.py
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
- tutorial.ipynb +1 -1
- utils/plots.py +4 -3
tutorial.ipynb
CHANGED
@@ -1104,4 +1104,4 @@
|
|
1104 |
"outputs": []
|
1105 |
}
|
1106 |
]
|
1107 |
-
}
|
|
|
1104 |
"outputs": []
|
1105 |
}
|
1106 |
]
|
1107 |
+
}
|
utils/plots.py
CHANGED
@@ -132,7 +132,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
|
|
132 |
if 'Detect' not in module_type:
|
133 |
batch, channels, height, width = x.shape # batch, channels, height, width
|
134 |
if height > 1 and width > 1:
|
135 |
-
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
136 |
|
137 |
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
138 |
n = min(n, channels) # number of plots
|
@@ -143,9 +143,10 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
|
|
143 |
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
144 |
ax[i].axis('off')
|
145 |
|
146 |
-
print(f'Saving {
|
147 |
-
plt.savefig(
|
148 |
plt.close()
|
|
|
149 |
|
150 |
|
151 |
def hist2d(x, y, n=100):
|
|
|
132 |
if 'Detect' not in module_type:
|
133 |
batch, channels, height, width = x.shape # batch, channels, height, width
|
134 |
if height > 1 and width > 1:
|
135 |
+
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
136 |
|
137 |
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
138 |
n = min(n, channels) # number of plots
|
|
|
143 |
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
144 |
ax[i].axis('off')
|
145 |
|
146 |
+
print(f'Saving {f}... ({n}/{channels})')
|
147 |
+
plt.savefig(f, dpi=300, bbox_inches='tight')
|
148 |
plt.close()
|
149 |
+
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
|
150 |
|
151 |
|
152 |
def hist2d(x, y, n=100):
|