glenn-jocher commited on
Commit
c5360f6
·
unverified ·
1 Parent(s): 4a025ae

Fix `--data from_HUB.zip` (#4732)

Browse files
Files changed (1) hide show
  1. train.py +3 -3
train.py CHANGED
@@ -36,7 +36,7 @@ from utils.autoanchor import check_anchors
36
  from utils.datasets import create_dataloader
37
  from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
38
  strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
39
- check_yaml, check_suffix, print_mutation, set_logging, one_cycle, colorstr, methods
40
  from utils.downloads import attempt_download
41
  from utils.loss import ComputeLoss
42
  from utils.plots import plot_labels, plot_evolve
@@ -105,6 +105,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
105
  is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
106
 
107
  # Model
 
108
  pretrained = weights.endswith('.pt')
109
  if pretrained:
110
  with torch_distributed_zero_first(RANK):
@@ -484,8 +485,7 @@ def main(opt, callbacks=Callbacks()):
484
  opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
485
  LOGGER.info(f'Resuming training from {ckpt}')
486
  else:
487
- check_suffix(opt.weights, '.pt') # check weights
488
- opt.data, opt.cfg, opt.hyp = check_yaml(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp) # check YAMLs
489
  assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
490
  if opt.evolve:
491
  opt.project = 'runs/evolve'
 
36
  from utils.datasets import create_dataloader
37
  from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
38
  strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
39
+ check_file, check_yaml, check_suffix, print_mutation, set_logging, one_cycle, colorstr, methods
40
  from utils.downloads import attempt_download
41
  from utils.loss import ComputeLoss
42
  from utils.plots import plot_labels, plot_evolve
 
105
  is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
106
 
107
  # Model
108
+ check_suffix(weights, '.pt') # check weights
109
  pretrained = weights.endswith('.pt')
110
  if pretrained:
111
  with torch_distributed_zero_first(RANK):
 
485
  opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
486
  LOGGER.info(f'Resuming training from {ckpt}')
487
  else:
488
+ opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp) # check YAMLs
 
489
  assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
490
  if opt.evolve:
491
  opt.project = 'runs/evolve'