Spaces:
Running
on
Zero
Running
on
Zero
"""Encoder self-attention layer definition.""" | |
import math | |
import pdb | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from vita.model.multimodal_encoder.whale.module.layer.attention import ( | |
Conv1dLinear, | |
MultiHeadedAttention, | |
MultiLayeredConv1d, | |
PositionalEncoding, | |
PositionwiseFeedForward, | |
RelPositionalEncoding, | |
) | |
# from vita.model.multimodal_encoder.whale.module.component.utils import * | |
from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, add_optional_chunk_mask, strtobool | |
def repeat(N, fn): | |
"""Repeat module N times. | |
:param int N: repeat time | |
:param function fn: function to generate module | |
:return: repeated modules | |
:rtype: MultiSequential | |
""" | |
return MultiSequential(*[fn(n) for n in range(N)]) | |
class MultiSequential(torch.nn.Sequential): | |
"""Multi-input multi-output torch.nn.Sequential.""" | |
def forward(self, x, masks, pos_emb): | |
"""Repeat.""" | |
for m in self: | |
x, masks, pos_emb = m(x, masks, pos_emb) | |
return x, masks, pos_emb | |
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out): | |
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor] | |
"""Repeat.""" | |
for m in self: | |
x, pos_emb, buffer, buffer_index, buffer_out = m.infer( | |
x, pos_emb, buffer, buffer_index, buffer_out | |
) | |
return x, pos_emb, buffer, buffer_index, buffer_out | |
def infer_hidden(self, x, pos_emb, buffer, buffer_index, buffer_out, hidden_out): | |
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor] | |
"""Repeat.""" | |
for m in self: | |
x, pos_emb, buffer, buffer_index, buffer_out = m.infer( | |
x, pos_emb, buffer, buffer_index, buffer_out | |
) | |
hidden_out.append(x) | |
return x, pos_emb, buffer, buffer_index, buffer_out, hidden_out | |
class TransformerLayer(nn.Module): | |
"""Transformer layer module. | |
:param int size: input dim | |
:param self_attn: self attention module | |
:param feed_forward: feed forward module | |
:param float dropout_rate: dropout rate | |
:param bool normalize_before: whether to use layer_norm before the first block | |
:param bool concat_after: whether to concat attention layer's input and output | |
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) | |
if False, no additional linear will be applied. i.e. x -> x + att(x) | |
""" | |
def __init__( | |
self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False | |
): | |
"""Construct an TransformerLayer object.""" | |
super(TransformerLayer, self).__init__() | |
self.self_attn = self_attn | |
self.feed_forward = feed_forward | |
self.norm1 = torch.nn.LayerNorm(size) | |
self.norm2 = torch.nn.LayerNorm(size) | |
self.dropout = nn.Dropout(dropout_rate) | |
self.size = size | |
self.normalize_before = normalize_before | |
self.concat_after = concat_after | |
if self.concat_after: | |
self.concat_linear = nn.Linear(size + size, size) | |
else: | |
self.concat_linear = nn.Identity() | |
def forward(self, x, mask, pos_emb): | |
"""Compute encoded features. | |
:param torch.Tensor x: encoded source features (batch, max_time_in, size) | |
:param torch.Tensor mask: mask for x (batch, max_time_in) | |
:rtype: Tuple[torch.Tensor, torch.Tensor] | |
""" | |
residual = x | |
if self.normalize_before: | |
x = self.norm1(x) | |
if self.concat_after: | |
x_concat = torch.cat((x, self.self_attn(x, x, x, mask, pos_emb)), dim=-1) | |
x = residual + self.concat_linear(x_concat) | |
else: | |
x = residual + self.dropout(self.self_attn(x, x, x, mask, pos_emb)) | |
if not self.normalize_before: | |
x = self.norm1(x) | |
residual = x | |
if self.normalize_before: | |
x = self.norm2(x) | |
x = residual + self.dropout(self.feed_forward(x)) | |
if not self.normalize_before: | |
x = self.norm2(x) | |
return x, mask, pos_emb | |
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out): | |
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor] | |
residual = x.clone() | |
if self.normalize_before: | |
x = self.norm1(x) | |
if self.concat_after: | |
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer( | |
x, x, x, pos_emb, buffer, buffer_index, buffer_out | |
) | |
x_concat = torch.cat((x, x_att), dim=-1) | |
x = residual + self.concat_linear(x_concat) | |
else: | |
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer( | |
x, x, x, pos_emb, buffer, buffer_index, buffer_out | |
) | |
x = residual + x_att | |
if not self.normalize_before: | |
x = self.norm1(x) | |
residual = x.clone() | |
if self.normalize_before: | |
x = self.norm2(x) | |
x_feed, buffer, buffer_index, buffer_out = self.feed_forward.infer( | |
x, buffer, buffer_index, buffer_out | |
) | |
x = residual + x_feed | |
if not self.normalize_before: | |
x = self.norm2(x) | |
return x, pos_emb, buffer, buffer_index, buffer_out | |
class Transformer(torch.nn.Module): | |
def add_arguments(group): | |
"""Add TDNN common arguments.""" | |
group.add_argument( | |
"--transformer-input-dim", default=256, type=int, help="Input dim of Transformer." | |
) | |
group.add_argument( | |
"--transformer-output-dim", default=4, type=int, help="Output dim of Transformer." | |
) | |
group.add_argument( | |
"--transformer-attention-dim", default=256, type=int, help="Dimention of attention." | |
) | |
group.add_argument( | |
"--transformer-attention-heads", | |
default=4, | |
type=int, | |
help="The number of heads of multi head attention.", | |
) | |
group.add_argument( | |
"--transformer-linear-units", | |
default=1024, | |
type=int, | |
help="The number of units of position-wise feed forward.", | |
) | |
group.add_argument( | |
"--transformer-num-blocks", default=6, type=int, help="The number of attention blocks." | |
) | |
group.add_argument( | |
"--transformer-dropout-rate", | |
default=0.1, | |
type=float, | |
help="Dropout rate in Transformer.", | |
) | |
group.add_argument( | |
"--transformer-attention-dropout-rate", | |
default=0.0, | |
type=float, | |
help="Dropout rate in attention.", | |
) | |
group.add_argument( | |
"--transformer-positional-dropout-rate", | |
default=0.1, | |
type=float, | |
help="Dropout rate after adding positional encoding.", | |
) | |
group.add_argument( | |
"--transformer-input-layer", default="linear", type=str, help="Type of input layer" | |
) | |
group.add_argument("--transformer-pos-enc-class", default="abs-enc", type=str, help="") | |
group.add_argument( | |
"--transformer-normalize-before", | |
default=True, | |
type=strtobool, | |
help="Whether to use layer-norm before the first block.", | |
) | |
group.add_argument( | |
"--transformer-concat-after", | |
default=False, | |
type=strtobool, | |
help="Whether to concat attention layer's input and output.", | |
) | |
group.add_argument( | |
"--transformer-positionwise-layer-type", | |
default="linear", | |
type=str, | |
help="Linear of conv1d.", | |
) | |
group.add_argument( | |
"--transformer-positionwise-conv-kernel_size", | |
default=1, | |
type=int, | |
help="Kernel size of positionwise conv1d layer.", | |
) | |
group.add_argument("--transformer-chunk_size", default=-1, type=int, help="") | |
group.add_argument("--transformer-left_chunks", default=-1, type=int, help="") | |
group.add_argument("--transformer-dynamic-chunks", default=True, type=strtobool, help="") | |
return group | |
def __init__( | |
self, | |
args, | |
input_dim=None, | |
output_dim=None, | |
attention_dim=None, | |
attention_heads=None, | |
linear_units=None, | |
num_blocks=None, | |
dropout_rate=None, | |
positional_dropout_rate=None, | |
attention_dropout_rate=None, | |
input_layer=None, | |
pos_enc_class=None, | |
normalize_before=None, | |
concat_after=None, | |
positionwise_layer_type=None, | |
positionwise_conv_kernel_size=None, | |
chunk_size=None, | |
left_chunks=None, | |
): | |
"""Construct an Encoder object.""" | |
super(Transformer, self).__init__() | |
if args is None: | |
self.input_dim = input_dim | |
self.output_dim = output_dim | |
self.attention_dim = attention_dim | |
self.attention_heads = attention_heads | |
self.linear_units = linear_units | |
self.num_blocks = num_blocks | |
self.dropout_rate = dropout_rate | |
self.positional_dropout_rate = positional_dropout_rate | |
self.attention_dropout_rate = attention_dropout_rate | |
self.input_layer = input_layer | |
self.pos_enc_class = pos_enc_class | |
self.normalize_before = normalize_before | |
self.concat_after = concat_after | |
self.positionwise_layer_type = positionwise_layer_type | |
self.positionwise_conv_kernel_size = positionwise_conv_kernel_size | |
self.chunk_size = chunk_size | |
self.left_chunks = left_chunks | |
else: | |
self.input_dim = args.transformer_input_dim | |
self.output_dim = args.transformer_output_dim | |
self.attention_dim = args.transformer_attention_dim | |
self.attention_heads = args.transformer_attention_heads | |
self.linear_units = args.transformer_linear_units | |
self.num_blocks = args.transformer_num_blocks | |
self.dropout_rate = args.transformer_dropout_rate | |
self.positional_dropout_rate = args.transformer_positional_dropout_rate | |
self.attention_dropout_rate = args.transformer_attention_dropout_rate | |
self.input_layer = args.transformer_input_layer | |
self.pos_enc_class = args.transformer_pos_enc_class | |
self.normalize_before = args.transformer_normalize_before | |
self.concat_after = args.transformer_concat_after | |
self.positionwise_layer_type = args.transformer_positionwise_layer_type | |
self.positionwise_conv_kernel_size = args.transformer_positionwise_conv_kernel_size | |
self.chunk_size = args.transformer_chunk_size | |
self.left_chunks = args.transformer_left_chunks | |
self.transformer_dynamic_chunks = args.transformer_dynamic_chunks | |
if self.pos_enc_class == "abs-enc": | |
pos_enc_args = (self.attention_dim, self.positional_dropout_rate) | |
pos_enc_class = PositionalEncoding | |
elif self.pos_enc_class == "rel-enc": | |
pos_enc_args = ( | |
self.attention_dim, | |
self.positional_dropout_rate, | |
self.chunk_size, | |
self.left_chunks, | |
) | |
pos_enc_class = RelPositionalEncoding | |
if self.input_layer == "linear": | |
self.embed = torch.nn.Sequential( | |
torch.nn.Linear(self.input_dim, self.attention_dim), | |
torch.nn.LayerNorm(self.attention_dim), | |
torch.nn.Dropout(self.dropout_rate), | |
torch.nn.ReLU(), | |
) | |
elif self.input_layer == "embed": | |
self.embed = torch.nn.Sequential( | |
torch.nn.Embedding(self.input_dim, self.attention_dim, padding_idx=IGNORE_ID) | |
) | |
elif self.input_layer == "none": | |
self.embed = torch.nn.Sequential(torch.nn.Identity()) | |
else: | |
raise ValueError("unknown input_layer: " + self.input_layer) | |
self.pe = pos_enc_class(*pos_enc_args) | |
self.embed_layer_num = len(self.embed) | |
if self.positionwise_layer_type == "linear": | |
positionwise_layer = PositionwiseFeedForward | |
positionwise_layer_args = (self.attention_dim, self.linear_units, self.dropout_rate) | |
elif self.positionwise_layer_type == "conv1d": | |
positionwise_layer = MultiLayeredConv1d | |
positionwise_layer_args = ( | |
self.attention_dim, | |
self.linear_units, | |
self.positionwise_conv_kernel_size, | |
self.dropout_rate, | |
) | |
elif self.positionwise_layer_type == "conv1d-linear": | |
positionwise_layer = Conv1dLinear | |
positionwise_layer_args = ( | |
self.attention_dim, | |
self.linear_units, | |
self.positionwise_conv_kernel_size, | |
self.dropout_rate, | |
) | |
else: | |
raise NotImplementedError("Support only linear or conv1d.") | |
self.encoders = repeat( | |
self.num_blocks, | |
lambda lnum: TransformerLayer( | |
self.attention_dim, | |
MultiHeadedAttention( | |
self.attention_heads, | |
self.attention_dim, | |
self.attention_dropout_rate, | |
self.chunk_size, | |
self.left_chunks, | |
self.pos_enc_class, | |
), | |
positionwise_layer(*positionwise_layer_args), | |
self.dropout_rate, | |
self.normalize_before, | |
self.concat_after, | |
), | |
) | |
if self.normalize_before: | |
self.after_norm = torch.nn.LayerNorm(self.attention_dim) | |
def forward(self, xs, ilens=None, masks=None): | |
"""Embed positions in tensor. | |
:param torch.Tensor xs: input tensor | |
:param torch.Tensor masks: input mask | |
:return: position embedded tensor and mask | |
:rtype Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
if self.transformer_dynamic_chunks == True: # and self.training: | |
chunk_masks = add_optional_chunk_mask(xs, masks, True, True, 0, 0, -1) | |
else: | |
chunk_masks = add_optional_chunk_mask( | |
xs, masks, False, False, self.chunk_size, self.chunk_size, self.left_chunks | |
).to(xs.device) | |
xs = self.embed(xs) | |
xs, pos_emb = self.pe(xs) | |
xs, chunk_masks, pos_emb = self.encoders(xs, chunk_masks, pos_emb) | |
if self.normalize_before: | |
xs = self.after_norm(xs) | |
return xs, ilens, masks | |
def infer(self, xs, buffer, buffer_index, buffer_out): | |
xs = self.embed(xs) | |
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64) | |
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0]) | |
# buffer_out.append(pe_index.reshape(-1).to(torch.float32)) | |
# buffer_index = buffer_index + 1 | |
xs, pos_emb, _ = self.pe.infer(xs, 0) | |
xs, pos_emb, buffer, buffer_index, buffer_out = self.encoders.infer( | |
xs, pos_emb, buffer, buffer_index, buffer_out | |
) | |
if self.normalize_before: | |
xs = self.after_norm(xs) | |
return xs, buffer, buffer_index, buffer_out | |
def infer_hidden(self, xs, buffer, buffer_index, buffer_out, hidden_out): | |
xs = self.embed(xs) | |
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64) | |
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0]) | |
# buffer_out.append(pe_index.reshape(-1).to(torch.float32)) | |
# buffer_index = buffer_index + 1 | |
xs, pos_emb, _ = self.pe.infer(xs, 0) | |
xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out = self.encoders.infer_hidden( | |
xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out | |
) | |
if self.normalize_before: | |
xs = self.after_norm(xs) | |
return xs, buffer, buffer_index, buffer_out, hidden_out | |