Eliminate `total_batch_size` variable (#3697)
Browse files* Eliminate `total_batch_size` variable
* cleanup
* Update train.py
train.py
CHANGED
@@ -46,10 +46,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
46 |
opt,
|
47 |
device,
|
48 |
):
|
49 |
-
save_dir, epochs, batch_size,
|
50 |
-
|
51 |
|
52 |
# Directories
|
|
|
53 |
wdir = save_dir / 'weights'
|
54 |
wdir.mkdir(parents=True, exist_ok=True) # make dir
|
55 |
last = wdir / 'last.pt'
|
@@ -127,8 +128,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
127 |
|
128 |
# Optimizer
|
129 |
nbs = 64 # nominal batch size
|
130 |
-
accumulate = max(round(nbs /
|
131 |
-
hyp['weight_decay'] *=
|
132 |
logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
|
133 |
|
134 |
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
|
@@ -205,7 +206,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
205 |
logger.info('Using SyncBatchNorm()')
|
206 |
|
207 |
# Trainloader
|
208 |
-
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
|
209 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
|
210 |
workers=opt.workers,
|
211 |
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
@@ -215,7 +216,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
215 |
|
216 |
# Process 0
|
217 |
if RANK in [-1, 0]:
|
218 |
-
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
|
219 |
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
220 |
workers=opt.workers,
|
221 |
pad=0.5, prefix=colorstr('val: '))[0]
|
@@ -302,7 +303,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
302 |
if ni <= nw:
|
303 |
xi = [0, nw] # x interp
|
304 |
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
|
305 |
-
accumulate = max(1, np.interp(ni, xi, [1, nbs /
|
306 |
for j, x in enumerate(optimizer.param_groups):
|
307 |
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
308 |
x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
|
@@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
371 |
if not opt.notest or final_epoch: # Calculate mAP
|
372 |
wandb_logger.current_epoch = epoch + 1
|
373 |
results, maps, _ = test.test(data_dict,
|
374 |
-
batch_size=batch_size * 2,
|
375 |
imgsz=imgsz_test,
|
376 |
model=ema.ema,
|
377 |
single_cls=single_cls,
|
@@ -439,7 +440,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
439 |
if is_coco: # COCO dataset
|
440 |
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
|
441 |
results, _, _ = test.test(opt.data,
|
442 |
-
batch_size=batch_size * 2,
|
443 |
imgsz=imgsz_test,
|
444 |
conf_thres=0.001,
|
445 |
iou_thres=0.7,
|
@@ -518,7 +519,7 @@ def main(opt):
|
|
518 |
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
|
519 |
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
|
520 |
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
|
521 |
-
opt.cfg, opt.weights, opt.resume
|
522 |
logger.info('Resuming training from %s' % ckpt)
|
523 |
else:
|
524 |
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
|
@@ -529,17 +530,15 @@ def main(opt):
|
|
529 |
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))
|
530 |
|
531 |
# DDP mode
|
532 |
-
opt.total_batch_size = opt.batch_size
|
533 |
device = select_device(opt.device, batch_size=opt.batch_size)
|
534 |
if LOCAL_RANK != -1:
|
535 |
from datetime import timedelta
|
536 |
-
assert torch.cuda.device_count() > LOCAL_RANK, '
|
537 |
torch.cuda.set_device(LOCAL_RANK)
|
538 |
device = torch.device('cuda', LOCAL_RANK)
|
539 |
dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60))
|
540 |
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
|
541 |
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
|
542 |
-
opt.batch_size = opt.total_batch_size // WORLD_SIZE
|
543 |
|
544 |
# Train
|
545 |
if not opt.evolve:
|
|
|
46 |
opt,
|
47 |
device,
|
48 |
):
|
49 |
+
save_dir, epochs, batch_size, weights, single_cls = \
|
50 |
+
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls
|
51 |
|
52 |
# Directories
|
53 |
+
save_dir = Path(save_dir)
|
54 |
wdir = save_dir / 'weights'
|
55 |
wdir.mkdir(parents=True, exist_ok=True) # make dir
|
56 |
last = wdir / 'last.pt'
|
|
|
128 |
|
129 |
# Optimizer
|
130 |
nbs = 64 # nominal batch size
|
131 |
+
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
|
132 |
+
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
|
133 |
logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
|
134 |
|
135 |
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
|
|
|
206 |
logger.info('Using SyncBatchNorm()')
|
207 |
|
208 |
# Trainloader
|
209 |
+
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
210 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
|
211 |
workers=opt.workers,
|
212 |
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
|
|
216 |
|
217 |
# Process 0
|
218 |
if RANK in [-1, 0]:
|
219 |
+
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
|
220 |
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
|
221 |
workers=opt.workers,
|
222 |
pad=0.5, prefix=colorstr('val: '))[0]
|
|
|
303 |
if ni <= nw:
|
304 |
xi = [0, nw] # x interp
|
305 |
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
|
306 |
+
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
|
307 |
for j, x in enumerate(optimizer.param_groups):
|
308 |
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
309 |
x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
|
|
|
372 |
if not opt.notest or final_epoch: # Calculate mAP
|
373 |
wandb_logger.current_epoch = epoch + 1
|
374 |
results, maps, _ = test.test(data_dict,
|
375 |
+
batch_size=batch_size // WORLD_SIZE * 2,
|
376 |
imgsz=imgsz_test,
|
377 |
model=ema.ema,
|
378 |
single_cls=single_cls,
|
|
|
440 |
if is_coco: # COCO dataset
|
441 |
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
|
442 |
results, _, _ = test.test(opt.data,
|
443 |
+
batch_size=batch_size // WORLD_SIZE * 2,
|
444 |
imgsz=imgsz_test,
|
445 |
conf_thres=0.001,
|
446 |
iou_thres=0.7,
|
|
|
519 |
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
|
520 |
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
|
521 |
opt = argparse.Namespace(**yaml.safe_load(f)) # replace
|
522 |
+
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
|
523 |
logger.info('Resuming training from %s' % ckpt)
|
524 |
else:
|
525 |
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
|
|
|
530 |
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))
|
531 |
|
532 |
# DDP mode
|
|
|
533 |
device = select_device(opt.device, batch_size=opt.batch_size)
|
534 |
if LOCAL_RANK != -1:
|
535 |
from datetime import timedelta
|
536 |
+
assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
|
537 |
torch.cuda.set_device(LOCAL_RANK)
|
538 |
device = torch.device('cuda', LOCAL_RANK)
|
539 |
dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60))
|
540 |
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
|
541 |
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
|
|
|
542 |
|
543 |
# Train
|
544 |
if not opt.evolve:
|