Add feature map visualization (#3804)
Browse files* Add feature map visualization
Add a feature_visualization function to visualize the mid feature map of the model.
* Update yolo.py
* remove boolean from forward and reorder if statement
* remove print from forward
* General cleanup
* Indent
* Update plots.py
Co-authored-by: Glenn Jocher <[email protected]>
- models/yolo.py +5 -1
- utils/plots.py +28 -2
models/yolo.py
CHANGED
@@ -17,6 +17,7 @@ from models.common import *
|
|
17 |
from models.experimental import *
|
18 |
from utils.autoanchor import check_anchor_order
|
19 |
from utils.general import make_divisible, check_file, set_logging
|
|
|
20 |
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
|
21 |
select_device, copy_attr
|
22 |
|
@@ -135,7 +136,7 @@ class Model(nn.Module):
|
|
135 |
y.append(yi)
|
136 |
return torch.cat(y, 1), None # augmented inference, train
|
137 |
|
138 |
-
def forward_once(self, x, profile=False):
|
139 |
y, dt = [], [] # outputs
|
140 |
for m in self.model:
|
141 |
if m.f != -1: # if not from previous layer
|
@@ -153,6 +154,9 @@ class Model(nn.Module):
|
|
153 |
|
154 |
x = m(x) # run
|
155 |
y.append(x if m.i in self.save else None) # save output
|
|
|
|
|
|
|
156 |
|
157 |
if profile:
|
158 |
logger.info('%.1fms total' % sum(dt))
|
|
|
17 |
from models.experimental import *
|
18 |
from utils.autoanchor import check_anchor_order
|
19 |
from utils.general import make_divisible, check_file, set_logging
|
20 |
+
from utils.plots import feature_visualization
|
21 |
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
|
22 |
select_device, copy_attr
|
23 |
|
|
|
136 |
y.append(yi)
|
137 |
return torch.cat(y, 1), None # augmented inference, train
|
138 |
|
139 |
+
def forward_once(self, x, profile=False, feature_vis=False):
|
140 |
y, dt = [], [] # outputs
|
141 |
for m in self.model:
|
142 |
if m.f != -1: # if not from previous layer
|
|
|
154 |
|
155 |
x = m(x) # run
|
156 |
y.append(x if m.i in self.save else None) # save output
|
157 |
+
|
158 |
+
if feature_vis and m.type == 'models.common.SPP':
|
159 |
+
feature_visualization(x, m.type, m.i)
|
160 |
|
161 |
if profile:
|
162 |
logger.info('%.1fms total' % sum(dt))
|
utils/plots.py
CHANGED
@@ -15,8 +15,9 @@ import seaborn as sn
|
|
15 |
import torch
|
16 |
import yaml
|
17 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
18 |
|
19 |
-
from utils.general import xywh2xyxy, xyxy2xywh
|
20 |
from utils.metrics import fitness
|
21 |
|
22 |
# Settings
|
@@ -299,7 +300,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
|
|
299 |
matplotlib.use('svg') # faster
|
300 |
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
301 |
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
302 |
-
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
|
303 |
ax[0].set_ylabel('instances')
|
304 |
if 0 < len(names) < 30:
|
305 |
ax[0].set_xticks(range(len(names)))
|
@@ -445,3 +446,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
|
|
445 |
|
446 |
ax[1].legend()
|
447 |
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
import torch
|
16 |
import yaml
|
17 |
from PIL import Image, ImageDraw, ImageFont
|
18 |
+
from torchvision import transforms
|
19 |
|
20 |
+
from utils.general import increment_path, xywh2xyxy, xyxy2xywh
|
21 |
from utils.metrics import fitness
|
22 |
|
23 |
# Settings
|
|
|
300 |
matplotlib.use('svg') # faster
|
301 |
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
302 |
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
303 |
+
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
|
304 |
ax[0].set_ylabel('instances')
|
305 |
if 0 < len(names) < 30:
|
306 |
ax[0].set_xticks(range(len(names)))
|
|
|
446 |
|
447 |
ax[1].legend()
|
448 |
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
449 |
+
|
450 |
+
|
451 |
+
def feature_visualization(features, module_type, module_idx, n=64):
|
452 |
+
"""
|
453 |
+
features: Features to be visualized
|
454 |
+
module_type: Module type
|
455 |
+
module_idx: Module layer index within model
|
456 |
+
n: Maximum number of feature maps to plot
|
457 |
+
"""
|
458 |
+
project, name = 'runs/features', 'exp'
|
459 |
+
save_dir = increment_path(Path(project) / name) # increment run
|
460 |
+
save_dir.mkdir(parents=True, exist_ok=True) # make dir
|
461 |
+
|
462 |
+
plt.figure(tight_layout=True)
|
463 |
+
blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension
|
464 |
+
n = min(n, len(blocks))
|
465 |
+
for i in range(n):
|
466 |
+
feature = transforms.ToPILImage()(blocks[i].squeeze())
|
467 |
+
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
|
468 |
+
ax.axis('off')
|
469 |
+
plt.imshow(feature) # cmap='gray'
|
470 |
+
|
471 |
+
f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png"
|
472 |
+
print(f'Saving {save_dir / f}...')
|
473 |
+
plt.savefig(save_dir / f, dpi=300)
|