glenn-jocher commited on
Commit
a34b376
·
unverified ·
1 Parent(s): 6e86af3

Link fuse() to AutoShape() for Hub models (#8599)

Browse files
Files changed (2) hide show
  1. hubconf.py +1 -2
  2. models/common.py +2 -2
hubconf.py CHANGED
@@ -36,7 +36,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
36
 
37
  if not verbose:
38
  LOGGER.setLevel(logging.WARNING)
39
-
40
  check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
41
  name = Path(name)
42
  path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
@@ -44,7 +43,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
44
  device = select_device(device)
45
 
46
  if pretrained and channels == 3 and classes == 80:
47
- model = DetectMultiBackend(path, device=device) # download/load FP32 model
48
  # model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
49
  else:
50
  cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
 
36
 
37
  if not verbose:
38
  LOGGER.setLevel(logging.WARNING)
 
39
  check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
40
  name = Path(name)
41
  path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
 
43
  device = select_device(device)
44
 
45
  if pretrained and channels == 3 and classes == 80:
46
+ model = DetectMultiBackend(path, device=device, fuse=autoshape) # download/load FP32 model
47
  # model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
48
  else:
49
  cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
models/common.py CHANGED
@@ -305,7 +305,7 @@ class Concat(nn.Module):
305
 
306
  class DetectMultiBackend(nn.Module):
307
  # YOLOv5 MultiBackend class for python inference on various backends
308
- def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
309
  # Usage:
310
  # PyTorch: weights = *.pt
311
  # TorchScript: *.torchscript
@@ -331,7 +331,7 @@ class DetectMultiBackend(nn.Module):
331
  names = yaml.safe_load(f)['names']
332
 
333
  if pt: # PyTorch
334
- model = attempt_load(weights if isinstance(weights, list) else w, device=device)
335
  stride = max(int(model.stride.max()), 32) # model stride
336
  names = model.module.names if hasattr(model, 'module') else model.names # get class names
337
  model.half() if fp16 else model.float()
 
305
 
306
  class DetectMultiBackend(nn.Module):
307
  # YOLOv5 MultiBackend class for python inference on various backends
308
+ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
309
  # Usage:
310
  # PyTorch: weights = *.pt
311
  # TorchScript: *.torchscript
 
331
  names = yaml.safe_load(f)['names']
332
 
333
  if pt: # PyTorch
334
+ model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
335
  stride = max(int(model.stride.max()), 32) # model stride
336
  names = model.module.names if hasattr(model, 'module') else model.names # get class names
337
  model.half() if fp16 else model.float()