curt-park's picture
Refactor code
1615d09
raw
history blame contribute delete
800 Bytes
import math
import torch
from isegm.utils.log import logger
def get_optimizer(model, opt_name, opt_kwargs):
params = []
base_lr = opt_kwargs["lr"]
for name, param in model.named_parameters():
param_group = {"params": [param]}
if not param.requires_grad:
params.append(param_group)
continue
if not math.isclose(getattr(param, "lr_mult", 1.0), 1.0):
logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
param_group["lr"] = param_group.get("lr", base_lr) * param.lr_mult
params.append(param_group)
optimizer = {
"sgd": torch.optim.SGD,
"adam": torch.optim.Adam,
"adamw": torch.optim.AdamW,
}[opt_name.lower()](params, **opt_kwargs)
return optimizer