glenn-jocher commited on
Commit
ba6f3f9
·
unverified ·
1 Parent(s): b78e30d

Enable direct `--weights URL` definition (#3373)

Browse files

* Enable direct `--weights URL` definition

@KalenMike this PR will enable direct --weights URL definition. Example use case:
```
python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt
```

* cleanup

* bug fixes

* weights = attempt_download(weights)

* Update experimental.py

* Update hubconf.py

* return bug fix

* comment mirror

* min_bytes

Files changed (4) hide show
  1. hubconf.py +1 -2
  2. models/experimental.py +1 -2
  3. train.py +1 -1
  4. utils/google_utils.py +33 -20
hubconf.py CHANGED
@@ -41,8 +41,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
41
  cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
42
  model = Model(cfg, channels, classes) # create model
43
  if pretrained:
44
- attempt_download(fname) # download if not found locally
45
- ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
46
  msd = model.state_dict() # model state_dict
47
  csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
48
  csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
 
41
  cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
42
  model = Model(cfg, channels, classes) # create model
43
  if pretrained:
44
+ ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
 
45
  msd = model.state_dict() # model state_dict
46
  csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
47
  csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
models/experimental.py CHANGED
@@ -116,8 +116,7 @@ def attempt_load(weights, map_location=None, inplace=True):
116
  # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
117
  model = Ensemble()
118
  for w in weights if isinstance(weights, list) else [weights]:
119
- attempt_download(w)
120
- ckpt = torch.load(w, map_location=map_location) # load
121
  model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
122
 
123
  # Compatibility updates
 
116
  # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
117
  model = Ensemble()
118
  for w in weights if isinstance(weights, list) else [weights]:
119
+ ckpt = torch.load(attempt_download(w), map_location=map_location) # load
 
120
  model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
121
 
122
  # Compatibility updates
train.py CHANGED
@@ -83,7 +83,7 @@ def train(hyp, opt, device, tb_writer=None):
83
  pretrained = weights.endswith('.pt')
84
  if pretrained:
85
  with torch_distributed_zero_first(rank):
86
- attempt_download(weights) # download if not found locally
87
  ckpt = torch.load(weights, map_location=device) # load checkpoint
88
  model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
89
  exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
 
83
  pretrained = weights.endswith('.pt')
84
  if pretrained:
85
  with torch_distributed_zero_first(rank):
86
+ weights = attempt_download(weights) # download if not found locally
87
  ckpt = torch.load(weights, map_location=device) # load checkpoint
88
  model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
89
  exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
utils/google_utils.py CHANGED
@@ -16,11 +16,37 @@ def gsutil_getsize(url=''):
16
  return eval(s.split(' ')[0]) if len(s) else 0 # bytes
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def attempt_download(file, repo='ultralytics/yolov5'):
20
  # Attempt file download if does not exist
21
  file = Path(str(file).strip().replace("'", ''))
22
 
23
  if not file.exists():
 
 
 
 
 
 
 
 
24
  file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
25
  try:
26
  response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
@@ -34,27 +60,14 @@ def attempt_download(file, repo='ultralytics/yolov5'):
34
  except:
35
  tag = 'v5.0' # current release
36
 
37
- name = file.name
38
  if name in assets:
39
- msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
40
- redundant = False # second download option
41
- try: # GitHub
42
- url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
43
- print(f'Downloading {url} to {file}...')
44
- torch.hub.download_url_to_file(url, file)
45
- assert file.exists() and file.stat().st_size > 1E6 # check
46
- except Exception as e: # GCP
47
- print(f'Download error: {e}')
48
- assert redundant, 'No secondary mirror'
49
- url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
50
- print(f'Downloading {url} to {file}...')
51
- os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
52
- finally:
53
- if not file.exists() or file.stat().st_size < 1E6: # check
54
- file.unlink(missing_ok=True) # remove partial downloads
55
- print(f'ERROR: Download failure: {msg}')
56
- print('')
57
- return
58
 
59
 
60
  def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
 
16
  return eval(s.split(' ')[0]) if len(s) else 0 # bytes
17
 
18
 
19
+ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
20
+ # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
21
+ file = Path(file)
22
+ try: # GitHub
23
+ print(f'Downloading {url} to {file}...')
24
+ torch.hub.download_url_to_file(url, str(file))
25
+ assert file.exists() and file.stat().st_size > min_bytes # check
26
+ except Exception as e: # GCP
27
+ file.unlink(missing_ok=True) # remove partial downloads
28
+ print(f'Download error: {e}\nRe-attempting {url2 or url} to {file}...')
29
+ os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
30
+ finally:
31
+ if not file.exists() or file.stat().st_size < min_bytes: # check
32
+ file.unlink(missing_ok=True) # remove partial downloads
33
+ print(f'ERROR: Download failure: {error_msg or url}')
34
+ print('')
35
+
36
+
37
  def attempt_download(file, repo='ultralytics/yolov5'):
38
  # Attempt file download if does not exist
39
  file = Path(str(file).strip().replace("'", ''))
40
 
41
  if not file.exists():
42
+ # URL specified
43
+ name = file.name
44
+ if str(file).startswith(('http:/', 'https:/')): # download
45
+ url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
46
+ safe_download(file=name, url=url, min_bytes=1E5)
47
+ return name
48
+
49
+ # GitHub assets
50
  file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
51
  try:
52
  response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
 
60
  except:
61
  tag = 'v5.0' # current release
62
 
 
63
  if name in assets:
64
+ safe_download(file,
65
+ url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
66
+ # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional)
67
+ min_bytes=1E5,
68
+ error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')
69
+
70
+ return str(file)
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):