Allow `--weights URL` (#5991)
Browse files- models/common.py +2 -2
- utils/downloads.py +6 -3
models/common.py
CHANGED
@@ -296,7 +296,7 @@ class DetectMultiBackend(nn.Module):
|
|
296 |
check_suffix(w, suffixes) # check weights have acceptable suffix
|
297 |
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
|
298 |
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
|
299 |
-
attempt_download(w) # download if not local
|
300 |
|
301 |
if jit: # TorchScript
|
302 |
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
@@ -306,7 +306,7 @@ class DetectMultiBackend(nn.Module):
|
|
306 |
d = json.loads(extra_files['config.txt']) # extra_files dict
|
307 |
stride, names = int(d['stride']), d['names']
|
308 |
elif pt: # PyTorch
|
309 |
-
model = attempt_load(weights, map_location=device)
|
310 |
stride = int(model.stride.max()) # model stride
|
311 |
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
312 |
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
|
|
296 |
check_suffix(w, suffixes) # check weights have acceptable suffix
|
297 |
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
|
298 |
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
|
299 |
+
w = attempt_download(w) # download if not local
|
300 |
|
301 |
if jit: # TorchScript
|
302 |
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
|
|
306 |
d = json.loads(extra_files['config.txt']) # extra_files dict
|
307 |
stride, names = int(d['stride']), d['names']
|
308 |
elif pt: # PyTorch
|
309 |
+
model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
|
310 |
stride = int(model.stride.max()) # model stride
|
311 |
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
312 |
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
utils/downloads.py
CHANGED
@@ -49,9 +49,12 @@ def attempt_download(file, repo='ultralytics/yolov5'): # from utils.downloads i
|
|
49 |
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
50 |
if str(file).startswith(('http:/', 'https:/')): # download
|
51 |
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
# GitHub assets
|
57 |
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
|
|
49 |
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
50 |
if str(file).startswith(('http:/', 'https:/')): # download
|
51 |
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
|
52 |
+
file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
|
53 |
+
if Path(file).is_file():
|
54 |
+
print(f'Found {url} locally at {file}') # file already exists
|
55 |
+
else:
|
56 |
+
safe_download(file=file, url=url, min_bytes=1E5)
|
57 |
+
return file
|
58 |
|
59 |
# GitHub assets
|
60 |
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|