glenn-jocher commited on
Commit
453acde
·
1 Parent(s): 7f16406

Update tensorboard logging

Browse files
Files changed (3) hide show
  1. test.py +2 -2
  2. train.py +4 -4
  3. utils/general.py +3 -2
test.py CHANGED
@@ -191,9 +191,9 @@ def test(data,
191
 
192
  # Plot images
193
  if plots and batch_i < 1:
194
- f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename
195
  plot_images(img, targets, paths, str(f), names) # ground truth
196
- f = save_dir / ('test_batch%g_pred.jpg' % batch_i)
197
  plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
198
 
199
  # Compute statistics
 
191
 
192
  # Plot images
193
  if plots and batch_i < 1:
194
+ f = save_dir / f'test_batch{batch_i}_gt.jpg' # filename
195
  plot_images(img, targets, paths, str(f), names) # ground truth
196
+ f = save_dir / f'test_batch{batch_i}_pred.jpg'
197
  plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions
198
 
199
  # Compute statistics
train.py CHANGED
@@ -291,11 +291,11 @@ def train(hyp, opt, device, tb_writer=None):
291
 
292
  # Plot
293
  if ni < 3:
294
- f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename
295
  result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
296
- if tb_writer and result is not None:
297
- tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
298
- # tb_writer.add_graph(model, imgs) # add model to tensorboard
299
 
300
  # end batch ------------------------------------------------------------------------------------------------
301
 
 
291
 
292
  # Plot
293
  if ni < 3:
294
+ f = str(log_dir / f'train_batch{ni}.jpg') # filename
295
  result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
296
+ # if tb_writer and result is not None:
297
+ # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
298
+ # tb_writer.add_graph(model, imgs) # add model to tensorboard
299
 
300
  # end batch ------------------------------------------------------------------------------------------------
301
 
utils/general.py CHANGED
@@ -19,6 +19,7 @@ import numpy as np
19
  import torch
20
  import torch.nn as nn
21
  import yaml
 
22
  from scipy.cluster.vq import kmeans
23
  from scipy.signal import butter, filtfilt
24
  from tqdm import tqdm
@@ -1096,8 +1097,8 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
1096
 
1097
  if fname is not None:
1098
  mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
1099
- cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))
1100
-
1101
  return mosaic
1102
 
1103
 
 
19
  import torch
20
  import torch.nn as nn
21
  import yaml
22
+ from PIL import Image
23
  from scipy.cluster.vq import kmeans
24
  from scipy.signal import butter, filtfilt
25
  from tqdm import tqdm
 
1097
 
1098
  if fname is not None:
1099
  mosaic = cv2.resize(mosaic, (int(ns * w * 0.5), int(ns * h * 0.5)), interpolation=cv2.INTER_AREA)
1100
+ # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
1101
+ Image.fromarray(mosaic).save(fname) # PIL save
1102
  return mosaic
1103
 
1104