Spaces:
Runtime error
Runtime error
File size: 1,135 Bytes
3953219 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
import monai
from .utils import load_config
def get_optimizer(model: torch.nn.Module,
config: dict):
"""Create an optimizer of `type` with specific keyword arguments from config.
Example:
config.optimizer
>>> {'Novograd': {'lr': 0.001, 'weight_decay': 0.01}}
get_optimizer(model, config)
>>> Novograd (
>>> Parameter Group 0
>>> amsgrad: False
>>> betas: (0.9, 0.999)
>>> eps: 1e-08
>>> grad_averaging: False
>>> lr: 0.0001
>>> weight_decay: 0.001
>>> )
"""
optimizer_type = list(config.optimizer.keys())[0]
opt_config = config.optimizer[optimizer_type]
if hasattr(torch.optim, optimizer_type):
optimizer_fun = getattr(torch.optim, optimizer_type)
elif hasattr(monai.optimizers, optimizer_type):
optimizer_fun = getattr(monai.optimizers, optimizer_type)
else:
raise ValueError(f'Optimizer {optimizer_type} not found')
optimizer = optimizer_fun(model.parameters(), **opt_config)
return optimizer
|