Fix `--data from_HUB.zip` (#4732)
Browse files
train.py
CHANGED
@@ -36,7 +36,7 @@ from utils.autoanchor import check_anchors
|
|
36 |
from utils.datasets import create_dataloader
|
37 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
38 |
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
|
39 |
-
check_yaml, check_suffix, print_mutation, set_logging, one_cycle, colorstr, methods
|
40 |
from utils.downloads import attempt_download
|
41 |
from utils.loss import ComputeLoss
|
42 |
from utils.plots import plot_labels, plot_evolve
|
@@ -105,6 +105,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
105 |
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
106 |
|
107 |
# Model
|
|
|
108 |
pretrained = weights.endswith('.pt')
|
109 |
if pretrained:
|
110 |
with torch_distributed_zero_first(RANK):
|
@@ -484,8 +485,7 @@ def main(opt, callbacks=Callbacks()):
|
|
484 |
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
|
485 |
LOGGER.info(f'Resuming training from {ckpt}')
|
486 |
else:
|
487 |
-
|
488 |
-
opt.data, opt.cfg, opt.hyp = check_yaml(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp) # check YAMLs
|
489 |
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
|
490 |
if opt.evolve:
|
491 |
opt.project = 'runs/evolve'
|
|
|
36 |
from utils.datasets import create_dataloader
|
37 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
38 |
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
|
39 |
+
check_file, check_yaml, check_suffix, print_mutation, set_logging, one_cycle, colorstr, methods
|
40 |
from utils.downloads import attempt_download
|
41 |
from utils.loss import ComputeLoss
|
42 |
from utils.plots import plot_labels, plot_evolve
|
|
|
105 |
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
106 |
|
107 |
# Model
|
108 |
+
check_suffix(weights, '.pt') # check weights
|
109 |
pretrained = weights.endswith('.pt')
|
110 |
if pretrained:
|
111 |
with torch_distributed_zero_first(RANK):
|
|
|
485 |
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
|
486 |
LOGGER.info(f'Resuming training from {ckpt}')
|
487 |
else:
|
488 |
+
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp) # check YAMLs
|
|
|
489 |
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
|
490 |
if opt.evolve:
|
491 |
opt.project = 'runs/evolve'
|