glenn-jocher commited on
Commit
08d4918
·
unverified ·
1 Parent(s): f419721

labels.jpg class names (#2454)

Browse files

* labels.png class names

* fontsize=10

Files changed (2) hide show
  1. train.py +1 -1
  2. utils/plots.py +7 -2
train.py CHANGED
@@ -203,7 +203,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
203
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
204
  # model._initialize_biases(cf.to(device))
205
  if plots:
206
- plot_labels(labels, save_dir, loggers)
207
  if tb_writer:
208
  tb_writer.add_histogram('classes', c, 0)
209
 
 
203
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
204
  # model._initialize_biases(cf.to(device))
205
  if plots:
206
+ plot_labels(labels, names, save_dir, loggers)
207
  if tb_writer:
208
  tb_writer.add_histogram('classes', c, 0)
209
 
utils/plots.py CHANGED
@@ -269,7 +269,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
269
  plt.savefig(str(Path(path).name) + '.png', dpi=300)
270
 
271
 
272
- def plot_labels(labels, save_dir=Path(''), loggers=None):
273
  # plot dataset labels
274
  print('Plotting labels... ')
275
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
@@ -286,7 +286,12 @@ def plot_labels(labels, save_dir=Path(''), loggers=None):
286
  matplotlib.use('svg') # faster
287
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
288
  ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
289
- ax[0].set_xlabel('classes')
 
 
 
 
 
290
  sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
291
  sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
292
 
 
269
  plt.savefig(str(Path(path).name) + '.png', dpi=300)
270
 
271
 
272
+ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
273
  # plot dataset labels
274
  print('Plotting labels... ')
275
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
 
286
  matplotlib.use('svg') # faster
287
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
288
  ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
289
+ ax[0].set_ylabel('instances')
290
+ if 0 < len(names) < 30:
291
+ ax[0].set_xticks(range(len(names)))
292
+ ax[0].set_xticklabels(names, rotation=90, fontsize=10)
293
+ else:
294
+ ax[0].set_xlabel('classes')
295
  sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
296
  sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
297