|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import math |
|
from collections.abc import Iterable |
|
|
|
import torch |
|
import torch.nn as nn |
|
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask |
|
from fairseq import utils |
|
from fairseq.models import ( |
|
FairseqEncoder, |
|
FairseqEncoderDecoderModel, |
|
FairseqEncoderModel, |
|
FairseqIncrementalDecoder, |
|
register_model, |
|
register_model_architecture, |
|
) |
|
from fairseq.modules import ( |
|
LinearizedConvolution, |
|
TransformerDecoderLayer, |
|
TransformerEncoderLayer, |
|
VGGBlock, |
|
) |
|
|
|
|
|
@register_model("asr_vggtransformer") |
|
class VGGTransformerModel(FairseqEncoderDecoderModel): |
|
""" |
|
Transformers with convolutional context for ASR |
|
https://arxiv.org/abs/1904.11660 |
|
""" |
|
|
|
def __init__(self, encoder, decoder): |
|
super().__init__(encoder, decoder) |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Add model-specific arguments to the parser.""" |
|
parser.add_argument( |
|
"--input-feat-per-channel", |
|
type=int, |
|
metavar="N", |
|
help="encoder input dimension per input channel", |
|
) |
|
parser.add_argument( |
|
"--vggblock-enc-config", |
|
type=str, |
|
metavar="EXPR", |
|
help=""" |
|
an array of tuples each containing the configuration of one vggblock: |
|
[(out_channels, |
|
conv_kernel_size, |
|
pooling_kernel_size, |
|
num_conv_layers, |
|
use_layer_norm), ...]) |
|
""", |
|
) |
|
parser.add_argument( |
|
"--transformer-enc-config", |
|
type=str, |
|
metavar="EXPR", |
|
help="""" |
|
a tuple containing the configuration of the encoder transformer layers |
|
configurations: |
|
[(input_dim, |
|
num_heads, |
|
ffn_dim, |
|
normalize_before, |
|
dropout, |
|
attention_dropout, |
|
relu_dropout), ...]') |
|
""", |
|
) |
|
parser.add_argument( |
|
"--enc-output-dim", |
|
type=int, |
|
metavar="N", |
|
help=""" |
|
encoder output dimension, can be None. If specified, projecting the |
|
transformer output to the specified dimension""", |
|
) |
|
parser.add_argument( |
|
"--in-channels", |
|
type=int, |
|
metavar="N", |
|
help="number of encoder input channels", |
|
) |
|
parser.add_argument( |
|
"--tgt-embed-dim", |
|
type=int, |
|
metavar="N", |
|
help="embedding dimension of the decoder target tokens", |
|
) |
|
parser.add_argument( |
|
"--transformer-dec-config", |
|
type=str, |
|
metavar="EXPR", |
|
help=""" |
|
a tuple containing the configuration of the decoder transformer layers |
|
configurations: |
|
[(input_dim, |
|
num_heads, |
|
ffn_dim, |
|
normalize_before, |
|
dropout, |
|
attention_dropout, |
|
relu_dropout), ...] |
|
""", |
|
) |
|
parser.add_argument( |
|
"--conv-dec-config", |
|
type=str, |
|
metavar="EXPR", |
|
help=""" |
|
an array of tuples for the decoder 1-D convolution config |
|
[(out_channels, conv_kernel_size, use_layer_norm), ...]""", |
|
) |
|
|
|
@classmethod |
|
def build_encoder(cls, args, task): |
|
return VGGTransformerEncoder( |
|
input_feat_per_channel=args.input_feat_per_channel, |
|
vggblock_config=eval(args.vggblock_enc_config), |
|
transformer_config=eval(args.transformer_enc_config), |
|
encoder_output_dim=args.enc_output_dim, |
|
in_channels=args.in_channels, |
|
) |
|
|
|
@classmethod |
|
def build_decoder(cls, args, task): |
|
return TransformerDecoder( |
|
dictionary=task.target_dictionary, |
|
embed_dim=args.tgt_embed_dim, |
|
transformer_config=eval(args.transformer_dec_config), |
|
conv_config=eval(args.conv_dec_config), |
|
encoder_output_dim=args.enc_output_dim, |
|
) |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
"""Build a new model instance.""" |
|
|
|
|
|
base_architecture(args) |
|
|
|
encoder = cls.build_encoder(args, task) |
|
decoder = cls.build_decoder(args, task) |
|
return cls(encoder, decoder) |
|
|
|
def get_normalized_probs(self, net_output, log_probs, sample=None): |
|
|
|
lprobs = super().get_normalized_probs(net_output, log_probs, sample) |
|
lprobs.batch_first = True |
|
return lprobs |
|
|
|
|
|
DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2 |
|
DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2 |
|
DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2 |
|
|
|
|
|
|
|
|
|
def prepare_transformer_encoder_params( |
|
input_dim, |
|
num_heads, |
|
ffn_dim, |
|
normalize_before, |
|
dropout, |
|
attention_dropout, |
|
relu_dropout, |
|
): |
|
args = argparse.Namespace() |
|
args.encoder_embed_dim = input_dim |
|
args.encoder_attention_heads = num_heads |
|
args.attention_dropout = attention_dropout |
|
args.dropout = dropout |
|
args.activation_dropout = relu_dropout |
|
args.encoder_normalize_before = normalize_before |
|
args.encoder_ffn_embed_dim = ffn_dim |
|
return args |
|
|
|
|
|
def prepare_transformer_decoder_params( |
|
input_dim, |
|
num_heads, |
|
ffn_dim, |
|
normalize_before, |
|
dropout, |
|
attention_dropout, |
|
relu_dropout, |
|
): |
|
args = argparse.Namespace() |
|
args.encoder_embed_dim = None |
|
args.decoder_embed_dim = input_dim |
|
args.decoder_attention_heads = num_heads |
|
args.attention_dropout = attention_dropout |
|
args.dropout = dropout |
|
args.activation_dropout = relu_dropout |
|
args.decoder_normalize_before = normalize_before |
|
args.decoder_ffn_embed_dim = ffn_dim |
|
return args |
|
|
|
|
|
class VGGTransformerEncoder(FairseqEncoder): |
|
"""VGG + Transformer encoder""" |
|
|
|
def __init__( |
|
self, |
|
input_feat_per_channel, |
|
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, |
|
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, |
|
encoder_output_dim=512, |
|
in_channels=1, |
|
transformer_context=None, |
|
transformer_sampling=None, |
|
): |
|
"""constructor for VGGTransformerEncoder |
|
|
|
Args: |
|
- input_feat_per_channel: feature dim (not including stacked, |
|
just base feature) |
|
- in_channel: # input channels (e.g., if stack 8 feature vector |
|
together, this is 8) |
|
- vggblock_config: configuration of vggblock, see comments on |
|
DEFAULT_ENC_VGGBLOCK_CONFIG |
|
- transformer_config: configuration of transformer layer, see comments |
|
on DEFAULT_ENC_TRANSFORMER_CONFIG |
|
- encoder_output_dim: final transformer output embedding dimension |
|
- transformer_context: (left, right) if set, self-attention will be focused |
|
on (t-left, t+right) |
|
- transformer_sampling: an iterable of int, must match with |
|
len(transformer_config), transformer_sampling[i] indicates sampling |
|
factor for i-th transformer layer, after multihead att and feedfoward |
|
part |
|
""" |
|
super().__init__(None) |
|
|
|
self.num_vggblocks = 0 |
|
if vggblock_config is not None: |
|
if not isinstance(vggblock_config, Iterable): |
|
raise ValueError("vggblock_config is not iterable") |
|
self.num_vggblocks = len(vggblock_config) |
|
|
|
self.conv_layers = nn.ModuleList() |
|
self.in_channels = in_channels |
|
self.input_dim = input_feat_per_channel |
|
self.pooling_kernel_sizes = [] |
|
|
|
if vggblock_config is not None: |
|
for _, config in enumerate(vggblock_config): |
|
( |
|
out_channels, |
|
conv_kernel_size, |
|
pooling_kernel_size, |
|
num_conv_layers, |
|
layer_norm, |
|
) = config |
|
self.conv_layers.append( |
|
VGGBlock( |
|
in_channels, |
|
out_channels, |
|
conv_kernel_size, |
|
pooling_kernel_size, |
|
num_conv_layers, |
|
input_dim=input_feat_per_channel, |
|
layer_norm=layer_norm, |
|
) |
|
) |
|
self.pooling_kernel_sizes.append(pooling_kernel_size) |
|
in_channels = out_channels |
|
input_feat_per_channel = self.conv_layers[-1].output_dim |
|
|
|
transformer_input_dim = self.infer_conv_output_dim( |
|
self.in_channels, self.input_dim |
|
) |
|
|
|
|
|
self.validate_transformer_config(transformer_config) |
|
self.transformer_context = self.parse_transformer_context(transformer_context) |
|
self.transformer_sampling = self.parse_transformer_sampling( |
|
transformer_sampling, len(transformer_config) |
|
) |
|
|
|
self.transformer_layers = nn.ModuleList() |
|
|
|
if transformer_input_dim != transformer_config[0][0]: |
|
self.transformer_layers.append( |
|
Linear(transformer_input_dim, transformer_config[0][0]) |
|
) |
|
self.transformer_layers.append( |
|
TransformerEncoderLayer( |
|
prepare_transformer_encoder_params(*transformer_config[0]) |
|
) |
|
) |
|
|
|
for i in range(1, len(transformer_config)): |
|
if transformer_config[i - 1][0] != transformer_config[i][0]: |
|
self.transformer_layers.append( |
|
Linear(transformer_config[i - 1][0], transformer_config[i][0]) |
|
) |
|
self.transformer_layers.append( |
|
TransformerEncoderLayer( |
|
prepare_transformer_encoder_params(*transformer_config[i]) |
|
) |
|
) |
|
|
|
self.encoder_output_dim = encoder_output_dim |
|
self.transformer_layers.extend( |
|
[ |
|
Linear(transformer_config[-1][0], encoder_output_dim), |
|
LayerNorm(encoder_output_dim), |
|
] |
|
) |
|
|
|
def forward(self, src_tokens, src_lengths, **kwargs): |
|
""" |
|
src_tokens: padded tensor (B, T, C * feat) |
|
src_lengths: tensor of original lengths of input utterances (B,) |
|
""" |
|
bsz, max_seq_len, _ = src_tokens.size() |
|
x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) |
|
x = x.transpose(1, 2).contiguous() |
|
|
|
|
|
for layer_idx in range(len(self.conv_layers)): |
|
x = self.conv_layers[layer_idx](x) |
|
|
|
bsz, _, output_seq_len, _ = x.size() |
|
|
|
|
|
x = x.transpose(1, 2).transpose(0, 1) |
|
x = x.contiguous().view(output_seq_len, bsz, -1) |
|
|
|
input_lengths = src_lengths.clone() |
|
for s in self.pooling_kernel_sizes: |
|
input_lengths = (input_lengths.float() / s).ceil().long() |
|
|
|
encoder_padding_mask, _ = lengths_to_encoder_padding_mask( |
|
input_lengths, batch_first=True |
|
) |
|
if not encoder_padding_mask.any(): |
|
encoder_padding_mask = None |
|
|
|
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) |
|
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor) |
|
|
|
transformer_layer_idx = 0 |
|
|
|
for layer_idx in range(len(self.transformer_layers)): |
|
|
|
if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer): |
|
x = self.transformer_layers[layer_idx]( |
|
x, encoder_padding_mask, attn_mask |
|
) |
|
|
|
if self.transformer_sampling[transformer_layer_idx] != 1: |
|
sampling_factor = self.transformer_sampling[transformer_layer_idx] |
|
x, encoder_padding_mask, attn_mask = self.slice( |
|
x, encoder_padding_mask, attn_mask, sampling_factor |
|
) |
|
|
|
transformer_layer_idx += 1 |
|
|
|
else: |
|
x = self.transformer_layers[layer_idx](x) |
|
|
|
|
|
|
|
|
|
return { |
|
"encoder_out": x, |
|
"encoder_padding_mask": encoder_padding_mask.t() |
|
if encoder_padding_mask is not None |
|
else None, |
|
|
|
} |
|
|
|
def infer_conv_output_dim(self, in_channels, input_dim): |
|
sample_seq_len = 200 |
|
sample_bsz = 10 |
|
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) |
|
for i, _ in enumerate(self.conv_layers): |
|
x = self.conv_layers[i](x) |
|
x = x.transpose(1, 2) |
|
mb, seq = x.size()[:2] |
|
return x.contiguous().view(mb, seq, -1).size(-1) |
|
|
|
def validate_transformer_config(self, transformer_config): |
|
for config in transformer_config: |
|
input_dim, num_heads = config[:2] |
|
if input_dim % num_heads != 0: |
|
msg = ( |
|
"ERROR in transformer config {}: ".format(config) |
|
+ "input dimension {} ".format(input_dim) |
|
+ "not dividable by number of heads {}".format(num_heads) |
|
) |
|
raise ValueError(msg) |
|
|
|
def parse_transformer_context(self, transformer_context): |
|
""" |
|
transformer_context can be the following: |
|
- None; indicates no context is used, i.e., |
|
transformer can access full context |
|
- a tuple/list of two int; indicates left and right context, |
|
any number <0 indicates infinite context |
|
* e.g., (5, 6) indicates that for query at x_t, transformer can |
|
access [t-5, t+6] (inclusive) |
|
* e.g., (-1, 6) indicates that for query at x_t, transformer can |
|
access [0, t+6] (inclusive) |
|
""" |
|
if transformer_context is None: |
|
return None |
|
|
|
if not isinstance(transformer_context, Iterable): |
|
raise ValueError("transformer context must be Iterable if it is not None") |
|
|
|
if len(transformer_context) != 2: |
|
raise ValueError("transformer context must have length 2") |
|
|
|
left_context = transformer_context[0] |
|
if left_context < 0: |
|
left_context = None |
|
|
|
right_context = transformer_context[1] |
|
if right_context < 0: |
|
right_context = None |
|
|
|
if left_context is None and right_context is None: |
|
return None |
|
|
|
return (left_context, right_context) |
|
|
|
def parse_transformer_sampling(self, transformer_sampling, num_layers): |
|
""" |
|
parsing transformer sampling configuration |
|
|
|
Args: |
|
- transformer_sampling, accepted input: |
|
* None, indicating no sampling |
|
* an Iterable with int (>0) as element |
|
- num_layers, expected number of transformer layers, must match with |
|
the length of transformer_sampling if it is not None |
|
|
|
Returns: |
|
- A tuple with length num_layers |
|
""" |
|
if transformer_sampling is None: |
|
return (1,) * num_layers |
|
|
|
if not isinstance(transformer_sampling, Iterable): |
|
raise ValueError( |
|
"transformer_sampling must be an iterable if it is not None" |
|
) |
|
|
|
if len(transformer_sampling) != num_layers: |
|
raise ValueError( |
|
"transformer_sampling {} does not match with the number " |
|
"of layers {}".format(transformer_sampling, num_layers) |
|
) |
|
|
|
for layer, value in enumerate(transformer_sampling): |
|
if not isinstance(value, int): |
|
raise ValueError("Invalid value in transformer_sampling: ") |
|
if value < 1: |
|
raise ValueError( |
|
"{} layer's subsampling is {}.".format(layer, value) |
|
+ " This is not allowed! " |
|
) |
|
return transformer_sampling |
|
|
|
def slice(self, embedding, padding_mask, attn_mask, sampling_factor): |
|
""" |
|
embedding is a (T, B, D) tensor |
|
padding_mask is a (B, T) tensor or None |
|
attn_mask is a (T, T) tensor or None |
|
""" |
|
embedding = embedding[::sampling_factor, :, :] |
|
if padding_mask is not None: |
|
padding_mask = padding_mask[:, ::sampling_factor] |
|
if attn_mask is not None: |
|
attn_mask = attn_mask[::sampling_factor, ::sampling_factor] |
|
|
|
return embedding, padding_mask, attn_mask |
|
|
|
def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1): |
|
""" |
|
create attention mask according to sequence lengths and transformer |
|
context |
|
|
|
Args: |
|
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is |
|
the length of b-th sequence |
|
- subsampling_factor: int |
|
* Note that the left_context and right_context is specified in |
|
the input frame-level while input to transformer may already |
|
go through subsampling (e.g., the use of striding in vggblock) |
|
we use subsampling_factor to scale the left/right context |
|
|
|
Return: |
|
- a (T, T) binary tensor or None, where T is max(input_lengths) |
|
* if self.transformer_context is None, None |
|
* if left_context is None, |
|
* attn_mask[t, t + right_context + 1:] = 1 |
|
* others = 0 |
|
* if right_context is None, |
|
* attn_mask[t, 0:t - left_context] = 1 |
|
* others = 0 |
|
* elsif |
|
* attn_mask[t, t - left_context: t + right_context + 1] = 0 |
|
* others = 1 |
|
""" |
|
if self.transformer_context is None: |
|
return None |
|
|
|
maxT = torch.max(input_lengths).item() |
|
attn_mask = torch.zeros(maxT, maxT) |
|
|
|
left_context = self.transformer_context[0] |
|
right_context = self.transformer_context[1] |
|
if left_context is not None: |
|
left_context = math.ceil(self.transformer_context[0] / subsampling_factor) |
|
if right_context is not None: |
|
right_context = math.ceil(self.transformer_context[1] / subsampling_factor) |
|
|
|
for t in range(maxT): |
|
if left_context is not None: |
|
st = 0 |
|
en = max(st, t - left_context) |
|
attn_mask[t, st:en] = 1 |
|
if right_context is not None: |
|
st = t + right_context + 1 |
|
st = min(st, maxT - 1) |
|
attn_mask[t, st:] = 1 |
|
|
|
return attn_mask.to(input_lengths.device) |
|
|
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( |
|
1, new_order |
|
) |
|
if encoder_out["encoder_padding_mask"] is not None: |
|
encoder_out["encoder_padding_mask"] = encoder_out[ |
|
"encoder_padding_mask" |
|
].index_select(1, new_order) |
|
return encoder_out |
|
|
|
|
|
class TransformerDecoder(FairseqIncrementalDecoder): |
|
""" |
|
Transformer decoder consisting of *args.decoder_layers* layers. Each layer |
|
is a :class:`TransformerDecoderLayer`. |
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
dictionary (~fairseq.data.Dictionary): decoding dictionary |
|
embed_tokens (torch.nn.Embedding): output embedding |
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs. |
|
Default: ``False`` |
|
left_pad (bool, optional): whether the input is left-padded. Default: |
|
``False`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dictionary, |
|
embed_dim=512, |
|
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, |
|
conv_config=DEFAULT_DEC_CONV_CONFIG, |
|
encoder_output_dim=512, |
|
): |
|
|
|
super().__init__(dictionary) |
|
vocab_size = len(dictionary) |
|
self.padding_idx = dictionary.pad() |
|
self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx) |
|
|
|
self.conv_layers = nn.ModuleList() |
|
for i in range(len(conv_config)): |
|
out_channels, kernel_size, layer_norm = conv_config[i] |
|
if i == 0: |
|
conv_layer = LinearizedConv1d( |
|
embed_dim, out_channels, kernel_size, padding=kernel_size - 1 |
|
) |
|
else: |
|
conv_layer = LinearizedConv1d( |
|
conv_config[i - 1][0], |
|
out_channels, |
|
kernel_size, |
|
padding=kernel_size - 1, |
|
) |
|
self.conv_layers.append(conv_layer) |
|
if layer_norm: |
|
self.conv_layers.append(nn.LayerNorm(out_channels)) |
|
self.conv_layers.append(nn.ReLU()) |
|
|
|
self.layers = nn.ModuleList() |
|
if conv_config[-1][0] != transformer_config[0][0]: |
|
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0])) |
|
self.layers.append( |
|
TransformerDecoderLayer( |
|
prepare_transformer_decoder_params(*transformer_config[0]) |
|
) |
|
) |
|
|
|
for i in range(1, len(transformer_config)): |
|
if transformer_config[i - 1][0] != transformer_config[i][0]: |
|
self.layers.append( |
|
Linear(transformer_config[i - 1][0], transformer_config[i][0]) |
|
) |
|
self.layers.append( |
|
TransformerDecoderLayer( |
|
prepare_transformer_decoder_params(*transformer_config[i]) |
|
) |
|
) |
|
self.fc_out = Linear(transformer_config[-1][0], vocab_size) |
|
|
|
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): |
|
""" |
|
Args: |
|
prev_output_tokens (LongTensor): previous decoder outputs of shape |
|
`(batch, tgt_len)`, for input feeding/teacher forcing |
|
encoder_out (Tensor, optional): output from the encoder, used for |
|
encoder-side attention |
|
incremental_state (dict): dictionary used for storing state during |
|
:ref:`Incremental decoding` |
|
Returns: |
|
tuple: |
|
- the last decoder layer's output of shape `(batch, tgt_len, |
|
vocab)` |
|
- the last decoder layer's attention weights of shape `(batch, |
|
tgt_len, src_len)` |
|
""" |
|
target_padding_mask = ( |
|
(prev_output_tokens == self.padding_idx).to(prev_output_tokens.device) |
|
if incremental_state is None |
|
else None |
|
) |
|
|
|
if incremental_state is not None: |
|
prev_output_tokens = prev_output_tokens[:, -1:] |
|
|
|
|
|
x = self.embed_tokens(prev_output_tokens) |
|
|
|
|
|
x = self._transpose_if_training(x, incremental_state) |
|
|
|
for layer in self.conv_layers: |
|
if isinstance(layer, LinearizedConvolution): |
|
x = layer(x, incremental_state) |
|
else: |
|
x = layer(x) |
|
|
|
|
|
x = self._transpose_if_inference(x, incremental_state) |
|
|
|
|
|
for layer in self.layers: |
|
if isinstance(layer, TransformerDecoderLayer): |
|
x, *_ = layer( |
|
x, |
|
(encoder_out["encoder_out"] if encoder_out is not None else None), |
|
( |
|
encoder_out["encoder_padding_mask"].t() |
|
if encoder_out["encoder_padding_mask"] is not None |
|
else None |
|
), |
|
incremental_state, |
|
self_attn_mask=( |
|
self.buffered_future_mask(x) |
|
if incremental_state is None |
|
else None |
|
), |
|
self_attn_padding_mask=( |
|
target_padding_mask if incremental_state is None else None |
|
), |
|
) |
|
else: |
|
x = layer(x) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
x = self.fc_out(x) |
|
|
|
return x, None |
|
|
|
def buffered_future_mask(self, tensor): |
|
dim = tensor.size(0) |
|
if ( |
|
not hasattr(self, "_future_mask") |
|
or self._future_mask is None |
|
or self._future_mask.device != tensor.device |
|
): |
|
self._future_mask = torch.triu( |
|
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 |
|
) |
|
if self._future_mask.size(0) < dim: |
|
self._future_mask = torch.triu( |
|
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 |
|
) |
|
return self._future_mask[:dim, :dim] |
|
|
|
def _transpose_if_training(self, x, incremental_state): |
|
if incremental_state is None: |
|
x = x.transpose(0, 1) |
|
return x |
|
|
|
def _transpose_if_inference(self, x, incremental_state): |
|
if incremental_state: |
|
x = x.transpose(0, 1) |
|
return x |
|
|
|
|
|
@register_model("asr_vggtransformer_encoder") |
|
class VGGTransformerEncoderModel(FairseqEncoderModel): |
|
def __init__(self, encoder): |
|
super().__init__(encoder) |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Add model-specific arguments to the parser.""" |
|
parser.add_argument( |
|
"--input-feat-per-channel", |
|
type=int, |
|
metavar="N", |
|
help="encoder input dimension per input channel", |
|
) |
|
parser.add_argument( |
|
"--vggblock-enc-config", |
|
type=str, |
|
metavar="EXPR", |
|
help=""" |
|
an array of tuples each containing the configuration of one vggblock |
|
[(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...] |
|
""", |
|
) |
|
parser.add_argument( |
|
"--transformer-enc-config", |
|
type=str, |
|
metavar="EXPR", |
|
help=""" |
|
a tuple containing the configuration of the Transformer layers |
|
configurations: |
|
[(input_dim, |
|
num_heads, |
|
ffn_dim, |
|
normalize_before, |
|
dropout, |
|
attention_dropout, |
|
relu_dropout), ]""", |
|
) |
|
parser.add_argument( |
|
"--enc-output-dim", |
|
type=int, |
|
metavar="N", |
|
help="encoder output dimension, projecting the LSTM output", |
|
) |
|
parser.add_argument( |
|
"--in-channels", |
|
type=int, |
|
metavar="N", |
|
help="number of encoder input channels", |
|
) |
|
parser.add_argument( |
|
"--transformer-context", |
|
type=str, |
|
metavar="EXPR", |
|
help=""" |
|
either None or a tuple of two ints, indicating left/right context a |
|
transformer can have access to""", |
|
) |
|
parser.add_argument( |
|
"--transformer-sampling", |
|
type=str, |
|
metavar="EXPR", |
|
help=""" |
|
either None or a tuple of ints, indicating sampling factor in each layer""", |
|
) |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
"""Build a new model instance.""" |
|
base_architecture_enconly(args) |
|
encoder = VGGTransformerEncoderOnly( |
|
vocab_size=len(task.target_dictionary), |
|
input_feat_per_channel=args.input_feat_per_channel, |
|
vggblock_config=eval(args.vggblock_enc_config), |
|
transformer_config=eval(args.transformer_enc_config), |
|
encoder_output_dim=args.enc_output_dim, |
|
in_channels=args.in_channels, |
|
transformer_context=eval(args.transformer_context), |
|
transformer_sampling=eval(args.transformer_sampling), |
|
) |
|
return cls(encoder) |
|
|
|
def get_normalized_probs(self, net_output, log_probs, sample=None): |
|
|
|
lprobs = super().get_normalized_probs(net_output, log_probs, sample) |
|
|
|
|
|
lprobs = lprobs.transpose(0, 1).contiguous() |
|
lprobs.batch_first = True |
|
return lprobs |
|
|
|
|
|
class VGGTransformerEncoderOnly(VGGTransformerEncoder): |
|
def __init__( |
|
self, |
|
vocab_size, |
|
input_feat_per_channel, |
|
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, |
|
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, |
|
encoder_output_dim=512, |
|
in_channels=1, |
|
transformer_context=None, |
|
transformer_sampling=None, |
|
): |
|
super().__init__( |
|
input_feat_per_channel=input_feat_per_channel, |
|
vggblock_config=vggblock_config, |
|
transformer_config=transformer_config, |
|
encoder_output_dim=encoder_output_dim, |
|
in_channels=in_channels, |
|
transformer_context=transformer_context, |
|
transformer_sampling=transformer_sampling, |
|
) |
|
self.fc_out = Linear(self.encoder_output_dim, vocab_size) |
|
|
|
def forward(self, src_tokens, src_lengths, **kwargs): |
|
""" |
|
src_tokens: padded tensor (B, T, C * feat) |
|
src_lengths: tensor of original lengths of input utterances (B,) |
|
""" |
|
|
|
enc_out = super().forward(src_tokens, src_lengths) |
|
x = self.fc_out(enc_out["encoder_out"]) |
|
|
|
|
|
|
|
return { |
|
"encoder_out": x, |
|
"encoder_padding_mask": enc_out["encoder_padding_mask"], |
|
} |
|
|
|
def max_positions(self): |
|
"""Maximum input length supported by the encoder.""" |
|
return (1e6, 1e6) |
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx): |
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
|
|
|
|
|
return m |
|
|
|
|
|
def Linear(in_features, out_features, bias=True, dropout=0): |
|
"""Linear layer (input: N x T x C)""" |
|
m = nn.Linear(in_features, out_features, bias=bias) |
|
|
|
|
|
|
|
return m |
|
|
|
|
|
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): |
|
"""Weight-normalized Conv1d layer optimized for decoding""" |
|
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) |
|
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) |
|
nn.init.normal_(m.weight, mean=0, std=std) |
|
nn.init.constant_(m.bias, 0) |
|
return nn.utils.weight_norm(m, dim=2) |
|
|
|
|
|
def LayerNorm(embedding_dim): |
|
m = nn.LayerNorm(embedding_dim) |
|
return m |
|
|
|
|
|
|
|
def base_architecture(args): |
|
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) |
|
args.vggblock_enc_config = getattr( |
|
args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG |
|
) |
|
args.transformer_enc_config = getattr( |
|
args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG |
|
) |
|
args.enc_output_dim = getattr(args, "enc_output_dim", 512) |
|
args.in_channels = getattr(args, "in_channels", 1) |
|
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) |
|
args.transformer_dec_config = getattr( |
|
args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG |
|
) |
|
args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG) |
|
args.transformer_context = getattr(args, "transformer_context", "None") |
|
|
|
|
|
@register_model_architecture("asr_vggtransformer", "vggtransformer_1") |
|
def vggtransformer_1(args): |
|
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) |
|
args.vggblock_enc_config = getattr( |
|
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" |
|
) |
|
args.transformer_enc_config = getattr( |
|
args, |
|
"transformer_enc_config", |
|
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14", |
|
) |
|
args.enc_output_dim = getattr(args, "enc_output_dim", 1024) |
|
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) |
|
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") |
|
args.transformer_dec_config = getattr( |
|
args, |
|
"transformer_dec_config", |
|
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4", |
|
) |
|
|
|
|
|
@register_model_architecture("asr_vggtransformer", "vggtransformer_2") |
|
def vggtransformer_2(args): |
|
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) |
|
args.vggblock_enc_config = getattr( |
|
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" |
|
) |
|
args.transformer_enc_config = getattr( |
|
args, |
|
"transformer_enc_config", |
|
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", |
|
) |
|
args.enc_output_dim = getattr(args, "enc_output_dim", 1024) |
|
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) |
|
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") |
|
args.transformer_dec_config = getattr( |
|
args, |
|
"transformer_dec_config", |
|
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6", |
|
) |
|
|
|
|
|
@register_model_architecture("asr_vggtransformer", "vggtransformer_base") |
|
def vggtransformer_base(args): |
|
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) |
|
args.vggblock_enc_config = getattr( |
|
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" |
|
) |
|
args.transformer_enc_config = getattr( |
|
args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12" |
|
) |
|
|
|
args.enc_output_dim = getattr(args, "enc_output_dim", 512) |
|
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) |
|
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") |
|
args.transformer_dec_config = getattr( |
|
args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def base_architecture_enconly(args): |
|
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) |
|
args.vggblock_enc_config = getattr( |
|
args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2" |
|
) |
|
args.transformer_enc_config = getattr( |
|
args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2" |
|
) |
|
args.enc_output_dim = getattr(args, "enc_output_dim", 512) |
|
args.in_channels = getattr(args, "in_channels", 1) |
|
args.transformer_context = getattr(args, "transformer_context", "None") |
|
args.transformer_sampling = getattr(args, "transformer_sampling", "None") |
|
|
|
|
|
@register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1") |
|
def vggtransformer_enc_1(args): |
|
|
|
|
|
|
|
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) |
|
args.vggblock_enc_config = getattr( |
|
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" |
|
) |
|
args.transformer_enc_config = getattr( |
|
args, |
|
"transformer_enc_config", |
|
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", |
|
) |
|
args.enc_output_dim = getattr(args, "enc_output_dim", 1024) |
|
|