File size: 4,080 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright 2020 Hirofumi Inaguma
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Conformer common arguments."""


def add_arguments_rnn_encoder_common(group):
    """Define common arguments for RNN encoder."""
    group.add_argument(
        "--etype",
        default="blstmp",
        type=str,
        choices=[
            "lstm",
            "blstm",
            "lstmp",
            "blstmp",
            "vgglstmp",
            "vggblstmp",
            "vgglstm",
            "vggblstm",
            "gru",
            "bgru",
            "grup",
            "bgrup",
            "vgggrup",
            "vggbgrup",
            "vgggru",
            "vggbgru",
        ],
        help="Type of encoder network architecture",
    )
    group.add_argument(
        "--elayers",
        default=4,
        type=int,
        help="Number of encoder layers",
    )
    group.add_argument(
        "--eunits",
        "-u",
        default=300,
        type=int,
        help="Number of encoder hidden units",
    )
    group.add_argument(
        "--eprojs", default=320, type=int, help="Number of encoder projection units"
    )
    group.add_argument(
        "--subsample",
        default="1",
        type=str,
        help="Subsample input frames x_y_z means "
        "subsample every x frame at 1st layer, "
        "every y frame at 2nd layer etc.",
    )
    return group


def add_arguments_rnn_decoder_common(group):
    """Define common arguments for RNN decoder."""
    group.add_argument(
        "--dtype",
        default="lstm",
        type=str,
        choices=["lstm", "gru"],
        help="Type of decoder network architecture",
    )
    group.add_argument(
        "--dlayers", default=1, type=int, help="Number of decoder layers"
    )
    group.add_argument(
        "--dunits", default=320, type=int, help="Number of decoder hidden units"
    )
    group.add_argument(
        "--dropout-rate-decoder",
        default=0.0,
        type=float,
        help="Dropout rate for the decoder",
    )
    group.add_argument(
        "--sampling-probability",
        default=0.0,
        type=float,
        help="Ratio of predicted labels fed back to decoder",
    )
    group.add_argument(
        "--lsm-type",
        const="",
        default="",
        type=str,
        nargs="?",
        choices=["", "unigram"],
        help="Apply label smoothing with a specified distribution type",
    )
    return group


def add_arguments_rnn_attention_common(group):
    """Define common arguments for RNN attention."""
    group.add_argument(
        "--atype",
        default="dot",
        type=str,
        choices=[
            "noatt",
            "dot",
            "add",
            "location",
            "coverage",
            "coverage_location",
            "location2d",
            "location_recurrent",
            "multi_head_dot",
            "multi_head_add",
            "multi_head_loc",
            "multi_head_multi_res_loc",
        ],
        help="Type of attention architecture",
    )
    group.add_argument(
        "--adim",
        default=320,
        type=int,
        help="Number of attention transformation dimensions",
    )
    group.add_argument(
        "--awin", default=5, type=int, help="Window size for location2d attention"
    )
    group.add_argument(
        "--aheads",
        default=4,
        type=int,
        help="Number of heads for multi head attention",
    )
    group.add_argument(
        "--aconv-chans",
        default=-1,
        type=int,
        help="Number of attention convolution channels \
                       (negative value indicates no location-aware attention)",
    )
    group.add_argument(
        "--aconv-filts",
        default=100,
        type=int,
        help="Number of attention convolution filters \
                       (negative value indicates no location-aware attention)",
    )
    group.add_argument(
        "--dropout-rate",
        default=0.0,
        type=float,
        help="Dropout rate for the encoder",
    )
    return group