Spaces:
Runtime error
Runtime error
import torch | |
import math | |
from isegm.utils.log import logger | |
def get_optimizer(model, opt_name, opt_kwargs): | |
params = [] | |
base_lr = opt_kwargs['lr'] | |
for name, param in model.named_parameters(): | |
param_group = {'params': [param]} | |
if not param.requires_grad: | |
params.append(param_group) | |
continue | |
if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): | |
logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') | |
param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult | |
params.append(param_group) | |
optimizer = { | |
'sgd': torch.optim.SGD, | |
'adam': torch.optim.Adam, | |
'adamw': torch.optim.AdamW | |
}[opt_name.lower()](params, **opt_kwargs) | |
return optimizer | |