Add `autobatch` feature for best `batch-size` estimation (#5092)
Browse files* Autobatch
* fix mem
* fix mem2
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Update train.py
* print result
* Cleanup print result
* swap fix in call
* to 64
* use total
* fix
* fix
* fix
* fix
* fix
* Update
* Update
* Update
* Update
* Update
* Update
* Update
* Cleanup printing
* Update final printout
* Update autobatch.py
* Update autobatch.py
* Update autobatch.py
- train.py +11 -6
- utils/autobatch.py +56 -0
- utils/torch_utils.py +1 -1
train.py
CHANGED
@@ -36,6 +36,7 @@ import val # for end-of-epoch mAP
|
|
36 |
from models.experimental import attempt_load
|
37 |
from models.yolo import Model
|
38 |
from utils.autoanchor import check_anchors
|
|
|
39 |
from utils.datasets import create_dataloader
|
40 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
41 |
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
|
@@ -131,6 +132,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
131 |
print(f'freezing {k}')
|
132 |
v.requires_grad = False
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
# Optimizer
|
135 |
nbs = 64 # nominal batch size
|
136 |
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
|
@@ -190,11 +199,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
190 |
|
191 |
del ckpt, csd
|
192 |
|
193 |
-
# Image sizes
|
194 |
-
gs = max(int(model.stride.max()), 32) # grid size (max stride)
|
195 |
-
nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
|
196 |
-
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
|
197 |
-
|
198 |
# DP mode
|
199 |
if cuda and RANK == -1 and torch.cuda.device_count() > 1:
|
200 |
logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
|
@@ -242,6 +246,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
242 |
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
243 |
|
244 |
# Model parameters
|
|
|
245 |
hyp['box'] *= 3. / nl # scale to layers
|
246 |
hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
|
247 |
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
|
@@ -440,7 +445,7 @@ def parse_opt(known=False):
|
|
440 |
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
|
441 |
parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
|
442 |
parser.add_argument('--epochs', type=int, default=300)
|
443 |
-
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
|
444 |
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
|
445 |
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
446 |
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
|
|
|
36 |
from models.experimental import attempt_load
|
37 |
from models.yolo import Model
|
38 |
from utils.autoanchor import check_anchors
|
39 |
+
from utils.autobatch import check_train_batch_size
|
40 |
from utils.datasets import create_dataloader
|
41 |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
|
42 |
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
|
|
|
132 |
print(f'freezing {k}')
|
133 |
v.requires_grad = False
|
134 |
|
135 |
+
# Image size
|
136 |
+
gs = max(int(model.stride.max()), 32) # grid size (max stride)
|
137 |
+
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
|
138 |
+
|
139 |
+
# Batch size
|
140 |
+
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
|
141 |
+
batch_size = check_train_batch_size(model, imgsz)
|
142 |
+
|
143 |
# Optimizer
|
144 |
nbs = 64 # nominal batch size
|
145 |
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
|
|
|
199 |
|
200 |
del ckpt, csd
|
201 |
|
|
|
|
|
|
|
|
|
|
|
202 |
# DP mode
|
203 |
if cuda and RANK == -1 and torch.cuda.device_count() > 1:
|
204 |
logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
|
|
|
246 |
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
247 |
|
248 |
# Model parameters
|
249 |
+
nl = model.model[-1].nl # number of detection layers (to scale hyps)
|
250 |
hyp['box'] *= 3. / nl # scale to layers
|
251 |
hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
|
252 |
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
|
|
|
445 |
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
|
446 |
parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
|
447 |
parser.add_argument('--epochs', type=int, default=300)
|
448 |
+
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
|
449 |
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
|
450 |
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
451 |
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
|
utils/autobatch.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
2 |
+
"""
|
3 |
+
Auto-batch utils
|
4 |
+
"""
|
5 |
+
|
6 |
+
from copy import deepcopy
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch.cuda import amp
|
11 |
+
|
12 |
+
from utils.general import colorstr
|
13 |
+
from utils.torch_utils import profile
|
14 |
+
|
15 |
+
|
16 |
+
def check_train_batch_size(model, imgsz=640):
|
17 |
+
# Check YOLOv5 training batch size
|
18 |
+
with amp.autocast():
|
19 |
+
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
20 |
+
|
21 |
+
|
22 |
+
def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
|
23 |
+
# Automatically estimate best batch size to use `fraction` of available CUDA memory
|
24 |
+
# Usage:
|
25 |
+
# import torch
|
26 |
+
# from utils.autobatch import autobatch
|
27 |
+
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
|
28 |
+
# print(autobatch(model))
|
29 |
+
|
30 |
+
prefix = colorstr('autobatch: ')
|
31 |
+
print(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
|
32 |
+
device = next(model.parameters()).device # get model device
|
33 |
+
if device.type == 'cpu':
|
34 |
+
print(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
35 |
+
return batch_size
|
36 |
+
|
37 |
+
d = str(device).upper() # 'CUDA:0'
|
38 |
+
t = torch.cuda.get_device_properties(device).total_memory / 1024 ** 3 # (GB)
|
39 |
+
r = torch.cuda.memory_reserved(device) / 1024 ** 3 # (GB)
|
40 |
+
a = torch.cuda.memory_allocated(device) / 1024 ** 3 # (GB)
|
41 |
+
f = t - (r + a) # free inside reserved
|
42 |
+
print(f'{prefix}{d} {t:.3g}G total, {r:.3g}G reserved, {a:.3g}G allocated, {f:.3g}G free')
|
43 |
+
|
44 |
+
batch_sizes = [1, 2, 4, 8, 16]
|
45 |
+
try:
|
46 |
+
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
|
47 |
+
y = profile(img, model, n=3, device=device)
|
48 |
+
except Exception as e:
|
49 |
+
print(f'{prefix}{e}')
|
50 |
+
|
51 |
+
y = [x[2] for x in y if x] # memory [2]
|
52 |
+
batch_sizes = batch_sizes[:len(y)]
|
53 |
+
p = np.polyfit(batch_sizes, y, deg=1) # first degree polynomial fit
|
54 |
+
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
|
55 |
+
print(f'{prefix}Using colorstr(batch-size {b}) for {d} {t * fraction:.3g}G/{t:.3g}G ({fraction * 100:.0f}%)')
|
56 |
+
return b
|
utils/torch_utils.py
CHANGED
@@ -126,7 +126,7 @@ def profile(input, ops, n=10, device=None):
|
|
126 |
_ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward()
|
127 |
t[2] = time_sync()
|
128 |
except Exception as e: # no backward method
|
129 |
-
print(e)
|
130 |
t[2] = float('nan')
|
131 |
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
132 |
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
|
|
126 |
_ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward()
|
127 |
t[2] = time_sync()
|
128 |
except Exception as e: # no backward method
|
129 |
+
# print(e) # for debug
|
130 |
t[2] = float('nan')
|
131 |
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
132 |
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|