File size: 750 Bytes
982865f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch.optim import SGD
from torch.optim import Adam
from torch.optim import ASGD
from torch.optim import Adamax
from torch.optim import Adadelta
from torch.optim import Adagrad
from torch.optim import RMSprop

key2opt = {
    'sgd': SGD,
    'adam': Adam,
    'asgd': ASGD,
    'adamax': Adamax,
    'adadelta': Adadelta,
    'adagrad': Adagrad,
    'rmsprop': RMSprop,
}


def get_optimizer(optimizer_name=None):
    if optimizer_name is None:
        print("Using default 'SGD' optimizer")
        return SGD

    else:
        if optimizer_name not in key2opt:
            raise NotImplementedError(f"Optimizer '{optimizer_name}' not implemented")

        print(f"Using optimizer: '{optimizer_name}'")
        return key2opt[optimizer_name]