File size: 5,133 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""TDNN modules definition for transformer encoder."""

import logging
from typing import Tuple
from typing import Union

import torch


class TDNN(torch.nn.Module):
    """TDNN implementation with symmetric context.

    Args:
        idim: Dimension of inputs
        odim: Dimension of outputs
        ctx_size: Size of context window
        stride: Stride of the sliding blocks
        dilation: Parameter to control the stride of
                  elements within the neighborhood
        batch_norm: Whether to use batch normalization
        relu: Whether to use non-linearity layer (ReLU)

    """

    def __init__(
        self,
        idim: int,
        odim: int,
        ctx_size: int = 5,
        dilation: int = 1,
        stride: int = 1,
        batch_norm: bool = False,
        relu: bool = True,
        dropout_rate: float = 0.0,
    ):
        """Construct a TDNN object."""
        super().__init__()

        self.idim = idim
        self.odim = odim

        self.ctx_size = ctx_size
        self.stride = stride
        self.dilation = dilation

        self.batch_norm = batch_norm
        self.relu = relu

        self.tdnn = torch.nn.Conv1d(
            idim, odim, ctx_size, stride=stride, dilation=dilation
        )

        if self.relu:
            self.relu_func = torch.nn.ReLU()

        if self.batch_norm:
            self.bn = torch.nn.BatchNorm1d(odim)

        self.dropout = torch.nn.Dropout(p=dropout_rate)

    def forward(
        self,
        x_input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
        masks: torch.Tensor,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
        """Forward TDNN.

        Args:
            x_input: Input tensor (B, T, idim) or ((B, T, idim), (B, T, att_dim))
            or ((B, T, idim), (B, 2*T-1, att_dim))
            masks: Input mask (B, 1, T)

        Returns:
            x_output: Output tensor (B, sub(T), odim)
                          or ((B, sub(T), odim), (B, sub(T), att_dim))
            mask: Output mask (B, 1, sub(T))

        """
        if isinstance(x_input, tuple):
            xs, pos_emb = x_input[0], x_input[1]
        else:
            xs, pos_emb = x_input, None

        # The bidirect_pos is used to distinguish legacy_rel_pos and rel_pos in
        # Conformer model. Note the `legacy_rel_pos` will be deprecated in the future.
        # Details can be found in https://github.com/espnet/espnet/pull/2816.
        if pos_emb is not None and pos_emb.size(1) == 2 * xs.size(1) - 1:
            logging.warning("Using bidirectional relative postitional encoding.")
            bidirect_pos = True
        else:
            bidirect_pos = False

        xs = xs.transpose(1, 2)
        xs = self.tdnn(xs)

        if self.relu:
            xs = self.relu_func(xs)

        xs = self.dropout(xs)

        if self.batch_norm:
            xs = self.bn(xs)

        xs = xs.transpose(1, 2)

        return self.create_outputs(xs, pos_emb, masks, bidirect_pos=bidirect_pos)

    def create_outputs(
        self,
        xs: torch.Tensor,
        pos_emb: torch.Tensor,
        masks: torch.Tensor,
        bidirect_pos: bool = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
        """Create outputs with subsampled version of pos_emb and masks.

        Args:
            xs: Output tensor (B, sub(T), odim)
            pos_emb: Input positional embedding tensor (B, T, att_dim)
            or (B, 2*T-1, att_dim)
            masks: Input mask (B, 1, T)
            bidirect_pos: whether to use bidirectional positional embedding

        Returns:
            xs: Output tensor (B, sub(T), odim)
            pos_emb: Output positional embedding tensor (B, sub(T), att_dim)
            or (B, 2*sub(T)-1, att_dim)
            masks: Output mask (B, 1, sub(T))

        """
        sub = (self.ctx_size - 1) * self.dilation

        if masks is not None:
            if sub != 0:
                masks = masks[:, :, :-sub]

            masks = masks[:, :, :: self.stride]

        if pos_emb is not None:
            # If the bidirect_pos is true, the pos_emb will include both positive and
            # negative embeddings. Refer to https://github.com/espnet/espnet/pull/2816.
            if bidirect_pos:
                pos_emb_positive = pos_emb[:, : pos_emb.size(1) // 2 + 1, :]
                pos_emb_negative = pos_emb[:, pos_emb.size(1) // 2 :, :]

                if sub != 0:
                    pos_emb_positive = pos_emb_positive[:, :-sub, :]
                    pos_emb_negative = pos_emb_negative[:, :-sub, :]

                pos_emb_positive = pos_emb_positive[:, :: self.stride, :]
                pos_emb_negative = pos_emb_negative[:, :: self.stride, :]
                pos_emb = torch.cat(
                    [pos_emb_positive, pos_emb_negative[:, 1:, :]], dim=1
                )
            else:
                if sub != 0:
                    pos_emb = pos_emb[:, :-sub, :]

                pos_emb = pos_emb[:, :: self.stride, :]

            return (xs, pos_emb), masks

        return xs, masks