astoken commited on
Commit
25e51bc
·
1 Parent(s): 490f1e7

add util function to get most recent last.pt file

Browse files

added logic in train.py __main__ to handle resuming from a run

Files changed (2) hide show
  1. train.py +10 -3
  2. utils/utils.py +6 -0
train.py CHANGED
@@ -198,10 +198,10 @@ def train(hyp):
198
  model.names = data_dict['names']
199
 
200
  #save hyperparamter and training options in run folder
201
- with open(os.path.join(log_dir, 'hyp.yaml', 'w')) as f:
202
  yaml.dump(hyp, f)
203
 
204
- with open(os.path.join(log_dir, 'opt.yaml', 'w')) as f:
205
  yaml.dump(opt, f)
206
 
207
  # Class frequency
@@ -294,7 +294,7 @@ def train(hyp):
294
 
295
  # Plot
296
  if ni < 3:
297
- f = 'train_batch%g.jpg' % i # filename
298
  res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
299
  if tb_writer:
300
  tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
@@ -385,6 +385,7 @@ if __name__ == '__main__':
385
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
386
  parser.add_argument('--rect', action='store_true', help='rectangular training')
387
  parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
 
388
  parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
389
  parser.add_argument('--notest', action='store_true', help='only test final epoch')
390
  parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
@@ -398,6 +399,12 @@ if __name__ == '__main__':
398
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
399
  parser.add_argument('--hyp', type=str, default='', help ='path to hyp yaml file')
400
  opt = parser.parse_args()
 
 
 
 
 
 
401
  opt.weights = last if opt.resume else opt.weights
402
  opt.cfg = check_file(opt.cfg) # check file
403
  opt.data = check_file(opt.data) # check file
 
198
  model.names = data_dict['names']
199
 
200
  #save hyperparamter and training options in run folder
201
+ with open(os.path.join(log_dir, 'hyp.yaml'), 'w') as f:
202
  yaml.dump(hyp, f)
203
 
204
+ with open(os.path.join(log_dir, 'opt.yaml'), 'w') as f:
205
  yaml.dump(opt, f)
206
 
207
  # Class frequency
 
294
 
295
  # Plot
296
  if ni < 3:
297
+ f = os.path.join(log_dir, 'train_batch%g.jpg' % i) # filename
298
  res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
299
  if tb_writer:
300
  tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
 
385
  parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
386
  parser.add_argument('--rect', action='store_true', help='rectangular training')
387
  parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
388
+ parser.add_argument('--resume_from_run', type=str, default='', 'resume training from last.pt in this dir')
389
  parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
390
  parser.add_argument('--notest', action='store_true', help='only test final epoch')
391
  parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
 
399
  parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
400
  parser.add_argument('--hyp', type=str, default='', help ='path to hyp yaml file')
401
  opt = parser.parse_args()
402
+
403
+ if opt.resume and not opt.resume_from_run:
404
+ last = get_latest_run()
405
+ print(f'WARNING: No run provided to resume from. Resuming from most recent run found at {last}')
406
+ else:
407
+ last = opt.resume_from_run
408
  opt.weights = last if opt.resume else opt.weights
409
  opt.cfg = check_file(opt.cfg) # check file
410
  opt.data = check_file(opt.data) # check file
utils/utils.py CHANGED
@@ -36,6 +36,12 @@ def init_seeds(seed=0):
36
  np.random.seed(seed)
37
  torch_utils.init_seeds(seed=seed)
38
 
 
 
 
 
 
 
39
 
40
  def check_git_status():
41
  # Suggest 'git pull' if repo is out of date
 
36
  np.random.seed(seed)
37
  torch_utils.init_seeds(seed=seed)
38
 
39
+ def get_latest_run(search_dir = './runs/'):
40
+ # get path to most recent 'last.pt' in run dirs
41
+ # assumes most recently saved 'last.pt' is the desired weights to --resume from
42
+ last_list = glob.glob('runs/*/last.pt')
43
+ latest = max(last_list, key = os.path.getctime)
44
+ return latest
45
 
46
  def check_git_status():
47
  # Suggest 'git pull' if repo is out of date