yxNONG commited on
Commit
cdb9bde
·
unverified ·
1 Parent(s): bfd51f6

Unify the check point of single and multi GPU

Browse files

save the model.hyp etc to checkpoint when use multi GPU training

Files changed (1) hide show
  1. train.py +10 -1
train.py CHANGED
@@ -79,7 +79,7 @@ def train(hyp):
79
  # Create model
80
  model = Model(opt.cfg).to(device)
81
  assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
82
- model.names = data_dict['names']
83
 
84
  # Image sizes
85
  gs = int(max(model.stride)) # grid size (max stride)
@@ -172,6 +172,7 @@ def train(hyp):
172
  model.hyp = hyp # attach hyperparameters to model
173
  model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
174
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
 
175
 
176
  # Class frequency
177
  labels = np.concatenate(dataset.labels, 0)
@@ -314,6 +315,14 @@ def train(hyp):
314
  # Save model
315
  save = (not opt.nosave) or (final_epoch and not opt.evolve)
316
  if save:
 
 
 
 
 
 
 
 
317
  with open(results_file, 'r') as f: # create checkpoint
318
  ckpt = {'epoch': epoch,
319
  'best_fitness': best_fitness,
 
79
  # Create model
80
  model = Model(opt.cfg).to(device)
81
  assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
82
+
83
 
84
  # Image sizes
85
  gs = int(max(model.stride)) # grid size (max stride)
 
172
  model.hyp = hyp # attach hyperparameters to model
173
  model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
174
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
175
+ model.names = data_dict['names']
176
 
177
  # Class frequency
178
  labels = np.concatenate(dataset.labels, 0)
 
315
  # Save model
316
  save = (not opt.nosave) or (final_epoch and not opt.evolve)
317
  if save:
318
+ if hasattr(model, 'module'):
319
+ # Duplicate Model parameters for Multi-GPU save
320
+ ema.ema.module.nc = model.nc # attach number of classes to model
321
+ ema.ema.module.hyp = model.hyp # attach hyperparameters to model
322
+ ema.ema.module.gr = model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
323
+ ema.ema.module.class_weights = model.class_weights # attach class weights
324
+ ema.ema.module.names = data_dict['names']
325
+
326
  with open(results_file, 'r') as f: # create checkpoint
327
  ckpt = {'epoch': epoch,
328
  'best_fitness': best_fitness,