developer0hye pre-commit-ci[bot] glenn-jocher commited on
Commit
7204c1c
·
unverified ·
1 Parent(s): 72a81e7

Explicitly set `weight_decay` value (#8592)

Browse files

* explicitly set weight_decay value

The default weight_decay value of AdamW is 1e-2, so we should set it to zero.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. train.py +2 -2
train.py CHANGED
@@ -163,12 +163,12 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
163
  if opt.optimizer == 'Adam':
164
  optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
165
  elif opt.optimizer == 'AdamW':
166
- optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
167
  else:
168
  optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
169
 
170
  optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay
171
- optimizer.add_param_group({'params': g[1]}) # add g1 (BatchNorm2d weights)
172
  LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
173
  f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
174
  del g
 
163
  if opt.optimizer == 'Adam':
164
  optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
165
  elif opt.optimizer == 'AdamW':
166
+ optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999), weight_decay=0.0)
167
  else:
168
  optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
169
 
170
  optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay
171
+ optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
172
  LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
173
  f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
174
  del g