Enable direct `--weights URL` definition (#3373)
Browse files* Enable direct `--weights URL` definition
@KalenMike this PR will enable direct --weights URL definition. Example use case:
```
python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt
```
* cleanup
* bug fixes
* weights = attempt_download(weights)
* Update experimental.py
* Update hubconf.py
* return bug fix
* comment mirror
* min_bytes
- hubconf.py +1 -2
- models/experimental.py +1 -2
- train.py +1 -1
- utils/google_utils.py +33 -20
hubconf.py
CHANGED
@@ -41,8 +41,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
|
|
41 |
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
|
42 |
model = Model(cfg, channels, classes) # create model
|
43 |
if pretrained:
|
44 |
-
attempt_download(fname) #
|
45 |
-
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
46 |
msd = model.state_dict() # model state_dict
|
47 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
48 |
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
|
|
41 |
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
|
42 |
model = Model(cfg, channels, classes) # create model
|
43 |
if pretrained:
|
44 |
+
ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
|
|
|
45 |
msd = model.state_dict() # model state_dict
|
46 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
47 |
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
|
models/experimental.py
CHANGED
@@ -116,8 +116,7 @@ def attempt_load(weights, map_location=None, inplace=True):
|
|
116 |
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
117 |
model = Ensemble()
|
118 |
for w in weights if isinstance(weights, list) else [weights]:
|
119 |
-
attempt_download(w)
|
120 |
-
ckpt = torch.load(w, map_location=map_location) # load
|
121 |
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
|
122 |
|
123 |
# Compatibility updates
|
|
|
116 |
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
117 |
model = Ensemble()
|
118 |
for w in weights if isinstance(weights, list) else [weights]:
|
119 |
+
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
|
|
|
120 |
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
|
121 |
|
122 |
# Compatibility updates
|
train.py
CHANGED
@@ -83,7 +83,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
83 |
pretrained = weights.endswith('.pt')
|
84 |
if pretrained:
|
85 |
with torch_distributed_zero_first(rank):
|
86 |
-
attempt_download(weights) # download if not found locally
|
87 |
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
88 |
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
89 |
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
|
|
|
83 |
pretrained = weights.endswith('.pt')
|
84 |
if pretrained:
|
85 |
with torch_distributed_zero_first(rank):
|
86 |
+
weights = attempt_download(weights) # download if not found locally
|
87 |
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
88 |
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
89 |
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
|
utils/google_utils.py
CHANGED
@@ -16,11 +16,37 @@ def gsutil_getsize(url=''):
|
|
16 |
return eval(s.split(' ')[0]) if len(s) else 0 # bytes
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def attempt_download(file, repo='ultralytics/yolov5'):
|
20 |
# Attempt file download if does not exist
|
21 |
file = Path(str(file).strip().replace("'", ''))
|
22 |
|
23 |
if not file.exists():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
25 |
try:
|
26 |
response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
|
@@ -34,27 +60,14 @@ def attempt_download(file, repo='ultralytics/yolov5'):
|
|
34 |
except:
|
35 |
tag = 'v5.0' # current release
|
36 |
|
37 |
-
name = file.name
|
38 |
if name in assets:
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
except Exception as e: # GCP
|
47 |
-
print(f'Download error: {e}')
|
48 |
-
assert redundant, 'No secondary mirror'
|
49 |
-
url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
|
50 |
-
print(f'Downloading {url} to {file}...')
|
51 |
-
os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
52 |
-
finally:
|
53 |
-
if not file.exists() or file.stat().st_size < 1E6: # check
|
54 |
-
file.unlink(missing_ok=True) # remove partial downloads
|
55 |
-
print(f'ERROR: Download failure: {msg}')
|
56 |
-
print('')
|
57 |
-
return
|
58 |
|
59 |
|
60 |
def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
|
|
|
16 |
return eval(s.split(' ')[0]) if len(s) else 0 # bytes
|
17 |
|
18 |
|
19 |
+
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
20 |
+
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
21 |
+
file = Path(file)
|
22 |
+
try: # GitHub
|
23 |
+
print(f'Downloading {url} to {file}...')
|
24 |
+
torch.hub.download_url_to_file(url, str(file))
|
25 |
+
assert file.exists() and file.stat().st_size > min_bytes # check
|
26 |
+
except Exception as e: # GCP
|
27 |
+
file.unlink(missing_ok=True) # remove partial downloads
|
28 |
+
print(f'Download error: {e}\nRe-attempting {url2 or url} to {file}...')
|
29 |
+
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
30 |
+
finally:
|
31 |
+
if not file.exists() or file.stat().st_size < min_bytes: # check
|
32 |
+
file.unlink(missing_ok=True) # remove partial downloads
|
33 |
+
print(f'ERROR: Download failure: {error_msg or url}')
|
34 |
+
print('')
|
35 |
+
|
36 |
+
|
37 |
def attempt_download(file, repo='ultralytics/yolov5'):
|
38 |
# Attempt file download if does not exist
|
39 |
file = Path(str(file).strip().replace("'", ''))
|
40 |
|
41 |
if not file.exists():
|
42 |
+
# URL specified
|
43 |
+
name = file.name
|
44 |
+
if str(file).startswith(('http:/', 'https:/')): # download
|
45 |
+
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
|
46 |
+
safe_download(file=name, url=url, min_bytes=1E5)
|
47 |
+
return name
|
48 |
+
|
49 |
+
# GitHub assets
|
50 |
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
51 |
try:
|
52 |
response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
|
|
|
60 |
except:
|
61 |
tag = 'v5.0' # current release
|
62 |
|
|
|
63 |
if name in assets:
|
64 |
+
safe_download(file,
|
65 |
+
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
|
66 |
+
# url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional)
|
67 |
+
min_bytes=1E5,
|
68 |
+
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')
|
69 |
+
|
70 |
+
return str(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
|
73 |
def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
|