Link fuse() to AutoShape() for Hub models (#8599)
Browse files- hubconf.py +1 -2
- models/common.py +2 -2
hubconf.py
CHANGED
@@ -36,7 +36,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
36 |
|
37 |
if not verbose:
|
38 |
LOGGER.setLevel(logging.WARNING)
|
39 |
-
|
40 |
check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
|
41 |
name = Path(name)
|
42 |
path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
|
@@ -44,7 +43,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
44 |
device = select_device(device)
|
45 |
|
46 |
if pretrained and channels == 3 and classes == 80:
|
47 |
-
model = DetectMultiBackend(path, device=device) # download/load FP32 model
|
48 |
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
|
49 |
else:
|
50 |
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
|
|
|
36 |
|
37 |
if not verbose:
|
38 |
LOGGER.setLevel(logging.WARNING)
|
|
|
39 |
check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
|
40 |
name = Path(name)
|
41 |
path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
|
|
|
43 |
device = select_device(device)
|
44 |
|
45 |
if pretrained and channels == 3 and classes == 80:
|
46 |
+
model = DetectMultiBackend(path, device=device, fuse=autoshape) # download/load FP32 model
|
47 |
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
|
48 |
else:
|
49 |
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
|
models/common.py
CHANGED
@@ -305,7 +305,7 @@ class Concat(nn.Module):
|
|
305 |
|
306 |
class DetectMultiBackend(nn.Module):
|
307 |
# YOLOv5 MultiBackend class for python inference on various backends
|
308 |
-
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
|
309 |
# Usage:
|
310 |
# PyTorch: weights = *.pt
|
311 |
# TorchScript: *.torchscript
|
@@ -331,7 +331,7 @@ class DetectMultiBackend(nn.Module):
|
|
331 |
names = yaml.safe_load(f)['names']
|
332 |
|
333 |
if pt: # PyTorch
|
334 |
-
model = attempt_load(weights if isinstance(weights, list) else w, device=device)
|
335 |
stride = max(int(model.stride.max()), 32) # model stride
|
336 |
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
337 |
model.half() if fp16 else model.float()
|
|
|
305 |
|
306 |
class DetectMultiBackend(nn.Module):
|
307 |
# YOLOv5 MultiBackend class for python inference on various backends
|
308 |
+
def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
|
309 |
# Usage:
|
310 |
# PyTorch: weights = *.pt
|
311 |
# TorchScript: *.torchscript
|
|
|
331 |
names = yaml.safe_load(f)['names']
|
332 |
|
333 |
if pt: # PyTorch
|
334 |
+
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
|
335 |
stride = max(int(model.stride.max()), 32) # model stride
|
336 |
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
337 |
model.half() if fp16 else model.float()
|