Load checkpoint on CPU instead of on GPU (#6516)
Browse files* Load checkpoint on CPU instead of on GPU
* refactor: simplify code
* Cleanup
* Update train.py
Co-authored-by: Glenn Jocher <[email protected]>
train.py
CHANGED
@@ -120,7 +120,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
120 |
if pretrained:
|
121 |
with torch_distributed_zero_first(LOCAL_RANK):
|
122 |
weights = attempt_download(weights) # download if not found locally
|
123 |
-
ckpt = torch.load(weights, map_location=
|
124 |
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
125 |
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
|
126 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
|
|
120 |
if pretrained:
|
121 |
with torch_distributed_zero_first(LOCAL_RANK):
|
122 |
weights = attempt_download(weights) # download if not found locally
|
123 |
+
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
|
124 |
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
125 |
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
|
126 |
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|