Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 858 Bytes
1f53a4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch.optim as optim
import torch.nn as nn
def set_optimizer(optimizer_name: str, network: nn.Module, lr: float) -> optim:
"""
Set optimizer.
Args:
optimizer_name (str): criterion name
network (torch.nn.Module): network
lr (float): learning rate
Returns:
torch.optim: optimizer
"""
optimizers = {
'SGD': optim.SGD,
'Adadelta': optim.Adadelta,
'Adam': optim.Adam,
'RMSprop': optim.RMSprop,
'RAdam': optim.RAdam
}
assert (optimizer_name in optimizers), f"No specified optimizer: {optimizer_name}."
_optim = optimizers[optimizer_name]
if lr is None:
optimizer = _optim(network.parameters())
else:
optimizer = _optim(network.parameters(), lr=lr)
return optimizer
|