glenn-jocher commited on
Commit
4ffd977
·
1 Parent(s): 1e95337

plotting improvements (#471)

Browse files

Signed-off-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. utils/utils.py +8 -6
utils/utils.py CHANGED
@@ -958,13 +958,14 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
958
  yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
959
 
960
  fig = plt.figure(figsize=(6, 3), dpi=150)
961
- plt.plot(x, ya, '.-', label='yolo method')
962
- plt.plot(x, yb ** 2, '.-', label='^2 power method')
963
- plt.plot(x, yb ** 2.5, '.-', label='^2.5 power method')
964
  plt.xlim(left=-4, right=4)
965
  plt.ylim(bottom=0, top=6)
966
  plt.xlabel('input')
967
  plt.ylabel('output')
 
968
  plt.legend()
969
  fig.tight_layout()
970
  fig.savefig('comparison.png', dpi=200)
@@ -1134,8 +1135,6 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st
1134
 
1135
  def plot_labels(labels, save_dir=''):
1136
  # plot dataset labels
1137
- c, b = labels[:, 0], labels[:, 1:].transpose() # classees, boxes
1138
-
1139
  def hist2d(x, y, n=100):
1140
  xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
1141
  hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
@@ -1143,9 +1142,12 @@ def plot_labels(labels, save_dir=''):
1143
  yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
1144
  return np.log(hist[xidx, yidx])
1145
 
 
 
 
1146
  fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
1147
  ax = ax.ravel()
1148
- ax[0].hist(c, bins=int(c.max() + 1))
1149
  ax[0].set_xlabel('classes')
1150
  ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
1151
  ax[1].set_xlabel('x')
 
958
  yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2
959
 
960
  fig = plt.figure(figsize=(6, 3), dpi=150)
961
+ plt.plot(x, ya, '.-', label='YOLOv3')
962
+ plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
963
+ plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
964
  plt.xlim(left=-4, right=4)
965
  plt.ylim(bottom=0, top=6)
966
  plt.xlabel('input')
967
  plt.ylabel('output')
968
+ plt.grid()
969
  plt.legend()
970
  fig.tight_layout()
971
  fig.savefig('comparison.png', dpi=200)
 
1135
 
1136
  def plot_labels(labels, save_dir=''):
1137
  # plot dataset labels
 
 
1138
  def hist2d(x, y, n=100):
1139
  xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
1140
  hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
 
1142
  yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
1143
  return np.log(hist[xidx, yidx])
1144
 
1145
+ c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
1146
+ nc = int(c.max() + 1) # number of classes
1147
+
1148
  fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
1149
  ax = ax.ravel()
1150
+ ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
1151
  ax[0].set_xlabel('classes')
1152
  ax[1].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
1153
  ax[1].set_xlabel('x')