File size: 2,658 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
"""Chainer optimizer builders."""
import argparse
import chainer
from chainer.optimizer_hooks import WeightDecay
from espnet.optimizer.factory import OptimizerFactoryInterface
from espnet.optimizer.parser import adadelta
from espnet.optimizer.parser import adam
from espnet.optimizer.parser import sgd
class AdamFactory(OptimizerFactoryInterface):
"""Adam factory."""
@staticmethod
def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Register args."""
return adam(parser)
@staticmethod
def from_args(target, args: argparse.Namespace):
"""Initialize optimizer from argparse Namespace.
Args:
target: for pytorch `model.parameters()`,
for chainer `model`
args (argparse.Namespace): parsed command-line args
"""
opt = chainer.optimizers.Adam(
alpha=args.lr,
beta1=args.beta1,
beta2=args.beta2,
)
opt.setup(target)
opt.add_hook(WeightDecay(args.weight_decay))
return opt
class SGDFactory(OptimizerFactoryInterface):
"""SGD factory."""
@staticmethod
def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Register args."""
return sgd(parser)
@staticmethod
def from_args(target, args: argparse.Namespace):
"""Initialize optimizer from argparse Namespace.
Args:
target: for pytorch `model.parameters()`,
for chainer `model`
args (argparse.Namespace): parsed command-line args
"""
opt = chainer.optimizers.SGD(
lr=args.lr,
)
opt.setup(target)
opt.add_hook(WeightDecay(args.weight_decay))
return opt
class AdadeltaFactory(OptimizerFactoryInterface):
"""Adadelta factory."""
@staticmethod
def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Register args."""
return adadelta(parser)
@staticmethod
def from_args(target, args: argparse.Namespace):
"""Initialize optimizer from argparse Namespace.
Args:
target: for pytorch `model.parameters()`,
for chainer `model`
args (argparse.Namespace): parsed command-line args
"""
opt = chainer.optimizers.AdaDelta(
rho=args.rho,
eps=args.eps,
)
opt.setup(target)
opt.add_hook(WeightDecay(args.weight_decay))
return opt
OPTIMIZER_FACTORY_DICT = {
"adam": AdamFactory,
"sgd": SGDFactory,
"adadelta": AdadeltaFactory,
}
|