lxysl's picture
upload vita-1.5 app.py
bc752b1
"""Encoder self-attention layer definition."""
import math
import pdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from vita.model.multimodal_encoder.whale.module.layer.attention import (
Conv1dLinear,
MultiHeadedAttention,
MultiLayeredConv1d,
PositionalEncoding,
PositionwiseFeedForward,
RelPositionalEncoding,
)
# from vita.model.multimodal_encoder.whale.module.component.utils import *
from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, add_optional_chunk_mask, strtobool
def repeat(N, fn):
"""Repeat module N times.
:param int N: repeat time
:param function fn: function to generate module
:return: repeated modules
:rtype: MultiSequential
"""
return MultiSequential(*[fn(n) for n in range(N)])
class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential."""
def forward(self, x, masks, pos_emb):
"""Repeat."""
for m in self:
x, masks, pos_emb = m(x, masks, pos_emb)
return x, masks, pos_emb
@torch.jit.export
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
"""Repeat."""
for m in self:
x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
x, pos_emb, buffer, buffer_index, buffer_out
)
return x, pos_emb, buffer, buffer_index, buffer_out
@torch.jit.export
def infer_hidden(self, x, pos_emb, buffer, buffer_index, buffer_out, hidden_out):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
"""Repeat."""
for m in self:
x, pos_emb, buffer, buffer_index, buffer_out = m.infer(
x, pos_emb, buffer, buffer_index, buffer_out
)
hidden_out.append(x)
return x, pos_emb, buffer, buffer_index, buffer_out, hidden_out
class TransformerLayer(nn.Module):
"""Transformer layer module.
:param int size: input dim
:param self_attn: self attention module
:param feed_forward: feed forward module
:param float dropout_rate: dropout rate
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False
):
"""Construct an TransformerLayer object."""
super(TransformerLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = torch.nn.LayerNorm(size)
self.norm2 = torch.nn.LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
else:
self.concat_linear = nn.Identity()
@torch.jit.unused
def forward(self, x, mask, pos_emb):
"""Compute encoded features.
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
:param torch.Tensor mask: mask for x (batch, max_time_in)
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, x, x, mask, pos_emb)), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x, x, x, mask, pos_emb))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, pos_emb
@torch.jit.export
def infer(self, x, pos_emb, buffer, buffer_index, buffer_out):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
residual = x.clone()
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
x, x, x, pos_emb, buffer, buffer_index, buffer_out
)
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x_att, buffer, buffer_index, buffer_out = self.self_attn.infer(
x, x, x, pos_emb, buffer, buffer_index, buffer_out
)
x = residual + x_att
if not self.normalize_before:
x = self.norm1(x)
residual = x.clone()
if self.normalize_before:
x = self.norm2(x)
x_feed, buffer, buffer_index, buffer_out = self.feed_forward.infer(
x, buffer, buffer_index, buffer_out
)
x = residual + x_feed
if not self.normalize_before:
x = self.norm2(x)
return x, pos_emb, buffer, buffer_index, buffer_out
class Transformer(torch.nn.Module):
@staticmethod
def add_arguments(group):
"""Add TDNN common arguments."""
group.add_argument(
"--transformer-input-dim", default=256, type=int, help="Input dim of Transformer."
)
group.add_argument(
"--transformer-output-dim", default=4, type=int, help="Output dim of Transformer."
)
group.add_argument(
"--transformer-attention-dim", default=256, type=int, help="Dimention of attention."
)
group.add_argument(
"--transformer-attention-heads",
default=4,
type=int,
help="The number of heads of multi head attention.",
)
group.add_argument(
"--transformer-linear-units",
default=1024,
type=int,
help="The number of units of position-wise feed forward.",
)
group.add_argument(
"--transformer-num-blocks", default=6, type=int, help="The number of attention blocks."
)
group.add_argument(
"--transformer-dropout-rate",
default=0.1,
type=float,
help="Dropout rate in Transformer.",
)
group.add_argument(
"--transformer-attention-dropout-rate",
default=0.0,
type=float,
help="Dropout rate in attention.",
)
group.add_argument(
"--transformer-positional-dropout-rate",
default=0.1,
type=float,
help="Dropout rate after adding positional encoding.",
)
group.add_argument(
"--transformer-input-layer", default="linear", type=str, help="Type of input layer"
)
group.add_argument("--transformer-pos-enc-class", default="abs-enc", type=str, help="")
group.add_argument(
"--transformer-normalize-before",
default=True,
type=strtobool,
help="Whether to use layer-norm before the first block.",
)
group.add_argument(
"--transformer-concat-after",
default=False,
type=strtobool,
help="Whether to concat attention layer's input and output.",
)
group.add_argument(
"--transformer-positionwise-layer-type",
default="linear",
type=str,
help="Linear of conv1d.",
)
group.add_argument(
"--transformer-positionwise-conv-kernel_size",
default=1,
type=int,
help="Kernel size of positionwise conv1d layer.",
)
group.add_argument("--transformer-chunk_size", default=-1, type=int, help="")
group.add_argument("--transformer-left_chunks", default=-1, type=int, help="")
group.add_argument("--transformer-dynamic-chunks", default=True, type=strtobool, help="")
return group
def __init__(
self,
args,
input_dim=None,
output_dim=None,
attention_dim=None,
attention_heads=None,
linear_units=None,
num_blocks=None,
dropout_rate=None,
positional_dropout_rate=None,
attention_dropout_rate=None,
input_layer=None,
pos_enc_class=None,
normalize_before=None,
concat_after=None,
positionwise_layer_type=None,
positionwise_conv_kernel_size=None,
chunk_size=None,
left_chunks=None,
):
"""Construct an Encoder object."""
super(Transformer, self).__init__()
if args is None:
self.input_dim = input_dim
self.output_dim = output_dim
self.attention_dim = attention_dim
self.attention_heads = attention_heads
self.linear_units = linear_units
self.num_blocks = num_blocks
self.dropout_rate = dropout_rate
self.positional_dropout_rate = positional_dropout_rate
self.attention_dropout_rate = attention_dropout_rate
self.input_layer = input_layer
self.pos_enc_class = pos_enc_class
self.normalize_before = normalize_before
self.concat_after = concat_after
self.positionwise_layer_type = positionwise_layer_type
self.positionwise_conv_kernel_size = positionwise_conv_kernel_size
self.chunk_size = chunk_size
self.left_chunks = left_chunks
else:
self.input_dim = args.transformer_input_dim
self.output_dim = args.transformer_output_dim
self.attention_dim = args.transformer_attention_dim
self.attention_heads = args.transformer_attention_heads
self.linear_units = args.transformer_linear_units
self.num_blocks = args.transformer_num_blocks
self.dropout_rate = args.transformer_dropout_rate
self.positional_dropout_rate = args.transformer_positional_dropout_rate
self.attention_dropout_rate = args.transformer_attention_dropout_rate
self.input_layer = args.transformer_input_layer
self.pos_enc_class = args.transformer_pos_enc_class
self.normalize_before = args.transformer_normalize_before
self.concat_after = args.transformer_concat_after
self.positionwise_layer_type = args.transformer_positionwise_layer_type
self.positionwise_conv_kernel_size = args.transformer_positionwise_conv_kernel_size
self.chunk_size = args.transformer_chunk_size
self.left_chunks = args.transformer_left_chunks
self.transformer_dynamic_chunks = args.transformer_dynamic_chunks
if self.pos_enc_class == "abs-enc":
pos_enc_args = (self.attention_dim, self.positional_dropout_rate)
pos_enc_class = PositionalEncoding
elif self.pos_enc_class == "rel-enc":
pos_enc_args = (
self.attention_dim,
self.positional_dropout_rate,
self.chunk_size,
self.left_chunks,
)
pos_enc_class = RelPositionalEncoding
if self.input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(self.input_dim, self.attention_dim),
torch.nn.LayerNorm(self.attention_dim),
torch.nn.Dropout(self.dropout_rate),
torch.nn.ReLU(),
)
elif self.input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(self.input_dim, self.attention_dim, padding_idx=IGNORE_ID)
)
elif self.input_layer == "none":
self.embed = torch.nn.Sequential(torch.nn.Identity())
else:
raise ValueError("unknown input_layer: " + self.input_layer)
self.pe = pos_enc_class(*pos_enc_args)
self.embed_layer_num = len(self.embed)
if self.positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (self.attention_dim, self.linear_units, self.dropout_rate)
elif self.positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
self.attention_dim,
self.linear_units,
self.positionwise_conv_kernel_size,
self.dropout_rate,
)
elif self.positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
self.attention_dim,
self.linear_units,
self.positionwise_conv_kernel_size,
self.dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
self.encoders = repeat(
self.num_blocks,
lambda lnum: TransformerLayer(
self.attention_dim,
MultiHeadedAttention(
self.attention_heads,
self.attention_dim,
self.attention_dropout_rate,
self.chunk_size,
self.left_chunks,
self.pos_enc_class,
),
positionwise_layer(*positionwise_layer_args),
self.dropout_rate,
self.normalize_before,
self.concat_after,
),
)
if self.normalize_before:
self.after_norm = torch.nn.LayerNorm(self.attention_dim)
@torch.jit.unused
def forward(self, xs, ilens=None, masks=None):
"""Embed positions in tensor.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
if self.transformer_dynamic_chunks == True: # and self.training:
chunk_masks = add_optional_chunk_mask(xs, masks, True, True, 0, 0, -1)
else:
chunk_masks = add_optional_chunk_mask(
xs, masks, False, False, self.chunk_size, self.chunk_size, self.left_chunks
).to(xs.device)
xs = self.embed(xs)
xs, pos_emb = self.pe(xs)
xs, chunk_masks, pos_emb = self.encoders(xs, chunk_masks, pos_emb)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, ilens, masks
@torch.jit.export
def infer(self, xs, buffer, buffer_index, buffer_out):
xs = self.embed(xs)
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
# buffer_out.append(pe_index.reshape(-1).to(torch.float32))
# buffer_index = buffer_index + 1
xs, pos_emb, _ = self.pe.infer(xs, 0)
xs, pos_emb, buffer, buffer_index, buffer_out = self.encoders.infer(
xs, pos_emb, buffer, buffer_index, buffer_out
)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, buffer, buffer_index, buffer_out
@torch.jit.export
def infer_hidden(self, xs, buffer, buffer_index, buffer_out, hidden_out):
xs = self.embed(xs)
# pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64)
# xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0])
# buffer_out.append(pe_index.reshape(-1).to(torch.float32))
# buffer_index = buffer_index + 1
xs, pos_emb, _ = self.pe.infer(xs, 0)
xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out = self.encoders.infer_hidden(
xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out
)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, buffer, buffer_index, buffer_out, hidden_out