File size: 1,135 Bytes
3953219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import monai
from .utils import load_config


def get_optimizer(model: torch.nn.Module, 
                  config: dict): 
    """Create an optimizer of `type` with specific keyword arguments from config.
    Example: 
        
        config.optimizer
        >>> {'Novograd': {'lr': 0.001, 'weight_decay': 0.01}}

        get_optimizer(model, config)
        >>> Novograd (
        >>> Parameter Group 0
        >>>     amsgrad: False
        >>>     betas: (0.9, 0.999)
        >>>     eps: 1e-08
        >>>     grad_averaging: False
        >>>     lr: 0.0001
        >>>     weight_decay: 0.001
        >>> )
    
    """
    optimizer_type = list(config.optimizer.keys())[0]
    opt_config = config.optimizer[optimizer_type]
    if hasattr(torch.optim, optimizer_type): 
        optimizer_fun = getattr(torch.optim, optimizer_type)
    elif hasattr(monai.optimizers, optimizer_type): 
        optimizer_fun = getattr(monai.optimizers, optimizer_type)
    else: 
        raise ValueError(f'Optimizer {optimizer_type} not found')
    optimizer = optimizer_fun(model.parameters(), **opt_config)
    return optimizer