|
|
|
|
|
|
|
"""Conformer common arguments.""" |
|
|
|
|
|
def add_arguments_rnn_encoder_common(group): |
|
"""Define common arguments for RNN encoder.""" |
|
group.add_argument( |
|
"--etype", |
|
default="blstmp", |
|
type=str, |
|
choices=[ |
|
"lstm", |
|
"blstm", |
|
"lstmp", |
|
"blstmp", |
|
"vgglstmp", |
|
"vggblstmp", |
|
"vgglstm", |
|
"vggblstm", |
|
"gru", |
|
"bgru", |
|
"grup", |
|
"bgrup", |
|
"vgggrup", |
|
"vggbgrup", |
|
"vgggru", |
|
"vggbgru", |
|
], |
|
help="Type of encoder network architecture", |
|
) |
|
group.add_argument( |
|
"--elayers", |
|
default=4, |
|
type=int, |
|
help="Number of encoder layers", |
|
) |
|
group.add_argument( |
|
"--eunits", |
|
"-u", |
|
default=300, |
|
type=int, |
|
help="Number of encoder hidden units", |
|
) |
|
group.add_argument( |
|
"--eprojs", default=320, type=int, help="Number of encoder projection units" |
|
) |
|
group.add_argument( |
|
"--subsample", |
|
default="1", |
|
type=str, |
|
help="Subsample input frames x_y_z means " |
|
"subsample every x frame at 1st layer, " |
|
"every y frame at 2nd layer etc.", |
|
) |
|
return group |
|
|
|
|
|
def add_arguments_rnn_decoder_common(group): |
|
"""Define common arguments for RNN decoder.""" |
|
group.add_argument( |
|
"--dtype", |
|
default="lstm", |
|
type=str, |
|
choices=["lstm", "gru"], |
|
help="Type of decoder network architecture", |
|
) |
|
group.add_argument( |
|
"--dlayers", default=1, type=int, help="Number of decoder layers" |
|
) |
|
group.add_argument( |
|
"--dunits", default=320, type=int, help="Number of decoder hidden units" |
|
) |
|
group.add_argument( |
|
"--dropout-rate-decoder", |
|
default=0.0, |
|
type=float, |
|
help="Dropout rate for the decoder", |
|
) |
|
group.add_argument( |
|
"--sampling-probability", |
|
default=0.0, |
|
type=float, |
|
help="Ratio of predicted labels fed back to decoder", |
|
) |
|
group.add_argument( |
|
"--lsm-type", |
|
const="", |
|
default="", |
|
type=str, |
|
nargs="?", |
|
choices=["", "unigram"], |
|
help="Apply label smoothing with a specified distribution type", |
|
) |
|
return group |
|
|
|
|
|
def add_arguments_rnn_attention_common(group): |
|
"""Define common arguments for RNN attention.""" |
|
group.add_argument( |
|
"--atype", |
|
default="dot", |
|
type=str, |
|
choices=[ |
|
"noatt", |
|
"dot", |
|
"add", |
|
"location", |
|
"coverage", |
|
"coverage_location", |
|
"location2d", |
|
"location_recurrent", |
|
"multi_head_dot", |
|
"multi_head_add", |
|
"multi_head_loc", |
|
"multi_head_multi_res_loc", |
|
], |
|
help="Type of attention architecture", |
|
) |
|
group.add_argument( |
|
"--adim", |
|
default=320, |
|
type=int, |
|
help="Number of attention transformation dimensions", |
|
) |
|
group.add_argument( |
|
"--awin", default=5, type=int, help="Window size for location2d attention" |
|
) |
|
group.add_argument( |
|
"--aheads", |
|
default=4, |
|
type=int, |
|
help="Number of heads for multi head attention", |
|
) |
|
group.add_argument( |
|
"--aconv-chans", |
|
default=-1, |
|
type=int, |
|
help="Number of attention convolution channels \ |
|
(negative value indicates no location-aware attention)", |
|
) |
|
group.add_argument( |
|
"--aconv-filts", |
|
default=100, |
|
type=int, |
|
help="Number of attention convolution filters \ |
|
(negative value indicates no location-aware attention)", |
|
) |
|
group.add_argument( |
|
"--dropout-rate", |
|
default=0.0, |
|
type=float, |
|
help="Dropout rate for the encoder", |
|
) |
|
return group |
|
|