from torch.optim import SGD from torch.optim import Adam from torch.optim import ASGD from torch.optim import Adamax from torch.optim import Adadelta from torch.optim import Adagrad from torch.optim import RMSprop key2opt = { 'sgd': SGD, 'adam': Adam, 'asgd': ASGD, 'adamax': Adamax, 'adadelta': Adadelta, 'adagrad': Adagrad, 'rmsprop': RMSprop, } def get_optimizer(optimizer_name=None): if optimizer_name is None: print("Using default 'SGD' optimizer") return SGD else: if optimizer_name not in key2opt: raise NotImplementedError(f"Optimizer '{optimizer_name}' not implemented") print(f"Using optimizer: '{optimizer_name}'") return key2opt[optimizer_name]