File size: 5,133 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
"""TDNN modules definition for transformer encoder."""
import logging
from typing import Tuple
from typing import Union
import torch
class TDNN(torch.nn.Module):
"""TDNN implementation with symmetric context.
Args:
idim: Dimension of inputs
odim: Dimension of outputs
ctx_size: Size of context window
stride: Stride of the sliding blocks
dilation: Parameter to control the stride of
elements within the neighborhood
batch_norm: Whether to use batch normalization
relu: Whether to use non-linearity layer (ReLU)
"""
def __init__(
self,
idim: int,
odim: int,
ctx_size: int = 5,
dilation: int = 1,
stride: int = 1,
batch_norm: bool = False,
relu: bool = True,
dropout_rate: float = 0.0,
):
"""Construct a TDNN object."""
super().__init__()
self.idim = idim
self.odim = odim
self.ctx_size = ctx_size
self.stride = stride
self.dilation = dilation
self.batch_norm = batch_norm
self.relu = relu
self.tdnn = torch.nn.Conv1d(
idim, odim, ctx_size, stride=stride, dilation=dilation
)
if self.relu:
self.relu_func = torch.nn.ReLU()
if self.batch_norm:
self.bn = torch.nn.BatchNorm1d(odim)
self.dropout = torch.nn.Dropout(p=dropout_rate)
def forward(
self,
x_input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
masks: torch.Tensor,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
"""Forward TDNN.
Args:
x_input: Input tensor (B, T, idim) or ((B, T, idim), (B, T, att_dim))
or ((B, T, idim), (B, 2*T-1, att_dim))
masks: Input mask (B, 1, T)
Returns:
x_output: Output tensor (B, sub(T), odim)
or ((B, sub(T), odim), (B, sub(T), att_dim))
mask: Output mask (B, 1, sub(T))
"""
if isinstance(x_input, tuple):
xs, pos_emb = x_input[0], x_input[1]
else:
xs, pos_emb = x_input, None
# The bidirect_pos is used to distinguish legacy_rel_pos and rel_pos in
# Conformer model. Note the `legacy_rel_pos` will be deprecated in the future.
# Details can be found in https://github.com/espnet/espnet/pull/2816.
if pos_emb is not None and pos_emb.size(1) == 2 * xs.size(1) - 1:
logging.warning("Using bidirectional relative postitional encoding.")
bidirect_pos = True
else:
bidirect_pos = False
xs = xs.transpose(1, 2)
xs = self.tdnn(xs)
if self.relu:
xs = self.relu_func(xs)
xs = self.dropout(xs)
if self.batch_norm:
xs = self.bn(xs)
xs = xs.transpose(1, 2)
return self.create_outputs(xs, pos_emb, masks, bidirect_pos=bidirect_pos)
def create_outputs(
self,
xs: torch.Tensor,
pos_emb: torch.Tensor,
masks: torch.Tensor,
bidirect_pos: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
"""Create outputs with subsampled version of pos_emb and masks.
Args:
xs: Output tensor (B, sub(T), odim)
pos_emb: Input positional embedding tensor (B, T, att_dim)
or (B, 2*T-1, att_dim)
masks: Input mask (B, 1, T)
bidirect_pos: whether to use bidirectional positional embedding
Returns:
xs: Output tensor (B, sub(T), odim)
pos_emb: Output positional embedding tensor (B, sub(T), att_dim)
or (B, 2*sub(T)-1, att_dim)
masks: Output mask (B, 1, sub(T))
"""
sub = (self.ctx_size - 1) * self.dilation
if masks is not None:
if sub != 0:
masks = masks[:, :, :-sub]
masks = masks[:, :, :: self.stride]
if pos_emb is not None:
# If the bidirect_pos is true, the pos_emb will include both positive and
# negative embeddings. Refer to https://github.com/espnet/espnet/pull/2816.
if bidirect_pos:
pos_emb_positive = pos_emb[:, : pos_emb.size(1) // 2 + 1, :]
pos_emb_negative = pos_emb[:, pos_emb.size(1) // 2 :, :]
if sub != 0:
pos_emb_positive = pos_emb_positive[:, :-sub, :]
pos_emb_negative = pos_emb_negative[:, :-sub, :]
pos_emb_positive = pos_emb_positive[:, :: self.stride, :]
pos_emb_negative = pos_emb_negative[:, :: self.stride, :]
pos_emb = torch.cat(
[pos_emb_positive, pos_emb_negative[:, 1:, :]], dim=1
)
else:
if sub != 0:
pos_emb = pos_emb[:, :-sub, :]
pos_emb = pos_emb[:, :: self.stride, :]
return (xs, pos_emb), masks
return xs, masks
|