File size: 858 Bytes
1f53a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch.optim as optim
import torch.nn as nn


def set_optimizer(optimizer_name: str, network: nn.Module, lr: float) -> optim:
    """
    Set optimizer.
    Args:
        optimizer_name (str): criterion name
        network (torch.nn.Module): network
        lr (float): learning rate
    Returns:
        torch.optim: optimizer
    """
    optimizers = {
        'SGD': optim.SGD,
        'Adadelta': optim.Adadelta,
        'Adam': optim.Adam,
        'RMSprop': optim.RMSprop,
        'RAdam': optim.RAdam
        }

    assert (optimizer_name in optimizers), f"No specified optimizer: {optimizer_name}."

    _optim = optimizers[optimizer_name]

    if lr is None:
        optimizer = _optim(network.parameters())
    else:
        optimizer = _optim(network.parameters(), lr=lr)
    return optimizer