File size: 4,031 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Cutom encoder definition for transducer models."""

import torch

from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L

from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling


class CustomEncoder(torch.nn.Module):
    """Custom encoder module for transducer models.

    Args:
        idim (int): input dim
        enc_arch (list): list of encoder blocks (type and parameters)
        input_layer (str): input layer type
        repeat_block (int): repeat provided block N times if N > 1
        self_attn_type (str): type of self-attention
        positional_encoding_type (str): positional encoding type
        positionwise_layer_type (str): linear
        positionwise_activation_type (str): positionwise activation type
        conv_mod_activation_type (str): convolutional module activation type
        normalize_before (bool): whether to use layer_norm before the first block
        aux_task_layer_list (list): list of layer ids for intermediate output
        padding_idx (int): padding_idx for embedding input layer (if specified)

    """

    def __init__(
        self,
        idim,
        enc_arch,
        input_layer="linear",
        repeat_block=0,
        self_attn_type="selfattn",
        positional_encoding_type="abs_pos",
        positionwise_layer_type="linear",
        positionwise_activation_type="relu",
        conv_mod_activation_type="relu",
        normalize_before=True,
        aux_task_layer_list=[],
        padding_idx=-1,
    ):
        """Construct an CustomEncoder object."""
        super().__init__()

        (
            self.embed,
            self.encoders,
            self.enc_out,
            self.conv_subsampling_factor,
        ) = build_blocks(
            "encoder",
            idim,
            input_layer,
            enc_arch,
            repeat_block=repeat_block,
            self_attn_type=self_attn_type,
            positional_encoding_type=positional_encoding_type,
            positionwise_layer_type=positionwise_layer_type,
            positionwise_activation_type=positionwise_activation_type,
            conv_mod_activation_type=conv_mod_activation_type,
            padding_idx=padding_idx,
        )

        self.normalize_before = normalize_before

        if self.normalize_before:
            self.after_norm = LayerNorm(self.enc_out)

        self.n_blocks = len(enc_arch) * repeat_block

        self.aux_task_layer_list = aux_task_layer_list

    def forward(self, xs, masks):
        """Encode input sequence.

        Args:
            xs (torch.Tensor): input tensor
            masks (torch.Tensor): input mask

        Returns:
            xs (torch.Tensor or tuple):
                position embedded output or
                (position embedded output, auxiliary outputs)
            mask (torch.Tensor): position embedded mask

        """
        if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
            xs, masks = self.embed(xs, masks)
        else:
            xs = self.embed(xs)

        if self.aux_task_layer_list:
            aux_xs_list = []

            for b in range(self.n_blocks):
                xs, masks = self.encoders[b](xs, masks)

                if b in self.aux_task_layer_list:
                    if isinstance(xs, tuple):
                        aux_xs = xs[0]
                    else:
                        aux_xs = xs

                    if self.normalize_before:
                        aux_xs_list.append(self.after_norm(aux_xs))
                    else:
                        aux_xs_list.append(aux_xs)
        else:
            xs, masks = self.encoders(xs, masks)

        if isinstance(xs, tuple):
            xs = xs[0]

        if self.normalize_before:
            xs = self.after_norm(xs)

        if self.aux_task_layer_list:
            return (xs, aux_xs_list), masks

        return xs, masks