astoken commited on
Commit
9b7386f
·
1 Parent(s): 945307b

Add save_dir arg to test.test, use arg as location for saving batch jpgs

Browse files
Files changed (2) hide show
  1. test.py +5 -4
  2. train.py +2 -1
test.py CHANGED
@@ -20,7 +20,8 @@ def test(data,
20
  model=None,
21
  dataloader=None,
22
  fast=False,
23
- verbose=False):
 
24
  # Initialize/load model and set device
25
  if model is None:
26
  training = False
@@ -28,7 +29,7 @@ def test(data,
28
  half = device.type != 'cpu' # half precision only supported on CUDA
29
 
30
  # Remove previous
31
- for f in glob.glob('test_batch*.jpg'):
32
  os.remove(f)
33
 
34
  # Load model
@@ -177,9 +178,9 @@ def test(data,
177
 
178
  # Plot images
179
  if batch_i < 1:
180
- f = 'test_batch%g_gt.jpg' % batch_i # filename
181
  plot_images(img, targets, paths, f, names) # ground truth
182
- f = 'test_batch%g_pred.jpg' % batch_i
183
  plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions
184
 
185
  # Compute statistics
 
20
  model=None,
21
  dataloader=None,
22
  fast=False,
23
+ verbose=False,
24
+ save_dir='.'):
25
  # Initialize/load model and set device
26
  if model is None:
27
  training = False
 
29
  half = device.type != 'cpu' # half precision only supported on CUDA
30
 
31
  # Remove previous
32
+ for f in glob.glob(f'{save_dir}/test_batch*.jpg'):
33
  os.remove(f)
34
 
35
  # Load model
 
178
 
179
  # Plot images
180
  if batch_i < 1:
181
+ f = os.path.join(save_dir, 'test_batch%g_gt.jpg' % batch_i) # filename
182
  plot_images(img, targets, paths, f, names) # ground truth
183
+ f = os.path.join(save_dir,'test_batch%g_pred.jpg' % batch_i)
184
  plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions
185
 
186
  # Compute statistics
train.py CHANGED
@@ -303,7 +303,8 @@ def train(hyp):
303
  model=ema.ema,
304
  single_cls=opt.single_cls,
305
  dataloader=testloader,
306
- fast=epoch < epochs / 2)
 
307
 
308
  # Write
309
  with open(results_file, 'a') as f:
 
303
  model=ema.ema,
304
  single_cls=opt.single_cls,
305
  dataloader=testloader,
306
+ fast=epoch < epochs / 2
307
+ save_dir=log_dir)
308
 
309
  # Write
310
  with open(results_file, 'a') as f: