glenn-jocher commited on
Commit
da9a1b7
·
unverified ·
1 Parent(s): b7d18f3

Allow `--weights URL` (#5991)

Browse files
Files changed (2) hide show
  1. models/common.py +2 -2
  2. utils/downloads.py +6 -3
models/common.py CHANGED
@@ -296,7 +296,7 @@ class DetectMultiBackend(nn.Module):
296
  check_suffix(w, suffixes) # check weights have acceptable suffix
297
  pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
298
  stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
299
- attempt_download(w) # download if not local
300
 
301
  if jit: # TorchScript
302
  LOGGER.info(f'Loading {w} for TorchScript inference...')
@@ -306,7 +306,7 @@ class DetectMultiBackend(nn.Module):
306
  d = json.loads(extra_files['config.txt']) # extra_files dict
307
  stride, names = int(d['stride']), d['names']
308
  elif pt: # PyTorch
309
- model = attempt_load(weights, map_location=device)
310
  stride = int(model.stride.max()) # model stride
311
  names = model.module.names if hasattr(model, 'module') else model.names # get class names
312
  self.model = model # explicitly assign for to(), cpu(), cuda(), half()
 
296
  check_suffix(w, suffixes) # check weights have acceptable suffix
297
  pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
298
  stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
299
+ w = attempt_download(w) # download if not local
300
 
301
  if jit: # TorchScript
302
  LOGGER.info(f'Loading {w} for TorchScript inference...')
 
306
  d = json.loads(extra_files['config.txt']) # extra_files dict
307
  stride, names = int(d['stride']), d['names']
308
  elif pt: # PyTorch
309
+ model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
310
  stride = int(model.stride.max()) # model stride
311
  names = model.module.names if hasattr(model, 'module') else model.names # get class names
312
  self.model = model # explicitly assign for to(), cpu(), cuda(), half()
utils/downloads.py CHANGED
@@ -49,9 +49,12 @@ def attempt_download(file, repo='ultralytics/yolov5'): # from utils.downloads i
49
  name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
50
  if str(file).startswith(('http:/', 'https:/')): # download
51
  url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
52
- name = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
53
- safe_download(file=name, url=url, min_bytes=1E5)
54
- return name
 
 
 
55
 
56
  # GitHub assets
57
  file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
 
49
  name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
50
  if str(file).startswith(('http:/', 'https:/')): # download
51
  url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
52
+ file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
53
+ if Path(file).is_file():
54
+ print(f'Found {url} locally at {file}') # file already exists
55
+ else:
56
+ safe_download(file=file, url=url, min_bytes=1E5)
57
+ return file
58
 
59
  # GitHub assets
60
  file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)