AMP check improvements backup YOLOv5n pretrained (#7959)
Browse files* Reduce AMP check to detections verification
More robust and faster
* Update general.py
* Update general.py
- utils/general.py +17 -17
utils/general.py
CHANGED
@@ -506,27 +506,27 @@ def check_dataset(data, autodownload=True):
|
|
506 |
|
507 |
def check_amp(model):
|
508 |
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
|
509 |
-
from models.common import AutoShape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
|
511 |
-
if next(model.parameters()).device.type == 'cpu': # get model device
|
512 |
-
return False
|
513 |
prefix = colorstr('AMP: ')
|
514 |
-
|
515 |
-
if
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
return True
|
522 |
-
m = AutoShape(model, verbose=False) # model
|
523 |
-
a = m(im).xywhn[0] # FP32 inference
|
524 |
-
m.amp = True
|
525 |
-
b = m(im).xywhn[0] # AMP inference
|
526 |
-
if (a.shape == b.shape) and torch.allclose(a, b, atol=0.05): # close to 5% absolute tolerance
|
527 |
LOGGER.info(emojis(f'{prefix}checks passed ✅'))
|
528 |
return True
|
529 |
-
|
530 |
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
|
531 |
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
|
532 |
return False
|
|
|
506 |
|
507 |
def check_amp(model):
|
508 |
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
|
509 |
+
from models.common import AutoShape, DetectMultiBackend
|
510 |
+
|
511 |
+
def amp_allclose(model, im):
|
512 |
+
# All close FP32 vs AMP results
|
513 |
+
m = AutoShape(model, verbose=False) # model
|
514 |
+
a = m(im).xywhn[0] # FP32 inference
|
515 |
+
m.amp = True
|
516 |
+
b = m(im).xywhn[0] # AMP inference
|
517 |
+
return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
|
518 |
|
|
|
|
|
519 |
prefix = colorstr('AMP: ')
|
520 |
+
device = next(model.parameters()).device # get model device
|
521 |
+
if device.type == 'cpu':
|
522 |
+
return False # AMP disabled on CPU
|
523 |
+
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
|
524 |
+
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
|
525 |
+
try:
|
526 |
+
assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
LOGGER.info(emojis(f'{prefix}checks passed ✅'))
|
528 |
return True
|
529 |
+
except Exception:
|
530 |
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
|
531 |
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
|
532 |
return False
|