File size: 800 Bytes
2cdd41c
1615d09
 
 
2cdd41c
 
 
 
 
1615d09
2cdd41c
1615d09
2cdd41c
 
 
 
1615d09
2cdd41c
1615d09
2cdd41c
 
 
 
1615d09
 
 
2cdd41c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import math

import torch

from isegm.utils.log import logger


def get_optimizer(model, opt_name, opt_kwargs):
    params = []
    base_lr = opt_kwargs["lr"]
    for name, param in model.named_parameters():
        param_group = {"params": [param]}
        if not param.requires_grad:
            params.append(param_group)
            continue

        if not math.isclose(getattr(param, "lr_mult", 1.0), 1.0):
            logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
            param_group["lr"] = param_group.get("lr", base_lr) * param.lr_mult

        params.append(param_group)

    optimizer = {
        "sgd": torch.optim.SGD,
        "adam": torch.optim.Adam,
        "adamw": torch.optim.AdamW,
    }[opt_name.lower()](params, **opt_kwargs)

    return optimizer