File size: 2,799 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright 2020 Hirofumi Inaguma
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Conformer common arguments."""


from distutils.util import strtobool
import logging


def add_arguments_conformer_common(group):
    """Add Transformer common arguments."""
    group.add_argument(
        "--transformer-encoder-pos-enc-layer-type",
        type=str,
        default="abs_pos",
        choices=["abs_pos", "scaled_abs_pos", "rel_pos"],
        help="Transformer encoder positional encoding layer type",
    )
    group.add_argument(
        "--transformer-encoder-activation-type",
        type=str,
        default="swish",
        choices=["relu", "hardtanh", "selu", "swish"],
        help="Transformer encoder activation function type",
    )
    group.add_argument(
        "--macaron-style",
        default=False,
        type=strtobool,
        help="Whether to use macaron style for positionwise layer",
    )
    # Attention
    group.add_argument(
        "--zero-triu",
        default=False,
        type=strtobool,
        help="If true, zero the uppper triangular part of attention matrix.",
    )
    # Relative positional encoding
    group.add_argument(
        "--rel-pos-type",
        type=str,
        default="legacy",
        choices=["legacy", "latest"],
        help="Whether to use the latest relative positional encoding or the legacy one."
        "The legacy relative positional encoding will be deprecated in the future."
        "More Details can be found in https://github.com/espnet/espnet/pull/2816.",
    )
    # CNN module
    group.add_argument(
        "--use-cnn-module",
        default=False,
        type=strtobool,
        help="Use convolution module or not",
    )
    group.add_argument(
        "--cnn-module-kernel",
        default=31,
        type=int,
        help="Kernel size of convolution module.",
    )
    return group


def verify_rel_pos_type(args):
    """Verify the relative positional encoding type for compatibility.

    Args:
        args (Namespace): original arguments
    Returns:
        args (Namespace): modified arguments
    """
    rel_pos_type = getattr(args, "rel_pos_type", None)
    if rel_pos_type is None or rel_pos_type == "legacy":
        if args.transformer_encoder_pos_enc_layer_type == "rel_pos":
            args.transformer_encoder_pos_enc_layer_type = "legacy_rel_pos"
            logging.warning(
                "Using legacy_rel_pos and it will be deprecated in the future."
            )
        if args.transformer_encoder_selfattn_layer_type == "rel_selfattn":
            args.transformer_encoder_selfattn_layer_type = "legacy_rel_selfattn"
            logging.warning(
                "Using legacy_rel_selfattn and it will be deprecated in the future."
            )

    return args