File size: 2,799 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Copyright 2020 Hirofumi Inaguma
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Conformer common arguments."""
from distutils.util import strtobool
import logging
def add_arguments_conformer_common(group):
"""Add Transformer common arguments."""
group.add_argument(
"--transformer-encoder-pos-enc-layer-type",
type=str,
default="abs_pos",
choices=["abs_pos", "scaled_abs_pos", "rel_pos"],
help="Transformer encoder positional encoding layer type",
)
group.add_argument(
"--transformer-encoder-activation-type",
type=str,
default="swish",
choices=["relu", "hardtanh", "selu", "swish"],
help="Transformer encoder activation function type",
)
group.add_argument(
"--macaron-style",
default=False,
type=strtobool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
group.add_argument(
"--zero-triu",
default=False,
type=strtobool,
help="If true, zero the uppper triangular part of attention matrix.",
)
# Relative positional encoding
group.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
group.add_argument(
"--use-cnn-module",
default=False,
type=strtobool,
help="Use convolution module or not",
)
group.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
return group
def verify_rel_pos_type(args):
"""Verify the relative positional encoding type for compatibility.
Args:
args (Namespace): original arguments
Returns:
args (Namespace): modified arguments
"""
rel_pos_type = getattr(args, "rel_pos_type", None)
if rel_pos_type is None or rel_pos_type == "legacy":
if args.transformer_encoder_pos_enc_layer_type == "rel_pos":
args.transformer_encoder_pos_enc_layer_type = "legacy_rel_pos"
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
if args.transformer_encoder_selfattn_layer_type == "rel_selfattn":
args.transformer_encoder_selfattn_layer_type = "legacy_rel_selfattn"
logging.warning(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
return args
|