bilzard glenn-jocher commited on
Commit
aff0281
·
unverified ·
1 Parent(s): 8fcdf3b

Load checkpoint on CPU instead of on GPU (#6516)

Browse files

* Load checkpoint on CPU instead of on GPU

* refactor: simplify code

* Cleanup

* Update train.py

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. train.py +1 -1
train.py CHANGED
@@ -120,7 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
120
  if pretrained:
121
  with torch_distributed_zero_first(LOCAL_RANK):
122
  weights = attempt_download(weights) # download if not found locally
123
- ckpt = torch.load(weights, map_location=device) # load checkpoint
124
  model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
125
  exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
126
  csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
 
120
  if pretrained:
121
  with torch_distributed_zero_first(LOCAL_RANK):
122
  weights = attempt_download(weights) # download if not found locally
123
+ ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
124
  model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
125
  exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
126
  csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32