glenn-jocher pre-commit-ci[bot] commited on
Commit
eb1217f
·
unverified ·
1 Parent(s): 547c89b

Add PyTorch AMP check (#7917)

Browse files

* Add PyTorch AMP check

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

* Cleanup

* Cleanup

* Robust for DDP

* Fixes

* Add amp enabled boolean to check_train_batch_size

* Simplify

* space to prefix

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (4) hide show
  1. models/common.py +3 -2
  2. train.py +8 -8
  3. utils/autobatch.py +2 -3
  4. utils/general.py +24 -1
models/common.py CHANGED
@@ -524,9 +524,10 @@ class AutoShape(nn.Module):
524
  max_det = 1000 # maximum number of detections per image
525
  amp = False # Automatic Mixed Precision (AMP) inference
526
 
527
- def __init__(self, model):
528
  super().__init__()
529
- LOGGER.info('Adding AutoShape... ')
 
530
  copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
531
  self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
532
  self.pt = not self.dmb or model.pt # PyTorch model
 
524
  max_det = 1000 # maximum number of detections per image
525
  amp = False # Automatic Mixed Precision (AMP) inference
526
 
527
+ def __init__(self, model, verbose=True):
528
  super().__init__()
529
+ if verbose:
530
+ LOGGER.info('Adding AutoShape... ')
531
  copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
532
  self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
533
  self.pt = not self.dmb or model.pt # PyTorch model
train.py CHANGED
@@ -27,7 +27,6 @@ import torch
27
  import torch.distributed as dist
28
  import torch.nn as nn
29
  import yaml
30
- from torch.cuda import amp
31
  from torch.nn.parallel import DistributedDataParallel as DDP
32
  from torch.optim import SGD, Adam, AdamW, lr_scheduler
33
  from tqdm import tqdm
@@ -46,10 +45,10 @@ from utils.autobatch import check_train_batch_size
46
  from utils.callbacks import Callbacks
47
  from utils.dataloaders import create_dataloader
48
  from utils.downloads import attempt_download
49
- from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
50
- check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
51
- init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
52
- one_cycle, print_args, print_mutation, strip_optimizer)
53
  from utils.loggers import Loggers
54
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
55
  from utils.loss import ComputeLoss
@@ -126,6 +125,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
126
  LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
127
  else:
128
  model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
 
129
 
130
  # Freeze
131
  freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
@@ -141,7 +141,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
141
 
142
  # Batch size
143
  if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
144
- batch_size = check_train_batch_size(model, imgsz)
145
  loggers.on_params_update({"batch_size": batch_size})
146
 
147
  # Optimizer
@@ -293,7 +293,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
293
  maps = np.zeros(nc) # mAP per class
294
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
295
  scheduler.last_epoch = start_epoch - 1 # do not move
296
- scaler = amp.GradScaler(enabled=cuda)
297
  stopper = EarlyStopping(patience=opt.patience)
298
  compute_loss = ComputeLoss(model) # init loss class
299
  callbacks.run('on_train_start')
@@ -348,7 +348,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
348
  imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
349
 
350
  # Forward
351
- with amp.autocast(enabled=cuda):
352
  pred = model(imgs) # forward
353
  loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
354
  if RANK != -1:
 
27
  import torch.distributed as dist
28
  import torch.nn as nn
29
  import yaml
 
30
  from torch.nn.parallel import DistributedDataParallel as DDP
31
  from torch.optim import SGD, Adam, AdamW, lr_scheduler
32
  from tqdm import tqdm
 
45
  from utils.callbacks import Callbacks
46
  from utils.dataloaders import create_dataloader
47
  from utils.downloads import attempt_download
48
+ from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
49
+ check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
50
+ increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
51
+ labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
52
  from utils.loggers import Loggers
53
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
54
  from utils.loss import ComputeLoss
 
125
  LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
126
  else:
127
  model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
128
+ amp = check_amp(model) # check AMP
129
 
130
  # Freeze
131
  freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
 
141
 
142
  # Batch size
143
  if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
144
+ batch_size = check_train_batch_size(model, imgsz, amp)
145
  loggers.on_params_update({"batch_size": batch_size})
146
 
147
  # Optimizer
 
293
  maps = np.zeros(nc) # mAP per class
294
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
295
  scheduler.last_epoch = start_epoch - 1 # do not move
296
+ scaler = torch.cuda.amp.GradScaler(enabled=amp)
297
  stopper = EarlyStopping(patience=opt.patience)
298
  compute_loss = ComputeLoss(model) # init loss class
299
  callbacks.run('on_train_start')
 
348
  imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
349
 
350
  # Forward
351
+ with torch.cuda.amp.autocast(amp):
352
  pred = model(imgs) # forward
353
  loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
354
  if RANK != -1:
utils/autobatch.py CHANGED
@@ -7,15 +7,14 @@ from copy import deepcopy
7
 
8
  import numpy as np
9
  import torch
10
- from torch.cuda import amp
11
 
12
  from utils.general import LOGGER, colorstr
13
  from utils.torch_utils import profile
14
 
15
 
16
- def check_train_batch_size(model, imgsz=640):
17
  # Check YOLOv5 training batch size
18
- with amp.autocast():
19
  return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
20
 
21
 
 
7
 
8
  import numpy as np
9
  import torch
 
10
 
11
  from utils.general import LOGGER, colorstr
12
  from utils.torch_utils import profile
13
 
14
 
15
+ def check_train_batch_size(model, imgsz=640, amp=True):
16
  # Check YOLOv5 training batch size
17
+ with torch.cuda.amp.autocast(amp):
18
  return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
19
 
20
 
utils/general.py CHANGED
@@ -36,9 +36,11 @@ import yaml
36
  from utils.downloads import gsutil_getsize
37
  from utils.metrics import box_iou, fitness
38
 
39
- # Settings
40
  FILE = Path(__file__).resolve()
41
  ROOT = FILE.parents[1] # YOLOv5 root directory
 
 
 
42
  DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
43
  NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
44
  AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
@@ -505,6 +507,27 @@ def check_dataset(data, autodownload=True):
505
  return data # dictionary
506
 
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  def url2file(url):
509
  # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
510
  url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
 
36
  from utils.downloads import gsutil_getsize
37
  from utils.metrics import box_iou, fitness
38
 
 
39
  FILE = Path(__file__).resolve()
40
  ROOT = FILE.parents[1] # YOLOv5 root directory
41
+ RANK = int(os.getenv('RANK', -1))
42
+
43
+ # Settings
44
  DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
45
  NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
46
  AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
 
507
  return data # dictionary
508
 
509
 
510
+ def check_amp(model):
511
+ # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
512
+ from models.common import AutoShape
513
+
514
+ if next(model.parameters()).device.type == 'cpu': # get model device
515
+ return False
516
+ prefix = colorstr('AMP: ')
517
+ im = cv2.imread(ROOT / 'data' / 'images' / 'bus.jpg')[..., ::-1] # OpenCV image (BGR to RGB)
518
+ m = AutoShape(model, verbose=False) # model
519
+ a = m(im).xyxy[0] # FP32 inference
520
+ m.amp = True
521
+ b = m(im).xyxy[0] # AMP inference
522
+ if (a.shape == b.shape) and torch.allclose(a, b, atol=1.0): # close to 1.0 pixel bounding box
523
+ LOGGER.info(emojis(f'{prefix}checks passed ✅'))
524
+ return True
525
+ else:
526
+ help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
527
+ LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
528
+ return False
529
+
530
+
531
  def url2file(url):
532
  # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
533
  url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/