Increase plot_labels() speed (#1736)
Browse files- train.py +1 -1
- utils/plots.py +9 -17
train.py
CHANGED
@@ -205,7 +205,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
205 |
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
206 |
# model._initialize_biases(cf.to(device))
|
207 |
if plots:
|
208 |
-
|
209 |
if tb_writer:
|
210 |
tb_writer.add_histogram('classes', c, 0)
|
211 |
|
|
|
205 |
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
206 |
# model._initialize_biases(cf.to(device))
|
207 |
if plots:
|
208 |
+
plot_labels(labels, save_dir, loggers)
|
209 |
if tb_writer:
|
210 |
tb_writer.add_histogram('classes', c, 0)
|
211 |
|
utils/plots.py
CHANGED
@@ -11,6 +11,8 @@ import cv2
|
|
11 |
import matplotlib
|
12 |
import matplotlib.pyplot as plt
|
13 |
import numpy as np
|
|
|
|
|
14 |
import torch
|
15 |
import yaml
|
16 |
from PIL import Image, ImageDraw
|
@@ -253,34 +255,24 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
|
|
253 |
|
254 |
def plot_labels(labels, save_dir=Path(''), loggers=None):
|
255 |
# plot dataset labels
|
|
|
256 |
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
257 |
nc = int(c.max() + 1) # number of classes
|
258 |
colors = color_list()
|
|
|
259 |
|
260 |
# seaborn correlogram
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
265 |
-
sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
|
266 |
-
plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
|
267 |
-
diag_kws=dict(bins=50))
|
268 |
-
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
269 |
-
plt.close()
|
270 |
-
except Exception as e:
|
271 |
-
pass
|
272 |
|
273 |
# matplotlib labels
|
274 |
matplotlib.use('svg') # faster
|
275 |
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
276 |
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
277 |
ax[0].set_xlabel('classes')
|
278 |
-
|
279 |
-
ax[
|
280 |
-
ax[2].set_ylabel('y')
|
281 |
-
ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
|
282 |
-
ax[3].set_xlabel('width')
|
283 |
-
ax[3].set_ylabel('height')
|
284 |
|
285 |
# rectangles
|
286 |
labels[:, 1:3] = 0.5 # center
|
|
|
11 |
import matplotlib
|
12 |
import matplotlib.pyplot as plt
|
13 |
import numpy as np
|
14 |
+
import pandas as pd
|
15 |
+
import seaborn as sns
|
16 |
import torch
|
17 |
import yaml
|
18 |
from PIL import Image, ImageDraw
|
|
|
255 |
|
256 |
def plot_labels(labels, save_dir=Path(''), loggers=None):
|
257 |
# plot dataset labels
|
258 |
+
print('Plotting labels... ')
|
259 |
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
260 |
nc = int(c.max() + 1) # number of classes
|
261 |
colors = color_list()
|
262 |
+
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
263 |
|
264 |
# seaborn correlogram
|
265 |
+
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
266 |
+
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
267 |
+
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
# matplotlib labels
|
270 |
matplotlib.use('svg') # faster
|
271 |
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
272 |
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
273 |
ax[0].set_xlabel('classes')
|
274 |
+
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
275 |
+
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
|
|
|
|
|
|
|
|
276 |
|
277 |
# rectangles
|
278 |
labels[:, 1:3] = 0.5 # center
|