P-DFD / optimizer /__init__.py
mrneuralnet's picture
Initial commit
982865f
raw
history blame contribute delete
750 Bytes
from torch.optim import SGD
from torch.optim import Adam
from torch.optim import ASGD
from torch.optim import Adamax
from torch.optim import Adadelta
from torch.optim import Adagrad
from torch.optim import RMSprop
key2opt = {
'sgd': SGD,
'adam': Adam,
'asgd': ASGD,
'adamax': Adamax,
'adadelta': Adadelta,
'adagrad': Adagrad,
'rmsprop': RMSprop,
}
def get_optimizer(optimizer_name=None):
if optimizer_name is None:
print("Using default 'SGD' optimizer")
return SGD
else:
if optimizer_name not in key2opt:
raise NotImplementedError(f"Optimizer '{optimizer_name}' not implemented")
print(f"Using optimizer: '{optimizer_name}'")
return key2opt[optimizer_name]