UnIVAL / models /unival /unify_transformer.py
mshukor
init
26fd00c
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import math
import random
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import (
AdaptiveSoftmax,
BaseLayer,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
from .unify_transformer_layer import TransformerEncoderLayer, TransformerDecoderLayer
from .resnet import ResNet
from .frozen_bn import FrozenBatchNorm2d
# image
from .encoders.timm_resnet import resnet101, resnet152, resnet50
# video
from .encoders.resnext3d import ResNeXt3D, ResNeXtBottleneck
# audio
from .encoders.pann import create_pann_model
from data.audio_utils import AUDIO_CFG, dotdict
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
import math
def BatchNorm2d(out_chan, momentum=0.1, eps=1e-3):
return nn.SyncBatchNorm.convert_sync_batchnorm(
nn.BatchNorm2d(out_chan, momentum=momentum, eps=eps)
)
def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS):
context_pos = torch.arange(max_position, dtype=torch.long)[:, None]
memory_pos = torch.arange(max_position, dtype=torch.long)[None, :]
relative_pos = context_pos - memory_pos
sign = torch.sign(relative_pos)
mid = bucket_size // 2
abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), mid-1, torch.abs(relative_pos))
log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid
log_pos = log_pos.int()
bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos*sign).long()
return bucket_pos + bucket_size - 1
def make_image_bucket_position(bucket_size, num_relative_distance):
coords_h = torch.arange(bucket_size)
coords_w = torch.arange(bucket_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0
relative_coords[:, :, 1] += bucket_size - 1
relative_coords[:, :, 0] *= 2 * bucket_size - 1
relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index
class PromptEncoder(torch.nn.Module):
r"""
Prompt encoder to generate prompts, including prompt, prefix, instance and instruction
"""
def __init__(
self,
type,
length,
projection,
embed_dim,
proj_dim,
layers,
vocab_size):
super().__init__()
self.prefix_projection = projection
if type == "prefix":
layers = layers
prompt_vocab_size = length
if self.prefix_projection:
self.embedding = torch.nn.Embedding(prompt_vocab_size, embed_dim)
self.trans = torch.nn.Sequential(
torch.nn.Linear(embed_dim, proj_dim),
torch.nn.ReLU(),
torch.nn.Linear(proj_dim, layers * 2 * embed_dim)
)
else:
if type == "prefix":
self.embedding = torch.nn.Embedding(
prompt_vocab_size, layers * 2 * embed_dim)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
@register_model("unify_transformer")
class TransformerModel(FairseqEncoderDecoderModel):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
Args:
encoder (TransformerEncoder): the encoder
decoder (TransformerDecoder): the decoder
The Transformer model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.transformer_parser
:prog:
"""
def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder)
self.args = args
self.supports_align_args = True
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--bitfit', default=False, action='store_true',
help='use bitfit in the transformer')
parser.add_argument('--freeze-encoder', action='store_true',
help='freeze the parameters in the encoder')
parser.add_argument('--adapter', action='store_true',
help='use adapter in the model')
parser.add_argument('--adapter-dim', type=int, metavar='N',
help='adapter-down-dim')
### vl ADAPTER
parser.add_argument('--adapter-type', type=str, metavar='STR',
help='adapter-type')
parser.add_argument('--unfreeze', action='store_true',
help='unfreeze model when using adapters/prompts')
parser.add_argument('--encoder-prompt', action='store_true',
help='use prompt tuning in the encoder')
parser.add_argument('--encoder-prompt-type', type=str, metavar='STR',
choices=['prefix'],
help='the type of prompt tuning')
parser.add_argument('--encoder-prompt-projection', action='store_true',
help='use prompt projection')
parser.add_argument('--encoder-prompt-length', type=int, metavar='N',
help='use prompt tuning in the decoder')
parser.add_argument('--encoder-prompt-dim', type=int, metavar='N',
help='encoder prompt dimension if use encoder prompt projection')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension (extra linear layer '
'if different from decoder embed dim')
parser.add_argument('--freeze-decoder', action='store_true',
help='freeze the parameters in the decoder')
parser.add_argument('--decoder-prompt', action='store_true',
help='use prompt tuning in the decoder')
parser.add_argument('--decoder-prompt-type', type=str, metavar='STR',
choices=['prefix'],
help='the type of prompt tuning')
parser.add_argument('--decoder-prompt-length', type=int, metavar='N',
help='use prompt tuning in the decoder')
parser.add_argument('--decoder-prompt-projection', action='store_true',
help='use prompt projection')
parser.add_argument('--decoder-prompt-dim', type=int, metavar='N',
help='decoder prompt dimension if use decoder prompt projection')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--layernorm-embedding', action='store_true',
help='add layernorm to embedding')
parser.add_argument('--no-scale-embedding', action='store_true',
help='if True, dont scale embeddings')
parser.add_argument('--checkpoint-activations', action='store_true',
help='checkpoint activations at each layer, which saves GPU '
'memory usage at the cost of some additional compute')
parser.add_argument('--offload-activations', action='store_true',
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
help='iterative PQ quantization noise at training time')
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
# args for Fully Sharded Data Parallel (FSDP) training
parser.add_argument(
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
help=(
'minimum number of params for a layer to be wrapped with FSDP() when '
'training with --ddp-backend=fully_sharded. Smaller values will '
'improve memory efficiency, but may make torch.distributed '
'communication less efficient due to smaller input sizes. This option '
'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
'--offload-activations are passed.'
)
)
parser.add_argument('--resnet-drop-path-rate', type=float,
help='resnet drop path rate')
parser.add_argument('--encoder-drop-path-rate', type=float,
help='encoder drop path rate')
parser.add_argument('--decoder-drop-path-rate', type=float,
help='encoder drop path rate')
parser.add_argument('--token-bucket-size', type=int,
help='token bucket size')
parser.add_argument('--image-bucket-size', type=int,
help='image bucket size')
parser.add_argument('--attn-scale-factor', type=float,
help='attention scale factor')
parser.add_argument('--freeze-resnet', action='store_true',
help='freeze resnet')
parser.add_argument('--freeze-encoder-embedding', action='store_true',
help='freeze encoder token embedding')
parser.add_argument('--freeze-decoder-embedding', action='store_true',
help='freeze decoder token embedding')
parser.add_argument('--add-type-embedding', action='store_true',
help='add source/region/patch type embedding')
parser.add_argument('--add-mm-type-embedding', action='store_true',
help='add source/region/patch type embedding')
parser.add_argument('--interpolate-position', action='store_true',
help='interpolate position')
parser.add_argument('--resnet-type', choices=['resnet50', 'resnet101', 'resnet152'],
help='resnet type')
parser.add_argument('--resnet-model-path', type=str, metavar='STR',
help='path to load resnet')
parser.add_argument('--code-image-size', type=int,
help='code image size')
parser.add_argument('--patch-layernorm-embedding', action='store_true',
help='add layernorm to patch embedding')
parser.add_argument('--code-layernorm-embedding', action='store_true',
help='add layernorm to code embedding')
parser.add_argument('--entangle-position-embedding', action='store_true',
help='entangle position embedding')
parser.add_argument('--disable-entangle', action='store_true',
help='disable entangle')
parser.add_argument('--sync-bn', action='store_true',
help='sync batchnorm')
parser.add_argument('--scale-attn', action='store_true',
help='scale attn')
parser.add_argument('--scale-fc', action='store_true',
help='scale fc')
parser.add_argument('--scale-heads', action='store_true',
help='scale heads')
parser.add_argument('--scale-resids', action='store_true',
help='scale resids')
# fmt: on
# image encoder
parser.add_argument('--image-encoder-name', type=str, metavar='STR', default='resnet',
help='image_encoder_name')
parser.add_argument('--return-hidden-state-vision', action='store_true',
help='return_hidden_state_vision')
parser.add_argument('--freeze-image-encoder', action='store_true',
help='freeze_image_encoder')
parser.add_argument('--nograd', action='store_true',
help='nograd for vis encoder')
parser.add_argument('--encoder-eval', action='store_true',
help='vision encoder.eval()')
# video
parser.add_argument('--video-encoder-name', type=str, default=None,
help='video_encoder_name')
parser.add_argument('--video-model-path', type=str, default=None,
help='video_model_path')
parser.add_argument('--freeze-video-encoder', action='store_true',
help='freeze_video_encoder')
parser.add_argument('--sample-patch-num', type=int, default=None,
help='num of tokens selected at random')
parser.add_argument('--with-cls', action='store_true',
help='with_cls')
parser.add_argument('--sample-video-patch-num', type=int, default=None,
help='num of tokens selected at random')
# audio
parser.add_argument('--audio-encoder-name', type=str, default=None,
help='audio_encoder_name')
parser.add_argument('--audio-model-path', type=str, default=None,
help='audio_model_path')
parser.add_argument('--freeze-audio-encoder', action='store_true',
help='freeze_audio_encoder')
parser.add_argument('--fusion-type', type=str, default=None,
help='fusion_type')
parser.add_argument('--enable-fusion', action='store_true',
help='enable_fusion')
parser.add_argument('--mel-bins', type=int, default=64,
help='mel_bins')
parser.add_argument('--hop-size', type=int, default=480,
help='hop_size')
parser.add_argument('--sample-audio-patch-num', type=int, default=None,
help='num of tokens selected at random')
parser.add_argument('--fstride', type=int, default=10,
help='fstride')
parser.add_argument('--tstride', type=int, default=10,
help='tstride')
parser.add_argument('--input-tdim', type=int, default=1024,
help='input_tdim')
# progressive training
parser.add_argument('--progressive', action='store_true',
help='progressive')
parser.add_argument('--unfreeze-epoch-encoder', type=int, default=0,
help='unfreeze_epoch_encoder')
parser.add_argument('--unfreeze-epoch-decoder', type=int, default=0,
help='unfreeze_epoch_decoder')
parser.add_argument('--unfreeze-epoch-image', type=int, default=0,
help='unfreeze_epoch_image')
parser.add_argument('--unfreeze-epoch-video', type=int, default=0,
help='unfreeze_epoch_video')
parser.add_argument('--unfreeze-epoch-audio', type=int, default=0,
help='unfreeze_epoch_audio')
## only linear
parser.add_argument('--only-linear-proj', action='store_true',
help='only_linear_proj')
parser.add_argument('--unfreeze-epoch', type=int, default=0,
help='unfreeze_epoch')
parser.add_argument('--freeze-perception', action='store_true',
help='freeze_perception')
##
parser.add_argument('--qk-norm', action='store_true',
help='qk_norm')
parser.add_argument('--layernorm-image-embedding', action='store_true',
help='layernorm_image_embedding')
parser.add_argument('--layernorm-video-embedding', action='store_true',
help='layernorm_video_embedding')
parser.add_argument('--layernorm-audio-embedding', action='store_true',
help='layernorm_audio_embedding')
parser.add_argument('--freeze-batchnorm', action='store_true',
help='freeze_batchnorm')
parser.add_argument('--freeze-resnet-video', action='store_true',
help='freeze resnet video')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, "max_source_positions", None) is None:
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = cls.build_embedding(
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
if getattr(args, "freeze_encoder_embedding", False) or getattr(
args, "encoder_prompt", False) or getattr(args, "decoder_prompt", False) or getattr(args, "adapter", False):
encoder_embed_tokens.weight.requires_grad = False
if getattr(args, "freeze_decoder_embedding", False) or getattr(
args, "encoder_prompt", False) or getattr(args, "decoder_prompt", False) or getattr(args, "adapter", False):
decoder_embed_tokens.weight.requires_grad = False
if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
if getattr(args, "freeze_encoder", False):
encoder.requires_grad_(False)
encoder.embed_images.requires_grad_(True)
encoder.image_proj.requires_grad_(True)
encoder.embed_image_positions.requires_grad_(True)
if encoder.layernorm_image_embedding is not None:
encoder.layernorm_image_embedding.requires_grad_(True)
print("freeze image LN")
if hasattr(encoder, 'embed_videos'):
encoder.embed_videos.requires_grad_(True)
encoder.video_proj.requires_grad_(True)
encoder.embed_video_positions.requires_grad_(True)
if encoder.layernorm_video_embedding is not None:
encoder.layernorm_video_embedding.requires_grad_(True)
print("freeze audio LN")
if getattr(args, "audio_encoder_name", False):
encoder.embed_audios.requires_grad_(True)
encoder.audio_proj.requires_grad_(True)
encoder.embed_audio_positions.requires_grad_(True)
if encoder.layernorm_audio_embedding is not None:
encoder.layernorm_audio_embedding.requires_grad_(True)
print("freeze audio LN")
if getattr(args, "freeze_decoder", False):
decoder.requires_grad_(False)
if getattr(args, "encoder_prompt", False) or getattr(
args, "decoder_prompt", False) or getattr(
args, "adapter", False):
if not getattr(args, "unfreeze", False):
encoder.requires_grad_(False)
decoder.requires_grad_(False)
if getattr(args, "encoder_prompt", False):
encoder.encoder_prompt_encoder.requires_grad_(True)
if getattr(args, "decoder_prompt", False):
decoder.decoder_prompt_encoder.requires_grad_(True)
if getattr(args, "adapter", False):
for idx, layer in enumerate(encoder.layers):
layer.adapter.requires_grad_(True)
for idx, layer in enumerate(decoder.layers):
layer.adapter.requires_grad_(True)
if getattr(args, "freeze_image_encoder", False):
encoder.embed_images.requires_grad_(False)
if getattr(args, "freeze_video_encoder", False) and hasattr(encoder, 'embed_videos'):
encoder.embed_videos.requires_grad_(False)
if getattr(args, "freeze_audio_encoder", False) and hasattr(encoder, 'embed_audios'):
encoder.embed_audios.requires_grad_(False)
if getattr(args, "freeze_batchnorm", False):
for n, p in encoder.named_parameters():
if 'bn' in n:
p.requires_grad = False
print("freeze:", n)
if hasattr(encoder, 'embed_audios'):
if hasattr(encoder.embed_audios, 'logmel_extractor'):
for param in encoder.embed_audios.logmel_extractor.parameters():
param.requires_grad = False
if getattr(args, "encoder_eval", False):
print("Set encoders.eval()")
encoder.embed_images.eval()
if getattr(args, "video_encoder_name", None):
encoder.embed_videos.eval()
if not args.share_all_embeddings:
min_params_to_wrap = getattr(
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
)
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
return cls(args, encoder, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
args.vocab_size = num_embeddings
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False),
)
# TorchScript doesn't support optional arguments with variable length (**kwargs).
# Current workaround is to add union of all arguments in child classes.
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
return_all_hiddens: bool = True,
features_only: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Run the forward pass for an encoder-decoder model.
Copied from the base class, but without ``**kwargs``,
which are not supported by TorchScript.
"""
encoder_out = self.encoder(
src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
src_lengths=src_lengths,
return_all_hiddens=return_all_hiddens,
)
return decoder_out
# Since get_normalized_probs is in the Fairseq Model which is not scriptable,
# I rewrite the get_normalized_probs from Base Class to call the
# helper function in the Base Class.
@torch.jit.export
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
class TransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self.freeze_image_encoder = getattr(args, "freeze_image_encoder", False)
self.freeze_video_encoder = getattr(args, "freeze_video_encoder", False)
self.nograd = getattr(args, "nograd", False)
self.num_frames = getattr(args, 'num_frames', 4)
self.sample_patch_num = getattr(args, "sample_patch_num", 196)
self.sample_audio_patch_num = getattr(args, "sample_audio_patch_num", self.sample_patch_num)
self.sample_video_patch_num = getattr(args, "sample_video_patch_num", self.sample_patch_num)
self.with_cls = getattr(args, "with_cls", False)
print("self.sample_patch_num", self.sample_patch_num)
print("self.sample_audio_patch_num", self.sample_audio_patch_num)
print("self.sample_video_patch_num", self.sample_video_patch_num)
print("self.with_cls", self.with_cls)
if getattr(args, "encoder_prompt", False):
self.encoder_prompt_encoder = PromptEncoder(
type=args.encoder_prompt_type,
length=args.encoder_prompt_length,
projection=args.encoder_prompt_projection,
embed_dim=args.encoder_embed_dim,
proj_dim=args.encoder_prompt_dim,
layers=args.encoder_layers,
vocab_size=args.vocab_size)
self.encoder_dropout = nn.Dropout(p=0.2)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.encoder_layerdrop = args.encoder_layerdrop
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.num_attention_heads = args.encoder_attention_heads
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
self.mm_type_embedding = False
if getattr(args, "add_mm_type_embedding", False):
self.type_embedding = Embedding(4, embed_dim, padding_idx=None)
self.mm_type_embedding = True
elif getattr(args, "add_type_embedding", False):
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
else:
self.type_embedding = None
norm_layer = None
norm_layer_video = None
if getattr(args, "sync_bn", False):
norm_layer = BatchNorm2d
else:
if getattr(args, "freeze_resnet", False): # or getattr(args, "freeze_batchnorm", False)
norm_layer = FrozenBatchNorm2d
print("Frozen image bn", norm_layer)
if getattr(args, "freeze_resnet_video", False):
norm_layer_video = FrozenBatchNorm2d
print("Frozen video bn", norm_layer_video)
if getattr(args, "image_encoder_name", None):
if 'timm_resnet' in args.image_encoder_name:
if args.resnet_type == 'resnet101':
self.embed_images = resnet101(norm_layer=norm_layer)
elif args.resnet_type == 'resnet152':
self.embed_images = resnet152(norm_layer=norm_layer)
elif args.resnet_type == 'resnet50':
self.embed_images = resnet50(norm_layer=norm_layer)
else:
raise NotImplementedError
self.image_proj = Linear(2048, embed_dim)
else:
if args.resnet_type == 'resnet101':
self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
elif args.resnet_type == 'resnet152':
self.embed_images = ResNet([3, 8, 36], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
elif args.resnet_type == 'resnet50':
self.embed_images = ResNet([3, 4, 6], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
else:
raise NotImplementedError
self.image_proj = Linear(1024, embed_dim)
else:
if args.resnet_type == 'resnet101':
self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
elif args.resnet_type == 'resnet152':
self.embed_images = ResNet([3, 8, 36], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
elif args.resnet_type == 'resnet50':
self.embed_images = ResNet([3, 4, 6], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate)
else:
raise NotImplementedError
self.image_proj = Linear(1024, embed_dim)
if getattr(args, "layernorm_image_embedding", False):
self.layernorm_image_embedding = LayerNorm(embed_dim)
else:
self.layernorm_image_embedding = None
## video
if getattr(args, "video_encoder_name", None):
print("Loading: ", args.video_encoder_name)
patch_frame_size = getattr(args, 'patch_frame_size', 224)
num_frames = getattr(args, 'num_frames', 4)
pretrained_model = getattr(args, "video_model_path", None)
if 'resnext' in args.video_encoder_name:
if 'all' in args.video_encoder_name:
if 'resnext101' in args.video_encoder_name :
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 4, 23, 3], norm_layer=norm_layer_video)
elif 'resnext152' in args.video_encoder_name:
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 8, 36, 3], norm_layer=norm_layer_video)
elif 'resnext50' in args.video_encoder_name:
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 4, 6, 3], norm_layer=norm_layer_video)
else:
raise NotImplementedError
vis_dim = 2048
else:
if args.video_encoder_name == 'resnext101':
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 4, 23])
elif args.video_encoder_name == 'resnext152':
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 8, 36])
elif args.video_encoder_name == 'resnext50':
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 4, 6])
else:
raise NotImplementedError
vis_dim = 1024
self.embed_video_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
if pretrained_model:
print("load pretrained_model {}".format(pretrained_model))
state_dict = torch.load(pretrained_model)['state_dict']
if 'module' in list(state_dict.keys())[0]:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of dataparallel
new_state_dict[name]=v
state_dict = new_state_dict
msg = self.embed_videos.load_state_dict(state_dict, strict=False)
print(msg)
else:
raise NotImplemented
self.video_proj = Linear(vis_dim, embed_dim)
if getattr(args, "layernorm_video_embedding", False):
self.layernorm_video_embedding = LayerNorm(embed_dim)
else:
self.layernorm_video_embedding = None
## video
if getattr(args, "audio_encoder_name", None):
print("Loading: ", args.audio_encoder_name)
pretrained_audio_model = getattr(args, "audio_model_path", None)
audio_cfg = getattr(args, "audio_cfg", None)
audio_cfg = audio_cfg if audio_cfg is not None else AUDIO_CFG
audio_cfg = dotdict(audio_cfg)
enable_fusion = getattr(args, "enable_fusion", False)
fusion_type = getattr(args, "fusion_type", None)
audio_cfg['mel_bins'] = getattr(args, "mel_bins", 64)
audio_cfg['hop_size'] = getattr(args, "hop_size", 480)
if 'pann' in args.audio_encoder_name:
if 'cnn6' in args.audio_encoder_name:
audio_cfg['model_name'] = 'Cnn6'
audio_dim = 512
elif 'cnn10' in args.audio_encoder_name:
audio_cfg['model_name'] = 'Cnn10'
audio_dim = 1024
elif 'cnn14' in args.audio_encoder_name:
audio_cfg['model_name'] = 'Cnn14'
# audio_dim = 512
audio_dim = 2048
else:
raise NotImplementedError
self.embed_audios = create_pann_model(audio_cfg, enable_fusion, fusion_type)
self.embed_audio_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
if pretrained_audio_model:
print("load pretrained_model {}".format(pretrained_audio_model))
state_dict = torch.load(pretrained_audio_model)
if 'model' in state_dict:
state_dict = state_dict['model']
if 'module' in list(state_dict.keys())[0]:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('module.', '') # remove 'module.' of dataparallel
new_state_dict[name]=v
state_dict = new_state_dict
if 'sed_model.' in list(state_dict.keys())[0] or 'module.' in list(state_dict.keys())[0]:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace('sed_model.', '').replace('module.', '').replace('audio_branch.', '') # remove 'module.' of dataparallel
new_state_dict[name]=v
state_dict = new_state_dict
if audio_cfg['mel_bins'] != 64:
del_keys = []
for k, v in state_dict.items():
if 'logmel_extractor' in k or 'bn0' in k:
del_keys.append(k)
for k in del_keys:
del state_dict[k]
msg = self.embed_audios.load_state_dict(state_dict, strict=False)
print(msg)
else:
raise NotImplementedError
self.audio_proj = Linear(audio_dim, embed_dim)
if getattr(args, "layernorm_audio_embedding", False):
self.layernorm_audio_embedding = LayerNorm(embed_dim)
else:
self.layernorm_audio_embedding = None
if getattr(args, "resnet_model_path", None):
print("load resnet {}".format(args.resnet_model_path))
resnet_state_dict = torch.load(self.args.resnet_model_path)
msg = self.embed_images.load_state_dict(resnet_state_dict, strict=False)
print(msg)
if getattr(args, "patch_layernorm_embedding", False):
self.patch_layernorm_embedding = LayerNorm(embed_dim)
else:
self.patch_layernorm_embedding = None
self.embed_positions = Embedding(args.max_source_positions + 2, embed_dim)
self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
self.pos_ln = LayerNorm(embed_dim)
self.image_pos_ln = LayerNorm(embed_dim)
self.pos_scaling = float(embed_dim / args.encoder_attention_heads * args.attn_scale_factor) ** -0.5
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
if not args.adaptive_input and args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
else:
self.layers = nn.ModuleList([])
dpr = [x.item() for x in torch.linspace(0, args.encoder_drop_path_rate, args.encoder_layers)]
self.layers.extend(
[self.build_encoder_layer(args, drop_path_rate=dpr[i]) for i in range(args.encoder_layers)]
)
self.num_layers = len(self.layers)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
token_bucket_size = args.token_bucket_size
token_num_rel_dis = 2 * token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(token_bucket_size)
self.token_rel_pos_table_list = nn.ModuleList(
[Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
)
image_bucket_size = args.image_bucket_size
image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis)
self.image_rel_pos_table_list = nn.ModuleList(
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)]
)
self.patch_image_size = args.patch_image_size
self.orig_patch_image_size = args.orig_patch_image_size
self.register_buffer("token_rp_bucket", token_rp_bucket)
self.register_buffer("image_rp_bucket", image_rp_bucket)
self.entangle_position_embedding = args.entangle_position_embedding
def build_encoder_layer(self, args, drop_path_rate=0.0):
layer = TransformerEncoderLayer(args, drop_path_rate=drop_path_rate, \
use_adapter=getattr(args, "adapter", False), adapter_dim=getattr(args, "adapter_dim", 200),
adapter_type=getattr(args, "adapter_type", 'UN'))
checkpoint = getattr(args, "checkpoint_activations", False)
if checkpoint:
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = (
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
if not checkpoint else 0
)
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def get_rel_pos_bias(self, x, idx):
seq_len = x.size(1)
rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1)
values = values.permute([0, 3, 1, 2])
return values.contiguous()
def get_image_rel_pos_bias(self, image_position_ids, idx):
bsz, seq_len = image_position_ids.shape
rp_bucket_size = self.image_rp_bucket.size(1)
rp_bucket = self.image_rp_bucket.unsqueeze(0).expand(
bsz, rp_bucket_size, rp_bucket_size
).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size)
).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len))
values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
values = values.permute(0, 3, 1, 2)
return values
def get_patch_audios_info(self, patch_images, sample_patch_num, device):
if self.nograd and self.freeze_audio_encoder:
with torch.no_grad():
image_embed = self.embed_audios(patch_images)
else:
image_embed = self.embed_audios(patch_images)
# in case of cnn (bs, c, h, w)
h, w = image_embed.shape[-2:]
image_num_patches = h * w
sh = int(math.ceil(math.sqrt(image_num_patches))) # to keep within image_bucket_size
sw = sh
image_embed = image_embed.flatten(2).transpose(1, 2) # (bs, c, hlw) -> (bs, hlw, c)
image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool()
image_position_idx = torch.arange(sw).unsqueeze(0).expand(sh, sw) + \
torch.arange(sh).unsqueeze(1) * self.args.image_bucket_size + 1
image_position_idx = image_position_idx.reshape(-1).to(device)[:image_num_patches]
image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches)
if sample_patch_num is not None and sample_patch_num < image_num_patches:
patch_orders = [
random.sample(range(image_num_patches), k=sample_patch_num)
for _ in range(patch_images.size(0))
]
patch_orders = torch.LongTensor(patch_orders).to(device)
image_embed = image_embed.gather(
1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))
)
image_num_patches = sample_patch_num
image_padding_mask = image_padding_mask.gather(1, patch_orders)
image_position_ids = image_position_ids.gather(1, patch_orders)
image_pos_embed = self.embed_audio_positions(image_position_ids)
image_embed = self.audio_proj(image_embed.type(self.audio_proj.weight.dtype))
if self.layernorm_audio_embedding is not None:
image_embed = self.layernorm_audio_embedding(image_embed)
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
def get_patch_videos_info(self, patch_images, sample_patch_num, device):
if self.nograd and self.freeze_video_encoder:
with torch.no_grad():
_, image_embed = self.embed_videos(patch_images)
else:
_, image_embed = self.embed_videos(patch_images)
l, h, w = image_embed.shape[-3:]
image_num_patches = h * w * l
image_embed = image_embed.flatten(2).transpose(1, 2) # (bs, c, hlw) -> (bs, hlw, c)
numframes = l
image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool()
image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \
torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1
image_position_idx = image_position_idx.unsqueeze(0).expand(numframes, -1, -1)
image_position_idx = image_position_idx.reshape(-1).to(device)
image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches)
if sample_patch_num is not None and sample_patch_num < image_num_patches:
patch_orders = [
random.sample(range(image_num_patches), k=sample_patch_num)
for _ in range(patch_images.size(0))
]
patch_orders = torch.LongTensor(patch_orders).to(device)
image_embed = image_embed.gather(
1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))
)
image_num_patches = sample_patch_num
image_padding_mask = image_padding_mask.gather(1, patch_orders)
image_position_ids = image_position_ids.gather(1, patch_orders)
image_pos_embed = self.embed_video_positions(image_position_ids)
image_embed = self.video_proj(image_embed)
if self.layernorm_video_embedding is not None:
image_embed = self.layernorm_video_embedding(image_embed)
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
def get_patch_images_info(self, patch_images, sample_patch_num, device):
if self.nograd and self.freeze_image_encoder:
with torch.no_grad():
image_embed = self.embed_images(patch_images)
else:
image_embed = self.embed_images(patch_images)
if isinstance(image_embed, tuple):
_, image_embed = image_embed
h, w = image_embed.shape[-2:]
image_num_patches = h * w
image_embed = image_embed.flatten(2).transpose(1, 2)
image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool()
image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \
torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1
image_position_idx = image_position_idx.view(-1).to(device)
image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches)
if sample_patch_num is not None and sample_patch_num < image_num_patches:
patch_orders = [
random.sample(range(image_num_patches), k=sample_patch_num)
for _ in range(patch_images.size(0))
]
patch_orders = torch.LongTensor(patch_orders).to(device)
image_embed = image_embed.gather(
1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))
)
image_num_patches = sample_patch_num
image_padding_mask = image_padding_mask.gather(1, patch_orders)
image_position_ids = image_position_ids.gather(1, patch_orders)
orig_num_patches = (self.orig_patch_image_size // 16) ** 2
orig_hw= self.orig_patch_image_size // 16
if getattr(self.args, "interpolate_position", False) and image_num_patches > orig_num_patches:
old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \
torch.arange(orig_hw).unsqueeze(1) * self.args.image_bucket_size + 1
old_image_position_ids = old_image_position_ids.to(device)
old_image_pos_embed = self.embed_image_positions(old_image_position_ids)
old_image_pos_embed = old_image_pos_embed.reshape(1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2)
image_pos_embed = F.interpolate(old_image_pos_embed, size=(h, w), mode='bilinear')
image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape(1, image_num_patches, -1)
image_pos_embed = image_pos_embed.expand(patch_images.size(0), -1, -1)
else:
image_pos_embed = self.embed_image_positions(image_position_ids)
image_embed = self.image_proj(image_embed)
if self.layernorm_image_embedding is not None:
image_embed = self.layernorm_image_embedding(image_embed)
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
def get_encoder_prompt(self, prompt_tokens):
past_key_values = self.encoder_prompt_encoder(prompt_tokens)
bsz, seqlen, _ = past_key_values.shape
past_key_values = past_key_values.view(
bsz,
seqlen,
(self.args.encoder_layers) * 2,
self.args.encoder_attention_heads,
self.args.encoder_embed_dim // self.args.encoder_attention_heads,
)
past_key_values = self.encoder_dropout(past_key_values)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
return past_key_values
def forward_embedding(
self,
src_tokens,
image_embed: Optional[torch.Tensor] = None,
image_embed_2: Optional[torch.Tensor] = None,
token_embedding: Optional[torch.Tensor] = None,
pos_embed: Optional[torch.Tensor] = None,
image_pos_embed: Optional[torch.Tensor] = None,
image_pos_embed_2: Optional[torch.Tensor] = None,
patch_types: Optional[torch.Tensor] = None,
):
# embed tokens and positions
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.entangle_position_embedding and pos_embed is not None:
x += pos_embed
if self.type_embedding is not None:
x += self.type_embedding(src_tokens.new_zeros(x.size()[:2]))
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
if self.quant_noise is not None:
x = self.quant_noise(x)
# embed raw images
if image_embed is not None:
# image_embed = self.image_proj(image_embed)
image_x = image_embed = self.embed_scale * image_embed
if self.entangle_position_embedding and image_pos_embed is not None:
image_x += image_pos_embed[:, -image_x.shape[1]:, :] # account for cls token
if self.type_embedding is not None:
if self.mm_type_embedding:
mm_type = patch_types.unsqueeze(1).to(src_tokens.device) + 1 # 0 for text
image_x += self.type_embedding(mm_type)
else:
image_x += self.type_embedding(src_tokens.new_ones(image_x.size()[:2]))
if self.patch_layernorm_embedding is not None:
image_x = self.patch_layernorm_embedding(image_x)
image_x = self.dropout_module(image_x)
if self.quant_noise is not None:
image_x = self.quant_noise(image_x)
x = torch.cat([image_x, x], dim=1)
embed = torch.cat([image_embed, embed], dim=1)
if image_embed_2 is not None:
assert self.type_embedding is not None
# image_embed_2 = self.image_proj(image_embed_2)
image_x_2 = image_embed_2 = self.embed_scale * image_embed_2
if self.entangle_position_embedding and image_pos_embed_2 is not None:
image_x_2 += image_pos_embed_2[:, -image_x_2.shape[1]:, :]
if self.type_embedding is not None:
image_x_2 += self.type_embedding(src_tokens.new_full(image_x_2.size()[:2], fill_value=2))
if self.patch_layernorm_embedding is not None:
image_x_2 = self.patch_layernorm_embedding(image_x_2)
image_x_2 = self.dropout_module(image_x_2)
if self.quant_noise is not None:
image_x_2 = self.quant_noise(image_x_2)
x = torch.cat([image_x_2, x], dim=1)
embed = torch.cat([image_embed_2, embed], dim=1)
return x, embed
def forward(
self,
src_tokens,
src_lengths,
patch_images: Optional[torch.Tensor] = None,
patch_images_2: Optional[torch.Tensor] = None,
patch_masks: Optional[torch.Tensor] = None,
code_masks: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
sample_patch_num: Optional[int] = None,
patch_videos: Optional[int] = None,
patch_types: Optional[torch.Tensor] = None,
patch_audios: Optional[torch.Tensor] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return self.forward_scriptable(src_tokens,
src_lengths,
patch_images,
patch_images_2,
patch_masks,
return_all_hiddens,
token_embeddings,
self.sample_patch_num,
patch_videos=patch_videos,
patch_types=patch_types,
patch_audios=patch_audios,
sample_audio_patch_num=self.sample_audio_patch_num,
sample_video_patch_num=self.sample_video_patch_num)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
src_tokens,
src_lengths,
patch_images: Optional[torch.Tensor] = None,
patch_images_2: Optional[torch.Tensor] = None,
patch_masks: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
sample_patch_num: Optional[int] = None,
patch_videos: Optional[int] = None,
patch_types: Optional[torch.Tensor] = None,
patch_audios: Optional[int] = None,
sample_audio_patch_num: Optional[int] = None,
sample_video_patch_num: Optional[int] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
prompt_tokens = None
prompt_padding_mask = None
prompt_kv_list = None
if self.args.encoder_prompt:
bsz, seq_len = src_tokens.shape[0], src_tokens.shape[1]
if self.args.encoder_prompt_type in ("prefix"):
prompt_tokens = torch.arange(
0, self.args.encoder_prompt_length).to(
src_tokens.device)
prompt_tokens = prompt_tokens.unsqueeze(0).expand(bsz, -1)
prompt_padding_mask = torch.zeros_like(prompt_tokens).to(prompt_tokens.device)
prompt_kv_list = self.get_encoder_prompt(prompt_tokens)
image_embed = None
image_embed_2 = None
image_pos_embed = None
image_pos_embed_2 = None
num_image_tokens = None
if sample_audio_patch_num is None:
sample_audio_patch_num = sample_patch_num
if sample_video_patch_num is None:
sample_video_patch_num = sample_patch_num
if patch_images is not None:
if patch_types is not None:
video_idx = patch_types==1
image_idx = patch_types==0
audio_idx = patch_types==2
image_idx_ = (patch_types==0).nonzero()[:, 0]
video_idx_ = (patch_types==1).nonzero()[:, 0]
audio_idx_ = (patch_types==2).nonzero()[:, 0]
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = None,None,None,None,None
# print(image_idx_, video_idx_, audio_idx_)
if torch.any(image_idx).item():
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device)
ids_merge = image_idx_
if torch.any(video_idx).item():
video_embed, image_num_patches, video_padding_mask, video_position_ids, video_pos_embed = \
self.get_patch_videos_info(patch_videos[video_idx], sample_video_patch_num, src_tokens.device)
if image_embed is not None:
ids_merge = torch.cat((ids_merge, video_idx_), dim=0).long()
image_embed = torch.cat((image_embed, video_embed), dim=0)
bs, L, D = image_embed.shape
image_padding_mask = torch.cat((image_padding_mask, video_padding_mask), dim=0)
image_position_ids = torch.cat((image_position_ids, video_position_ids), dim=0)
image_pos_embed = torch.cat((image_pos_embed, video_pos_embed), dim=0)
image_embed = torch.gather(image_embed, dim=0, index=ids_merge[:, None, None].repeat(1, L, D))
image_padding_mask = torch.gather(image_padding_mask, dim=0, index=ids_merge[:, None].repeat(1, image_padding_mask.shape[1]))
image_position_ids = torch.gather(image_position_ids, dim=0, index=ids_merge[:, None].repeat(1, image_position_ids.shape[1]))
image_pos_embed = torch.gather(image_pos_embed, dim=0, index=ids_merge[:, None, None].repeat(1, image_pos_embed.shape[1], D))
else:
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
video_embed, image_num_patches, video_padding_mask, video_position_ids, video_pos_embed
ids_merge = video_idx_
if torch.any(audio_idx).item() : # or image_embed is None
audio_embed, image_num_patches, audio_padding_mask, audio_position_ids, audio_pos_embed = \
self.get_patch_audios_info(patch_audios[audio_idx], sample_audio_patch_num, src_tokens.device)
if image_embed is not None:
ids_merge = torch.cat((ids_merge, audio_idx_), dim=0).long()
image_embed = torch.cat((image_embed, audio_embed), dim=0)
bs, L, D = image_embed.shape
image_padding_mask = torch.cat((image_padding_mask, audio_padding_mask), dim=0)
image_position_ids = torch.cat((image_position_ids, audio_position_ids), dim=0)
image_pos_embed = torch.cat((image_pos_embed, audio_pos_embed), dim=0)
image_embed = torch.gather(image_embed, dim=0, index=ids_merge[:, None, None].repeat(1, L, D))
image_padding_mask = torch.gather(image_padding_mask, dim=0, index=ids_merge[:, None].repeat(1, image_padding_mask.shape[1]))
image_position_ids = torch.gather(image_position_ids, dim=0, index=ids_merge[:, None].repeat(1, image_position_ids.shape[1]))
image_pos_embed = torch.gather(image_pos_embed, dim=0, index=ids_merge[:, None, None].repeat(1, image_pos_embed.shape[1], D))
else:
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
audio_embed, image_num_patches, audio_padding_mask, audio_position_ids, audio_pos_embed
else:
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device)
image_padding_mask[~patch_masks] = True
num_image_tokens=image_num_patches
if patch_images_2 is not None:
image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \
self.get_patch_images_info(patch_images_2, sample_patch_num, src_tokens.device)
image_padding_mask_2[~patch_masks] = True
num_image_tokens+=image_num_patches_2
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if patch_images is not None:
encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1)
if patch_images_2 is not None:
encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1)
has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any())
pos_embed = self.embed_positions(utils.new_arange(src_tokens))
x, encoder_embedding = self.forward_embedding(
src_tokens, image_embed, image_embed_2, token_embeddings,
pos_embed, image_pos_embed, image_pos_embed_2, patch_types=patch_types
)
# account for padding while computing the representation
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
pos_embed = self.pos_ln(pos_embed)
if patch_images is not None:
image_pos_embed = self.image_pos_ln(image_pos_embed)
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
pos_q = self.pos_q_linear(pos_embed).view(
pos_embed.size(0), pos_embed.size(1), self.num_attention_heads, -1
).transpose(1, 2) * self.pos_scaling
pos_k = self.pos_k_linear(pos_embed).view(
pos_embed.size(0), pos_embed.size(1), self.num_attention_heads, -1
).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
encoder_states = []
if return_all_hiddens:
encoder_states.append(x)
if prompt_padding_mask is not None:
encoder_padding_mask = torch.cat([prompt_padding_mask, encoder_padding_mask], dim=1)
# encoder layers
if self.with_cls:
offset_rel_pos = 1
abs_pos_bias = F.pad(abs_pos_bias, (1, 0, 1, 0), "constant", 0)
else:
offset_rel_pos = 0
for idx, layer in enumerate(self.layers):
self_attn_bias = abs_pos_bias.clone()
self_attn_bias[:, :, -src_tokens.size(1):, -src_tokens.size(1):] += self.get_rel_pos_bias(src_tokens, idx)
if patch_images_2 is not None:
self_attn_bias[:, :, offset_rel_pos:image_num_patches_2, offset_rel_pos:image_num_patches_2] += \
self.get_image_rel_pos_bias(image_position_ids_2, idx)
self_attn_bias[:, :, offset_rel_pos+image_num_patches_2:image_num_patches_2+image_num_patches, offset_rel_pos+image_num_patches_2:image_num_patches_2+image_num_patches] += \
self.get_image_rel_pos_bias(image_position_ids, idx)
elif patch_images is not None:
self_attn_bias[:, :, offset_rel_pos:x.size(0) - src_tokens.size(1), offset_rel_pos:x.size(0) - src_tokens.size(1)] += \
self.get_image_rel_pos_bias(image_position_ids, idx)
self_attn_bias = self_attn_bias.reshape(-1, self_attn_bias.size(2), self_attn_bias.size(2))
if self.args.encoder_prompt:
if self.args.encoder_prompt_type != "prompt":
prompt_kv = prompt_kv_list[idx]
else:
if idx == 0:
prompt_kv = prompt_kv_list[idx]
else:
prompt_kv = None
else:
prompt_kv = None
x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None, \
self_attn_bias=self_attn_bias, prompt_kv=prompt_kv, num_image_tokens=num_image_tokens)
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
if self.args.encoder_prompt:
encoder_padding_mask = encoder_padding_mask[:, prompt_tokens.size(1):]
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
"position_embeddings": [pos_embed], # B x T x C
}
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
encoder_out["encoder_embedding"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
new_src_tokens = []
else:
new_src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["src_lengths"]) == 0:
new_src_lengths = []
else:
new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
if len(encoder_out["position_embeddings"]) == 0:
new_position_embeddings = []
else:
new_position_embeddings = [(encoder_out["position_embeddings"][0]).index_select(0, new_order)]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": new_src_tokens, # B x T
"src_lengths": new_src_lengths, # B x 1
"position_embeddings": new_position_embeddings, # B x T x C
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return self.max_source_positions
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
print("deleting {0}".format(weights_key))
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(
state_dict, "{}.layers.{}".format(name, i)
)
prefix = name + "." if name != "" else ""
for param_name, param_tensor in self.state_dict().items():
if (prefix + param_name) not in state_dict:
state_dict[prefix + param_name] = self.state_dict()[param_name]
if len(state_dict["encoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]):
num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["encoder.embed_image_positions.weight"])
embed_dim = state_dict["encoder.embed_image_positions.weight"].size(1)
new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim)
nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5)
new_pos_embed_to_add = new_pos_embed_to_add.to(
dtype=state_dict["encoder.embed_image_positions.weight"].dtype,
)
state_dict["encoder.embed_image_positions.weight"] = torch.cat(
[state_dict["encoder.embed_image_positions.weight"], new_pos_embed_to_add]
)
return state_dict
class TransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self._future_mask = torch.empty(0)
self.with_cls = getattr(args, "with_cls", False)
if getattr(args, "decoder_prompt", False):
self.decoder_prompt_encoder = PromptEncoder(
type=args.decoder_prompt_type,
length=args.decoder_prompt_length,
projection=args.decoder_prompt_projection,
embed_dim=args.decoder_embed_dim,
proj_dim=args.decoder_prompt_dim,
layers=args.decoder_layers,
vocab_size=args.vocab_size)
self.decoder_dropout = nn.Dropout(p=0.2)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.decoder_layerdrop = args.decoder_layerdrop
self.share_input_output_embed = args.share_decoder_input_output_embed
self.num_attention_heads = args.decoder_attention_heads
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
if not args.adaptive_input and args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
self.window_size = args.code_image_size // 8
self.embed_positions = Embedding(args.max_target_positions + 2, embed_dim)
self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim)
self.pos_ln = LayerNorm(embed_dim)
self.image_pos_ln = LayerNorm(embed_dim)
self.pos_scaling = float(embed_dim / self.num_attention_heads * args.attn_scale_factor) ** -0.5
self.self_pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.self_pos_k_linear = nn.Linear(embed_dim, embed_dim)
self.cross_pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.cross_pos_k_linear = nn.Linear(embed_dim, embed_dim)
if getattr(args, "code_layernorm_embedding", False):
self.code_layernorm_embedding = LayerNorm(embed_dim)
else:
self.code_layernorm_embedding = None
self.cross_self_attention = getattr(args, "cross_self_attention", False)
if self.decoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
dpr = [x.item() for x in torch.linspace(0, args.decoder_drop_path_rate, args.decoder_layers)]
self.layers.extend(
[
self.build_decoder_layer(args, no_encoder_attn, drop_path_rate=dpr[i])
for i in range(args.decoder_layers)
]
)
self.num_layers = len(self.layers)
if args.decoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
self.output_projection = output_projection
if self.output_projection is None:
self.build_output_projection(args, dictionary, embed_tokens)
token_bucket_size = args.token_bucket_size
token_num_rel_dis = 2 * token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(token_bucket_size)
self.token_rel_pos_table_list = nn.ModuleList(
[Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)]
)
image_bucket_size = args.image_bucket_size
image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis)
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
torch.arange(self.window_size).unsqueeze(1) * image_bucket_size + 1
image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)])
image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 769)])
self.image_rel_pos_table_list = nn.ModuleList(
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)]
)
self.register_buffer("token_rp_bucket", token_rp_bucket)
self.register_buffer("image_rp_bucket", image_rp_bucket)
self.register_buffer("image_position_idx", image_position_idx)
self.entangle_position_embedding = args.entangle_position_embedding
def get_decoder_prompt(self, prompt_tokens):
past_key_values = self.decoder_prompt_encoder(prompt_tokens)
bsz, seqlen, _ = past_key_values.shape
past_key_values = past_key_values.view(
bsz,
seqlen,
self.args.decoder_layers * 2,
self.args.decoder_attention_heads,
self.args.decoder_embed_dim // self.args.decoder_attention_heads,
)
past_key_values = self.decoder_dropout(past_key_values)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
return past_key_values
def build_output_projection(self, args, dictionary, embed_tokens):
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = self.embed_tokens.weight
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
)
num_base_layers = getattr(args, "base_layers", 0)
for i in range(num_base_layers):
self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args))
def build_decoder_layer(self, args, no_encoder_attn=False, drop_path_rate=0.0):
layer = TransformerDecoderLayer(args, no_encoder_attn, drop_path_rate= \
drop_path_rate, use_adapter=getattr(args, "adapter", False), adapter_dim=getattr(args, "adapter_dim", 200))
checkpoint = getattr(args, "checkpoint_activations", False)
if checkpoint:
offload_to_cpu = getattr(args, "offload_activations", False)
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = (
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP)
if not checkpoint else 0
)
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def get_rel_pos_bias(self, x, idx):
seq_len = x.size(1)
rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
values = values.permute([2, 0, 1])
return values.contiguous()
def get_image_rel_pos_bias(self, x, idx):
seq_len = x.size(1)
image_position_idx = self.image_position_idx[:seq_len]
rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx]
values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
values = values.permute(2, 0, 1)
return values
def get_pos_info(self, tokens, tgt_pos_embed, src_pos_embed=None, use_image=False):
batch_size = tokens.size(0)
tgt_len = tokens.size(1)
tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
if src_pos_embed is not None:
src_len = src_pos_embed.size(1)
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1
).transpose(1, 2) * self.pos_scaling
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
batch_size, src_len, self.num_attention_heads, -1
).transpose(1, 2)
else:
src_len = tgt_pos_embed.size(1)
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1
).transpose(1, 2) * self.pos_scaling
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
batch_size, src_len, self.num_attention_heads, -1
).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
return abs_pos_bias
def forward(
self,
prev_output_tokens,
code_masks: Optional[torch.Tensor] = None,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
# print(self.training, "encoder_out", encoder_out['encoder_out'][0][0])
x, extra = self.extract_features(
prev_output_tokens,
code_masks=code_masks,
encoder_out=encoder_out,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens,
code_masks: Optional[torch.Tensor],
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
return self.extract_features_scriptable(
prev_output_tokens,
code_masks,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. A copy of
this function is made to be used in the subclass instead.
"""
def extract_features_scriptable(
self,
prev_output_tokens,
code_masks: Optional[torch.Tensor],
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
prompt_tokens = None
prompt_padding_mask = None
prompt_kv_list = None
if self.args.decoder_prompt:
bsz, seq_len = prev_output_tokens.shape[0], prev_output_tokens.shape[1]
if self.args.decoder_prompt_type in ("prefix"):
prompt_tokens = torch.arange(
0, self.args.decoder_prompt_length).to(
prev_output_tokens.device)
prompt_tokens = prompt_tokens.unsqueeze(0).expand(bsz, -1)
prompt_padding_mask = torch.zeros_like(prompt_tokens).to(prompt_tokens.device)
prompt_kv_list = self.get_decoder_prompt(prompt_tokens)
bs, slen = prev_output_tokens.size()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
enc: Optional[Tensor] = None
padding_mask: Optional[Tensor] = None
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
enc = encoder_out["encoder_out"][0]
assert (
enc.size()[1] == bs
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
padding_mask = encoder_out["encoder_padding_mask"][0]
bsz, tgt_len = prev_output_tokens.shape
token_position_idx = utils.new_arange(prev_output_tokens)
tgt_pos_embed = self.embed_positions(token_position_idx)
if code_masks is not None and torch.any(code_masks):
image_position_idx = self.image_position_idx[:prev_output_tokens.size(1)].unsqueeze(0).expand(bsz, tgt_len)
tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks]
# self attn position bias
self_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=False)
if code_masks is not None and torch.any(code_masks):
self_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=True)
self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks]
# cross attn position bias
src_pos_embed = encoder_out['position_embeddings'][0]
cross_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed)
if code_masks is not None and torch.any(code_masks):
cross_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True)
cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks]
if self.with_cls:
cross_abs_pos_bias = F.pad(cross_abs_pos_bias, (enc.shape[0] - cross_abs_pos_bias.shape[-1], 0, prev_output_tokens.shape[1] - cross_abs_pos_bias.shape[2], 0), "constant", 0)
cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:])
all_prev_output_tokens = prev_output_tokens.clone()
# print(all_prev_output_tokens.shape, prev_output_tokens.shape, tgt_pos_embed.shape, self_abs_pos_bias.shape)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :]
tgt_pos_embed = tgt_pos_embed[:, -1:, :]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if self.entangle_position_embedding is not None and not self.args.disable_entangle:
x += tgt_pos_embed
if self.layernorm_embedding is not None:
if code_masks is None or not code_masks.any() or not getattr(self, "code_layernorm_embedding", False):
x = self.layernorm_embedding(x)
elif code_masks is not None and code_masks.all():
x = self.code_layernorm_embedding(x)
else:
x[~code_masks] = self.layernorm_embedding(x[~code_masks])
x[code_masks] = self.code_layernorm_embedding(x[code_masks])
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if prompt_padding_mask is not None:
self_attn_padding_mask = torch.cat([prompt_padding_mask, self_attn_padding_mask], dim=1)
# decoder layers
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
if self.args.decoder_prompt:
seq_len, prompt_len = x.size(0), prompt_tokens.size(1)
prompt_mask = torch.zeros([seq_len, prompt_len]).to(x.device)
self_attn_mask = torch.cat([prompt_mask, self_attn_mask], dim=1)
else:
self_attn_mask = None
self_attn_bias = self_abs_pos_bias.clone()
# print(self_attn_bias.shape, 'self_attn_bias')
if code_masks is None or not code_masks.any():
# try:
self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
# except:
# print(idx, self_attn_bias.shape, all_prev_output_tokens.shape, self_abs_pos_bias.shape, incremental_state)
# self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
elif code_masks is not None and code_masks.all():
self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
else:
self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:])
if incremental_state is not None:
self_attn_bias = self_attn_bias[:, -1:, :]
if self.args.decoder_prompt:
if self.args.decoder_prompt_type != "prompt":
prompt_kv = prompt_kv_list[idx]
else:
if idx == 0:
prompt_kv = prompt_kv_list[idx]
else:
prompt_kv = None
else:
prompt_kv = None
x, layer_attn, _ = layer(
x,
enc,
padding_mask,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
self_attn_bias=self_attn_bias,
cross_attn_bias=cross_abs_pos_bias,
prompt_kv=prompt_kv
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states}
def output_layer(self, features):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
return self.output_projection(features)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return self.max_target_positions
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
self._future_mask.size(0) == 0
or (not self._future_mask.device == tensor.device)
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
)
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"
else:
embed_out_key = f"{name}.embed_out"
if embed_out_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_out_key
]
if not self.share_input_output_embed:
del state_dict[embed_out_key]
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(
state_dict, "{}.layers.{}".format(name, i)
)
prefix = name + "." if name != "" else ""
image_params = ["image_position_idx"]
for image_param in image_params:
state_dict[prefix + image_param] = self.state_dict()[image_param]
for param_name, param_tensor in self.state_dict().items():
if (prefix + param_name) not in state_dict:
state_dict[prefix + param_name] = self.state_dict()[param_name]
if len(state_dict["decoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]):
num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["decoder.embed_image_positions.weight"])
embed_dim = state_dict["decoder.embed_image_positions.weight"].size(1)
new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim)
nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5)
new_pos_embed_to_add = new_pos_embed_to_add.to(
dtype=state_dict["decoder.embed_image_positions.weight"].dtype,
)
state_dict["decoder.embed_image_positions.weight"] = torch.cat(
[state_dict["decoder.embed_image_positions.weight"], new_pos_embed_to_add]
)
return state_dict
def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
if zero_init:
nn.init.constant_(m.weight, 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
@register_model_architecture("unify_transformer", "unify_transformer")
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.encoder_prompt = getattr(args, "encoder_prompt", False)
args.encoder_prompt_length = getattr(args, "encoder_prompt_length", 100)
args.encoder_prompt_type = getattr(args, "encoder_prompt_type", "prefix")
args.encoder_prompt_projection = getattr(args, "encoder_prompt_projection", False)
args.encoder_prompt_dim = getattr(args, "encoder_prompt_dim", 2 * args.encoder_embed_dim)
args.decoder_prompt = getattr(args, "decoder_prompt", False)
args.decoder_prompt_length = getattr(args, "decoder_prompt_length", 100)
args.decoder_prompt_type = getattr(args, "decoder_prompt_type", "prefix")
args.decoder_prompt_projection = getattr(args, "decoder_prompt_projection", False)
args.decoder_prompt_dim = getattr(args, "decoder_prompt_dim", 2 * args.encoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)