Spaces:
Runtime error
Runtime error
# Copyright 2020 Hirofumi Inaguma | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
"""Conformer common arguments.""" | |
def add_arguments_rnn_encoder_common(group): | |
"""Define common arguments for RNN encoder.""" | |
group.add_argument( | |
"--etype", | |
default="blstmp", | |
type=str, | |
choices=[ | |
"lstm", | |
"blstm", | |
"lstmp", | |
"blstmp", | |
"vgglstmp", | |
"vggblstmp", | |
"vgglstm", | |
"vggblstm", | |
"gru", | |
"bgru", | |
"grup", | |
"bgrup", | |
"vgggrup", | |
"vggbgrup", | |
"vgggru", | |
"vggbgru", | |
], | |
help="Type of encoder network architecture", | |
) | |
group.add_argument( | |
"--elayers", | |
default=4, | |
type=int, | |
help="Number of encoder layers", | |
) | |
group.add_argument( | |
"--eunits", | |
"-u", | |
default=300, | |
type=int, | |
help="Number of encoder hidden units", | |
) | |
group.add_argument( | |
"--eprojs", default=320, type=int, help="Number of encoder projection units" | |
) | |
group.add_argument( | |
"--subsample", | |
default="1", | |
type=str, | |
help="Subsample input frames x_y_z means " | |
"subsample every x frame at 1st layer, " | |
"every y frame at 2nd layer etc.", | |
) | |
return group | |
def add_arguments_rnn_decoder_common(group): | |
"""Define common arguments for RNN decoder.""" | |
group.add_argument( | |
"--dtype", | |
default="lstm", | |
type=str, | |
choices=["lstm", "gru"], | |
help="Type of decoder network architecture", | |
) | |
group.add_argument( | |
"--dlayers", default=1, type=int, help="Number of decoder layers" | |
) | |
group.add_argument( | |
"--dunits", default=320, type=int, help="Number of decoder hidden units" | |
) | |
group.add_argument( | |
"--dropout-rate-decoder", | |
default=0.0, | |
type=float, | |
help="Dropout rate for the decoder", | |
) | |
group.add_argument( | |
"--sampling-probability", | |
default=0.0, | |
type=float, | |
help="Ratio of predicted labels fed back to decoder", | |
) | |
group.add_argument( | |
"--lsm-type", | |
const="", | |
default="", | |
type=str, | |
nargs="?", | |
choices=["", "unigram"], | |
help="Apply label smoothing with a specified distribution type", | |
) | |
return group | |
def add_arguments_rnn_attention_common(group): | |
"""Define common arguments for RNN attention.""" | |
group.add_argument( | |
"--atype", | |
default="dot", | |
type=str, | |
choices=[ | |
"noatt", | |
"dot", | |
"add", | |
"location", | |
"coverage", | |
"coverage_location", | |
"location2d", | |
"location_recurrent", | |
"multi_head_dot", | |
"multi_head_add", | |
"multi_head_loc", | |
"multi_head_multi_res_loc", | |
], | |
help="Type of attention architecture", | |
) | |
group.add_argument( | |
"--adim", | |
default=320, | |
type=int, | |
help="Number of attention transformation dimensions", | |
) | |
group.add_argument( | |
"--awin", default=5, type=int, help="Window size for location2d attention" | |
) | |
group.add_argument( | |
"--aheads", | |
default=4, | |
type=int, | |
help="Number of heads for multi head attention", | |
) | |
group.add_argument( | |
"--aconv-chans", | |
default=-1, | |
type=int, | |
help="Number of attention convolution channels \ | |
(negative value indicates no location-aware attention)", | |
) | |
group.add_argument( | |
"--aconv-filts", | |
default=100, | |
type=int, | |
help="Number of attention convolution filters \ | |
(negative value indicates no location-aware attention)", | |
) | |
group.add_argument( | |
"--dropout-rate", | |
default=0.0, | |
type=float, | |
help="Dropout rate for the encoder", | |
) | |
return group | |