astoken commited on
Commit
333f678
·
1 Parent(s): a448c3b

add update default hyp dict with provided yaml

Browse files
Files changed (1) hide show
  1. train.py +10 -12
train.py CHANGED
@@ -42,17 +42,6 @@ hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
42
  # Don't need to be printing every time
43
  #print(hyp)
44
 
45
- # Overwrite hyp with hyp*.txt (optional)
46
- if f:
47
- print('Using %s' % f[0])
48
- for k, v in zip(hyp.keys(), np.loadtxt(f[0])):
49
- hyp[k] = v
50
-
51
- # Print focal loss if gamma > 0
52
- if hyp['fl_gamma']:
53
- print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma'])
54
-
55
-
56
  def train(hyp):
57
  #write all results to the tb log_dir, so all data from one run is together
58
  log_dir = tb_writer.log_dir
@@ -410,7 +399,7 @@ if __name__ == '__main__':
410
  print(f'WARNING: No run provided to resume from. Resuming from most recent run found at {last}')
411
  else:
412
  last = ''
413
-
414
  # if resuming, check for hyp file
415
  if last:
416
  last_hyp = last.replace('last.pt', 'hyp.yaml')
@@ -430,7 +419,16 @@ if __name__ == '__main__':
430
  # Train
431
  if not opt.evolve:
432
  tb_writer = SummaryWriter(comment=opt.name)
 
 
 
 
 
 
 
 
433
  print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
 
434
  train(hyp)
435
 
436
  # Evolve hyperparameters (optional)
 
42
  # Don't need to be printing every time
43
  #print(hyp)
44
 
 
 
 
 
 
 
 
 
 
 
 
45
  def train(hyp):
46
  #write all results to the tb log_dir, so all data from one run is together
47
  log_dir = tb_writer.log_dir
 
399
  print(f'WARNING: No run provided to resume from. Resuming from most recent run found at {last}')
400
  else:
401
  last = ''
402
+
403
  # if resuming, check for hyp file
404
  if last:
405
  last_hyp = last.replace('last.pt', 'hyp.yaml')
 
419
  # Train
420
  if not opt.evolve:
421
  tb_writer = SummaryWriter(comment=opt.name)
422
+
423
+ #updates hyp defaults from hyp.yaml
424
+ if opt.hyp: hyp.update(opt.hyp)
425
+
426
+ # Print focal loss if gamma > 0
427
+ if hyp['fl_gamma']:
428
+ print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma'])
429
+ print(f'Beginning training with {hyp}\n\n')
430
  print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
431
+
432
  train(hyp)
433
 
434
  # Evolve hyperparameters (optional)