Add PyTorch AMP check (#7917)
Browse files* Add PyTorch AMP check
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Cleanup
* Cleanup
* Cleanup
* Robust for DDP
* Fixes
* Add amp enabled boolean to check_train_batch_size
* Simplify
* space to prefix
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- models/common.py +3 -2
- train.py +8 -8
- utils/autobatch.py +2 -3
- utils/general.py +24 -1
models/common.py
CHANGED
@@ -524,9 +524,10 @@ class AutoShape(nn.Module):
|
|
524 |
max_det = 1000 # maximum number of detections per image
|
525 |
amp = False # Automatic Mixed Precision (AMP) inference
|
526 |
|
527 |
-
def __init__(self, model):
|
528 |
super().__init__()
|
529 |
-
|
|
|
530 |
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
|
531 |
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
|
532 |
self.pt = not self.dmb or model.pt # PyTorch model
|
|
|
524 |
max_det = 1000 # maximum number of detections per image
|
525 |
amp = False # Automatic Mixed Precision (AMP) inference
|
526 |
|
527 |
+
def __init__(self, model, verbose=True):
|
528 |
super().__init__()
|
529 |
+
if verbose:
|
530 |
+
LOGGER.info('Adding AutoShape... ')
|
531 |
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
|
532 |
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
|
533 |
self.pt = not self.dmb or model.pt # PyTorch model
|
train.py
CHANGED
@@ -27,7 +27,6 @@ import torch
|
|
27 |
import torch.distributed as dist
|
28 |
import torch.nn as nn
|
29 |
import yaml
|
30 |
-
from torch.cuda import amp
|
31 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
32 |
from torch.optim import SGD, Adam, AdamW, lr_scheduler
|
33 |
from tqdm import tqdm
|
@@ -46,10 +45,10 @@ from utils.autobatch import check_train_batch_size
|
|
46 |
from utils.callbacks import Callbacks
|
47 |
from utils.dataloaders import create_dataloader
|
48 |
from utils.downloads import attempt_download
|
49 |
-
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size,
|
50 |
-
check_suffix, check_version, check_yaml, colorstr, get_latest_run,
|
51 |
-
init_seeds, intersect_dicts, labels_to_class_weights,
|
52 |
-
one_cycle, print_args, print_mutation, strip_optimizer)
|
53 |
from utils.loggers import Loggers
|
54 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
55 |
from utils.loss import ComputeLoss
|
@@ -126,6 +125,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
126 |
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
|
127 |
else:
|
128 |
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
|
|
129 |
|
130 |
# Freeze
|
131 |
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
|
@@ -141,7 +141,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
141 |
|
142 |
# Batch size
|
143 |
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
|
144 |
-
batch_size = check_train_batch_size(model, imgsz)
|
145 |
loggers.on_params_update({"batch_size": batch_size})
|
146 |
|
147 |
# Optimizer
|
@@ -293,7 +293,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
293 |
maps = np.zeros(nc) # mAP per class
|
294 |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
|
295 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
296 |
-
scaler = amp.GradScaler(enabled=
|
297 |
stopper = EarlyStopping(patience=opt.patience)
|
298 |
compute_loss = ComputeLoss(model) # init loss class
|
299 |
callbacks.run('on_train_start')
|
@@ -348,7 +348,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
348 |
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
349 |
|
350 |
# Forward
|
351 |
-
with amp.autocast(
|
352 |
pred = model(imgs) # forward
|
353 |
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
354 |
if RANK != -1:
|
|
|
27 |
import torch.distributed as dist
|
28 |
import torch.nn as nn
|
29 |
import yaml
|
|
|
30 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
31 |
from torch.optim import SGD, Adam, AdamW, lr_scheduler
|
32 |
from tqdm import tqdm
|
|
|
45 |
from utils.callbacks import Callbacks
|
46 |
from utils.dataloaders import create_dataloader
|
47 |
from utils.downloads import attempt_download
|
48 |
+
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
|
49 |
+
check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
|
50 |
+
increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
|
51 |
+
labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
|
52 |
from utils.loggers import Loggers
|
53 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
54 |
from utils.loss import ComputeLoss
|
|
|
125 |
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
|
126 |
else:
|
127 |
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
128 |
+
amp = check_amp(model) # check AMP
|
129 |
|
130 |
# Freeze
|
131 |
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
|
|
|
141 |
|
142 |
# Batch size
|
143 |
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
|
144 |
+
batch_size = check_train_batch_size(model, imgsz, amp)
|
145 |
loggers.on_params_update({"batch_size": batch_size})
|
146 |
|
147 |
# Optimizer
|
|
|
293 |
maps = np.zeros(nc) # mAP per class
|
294 |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
|
295 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
296 |
+
scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
297 |
stopper = EarlyStopping(patience=opt.patience)
|
298 |
compute_loss = ComputeLoss(model) # init loss class
|
299 |
callbacks.run('on_train_start')
|
|
|
348 |
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
349 |
|
350 |
# Forward
|
351 |
+
with torch.cuda.amp.autocast(amp):
|
352 |
pred = model(imgs) # forward
|
353 |
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
354 |
if RANK != -1:
|
utils/autobatch.py
CHANGED
@@ -7,15 +7,14 @@ from copy import deepcopy
|
|
7 |
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
-
from torch.cuda import amp
|
11 |
|
12 |
from utils.general import LOGGER, colorstr
|
13 |
from utils.torch_utils import profile
|
14 |
|
15 |
|
16 |
-
def check_train_batch_size(model, imgsz=640):
|
17 |
# Check YOLOv5 training batch size
|
18 |
-
with amp.autocast():
|
19 |
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
20 |
|
21 |
|
|
|
7 |
|
8 |
import numpy as np
|
9 |
import torch
|
|
|
10 |
|
11 |
from utils.general import LOGGER, colorstr
|
12 |
from utils.torch_utils import profile
|
13 |
|
14 |
|
15 |
+
def check_train_batch_size(model, imgsz=640, amp=True):
|
16 |
# Check YOLOv5 training batch size
|
17 |
+
with torch.cuda.amp.autocast(amp):
|
18 |
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
19 |
|
20 |
|
utils/general.py
CHANGED
@@ -36,9 +36,11 @@ import yaml
|
|
36 |
from utils.downloads import gsutil_getsize
|
37 |
from utils.metrics import box_iou, fitness
|
38 |
|
39 |
-
# Settings
|
40 |
FILE = Path(__file__).resolve()
|
41 |
ROOT = FILE.parents[1] # YOLOv5 root directory
|
|
|
|
|
|
|
42 |
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
|
43 |
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
44 |
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
@@ -505,6 +507,27 @@ def check_dataset(data, autodownload=True):
|
|
505 |
return data # dictionary
|
506 |
|
507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
def url2file(url):
|
509 |
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
|
510 |
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|
|
|
36 |
from utils.downloads import gsutil_getsize
|
37 |
from utils.metrics import box_iou, fitness
|
38 |
|
|
|
39 |
FILE = Path(__file__).resolve()
|
40 |
ROOT = FILE.parents[1] # YOLOv5 root directory
|
41 |
+
RANK = int(os.getenv('RANK', -1))
|
42 |
+
|
43 |
+
# Settings
|
44 |
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
|
45 |
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
46 |
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
|
|
507 |
return data # dictionary
|
508 |
|
509 |
|
510 |
+
def check_amp(model):
|
511 |
+
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
|
512 |
+
from models.common import AutoShape
|
513 |
+
|
514 |
+
if next(model.parameters()).device.type == 'cpu': # get model device
|
515 |
+
return False
|
516 |
+
prefix = colorstr('AMP: ')
|
517 |
+
im = cv2.imread(ROOT / 'data' / 'images' / 'bus.jpg')[..., ::-1] # OpenCV image (BGR to RGB)
|
518 |
+
m = AutoShape(model, verbose=False) # model
|
519 |
+
a = m(im).xyxy[0] # FP32 inference
|
520 |
+
m.amp = True
|
521 |
+
b = m(im).xyxy[0] # AMP inference
|
522 |
+
if (a.shape == b.shape) and torch.allclose(a, b, atol=1.0): # close to 1.0 pixel bounding box
|
523 |
+
LOGGER.info(emojis(f'{prefix}checks passed ✅'))
|
524 |
+
return True
|
525 |
+
else:
|
526 |
+
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
|
527 |
+
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
|
528 |
+
return False
|
529 |
+
|
530 |
+
|
531 |
def url2file(url):
|
532 |
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
|
533 |
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|