File size: 4,031 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
"""Cutom encoder definition for transducer models."""
import torch
from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
class CustomEncoder(torch.nn.Module):
"""Custom encoder module for transducer models.
Args:
idim (int): input dim
enc_arch (list): list of encoder blocks (type and parameters)
input_layer (str): input layer type
repeat_block (int): repeat provided block N times if N > 1
self_attn_type (str): type of self-attention
positional_encoding_type (str): positional encoding type
positionwise_layer_type (str): linear
positionwise_activation_type (str): positionwise activation type
conv_mod_activation_type (str): convolutional module activation type
normalize_before (bool): whether to use layer_norm before the first block
aux_task_layer_list (list): list of layer ids for intermediate output
padding_idx (int): padding_idx for embedding input layer (if specified)
"""
def __init__(
self,
idim,
enc_arch,
input_layer="linear",
repeat_block=0,
self_attn_type="selfattn",
positional_encoding_type="abs_pos",
positionwise_layer_type="linear",
positionwise_activation_type="relu",
conv_mod_activation_type="relu",
normalize_before=True,
aux_task_layer_list=[],
padding_idx=-1,
):
"""Construct an CustomEncoder object."""
super().__init__()
(
self.embed,
self.encoders,
self.enc_out,
self.conv_subsampling_factor,
) = build_blocks(
"encoder",
idim,
input_layer,
enc_arch,
repeat_block=repeat_block,
self_attn_type=self_attn_type,
positional_encoding_type=positional_encoding_type,
positionwise_layer_type=positionwise_layer_type,
positionwise_activation_type=positionwise_activation_type,
conv_mod_activation_type=conv_mod_activation_type,
padding_idx=padding_idx,
)
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(self.enc_out)
self.n_blocks = len(enc_arch) * repeat_block
self.aux_task_layer_list = aux_task_layer_list
def forward(self, xs, masks):
"""Encode input sequence.
Args:
xs (torch.Tensor): input tensor
masks (torch.Tensor): input mask
Returns:
xs (torch.Tensor or tuple):
position embedded output or
(position embedded output, auxiliary outputs)
mask (torch.Tensor): position embedded mask
"""
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
if self.aux_task_layer_list:
aux_xs_list = []
for b in range(self.n_blocks):
xs, masks = self.encoders[b](xs, masks)
if b in self.aux_task_layer_list:
if isinstance(xs, tuple):
aux_xs = xs[0]
else:
aux_xs = xs
if self.normalize_before:
aux_xs_list.append(self.after_norm(aux_xs))
else:
aux_xs_list.append(aux_xs)
else:
xs, masks = self.encoders(xs, masks)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
if self.aux_task_layer_list:
return (xs, aux_xs_list), masks
return xs, masks
|