Explicitly set `weight_decay` value (#8592)
Browse files* explicitly set weight_decay value
The default weight_decay value of AdamW is 1e-2, so we should set it to zero.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Cleanup
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
train.py
CHANGED
@@ -163,12 +163,12 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
163 |
if opt.optimizer == 'Adam':
|
164 |
optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
165 |
elif opt.optimizer == 'AdamW':
|
166 |
-
optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))
|
167 |
else:
|
168 |
optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
169 |
|
170 |
optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay
|
171 |
-
optimizer.add_param_group({'params': g[1]}) # add g1 (BatchNorm2d weights)
|
172 |
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
|
173 |
f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
|
174 |
del g
|
|
|
163 |
if opt.optimizer == 'Adam':
|
164 |
optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
165 |
elif opt.optimizer == 'AdamW':
|
166 |
+
optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999), weight_decay=0.0)
|
167 |
else:
|
168 |
optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
169 |
|
170 |
optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay
|
171 |
+
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
172 |
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
|
173 |
f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
|
174 |
del g
|