# Copyright 2020 Kyoto University (Hirofumi Inaguma) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """ Conformer speech translation model (pytorch). It is a fusion of `e2e_st_transformer.py` Refer to: https://arxiv.org/abs/2005.08100 """ from espnet.nets.pytorch_backend.conformer.encoder import Encoder from espnet.nets.pytorch_backend.e2e_st_transformer import E2E as E2ETransformer from espnet.nets.pytorch_backend.conformer.argument import ( add_arguments_conformer_common, # noqa: H301 verify_rel_pos_type, # noqa: H301 ) class E2E(E2ETransformer): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" E2ETransformer.add_arguments(parser) E2E.add_conformer_arguments(parser) return parser @staticmethod def add_conformer_arguments(parser): """Add arguments for conformer model.""" group = parser.add_argument_group("conformer model specific setting") group = add_arguments_conformer_common(group) return parser def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super().__init__(idim, odim, args, ignore_id) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate # Check the relative positional encoding type args = verify_rel_pos_type(args) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, pos_enc_layer_type=args.transformer_encoder_pos_enc_layer_type, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, activation_type=args.transformer_encoder_activation_type, macaron_style=args.macaron_style, use_cnn_module=args.use_cnn_module, cnn_module_kernel=args.cnn_module_kernel, ) self.reset_parameters(args)