|
|
|
|
|
|
|
"""Conformer common arguments.""" |
|
|
|
|
|
from distutils.util import strtobool |
|
import logging |
|
|
|
|
|
def add_arguments_conformer_common(group): |
|
"""Add Transformer common arguments.""" |
|
group.add_argument( |
|
"--transformer-encoder-pos-enc-layer-type", |
|
type=str, |
|
default="abs_pos", |
|
choices=["abs_pos", "scaled_abs_pos", "rel_pos"], |
|
help="Transformer encoder positional encoding layer type", |
|
) |
|
group.add_argument( |
|
"--transformer-encoder-activation-type", |
|
type=str, |
|
default="swish", |
|
choices=["relu", "hardtanh", "selu", "swish"], |
|
help="Transformer encoder activation function type", |
|
) |
|
group.add_argument( |
|
"--macaron-style", |
|
default=False, |
|
type=strtobool, |
|
help="Whether to use macaron style for positionwise layer", |
|
) |
|
|
|
group.add_argument( |
|
"--zero-triu", |
|
default=False, |
|
type=strtobool, |
|
help="If true, zero the uppper triangular part of attention matrix.", |
|
) |
|
|
|
group.add_argument( |
|
"--rel-pos-type", |
|
type=str, |
|
default="legacy", |
|
choices=["legacy", "latest"], |
|
help="Whether to use the latest relative positional encoding or the legacy one." |
|
"The legacy relative positional encoding will be deprecated in the future." |
|
"More Details can be found in https://github.com/espnet/espnet/pull/2816.", |
|
) |
|
|
|
group.add_argument( |
|
"--use-cnn-module", |
|
default=False, |
|
type=strtobool, |
|
help="Use convolution module or not", |
|
) |
|
group.add_argument( |
|
"--cnn-module-kernel", |
|
default=31, |
|
type=int, |
|
help="Kernel size of convolution module.", |
|
) |
|
return group |
|
|
|
|
|
def verify_rel_pos_type(args): |
|
"""Verify the relative positional encoding type for compatibility. |
|
|
|
Args: |
|
args (Namespace): original arguments |
|
Returns: |
|
args (Namespace): modified arguments |
|
""" |
|
rel_pos_type = getattr(args, "rel_pos_type", None) |
|
if rel_pos_type is None or rel_pos_type == "legacy": |
|
if args.transformer_encoder_pos_enc_layer_type == "rel_pos": |
|
args.transformer_encoder_pos_enc_layer_type = "legacy_rel_pos" |
|
logging.warning( |
|
"Using legacy_rel_pos and it will be deprecated in the future." |
|
) |
|
if args.transformer_encoder_selfattn_layer_type == "rel_selfattn": |
|
args.transformer_encoder_selfattn_layer_type = "legacy_rel_selfattn" |
|
logging.warning( |
|
"Using legacy_rel_selfattn and it will be deprecated in the future." |
|
) |
|
|
|
return args |
|
|