|
|
|
|
|
|
|
|
|
|
|
import math |
|
import random |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from fairseq import utils |
|
from fairseq.distributed import fsdp_wrap |
|
from fairseq.models import ( |
|
FairseqEncoder, |
|
FairseqEncoderDecoderModel, |
|
FairseqIncrementalDecoder, |
|
register_model, |
|
register_model_architecture, |
|
) |
|
from fairseq.modules import ( |
|
AdaptiveSoftmax, |
|
BaseLayer, |
|
FairseqDropout, |
|
LayerDropModuleList, |
|
LayerNorm, |
|
SinusoidalPositionalEmbedding, |
|
) |
|
from fairseq.modules.checkpoint_activations import checkpoint_wrapper |
|
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ |
|
from torch import Tensor |
|
|
|
from .unify_transformer_layer import TransformerEncoderLayer, TransformerDecoderLayer |
|
from .resnet import ResNet |
|
from .frozen_bn import FrozenBatchNorm2d |
|
|
|
|
|
|
|
from .encoders.timm_resnet import resnet101, resnet152, resnet50 |
|
|
|
from .encoders.resnext3d import ResNeXt3D, ResNeXtBottleneck |
|
|
|
|
|
from .encoders.pann import create_pann_model |
|
from data.audio_utils import AUDIO_CFG, dotdict |
|
|
|
DEFAULT_MAX_SOURCE_POSITIONS = 1024 |
|
DEFAULT_MAX_TARGET_POSITIONS = 1024 |
|
|
|
|
|
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) |
|
import math |
|
|
|
def BatchNorm2d(out_chan, momentum=0.1, eps=1e-3): |
|
return nn.SyncBatchNorm.convert_sync_batchnorm( |
|
nn.BatchNorm2d(out_chan, momentum=momentum, eps=eps) |
|
) |
|
|
|
|
|
def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS): |
|
context_pos = torch.arange(max_position, dtype=torch.long)[:, None] |
|
memory_pos = torch.arange(max_position, dtype=torch.long)[None, :] |
|
relative_pos = context_pos - memory_pos |
|
sign = torch.sign(relative_pos) |
|
mid = bucket_size // 2 |
|
abs_pos = torch.where((relative_pos<mid) & (relative_pos > -mid), mid-1, torch.abs(relative_pos)) |
|
log_pos = torch.ceil(torch.log(abs_pos/mid)/math.log((max_position-1)/mid) * (mid-1)) + mid |
|
log_pos = log_pos.int() |
|
bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos*sign).long() |
|
return bucket_pos + bucket_size - 1 |
|
|
|
|
|
def make_image_bucket_position(bucket_size, num_relative_distance): |
|
coords_h = torch.arange(bucket_size) |
|
coords_w = torch.arange(bucket_size) |
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) |
|
coords_flatten = torch.flatten(coords, 1) |
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] |
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() |
|
relative_coords[:, :, 0] += bucket_size - 1 |
|
relative_coords[:, :, 1] += bucket_size - 1 |
|
relative_coords[:, :, 0] *= 2 * bucket_size - 1 |
|
relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype) |
|
relative_position_index[1:, 1:] = relative_coords.sum(-1) |
|
relative_position_index[0, 0:] = num_relative_distance - 3 |
|
relative_position_index[0:, 0] = num_relative_distance - 2 |
|
relative_position_index[0, 0] = num_relative_distance - 1 |
|
return relative_position_index |
|
|
|
|
|
class PromptEncoder(torch.nn.Module): |
|
r""" |
|
Prompt encoder to generate prompts, including prompt, prefix, instance and instruction |
|
""" |
|
|
|
def __init__( |
|
self, |
|
type, |
|
length, |
|
projection, |
|
embed_dim, |
|
proj_dim, |
|
layers, |
|
vocab_size): |
|
super().__init__() |
|
self.prefix_projection = projection |
|
|
|
if type == "prefix": |
|
layers = layers |
|
prompt_vocab_size = length |
|
|
|
if self.prefix_projection: |
|
self.embedding = torch.nn.Embedding(prompt_vocab_size, embed_dim) |
|
self.trans = torch.nn.Sequential( |
|
torch.nn.Linear(embed_dim, proj_dim), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(proj_dim, layers * 2 * embed_dim) |
|
) |
|
else: |
|
if type == "prefix": |
|
self.embedding = torch.nn.Embedding( |
|
prompt_vocab_size, layers * 2 * embed_dim) |
|
|
|
def forward(self, prefix: torch.Tensor): |
|
if self.prefix_projection: |
|
prefix_tokens = self.embedding(prefix) |
|
past_key_values = self.trans(prefix_tokens) |
|
else: |
|
past_key_values = self.embedding(prefix) |
|
return past_key_values |
|
|
|
|
|
@register_model("unify_transformer") |
|
class TransformerModel(FairseqEncoderDecoderModel): |
|
""" |
|
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) |
|
<https://arxiv.org/abs/1706.03762>`_. |
|
|
|
Args: |
|
encoder (TransformerEncoder): the encoder |
|
decoder (TransformerDecoder): the decoder |
|
|
|
The Transformer model provides the following named architectures and |
|
command-line arguments: |
|
|
|
.. argparse:: |
|
:ref: fairseq.models.transformer_parser |
|
:prog: |
|
""" |
|
|
|
def __init__(self, args, encoder, decoder): |
|
super().__init__(encoder, decoder) |
|
self.args = args |
|
self.supports_align_args = True |
|
|
|
|
|
|
|
@staticmethod |
|
def add_args(parser): |
|
"""Add model-specific arguments to the parser.""" |
|
|
|
parser.add_argument('--activation-fn', |
|
choices=utils.get_available_activation_fns(), |
|
help='activation function to use') |
|
parser.add_argument('--dropout', type=float, metavar='D', |
|
help='dropout probability') |
|
parser.add_argument('--attention-dropout', type=float, metavar='D', |
|
help='dropout probability for attention weights') |
|
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', |
|
help='dropout probability after activation in FFN.') |
|
parser.add_argument('--encoder-embed-path', type=str, metavar='STR', |
|
help='path to pre-trained encoder embedding') |
|
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', |
|
help='encoder embedding dimension') |
|
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', |
|
help='encoder embedding dimension for FFN') |
|
parser.add_argument('--encoder-layers', type=int, metavar='N', |
|
help='num encoder layers') |
|
parser.add_argument('--encoder-attention-heads', type=int, metavar='N', |
|
help='num encoder attention heads') |
|
parser.add_argument('--encoder-normalize-before', action='store_true', |
|
help='apply layernorm before each encoder block') |
|
parser.add_argument('--encoder-learned-pos', action='store_true', |
|
help='use learned positional embeddings in the encoder') |
|
parser.add_argument('--bitfit', default=False, action='store_true', |
|
help='use bitfit in the transformer') |
|
parser.add_argument('--freeze-encoder', action='store_true', |
|
help='freeze the parameters in the encoder') |
|
|
|
|
|
parser.add_argument('--adapter', action='store_true', |
|
help='use adapter in the model') |
|
parser.add_argument('--adapter-dim', type=int, metavar='N', |
|
help='adapter-down-dim') |
|
|
|
parser.add_argument('--adapter-type', type=str, metavar='STR', |
|
help='adapter-type') |
|
parser.add_argument('--unfreeze', action='store_true', |
|
help='unfreeze model when using adapters/prompts') |
|
|
|
parser.add_argument('--encoder-prompt', action='store_true', |
|
help='use prompt tuning in the encoder') |
|
parser.add_argument('--encoder-prompt-type', type=str, metavar='STR', |
|
choices=['prefix'], |
|
help='the type of prompt tuning') |
|
parser.add_argument('--encoder-prompt-projection', action='store_true', |
|
help='use prompt projection') |
|
parser.add_argument('--encoder-prompt-length', type=int, metavar='N', |
|
help='use prompt tuning in the decoder') |
|
parser.add_argument('--encoder-prompt-dim', type=int, metavar='N', |
|
help='encoder prompt dimension if use encoder prompt projection') |
|
parser.add_argument('--decoder-embed-path', type=str, metavar='STR', |
|
help='path to pre-trained decoder embedding') |
|
parser.add_argument('--decoder-embed-dim', type=int, metavar='N', |
|
help='decoder embedding dimension') |
|
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', |
|
help='decoder embedding dimension for FFN') |
|
parser.add_argument('--decoder-layers', type=int, metavar='N', |
|
help='num decoder layers') |
|
parser.add_argument('--decoder-attention-heads', type=int, metavar='N', |
|
help='num decoder attention heads') |
|
parser.add_argument('--decoder-learned-pos', action='store_true', |
|
help='use learned positional embeddings in the decoder') |
|
parser.add_argument('--decoder-normalize-before', action='store_true', |
|
help='apply layernorm before each decoder block') |
|
parser.add_argument('--decoder-output-dim', type=int, metavar='N', |
|
help='decoder output dimension (extra linear layer ' |
|
'if different from decoder embed dim') |
|
parser.add_argument('--freeze-decoder', action='store_true', |
|
help='freeze the parameters in the decoder') |
|
|
|
parser.add_argument('--decoder-prompt', action='store_true', |
|
help='use prompt tuning in the decoder') |
|
parser.add_argument('--decoder-prompt-type', type=str, metavar='STR', |
|
choices=['prefix'], |
|
help='the type of prompt tuning') |
|
parser.add_argument('--decoder-prompt-length', type=int, metavar='N', |
|
help='use prompt tuning in the decoder') |
|
parser.add_argument('--decoder-prompt-projection', action='store_true', |
|
help='use prompt projection') |
|
parser.add_argument('--decoder-prompt-dim', type=int, metavar='N', |
|
help='decoder prompt dimension if use decoder prompt projection') |
|
parser.add_argument('--share-decoder-input-output-embed', action='store_true', |
|
help='share decoder input and output embeddings') |
|
parser.add_argument('--share-all-embeddings', action='store_true', |
|
help='share encoder, decoder and output embeddings' |
|
' (requires shared dictionary and embed dim)') |
|
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', |
|
help='if set, disables positional embeddings (outside self attention)') |
|
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', |
|
help='comma separated list of adaptive softmax cutoff points. ' |
|
'Must be used with adaptive_loss criterion'), |
|
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', |
|
help='sets adaptive softmax dropout for the tail projections') |
|
parser.add_argument('--layernorm-embedding', action='store_true', |
|
help='add layernorm to embedding') |
|
parser.add_argument('--no-scale-embedding', action='store_true', |
|
help='if True, dont scale embeddings') |
|
parser.add_argument('--checkpoint-activations', action='store_true', |
|
help='checkpoint activations at each layer, which saves GPU ' |
|
'memory usage at the cost of some additional compute') |
|
parser.add_argument('--offload-activations', action='store_true', |
|
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.') |
|
|
|
parser.add_argument('--no-cross-attention', default=False, action='store_true', |
|
help='do not perform cross-attention') |
|
parser.add_argument('--cross-self-attention', default=False, action='store_true', |
|
help='perform cross+self-attention') |
|
|
|
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, |
|
help='LayerDrop probability for encoder') |
|
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, |
|
help='LayerDrop probability for decoder') |
|
parser.add_argument('--encoder-layers-to-keep', default=None, |
|
help='which layers to *keep* when pruning as a comma-separated list') |
|
parser.add_argument('--decoder-layers-to-keep', default=None, |
|
help='which layers to *keep* when pruning as a comma-separated list') |
|
|
|
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, |
|
help='iterative PQ quantization noise at training time') |
|
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, |
|
help='block size of quantization noise at training time') |
|
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, |
|
help='scalar quantization noise and scalar quantization at training time') |
|
|
|
parser.add_argument( |
|
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, |
|
help=( |
|
'minimum number of params for a layer to be wrapped with FSDP() when ' |
|
'training with --ddp-backend=fully_sharded. Smaller values will ' |
|
'improve memory efficiency, but may make torch.distributed ' |
|
'communication less efficient due to smaller input sizes. This option ' |
|
'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' |
|
'--offload-activations are passed.' |
|
) |
|
) |
|
|
|
parser.add_argument('--resnet-drop-path-rate', type=float, |
|
help='resnet drop path rate') |
|
parser.add_argument('--encoder-drop-path-rate', type=float, |
|
help='encoder drop path rate') |
|
parser.add_argument('--decoder-drop-path-rate', type=float, |
|
help='encoder drop path rate') |
|
|
|
parser.add_argument('--token-bucket-size', type=int, |
|
help='token bucket size') |
|
parser.add_argument('--image-bucket-size', type=int, |
|
help='image bucket size') |
|
|
|
parser.add_argument('--attn-scale-factor', type=float, |
|
help='attention scale factor') |
|
parser.add_argument('--freeze-resnet', action='store_true', |
|
help='freeze resnet') |
|
parser.add_argument('--freeze-encoder-embedding', action='store_true', |
|
help='freeze encoder token embedding') |
|
parser.add_argument('--freeze-decoder-embedding', action='store_true', |
|
help='freeze decoder token embedding') |
|
parser.add_argument('--add-type-embedding', action='store_true', |
|
help='add source/region/patch type embedding') |
|
parser.add_argument('--add-mm-type-embedding', action='store_true', |
|
help='add source/region/patch type embedding') |
|
|
|
parser.add_argument('--interpolate-position', action='store_true', |
|
help='interpolate position') |
|
|
|
parser.add_argument('--resnet-type', choices=['resnet50', 'resnet101', 'resnet152'], |
|
help='resnet type') |
|
parser.add_argument('--resnet-model-path', type=str, metavar='STR', |
|
help='path to load resnet') |
|
parser.add_argument('--code-image-size', type=int, |
|
help='code image size') |
|
parser.add_argument('--patch-layernorm-embedding', action='store_true', |
|
help='add layernorm to patch embedding') |
|
parser.add_argument('--code-layernorm-embedding', action='store_true', |
|
help='add layernorm to code embedding') |
|
parser.add_argument('--entangle-position-embedding', action='store_true', |
|
help='entangle position embedding') |
|
parser.add_argument('--disable-entangle', action='store_true', |
|
help='disable entangle') |
|
parser.add_argument('--sync-bn', action='store_true', |
|
help='sync batchnorm') |
|
|
|
parser.add_argument('--scale-attn', action='store_true', |
|
help='scale attn') |
|
parser.add_argument('--scale-fc', action='store_true', |
|
help='scale fc') |
|
parser.add_argument('--scale-heads', action='store_true', |
|
help='scale heads') |
|
parser.add_argument('--scale-resids', action='store_true', |
|
help='scale resids') |
|
|
|
|
|
|
|
parser.add_argument('--image-encoder-name', type=str, metavar='STR', default='resnet', |
|
help='image_encoder_name') |
|
parser.add_argument('--return-hidden-state-vision', action='store_true', |
|
help='return_hidden_state_vision') |
|
|
|
parser.add_argument('--freeze-image-encoder', action='store_true', |
|
help='freeze_image_encoder') |
|
|
|
parser.add_argument('--nograd', action='store_true', |
|
help='nograd for vis encoder') |
|
|
|
parser.add_argument('--encoder-eval', action='store_true', |
|
help='vision encoder.eval()') |
|
|
|
parser.add_argument('--video-encoder-name', type=str, default=None, |
|
help='video_encoder_name') |
|
|
|
parser.add_argument('--video-model-path', type=str, default=None, |
|
help='video_model_path') |
|
parser.add_argument('--freeze-video-encoder', action='store_true', |
|
help='freeze_video_encoder') |
|
|
|
parser.add_argument('--sample-patch-num', type=int, default=None, |
|
help='num of tokens selected at random') |
|
parser.add_argument('--with-cls', action='store_true', |
|
help='with_cls') |
|
|
|
parser.add_argument('--sample-video-patch-num', type=int, default=None, |
|
help='num of tokens selected at random') |
|
|
|
parser.add_argument('--audio-encoder-name', type=str, default=None, |
|
help='audio_encoder_name') |
|
parser.add_argument('--audio-model-path', type=str, default=None, |
|
help='audio_model_path') |
|
parser.add_argument('--freeze-audio-encoder', action='store_true', |
|
help='freeze_audio_encoder') |
|
parser.add_argument('--fusion-type', type=str, default=None, |
|
help='fusion_type') |
|
parser.add_argument('--enable-fusion', action='store_true', |
|
help='enable_fusion') |
|
parser.add_argument('--mel-bins', type=int, default=64, |
|
help='mel_bins') |
|
parser.add_argument('--hop-size', type=int, default=480, |
|
help='hop_size') |
|
parser.add_argument('--sample-audio-patch-num', type=int, default=None, |
|
help='num of tokens selected at random') |
|
parser.add_argument('--fstride', type=int, default=10, |
|
help='fstride') |
|
parser.add_argument('--tstride', type=int, default=10, |
|
help='tstride') |
|
parser.add_argument('--input-tdim', type=int, default=1024, |
|
help='input_tdim') |
|
|
|
|
|
|
|
|
|
parser.add_argument('--progressive', action='store_true', |
|
help='progressive') |
|
parser.add_argument('--unfreeze-epoch-encoder', type=int, default=0, |
|
help='unfreeze_epoch_encoder') |
|
parser.add_argument('--unfreeze-epoch-decoder', type=int, default=0, |
|
help='unfreeze_epoch_decoder') |
|
|
|
parser.add_argument('--unfreeze-epoch-image', type=int, default=0, |
|
help='unfreeze_epoch_image') |
|
parser.add_argument('--unfreeze-epoch-video', type=int, default=0, |
|
help='unfreeze_epoch_video') |
|
parser.add_argument('--unfreeze-epoch-audio', type=int, default=0, |
|
help='unfreeze_epoch_audio') |
|
|
|
|
|
|
|
parser.add_argument('--only-linear-proj', action='store_true', |
|
help='only_linear_proj') |
|
parser.add_argument('--unfreeze-epoch', type=int, default=0, |
|
help='unfreeze_epoch') |
|
|
|
parser.add_argument('--freeze-perception', action='store_true', |
|
help='freeze_perception') |
|
|
|
parser.add_argument('--qk-norm', action='store_true', |
|
help='qk_norm') |
|
|
|
|
|
parser.add_argument('--layernorm-image-embedding', action='store_true', |
|
help='layernorm_image_embedding') |
|
|
|
parser.add_argument('--layernorm-video-embedding', action='store_true', |
|
help='layernorm_video_embedding') |
|
|
|
parser.add_argument('--layernorm-audio-embedding', action='store_true', |
|
help='layernorm_audio_embedding') |
|
|
|
parser.add_argument('--freeze-batchnorm', action='store_true', |
|
help='freeze_batchnorm') |
|
|
|
parser.add_argument('--freeze-resnet-video', action='store_true', |
|
help='freeze resnet video') |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
"""Build a new model instance.""" |
|
|
|
|
|
base_architecture(args) |
|
|
|
if args.encoder_layers_to_keep: |
|
args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) |
|
if args.decoder_layers_to_keep: |
|
args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) |
|
|
|
if getattr(args, "max_source_positions", None) is None: |
|
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS |
|
if getattr(args, "max_target_positions", None) is None: |
|
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS |
|
|
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary |
|
|
|
if args.share_all_embeddings: |
|
if src_dict != tgt_dict: |
|
raise ValueError("--share-all-embeddings requires a joined dictionary") |
|
if args.encoder_embed_dim != args.decoder_embed_dim: |
|
raise ValueError( |
|
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" |
|
) |
|
if args.decoder_embed_path and ( |
|
args.decoder_embed_path != args.encoder_embed_path |
|
): |
|
raise ValueError( |
|
"--share-all-embeddings not compatible with --decoder-embed-path" |
|
) |
|
encoder_embed_tokens = cls.build_embedding( |
|
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path |
|
) |
|
decoder_embed_tokens = encoder_embed_tokens |
|
args.share_decoder_input_output_embed = True |
|
else: |
|
encoder_embed_tokens = cls.build_embedding( |
|
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path |
|
) |
|
decoder_embed_tokens = cls.build_embedding( |
|
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path |
|
) |
|
if getattr(args, "freeze_encoder_embedding", False) or getattr( |
|
args, "encoder_prompt", False) or getattr(args, "decoder_prompt", False) or getattr(args, "adapter", False): |
|
encoder_embed_tokens.weight.requires_grad = False |
|
if getattr(args, "freeze_decoder_embedding", False) or getattr( |
|
args, "encoder_prompt", False) or getattr(args, "decoder_prompt", False) or getattr(args, "adapter", False): |
|
decoder_embed_tokens.weight.requires_grad = False |
|
if getattr(args, "offload_activations", False): |
|
args.checkpoint_activations = True |
|
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) |
|
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) |
|
|
|
if getattr(args, "freeze_encoder", False): |
|
encoder.requires_grad_(False) |
|
|
|
|
|
encoder.embed_images.requires_grad_(True) |
|
encoder.image_proj.requires_grad_(True) |
|
encoder.embed_image_positions.requires_grad_(True) |
|
|
|
if encoder.layernorm_image_embedding is not None: |
|
encoder.layernorm_image_embedding.requires_grad_(True) |
|
print("freeze image LN") |
|
|
|
if hasattr(encoder, 'embed_videos'): |
|
encoder.embed_videos.requires_grad_(True) |
|
encoder.video_proj.requires_grad_(True) |
|
encoder.embed_video_positions.requires_grad_(True) |
|
|
|
if encoder.layernorm_video_embedding is not None: |
|
encoder.layernorm_video_embedding.requires_grad_(True) |
|
print("freeze audio LN") |
|
|
|
if getattr(args, "audio_encoder_name", False): |
|
encoder.embed_audios.requires_grad_(True) |
|
encoder.audio_proj.requires_grad_(True) |
|
encoder.embed_audio_positions.requires_grad_(True) |
|
if encoder.layernorm_audio_embedding is not None: |
|
encoder.layernorm_audio_embedding.requires_grad_(True) |
|
print("freeze audio LN") |
|
|
|
|
|
|
|
if getattr(args, "freeze_decoder", False): |
|
decoder.requires_grad_(False) |
|
|
|
|
|
if getattr(args, "encoder_prompt", False) or getattr( |
|
args, "decoder_prompt", False) or getattr( |
|
args, "adapter", False): |
|
if not getattr(args, "unfreeze", False): |
|
encoder.requires_grad_(False) |
|
decoder.requires_grad_(False) |
|
if getattr(args, "encoder_prompt", False): |
|
encoder.encoder_prompt_encoder.requires_grad_(True) |
|
if getattr(args, "decoder_prompt", False): |
|
decoder.decoder_prompt_encoder.requires_grad_(True) |
|
if getattr(args, "adapter", False): |
|
for idx, layer in enumerate(encoder.layers): |
|
layer.adapter.requires_grad_(True) |
|
for idx, layer in enumerate(decoder.layers): |
|
layer.adapter.requires_grad_(True) |
|
if getattr(args, "freeze_image_encoder", False): |
|
encoder.embed_images.requires_grad_(False) |
|
|
|
if getattr(args, "freeze_video_encoder", False) and hasattr(encoder, 'embed_videos'): |
|
encoder.embed_videos.requires_grad_(False) |
|
|
|
if getattr(args, "freeze_audio_encoder", False) and hasattr(encoder, 'embed_audios'): |
|
encoder.embed_audios.requires_grad_(False) |
|
|
|
if getattr(args, "freeze_batchnorm", False): |
|
|
|
for n, p in encoder.named_parameters(): |
|
if 'bn' in n: |
|
p.requires_grad = False |
|
print("freeze:", n) |
|
|
|
if hasattr(encoder, 'embed_audios'): |
|
if hasattr(encoder.embed_audios, 'logmel_extractor'): |
|
for param in encoder.embed_audios.logmel_extractor.parameters(): |
|
param.requires_grad = False |
|
|
|
if getattr(args, "encoder_eval", False): |
|
print("Set encoders.eval()") |
|
encoder.embed_images.eval() |
|
|
|
if getattr(args, "video_encoder_name", None): |
|
encoder.embed_videos.eval() |
|
|
|
|
|
if not args.share_all_embeddings: |
|
min_params_to_wrap = getattr( |
|
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP |
|
) |
|
|
|
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) |
|
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) |
|
return cls(args, encoder, decoder) |
|
|
|
@classmethod |
|
def build_embedding(cls, args, dictionary, embed_dim, path=None): |
|
num_embeddings = len(dictionary) |
|
padding_idx = dictionary.pad() |
|
|
|
args.vocab_size = num_embeddings |
|
emb = Embedding(num_embeddings, embed_dim, padding_idx) |
|
|
|
if path: |
|
embed_dict = utils.parse_embedding(path) |
|
utils.load_embedding(embed_dict, dictionary, emb) |
|
return emb |
|
|
|
@classmethod |
|
def build_encoder(cls, args, src_dict, embed_tokens): |
|
return TransformerEncoder(args, src_dict, embed_tokens) |
|
|
|
@classmethod |
|
def build_decoder(cls, args, tgt_dict, embed_tokens): |
|
return TransformerDecoder( |
|
args, |
|
tgt_dict, |
|
embed_tokens, |
|
no_encoder_attn=getattr(args, "no_cross_attention", False), |
|
) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
src_tokens, |
|
src_lengths, |
|
prev_output_tokens, |
|
return_all_hiddens: bool = True, |
|
features_only: bool = False, |
|
alignment_layer: Optional[int] = None, |
|
alignment_heads: Optional[int] = None, |
|
): |
|
""" |
|
Run the forward pass for an encoder-decoder model. |
|
|
|
Copied from the base class, but without ``**kwargs``, |
|
which are not supported by TorchScript. |
|
""" |
|
encoder_out = self.encoder( |
|
src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens |
|
) |
|
decoder_out = self.decoder( |
|
prev_output_tokens, |
|
encoder_out=encoder_out, |
|
features_only=features_only, |
|
alignment_layer=alignment_layer, |
|
alignment_heads=alignment_heads, |
|
src_lengths=src_lengths, |
|
return_all_hiddens=return_all_hiddens, |
|
) |
|
return decoder_out |
|
|
|
|
|
|
|
|
|
@torch.jit.export |
|
def get_normalized_probs( |
|
self, |
|
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], |
|
log_probs: bool, |
|
sample: Optional[Dict[str, Tensor]] = None, |
|
): |
|
"""Get normalized probabilities (or log probs) from a net's output.""" |
|
return self.get_normalized_probs_scriptable(net_output, log_probs, sample) |
|
|
|
|
|
class TransformerEncoder(FairseqEncoder): |
|
""" |
|
Transformer encoder consisting of *args.encoder_layers* layers. Each layer |
|
is a :class:`TransformerEncoderLayer`. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
dictionary (~fairseq.data.Dictionary): encoding dictionary |
|
embed_tokens (torch.nn.Embedding): input embedding |
|
""" |
|
|
|
def __init__(self, args, dictionary, embed_tokens): |
|
self.args = args |
|
super().__init__(dictionary) |
|
self.register_buffer("version", torch.Tensor([3])) |
|
|
|
|
|
self.freeze_image_encoder = getattr(args, "freeze_image_encoder", False) |
|
self.freeze_video_encoder = getattr(args, "freeze_video_encoder", False) |
|
self.nograd = getattr(args, "nograd", False) |
|
self.num_frames = getattr(args, 'num_frames', 4) |
|
|
|
self.sample_patch_num = getattr(args, "sample_patch_num", 196) |
|
|
|
self.sample_audio_patch_num = getattr(args, "sample_audio_patch_num", self.sample_patch_num) |
|
self.sample_video_patch_num = getattr(args, "sample_video_patch_num", self.sample_patch_num) |
|
|
|
self.with_cls = getattr(args, "with_cls", False) |
|
print("self.sample_patch_num", self.sample_patch_num) |
|
print("self.sample_audio_patch_num", self.sample_audio_patch_num) |
|
print("self.sample_video_patch_num", self.sample_video_patch_num) |
|
print("self.with_cls", self.with_cls) |
|
|
|
if getattr(args, "encoder_prompt", False): |
|
self.encoder_prompt_encoder = PromptEncoder( |
|
type=args.encoder_prompt_type, |
|
length=args.encoder_prompt_length, |
|
projection=args.encoder_prompt_projection, |
|
embed_dim=args.encoder_embed_dim, |
|
proj_dim=args.encoder_prompt_dim, |
|
layers=args.encoder_layers, |
|
vocab_size=args.vocab_size) |
|
self.encoder_dropout = nn.Dropout(p=0.2) |
|
|
|
self.dropout_module = FairseqDropout( |
|
args.dropout, module_name=self.__class__.__name__ |
|
) |
|
self.encoder_layerdrop = args.encoder_layerdrop |
|
|
|
embed_dim = embed_tokens.embedding_dim |
|
self.padding_idx = embed_tokens.padding_idx |
|
self.max_source_positions = args.max_source_positions |
|
self.num_attention_heads = args.encoder_attention_heads |
|
|
|
self.embed_tokens = embed_tokens |
|
|
|
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) |
|
|
|
if getattr(args, "layernorm_embedding", False): |
|
self.layernorm_embedding = LayerNorm(embed_dim) |
|
else: |
|
self.layernorm_embedding = None |
|
|
|
self.mm_type_embedding = False |
|
if getattr(args, "add_mm_type_embedding", False): |
|
self.type_embedding = Embedding(4, embed_dim, padding_idx=None) |
|
self.mm_type_embedding = True |
|
elif getattr(args, "add_type_embedding", False): |
|
self.type_embedding = Embedding(2, embed_dim, padding_idx=None) |
|
else: |
|
self.type_embedding = None |
|
|
|
norm_layer = None |
|
norm_layer_video = None |
|
if getattr(args, "sync_bn", False): |
|
norm_layer = BatchNorm2d |
|
else: |
|
if getattr(args, "freeze_resnet", False): |
|
norm_layer = FrozenBatchNorm2d |
|
print("Frozen image bn", norm_layer) |
|
if getattr(args, "freeze_resnet_video", False): |
|
norm_layer_video = FrozenBatchNorm2d |
|
print("Frozen video bn", norm_layer_video) |
|
|
|
|
|
|
|
if getattr(args, "image_encoder_name", None): |
|
|
|
if 'timm_resnet' in args.image_encoder_name: |
|
if args.resnet_type == 'resnet101': |
|
self.embed_images = resnet101(norm_layer=norm_layer) |
|
elif args.resnet_type == 'resnet152': |
|
self.embed_images = resnet152(norm_layer=norm_layer) |
|
elif args.resnet_type == 'resnet50': |
|
self.embed_images = resnet50(norm_layer=norm_layer) |
|
else: |
|
raise NotImplementedError |
|
self.image_proj = Linear(2048, embed_dim) |
|
else: |
|
if args.resnet_type == 'resnet101': |
|
self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) |
|
elif args.resnet_type == 'resnet152': |
|
self.embed_images = ResNet([3, 8, 36], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) |
|
elif args.resnet_type == 'resnet50': |
|
self.embed_images = ResNet([3, 4, 6], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) |
|
else: |
|
raise NotImplementedError |
|
self.image_proj = Linear(1024, embed_dim) |
|
else: |
|
if args.resnet_type == 'resnet101': |
|
self.embed_images = ResNet([3, 4, 23], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) |
|
elif args.resnet_type == 'resnet152': |
|
self.embed_images = ResNet([3, 8, 36], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) |
|
elif args.resnet_type == 'resnet50': |
|
self.embed_images = ResNet([3, 4, 6], norm_layer=norm_layer, drop_path_rate=args.resnet_drop_path_rate) |
|
else: |
|
raise NotImplementedError |
|
self.image_proj = Linear(1024, embed_dim) |
|
|
|
if getattr(args, "layernorm_image_embedding", False): |
|
self.layernorm_image_embedding = LayerNorm(embed_dim) |
|
else: |
|
self.layernorm_image_embedding = None |
|
|
|
|
|
|
|
if getattr(args, "video_encoder_name", None): |
|
print("Loading: ", args.video_encoder_name) |
|
patch_frame_size = getattr(args, 'patch_frame_size', 224) |
|
num_frames = getattr(args, 'num_frames', 4) |
|
pretrained_model = getattr(args, "video_model_path", None) |
|
|
|
if 'resnext' in args.video_encoder_name: |
|
|
|
if 'all' in args.video_encoder_name: |
|
if 'resnext101' in args.video_encoder_name : |
|
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 4, 23, 3], norm_layer=norm_layer_video) |
|
elif 'resnext152' in args.video_encoder_name: |
|
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 8, 36, 3], norm_layer=norm_layer_video) |
|
elif 'resnext50' in args.video_encoder_name: |
|
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 4, 6, 3], norm_layer=norm_layer_video) |
|
else: |
|
raise NotImplementedError |
|
vis_dim = 2048 |
|
else: |
|
if args.video_encoder_name == 'resnext101': |
|
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 4, 23]) |
|
elif args.video_encoder_name == 'resnext152': |
|
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 8, 36]) |
|
elif args.video_encoder_name == 'resnext50': |
|
self.embed_videos = ResNeXt3D(ResNeXtBottleneck, [3, 4, 6]) |
|
else: |
|
raise NotImplementedError |
|
vis_dim = 1024 |
|
|
|
self.embed_video_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim) |
|
|
|
if pretrained_model: |
|
print("load pretrained_model {}".format(pretrained_model)) |
|
state_dict = torch.load(pretrained_model)['state_dict'] |
|
if 'module' in list(state_dict.keys())[0]: |
|
from collections import OrderedDict |
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
name = k[7:] |
|
new_state_dict[name]=v |
|
state_dict = new_state_dict |
|
|
|
msg = self.embed_videos.load_state_dict(state_dict, strict=False) |
|
print(msg) |
|
|
|
else: |
|
raise NotImplemented |
|
|
|
|
|
self.video_proj = Linear(vis_dim, embed_dim) |
|
|
|
|
|
if getattr(args, "layernorm_video_embedding", False): |
|
self.layernorm_video_embedding = LayerNorm(embed_dim) |
|
else: |
|
self.layernorm_video_embedding = None |
|
|
|
|
|
if getattr(args, "audio_encoder_name", None): |
|
print("Loading: ", args.audio_encoder_name) |
|
|
|
|
|
|
|
pretrained_audio_model = getattr(args, "audio_model_path", None) |
|
audio_cfg = getattr(args, "audio_cfg", None) |
|
audio_cfg = audio_cfg if audio_cfg is not None else AUDIO_CFG |
|
|
|
audio_cfg = dotdict(audio_cfg) |
|
enable_fusion = getattr(args, "enable_fusion", False) |
|
fusion_type = getattr(args, "fusion_type", None) |
|
|
|
audio_cfg['mel_bins'] = getattr(args, "mel_bins", 64) |
|
audio_cfg['hop_size'] = getattr(args, "hop_size", 480) |
|
|
|
if 'pann' in args.audio_encoder_name: |
|
|
|
if 'cnn6' in args.audio_encoder_name: |
|
audio_cfg['model_name'] = 'Cnn6' |
|
audio_dim = 512 |
|
elif 'cnn10' in args.audio_encoder_name: |
|
audio_cfg['model_name'] = 'Cnn10' |
|
audio_dim = 1024 |
|
elif 'cnn14' in args.audio_encoder_name: |
|
audio_cfg['model_name'] = 'Cnn14' |
|
|
|
audio_dim = 2048 |
|
else: |
|
raise NotImplementedError |
|
|
|
self.embed_audios = create_pann_model(audio_cfg, enable_fusion, fusion_type) |
|
|
|
|
|
self.embed_audio_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim) |
|
|
|
|
|
|
|
if pretrained_audio_model: |
|
print("load pretrained_model {}".format(pretrained_audio_model)) |
|
state_dict = torch.load(pretrained_audio_model) |
|
if 'model' in state_dict: |
|
state_dict = state_dict['model'] |
|
|
|
if 'module' in list(state_dict.keys())[0]: |
|
from collections import OrderedDict |
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
name = k.replace('module.', '') |
|
new_state_dict[name]=v |
|
state_dict = new_state_dict |
|
|
|
if 'sed_model.' in list(state_dict.keys())[0] or 'module.' in list(state_dict.keys())[0]: |
|
from collections import OrderedDict |
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
name = k.replace('sed_model.', '').replace('module.', '').replace('audio_branch.', '') |
|
new_state_dict[name]=v |
|
state_dict = new_state_dict |
|
|
|
if audio_cfg['mel_bins'] != 64: |
|
|
|
del_keys = [] |
|
for k, v in state_dict.items(): |
|
if 'logmel_extractor' in k or 'bn0' in k: |
|
del_keys.append(k) |
|
|
|
for k in del_keys: |
|
del state_dict[k] |
|
|
|
msg = self.embed_audios.load_state_dict(state_dict, strict=False) |
|
print(msg) |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
self.audio_proj = Linear(audio_dim, embed_dim) |
|
|
|
|
|
if getattr(args, "layernorm_audio_embedding", False): |
|
self.layernorm_audio_embedding = LayerNorm(embed_dim) |
|
else: |
|
self.layernorm_audio_embedding = None |
|
|
|
|
|
if getattr(args, "resnet_model_path", None): |
|
print("load resnet {}".format(args.resnet_model_path)) |
|
resnet_state_dict = torch.load(self.args.resnet_model_path) |
|
msg = self.embed_images.load_state_dict(resnet_state_dict, strict=False) |
|
print(msg) |
|
|
|
if getattr(args, "patch_layernorm_embedding", False): |
|
self.patch_layernorm_embedding = LayerNorm(embed_dim) |
|
else: |
|
self.patch_layernorm_embedding = None |
|
|
|
self.embed_positions = Embedding(args.max_source_positions + 2, embed_dim) |
|
self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim) |
|
self.pos_ln = LayerNorm(embed_dim) |
|
self.image_pos_ln = LayerNorm(embed_dim) |
|
self.pos_scaling = float(embed_dim / args.encoder_attention_heads * args.attn_scale_factor) ** -0.5 |
|
self.pos_q_linear = nn.Linear(embed_dim, embed_dim) |
|
self.pos_k_linear = nn.Linear(embed_dim, embed_dim) |
|
|
|
if not args.adaptive_input and args.quant_noise_pq > 0: |
|
self.quant_noise = apply_quant_noise_( |
|
nn.Linear(embed_dim, embed_dim, bias=False), |
|
args.quant_noise_pq, |
|
args.quant_noise_pq_block_size, |
|
) |
|
else: |
|
self.quant_noise = None |
|
|
|
if self.encoder_layerdrop > 0.0: |
|
self.layers = LayerDropModuleList(p=self.encoder_layerdrop) |
|
else: |
|
self.layers = nn.ModuleList([]) |
|
|
|
dpr = [x.item() for x in torch.linspace(0, args.encoder_drop_path_rate, args.encoder_layers)] |
|
self.layers.extend( |
|
[self.build_encoder_layer(args, drop_path_rate=dpr[i]) for i in range(args.encoder_layers)] |
|
) |
|
self.num_layers = len(self.layers) |
|
|
|
if args.encoder_normalize_before: |
|
self.layer_norm = LayerNorm(embed_dim) |
|
else: |
|
self.layer_norm = None |
|
|
|
token_bucket_size = args.token_bucket_size |
|
token_num_rel_dis = 2 * token_bucket_size - 1 |
|
token_rp_bucket = make_token_bucket_position(token_bucket_size) |
|
self.token_rel_pos_table_list = nn.ModuleList( |
|
[Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)] |
|
) |
|
|
|
image_bucket_size = args.image_bucket_size |
|
image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3 |
|
image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis) |
|
self.image_rel_pos_table_list = nn.ModuleList( |
|
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.encoder_layers)] |
|
) |
|
|
|
self.patch_image_size = args.patch_image_size |
|
self.orig_patch_image_size = args.orig_patch_image_size |
|
|
|
self.register_buffer("token_rp_bucket", token_rp_bucket) |
|
self.register_buffer("image_rp_bucket", image_rp_bucket) |
|
self.entangle_position_embedding = args.entangle_position_embedding |
|
|
|
|
|
def build_encoder_layer(self, args, drop_path_rate=0.0): |
|
layer = TransformerEncoderLayer(args, drop_path_rate=drop_path_rate, \ |
|
use_adapter=getattr(args, "adapter", False), adapter_dim=getattr(args, "adapter_dim", 200), |
|
adapter_type=getattr(args, "adapter_type", 'UN')) |
|
checkpoint = getattr(args, "checkpoint_activations", False) |
|
if checkpoint: |
|
offload_to_cpu = getattr(args, "offload_activations", False) |
|
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) |
|
|
|
|
|
min_params_to_wrap = ( |
|
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) |
|
if not checkpoint else 0 |
|
) |
|
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) |
|
return layer |
|
|
|
def get_rel_pos_bias(self, x, idx): |
|
seq_len = x.size(1) |
|
rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] |
|
values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight) |
|
values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1) |
|
values = values.permute([0, 3, 1, 2]) |
|
return values.contiguous() |
|
|
|
def get_image_rel_pos_bias(self, image_position_ids, idx): |
|
bsz, seq_len = image_position_ids.shape |
|
rp_bucket_size = self.image_rp_bucket.size(1) |
|
|
|
rp_bucket = self.image_rp_bucket.unsqueeze(0).expand( |
|
bsz, rp_bucket_size, rp_bucket_size |
|
).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size) |
|
).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len)) |
|
values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight) |
|
values = values.permute(0, 3, 1, 2) |
|
return values |
|
|
|
def get_patch_audios_info(self, patch_images, sample_patch_num, device): |
|
|
|
if self.nograd and self.freeze_audio_encoder: |
|
with torch.no_grad(): |
|
image_embed = self.embed_audios(patch_images) |
|
else: |
|
image_embed = self.embed_audios(patch_images) |
|
|
|
|
|
h, w = image_embed.shape[-2:] |
|
image_num_patches = h * w |
|
sh = int(math.ceil(math.sqrt(image_num_patches))) |
|
sw = sh |
|
image_embed = image_embed.flatten(2).transpose(1, 2) |
|
|
|
image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool() |
|
|
|
image_position_idx = torch.arange(sw).unsqueeze(0).expand(sh, sw) + \ |
|
torch.arange(sh).unsqueeze(1) * self.args.image_bucket_size + 1 |
|
image_position_idx = image_position_idx.reshape(-1).to(device)[:image_num_patches] |
|
image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches) |
|
|
|
if sample_patch_num is not None and sample_patch_num < image_num_patches: |
|
patch_orders = [ |
|
random.sample(range(image_num_patches), k=sample_patch_num) |
|
for _ in range(patch_images.size(0)) |
|
] |
|
patch_orders = torch.LongTensor(patch_orders).to(device) |
|
image_embed = image_embed.gather( |
|
1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2)) |
|
) |
|
image_num_patches = sample_patch_num |
|
image_padding_mask = image_padding_mask.gather(1, patch_orders) |
|
image_position_ids = image_position_ids.gather(1, patch_orders) |
|
|
|
image_pos_embed = self.embed_audio_positions(image_position_ids) |
|
|
|
|
|
image_embed = self.audio_proj(image_embed.type(self.audio_proj.weight.dtype)) |
|
|
|
if self.layernorm_audio_embedding is not None: |
|
image_embed = self.layernorm_audio_embedding(image_embed) |
|
|
|
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed |
|
|
|
|
|
def get_patch_videos_info(self, patch_images, sample_patch_num, device): |
|
if self.nograd and self.freeze_video_encoder: |
|
with torch.no_grad(): |
|
_, image_embed = self.embed_videos(patch_images) |
|
else: |
|
_, image_embed = self.embed_videos(patch_images) |
|
|
|
l, h, w = image_embed.shape[-3:] |
|
image_num_patches = h * w * l |
|
image_embed = image_embed.flatten(2).transpose(1, 2) |
|
numframes = l |
|
|
|
image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool() |
|
|
|
image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \ |
|
torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1 |
|
image_position_idx = image_position_idx.unsqueeze(0).expand(numframes, -1, -1) |
|
image_position_idx = image_position_idx.reshape(-1).to(device) |
|
image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches) |
|
|
|
if sample_patch_num is not None and sample_patch_num < image_num_patches: |
|
patch_orders = [ |
|
random.sample(range(image_num_patches), k=sample_patch_num) |
|
for _ in range(patch_images.size(0)) |
|
] |
|
patch_orders = torch.LongTensor(patch_orders).to(device) |
|
image_embed = image_embed.gather( |
|
1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2)) |
|
) |
|
image_num_patches = sample_patch_num |
|
image_padding_mask = image_padding_mask.gather(1, patch_orders) |
|
image_position_ids = image_position_ids.gather(1, patch_orders) |
|
|
|
image_pos_embed = self.embed_video_positions(image_position_ids) |
|
|
|
|
|
image_embed = self.video_proj(image_embed) |
|
|
|
if self.layernorm_video_embedding is not None: |
|
image_embed = self.layernorm_video_embedding(image_embed) |
|
|
|
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed |
|
|
|
def get_patch_images_info(self, patch_images, sample_patch_num, device): |
|
if self.nograd and self.freeze_image_encoder: |
|
with torch.no_grad(): |
|
image_embed = self.embed_images(patch_images) |
|
else: |
|
image_embed = self.embed_images(patch_images) |
|
|
|
if isinstance(image_embed, tuple): |
|
_, image_embed = image_embed |
|
|
|
h, w = image_embed.shape[-2:] |
|
image_num_patches = h * w |
|
image_embed = image_embed.flatten(2).transpose(1, 2) |
|
image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool() |
|
image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \ |
|
torch.arange(h).unsqueeze(1) * self.args.image_bucket_size + 1 |
|
image_position_idx = image_position_idx.view(-1).to(device) |
|
image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches) |
|
|
|
if sample_patch_num is not None and sample_patch_num < image_num_patches: |
|
patch_orders = [ |
|
random.sample(range(image_num_patches), k=sample_patch_num) |
|
for _ in range(patch_images.size(0)) |
|
] |
|
patch_orders = torch.LongTensor(patch_orders).to(device) |
|
image_embed = image_embed.gather( |
|
1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2)) |
|
) |
|
image_num_patches = sample_patch_num |
|
image_padding_mask = image_padding_mask.gather(1, patch_orders) |
|
image_position_ids = image_position_ids.gather(1, patch_orders) |
|
orig_num_patches = (self.orig_patch_image_size // 16) ** 2 |
|
orig_hw= self.orig_patch_image_size // 16 |
|
if getattr(self.args, "interpolate_position", False) and image_num_patches > orig_num_patches: |
|
old_image_position_ids = torch.arange(orig_hw).unsqueeze(0).expand(orig_hw, orig_hw) + \ |
|
torch.arange(orig_hw).unsqueeze(1) * self.args.image_bucket_size + 1 |
|
old_image_position_ids = old_image_position_ids.to(device) |
|
old_image_pos_embed = self.embed_image_positions(old_image_position_ids) |
|
old_image_pos_embed = old_image_pos_embed.reshape(1, orig_hw, orig_hw, -1).permute(0, 3, 1, 2) |
|
image_pos_embed = F.interpolate(old_image_pos_embed, size=(h, w), mode='bilinear') |
|
image_pos_embed = image_pos_embed.permute(0, 2, 3, 1).reshape(1, image_num_patches, -1) |
|
image_pos_embed = image_pos_embed.expand(patch_images.size(0), -1, -1) |
|
else: |
|
image_pos_embed = self.embed_image_positions(image_position_ids) |
|
|
|
|
|
|
|
image_embed = self.image_proj(image_embed) |
|
if self.layernorm_image_embedding is not None: |
|
image_embed = self.layernorm_image_embedding(image_embed) |
|
|
|
return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed |
|
|
|
def get_encoder_prompt(self, prompt_tokens): |
|
past_key_values = self.encoder_prompt_encoder(prompt_tokens) |
|
bsz, seqlen, _ = past_key_values.shape |
|
past_key_values = past_key_values.view( |
|
bsz, |
|
seqlen, |
|
(self.args.encoder_layers) * 2, |
|
self.args.encoder_attention_heads, |
|
self.args.encoder_embed_dim // self.args.encoder_attention_heads, |
|
) |
|
past_key_values = self.encoder_dropout(past_key_values) |
|
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) |
|
return past_key_values |
|
|
|
def forward_embedding( |
|
self, |
|
src_tokens, |
|
image_embed: Optional[torch.Tensor] = None, |
|
image_embed_2: Optional[torch.Tensor] = None, |
|
token_embedding: Optional[torch.Tensor] = None, |
|
pos_embed: Optional[torch.Tensor] = None, |
|
image_pos_embed: Optional[torch.Tensor] = None, |
|
image_pos_embed_2: Optional[torch.Tensor] = None, |
|
patch_types: Optional[torch.Tensor] = None, |
|
): |
|
|
|
if token_embedding is None: |
|
token_embedding = self.embed_tokens(src_tokens) |
|
x = embed = self.embed_scale * token_embedding |
|
if self.entangle_position_embedding and pos_embed is not None: |
|
x += pos_embed |
|
if self.type_embedding is not None: |
|
x += self.type_embedding(src_tokens.new_zeros(x.size()[:2])) |
|
if self.layernorm_embedding is not None: |
|
x = self.layernorm_embedding(x) |
|
x = self.dropout_module(x) |
|
if self.quant_noise is not None: |
|
x = self.quant_noise(x) |
|
|
|
|
|
if image_embed is not None: |
|
|
|
image_x = image_embed = self.embed_scale * image_embed |
|
if self.entangle_position_embedding and image_pos_embed is not None: |
|
image_x += image_pos_embed[:, -image_x.shape[1]:, :] |
|
if self.type_embedding is not None: |
|
if self.mm_type_embedding: |
|
mm_type = patch_types.unsqueeze(1).to(src_tokens.device) + 1 |
|
image_x += self.type_embedding(mm_type) |
|
else: |
|
image_x += self.type_embedding(src_tokens.new_ones(image_x.size()[:2])) |
|
if self.patch_layernorm_embedding is not None: |
|
image_x = self.patch_layernorm_embedding(image_x) |
|
image_x = self.dropout_module(image_x) |
|
if self.quant_noise is not None: |
|
image_x = self.quant_noise(image_x) |
|
x = torch.cat([image_x, x], dim=1) |
|
embed = torch.cat([image_embed, embed], dim=1) |
|
|
|
if image_embed_2 is not None: |
|
assert self.type_embedding is not None |
|
|
|
image_x_2 = image_embed_2 = self.embed_scale * image_embed_2 |
|
if self.entangle_position_embedding and image_pos_embed_2 is not None: |
|
image_x_2 += image_pos_embed_2[:, -image_x_2.shape[1]:, :] |
|
if self.type_embedding is not None: |
|
image_x_2 += self.type_embedding(src_tokens.new_full(image_x_2.size()[:2], fill_value=2)) |
|
if self.patch_layernorm_embedding is not None: |
|
image_x_2 = self.patch_layernorm_embedding(image_x_2) |
|
image_x_2 = self.dropout_module(image_x_2) |
|
if self.quant_noise is not None: |
|
image_x_2 = self.quant_noise(image_x_2) |
|
x = torch.cat([image_x_2, x], dim=1) |
|
embed = torch.cat([image_embed_2, embed], dim=1) |
|
|
|
return x, embed |
|
|
|
def forward( |
|
self, |
|
src_tokens, |
|
src_lengths, |
|
patch_images: Optional[torch.Tensor] = None, |
|
patch_images_2: Optional[torch.Tensor] = None, |
|
patch_masks: Optional[torch.Tensor] = None, |
|
code_masks: Optional[torch.Tensor] = None, |
|
return_all_hiddens: bool = False, |
|
token_embeddings: Optional[torch.Tensor] = None, |
|
sample_patch_num: Optional[int] = None, |
|
patch_videos: Optional[int] = None, |
|
patch_types: Optional[torch.Tensor] = None, |
|
patch_audios: Optional[torch.Tensor] = None, |
|
): |
|
""" |
|
Args: |
|
src_tokens (LongTensor): tokens in the source language of shape |
|
`(batch, src_len)` |
|
src_lengths (torch.LongTensor): lengths of each source sentence of |
|
shape `(batch)` |
|
return_all_hiddens (bool, optional): also return all of the |
|
intermediate hidden states (default: False). |
|
token_embeddings (torch.Tensor, optional): precomputed embeddings |
|
default `None` will recompute embeddings |
|
|
|
Returns: |
|
dict: |
|
- **encoder_out** (Tensor): the last encoder layer's output of |
|
shape `(src_len, batch, embed_dim)` |
|
- **encoder_padding_mask** (ByteTensor): the positions of |
|
padding elements of shape `(batch, src_len)` |
|
- **encoder_embedding** (Tensor): the (scaled) embedding lookup |
|
of shape `(batch, src_len, embed_dim)` |
|
- **encoder_states** (List[Tensor]): all intermediate |
|
hidden states of shape `(src_len, batch, embed_dim)`. |
|
Only populated if *return_all_hiddens* is True. |
|
""" |
|
return self.forward_scriptable(src_tokens, |
|
src_lengths, |
|
patch_images, |
|
patch_images_2, |
|
patch_masks, |
|
return_all_hiddens, |
|
token_embeddings, |
|
self.sample_patch_num, |
|
patch_videos=patch_videos, |
|
patch_types=patch_types, |
|
patch_audios=patch_audios, |
|
sample_audio_patch_num=self.sample_audio_patch_num, |
|
sample_video_patch_num=self.sample_video_patch_num) |
|
|
|
|
|
|
|
|
|
|
|
def forward_scriptable( |
|
self, |
|
src_tokens, |
|
src_lengths, |
|
patch_images: Optional[torch.Tensor] = None, |
|
patch_images_2: Optional[torch.Tensor] = None, |
|
patch_masks: Optional[torch.Tensor] = None, |
|
return_all_hiddens: bool = False, |
|
token_embeddings: Optional[torch.Tensor] = None, |
|
sample_patch_num: Optional[int] = None, |
|
patch_videos: Optional[int] = None, |
|
patch_types: Optional[torch.Tensor] = None, |
|
patch_audios: Optional[int] = None, |
|
sample_audio_patch_num: Optional[int] = None, |
|
sample_video_patch_num: Optional[int] = None, |
|
): |
|
""" |
|
Args: |
|
src_tokens (LongTensor): tokens in the source language of shape |
|
`(batch, src_len)` |
|
src_lengths (torch.LongTensor): lengths of each source sentence of |
|
shape `(batch)` |
|
return_all_hiddens (bool, optional): also return all of the |
|
intermediate hidden states (default: False). |
|
token_embeddings (torch.Tensor, optional): precomputed embeddings |
|
default `None` will recompute embeddings |
|
|
|
Returns: |
|
dict: |
|
- **encoder_out** (Tensor): the last encoder layer's output of |
|
shape `(src_len, batch, embed_dim)` |
|
- **encoder_padding_mask** (ByteTensor): the positions of |
|
padding elements of shape `(batch, src_len)` |
|
- **encoder_embedding** (Tensor): the (scaled) embedding lookup |
|
of shape `(batch, src_len, embed_dim)` |
|
- **encoder_states** (List[Tensor]): all intermediate |
|
hidden states of shape `(src_len, batch, embed_dim)`. |
|
Only populated if *return_all_hiddens* is True. |
|
""" |
|
prompt_tokens = None |
|
prompt_padding_mask = None |
|
prompt_kv_list = None |
|
if self.args.encoder_prompt: |
|
bsz, seq_len = src_tokens.shape[0], src_tokens.shape[1] |
|
if self.args.encoder_prompt_type in ("prefix"): |
|
prompt_tokens = torch.arange( |
|
0, self.args.encoder_prompt_length).to( |
|
src_tokens.device) |
|
prompt_tokens = prompt_tokens.unsqueeze(0).expand(bsz, -1) |
|
prompt_padding_mask = torch.zeros_like(prompt_tokens).to(prompt_tokens.device) |
|
prompt_kv_list = self.get_encoder_prompt(prompt_tokens) |
|
image_embed = None |
|
image_embed_2 = None |
|
image_pos_embed = None |
|
image_pos_embed_2 = None |
|
num_image_tokens = None |
|
|
|
if sample_audio_patch_num is None: |
|
sample_audio_patch_num = sample_patch_num |
|
|
|
if sample_video_patch_num is None: |
|
sample_video_patch_num = sample_patch_num |
|
|
|
if patch_images is not None: |
|
|
|
if patch_types is not None: |
|
video_idx = patch_types==1 |
|
image_idx = patch_types==0 |
|
audio_idx = patch_types==2 |
|
|
|
image_idx_ = (patch_types==0).nonzero()[:, 0] |
|
video_idx_ = (patch_types==1).nonzero()[:, 0] |
|
audio_idx_ = (patch_types==2).nonzero()[:, 0] |
|
|
|
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = None,None,None,None,None |
|
|
|
if torch.any(image_idx).item(): |
|
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ |
|
self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device) |
|
|
|
ids_merge = image_idx_ |
|
|
|
if torch.any(video_idx).item(): |
|
video_embed, image_num_patches, video_padding_mask, video_position_ids, video_pos_embed = \ |
|
self.get_patch_videos_info(patch_videos[video_idx], sample_video_patch_num, src_tokens.device) |
|
|
|
if image_embed is not None: |
|
ids_merge = torch.cat((ids_merge, video_idx_), dim=0).long() |
|
|
|
image_embed = torch.cat((image_embed, video_embed), dim=0) |
|
bs, L, D = image_embed.shape |
|
image_padding_mask = torch.cat((image_padding_mask, video_padding_mask), dim=0) |
|
image_position_ids = torch.cat((image_position_ids, video_position_ids), dim=0) |
|
image_pos_embed = torch.cat((image_pos_embed, video_pos_embed), dim=0) |
|
|
|
image_embed = torch.gather(image_embed, dim=0, index=ids_merge[:, None, None].repeat(1, L, D)) |
|
image_padding_mask = torch.gather(image_padding_mask, dim=0, index=ids_merge[:, None].repeat(1, image_padding_mask.shape[1])) |
|
image_position_ids = torch.gather(image_position_ids, dim=0, index=ids_merge[:, None].repeat(1, image_position_ids.shape[1])) |
|
image_pos_embed = torch.gather(image_pos_embed, dim=0, index=ids_merge[:, None, None].repeat(1, image_pos_embed.shape[1], D)) |
|
|
|
else: |
|
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ |
|
video_embed, image_num_patches, video_padding_mask, video_position_ids, video_pos_embed |
|
ids_merge = video_idx_ |
|
|
|
if torch.any(audio_idx).item() : |
|
audio_embed, image_num_patches, audio_padding_mask, audio_position_ids, audio_pos_embed = \ |
|
self.get_patch_audios_info(patch_audios[audio_idx], sample_audio_patch_num, src_tokens.device) |
|
|
|
if image_embed is not None: |
|
|
|
ids_merge = torch.cat((ids_merge, audio_idx_), dim=0).long() |
|
|
|
|
|
image_embed = torch.cat((image_embed, audio_embed), dim=0) |
|
bs, L, D = image_embed.shape |
|
image_padding_mask = torch.cat((image_padding_mask, audio_padding_mask), dim=0) |
|
image_position_ids = torch.cat((image_position_ids, audio_position_ids), dim=0) |
|
image_pos_embed = torch.cat((image_pos_embed, audio_pos_embed), dim=0) |
|
|
|
image_embed = torch.gather(image_embed, dim=0, index=ids_merge[:, None, None].repeat(1, L, D)) |
|
image_padding_mask = torch.gather(image_padding_mask, dim=0, index=ids_merge[:, None].repeat(1, image_padding_mask.shape[1])) |
|
image_position_ids = torch.gather(image_position_ids, dim=0, index=ids_merge[:, None].repeat(1, image_position_ids.shape[1])) |
|
image_pos_embed = torch.gather(image_pos_embed, dim=0, index=ids_merge[:, None, None].repeat(1, image_pos_embed.shape[1], D)) |
|
|
|
else: |
|
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ |
|
audio_embed, image_num_patches, audio_padding_mask, audio_position_ids, audio_pos_embed |
|
|
|
else: |
|
image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \ |
|
self.get_patch_images_info(patch_images, sample_patch_num, src_tokens.device) |
|
|
|
image_padding_mask[~patch_masks] = True |
|
|
|
num_image_tokens=image_num_patches |
|
if patch_images_2 is not None: |
|
image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \ |
|
self.get_patch_images_info(patch_images_2, sample_patch_num, src_tokens.device) |
|
image_padding_mask_2[~patch_masks] = True |
|
num_image_tokens+=image_num_patches_2 |
|
|
|
encoder_padding_mask = src_tokens.eq(self.padding_idx) |
|
if patch_images is not None: |
|
encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1) |
|
if patch_images_2 is not None: |
|
encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1) |
|
has_pads = (src_tokens.device.type == "xla" or encoder_padding_mask.any()) |
|
|
|
pos_embed = self.embed_positions(utils.new_arange(src_tokens)) |
|
x, encoder_embedding = self.forward_embedding( |
|
src_tokens, image_embed, image_embed_2, token_embeddings, |
|
pos_embed, image_pos_embed, image_pos_embed_2, patch_types=patch_types |
|
) |
|
|
|
|
|
if has_pads: |
|
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
pos_embed = self.pos_ln(pos_embed) |
|
if patch_images is not None: |
|
image_pos_embed = self.image_pos_ln(image_pos_embed) |
|
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) |
|
if patch_images_2 is not None: |
|
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) |
|
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) |
|
|
|
pos_q = self.pos_q_linear(pos_embed).view( |
|
pos_embed.size(0), pos_embed.size(1), self.num_attention_heads, -1 |
|
).transpose(1, 2) * self.pos_scaling |
|
pos_k = self.pos_k_linear(pos_embed).view( |
|
pos_embed.size(0), pos_embed.size(1), self.num_attention_heads, -1 |
|
).transpose(1, 2) |
|
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) |
|
|
|
encoder_states = [] |
|
|
|
if return_all_hiddens: |
|
encoder_states.append(x) |
|
|
|
if prompt_padding_mask is not None: |
|
encoder_padding_mask = torch.cat([prompt_padding_mask, encoder_padding_mask], dim=1) |
|
|
|
if self.with_cls: |
|
offset_rel_pos = 1 |
|
abs_pos_bias = F.pad(abs_pos_bias, (1, 0, 1, 0), "constant", 0) |
|
else: |
|
offset_rel_pos = 0 |
|
|
|
for idx, layer in enumerate(self.layers): |
|
self_attn_bias = abs_pos_bias.clone() |
|
self_attn_bias[:, :, -src_tokens.size(1):, -src_tokens.size(1):] += self.get_rel_pos_bias(src_tokens, idx) |
|
if patch_images_2 is not None: |
|
self_attn_bias[:, :, offset_rel_pos:image_num_patches_2, offset_rel_pos:image_num_patches_2] += \ |
|
self.get_image_rel_pos_bias(image_position_ids_2, idx) |
|
self_attn_bias[:, :, offset_rel_pos+image_num_patches_2:image_num_patches_2+image_num_patches, offset_rel_pos+image_num_patches_2:image_num_patches_2+image_num_patches] += \ |
|
self.get_image_rel_pos_bias(image_position_ids, idx) |
|
elif patch_images is not None: |
|
self_attn_bias[:, :, offset_rel_pos:x.size(0) - src_tokens.size(1), offset_rel_pos:x.size(0) - src_tokens.size(1)] += \ |
|
self.get_image_rel_pos_bias(image_position_ids, idx) |
|
|
|
self_attn_bias = self_attn_bias.reshape(-1, self_attn_bias.size(2), self_attn_bias.size(2)) |
|
|
|
if self.args.encoder_prompt: |
|
if self.args.encoder_prompt_type != "prompt": |
|
prompt_kv = prompt_kv_list[idx] |
|
else: |
|
if idx == 0: |
|
prompt_kv = prompt_kv_list[idx] |
|
else: |
|
prompt_kv = None |
|
else: |
|
prompt_kv = None |
|
x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None, \ |
|
self_attn_bias=self_attn_bias, prompt_kv=prompt_kv, num_image_tokens=num_image_tokens) |
|
if return_all_hiddens: |
|
assert encoder_states is not None |
|
encoder_states.append(x) |
|
|
|
if self.layer_norm is not None: |
|
x = self.layer_norm(x) |
|
if self.args.encoder_prompt: |
|
encoder_padding_mask = encoder_padding_mask[:, prompt_tokens.size(1):] |
|
|
|
|
|
|
|
|
|
return { |
|
"encoder_out": [x], |
|
"encoder_padding_mask": [encoder_padding_mask], |
|
"encoder_embedding": [], |
|
"encoder_states": encoder_states, |
|
"src_tokens": [], |
|
"src_lengths": [], |
|
"position_embeddings": [pos_embed], |
|
} |
|
|
|
@torch.jit.export |
|
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): |
|
""" |
|
Reorder encoder output according to *new_order*. |
|
|
|
Args: |
|
encoder_out: output from the ``forward()`` method |
|
new_order (LongTensor): desired order |
|
|
|
Returns: |
|
*encoder_out* rearranged according to *new_order* |
|
""" |
|
if len(encoder_out["encoder_out"]) == 0: |
|
new_encoder_out = [] |
|
else: |
|
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] |
|
if len(encoder_out["encoder_padding_mask"]) == 0: |
|
new_encoder_padding_mask = [] |
|
else: |
|
new_encoder_padding_mask = [ |
|
encoder_out["encoder_padding_mask"][0].index_select(0, new_order) |
|
] |
|
if len(encoder_out["encoder_embedding"]) == 0: |
|
new_encoder_embedding = [] |
|
else: |
|
new_encoder_embedding = [ |
|
encoder_out["encoder_embedding"][0].index_select(0, new_order) |
|
] |
|
|
|
if len(encoder_out["src_tokens"]) == 0: |
|
new_src_tokens = [] |
|
else: |
|
new_src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] |
|
|
|
if len(encoder_out["src_lengths"]) == 0: |
|
new_src_lengths = [] |
|
else: |
|
new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] |
|
|
|
if len(encoder_out["position_embeddings"]) == 0: |
|
new_position_embeddings = [] |
|
else: |
|
new_position_embeddings = [(encoder_out["position_embeddings"][0]).index_select(0, new_order)] |
|
|
|
encoder_states = encoder_out["encoder_states"] |
|
if len(encoder_states) > 0: |
|
for idx, state in enumerate(encoder_states): |
|
encoder_states[idx] = state.index_select(1, new_order) |
|
|
|
return { |
|
"encoder_out": new_encoder_out, |
|
"encoder_padding_mask": new_encoder_padding_mask, |
|
"encoder_embedding": new_encoder_embedding, |
|
"encoder_states": encoder_states, |
|
"src_tokens": new_src_tokens, |
|
"src_lengths": new_src_lengths, |
|
"position_embeddings": new_position_embeddings, |
|
} |
|
|
|
def max_positions(self): |
|
"""Maximum input length supported by the encoder.""" |
|
if self.embed_positions is None: |
|
return self.max_source_positions |
|
return self.max_source_positions |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" |
|
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): |
|
weights_key = "{}.embed_positions.weights".format(name) |
|
if weights_key in state_dict: |
|
print("deleting {0}".format(weights_key)) |
|
del state_dict[weights_key] |
|
state_dict[ |
|
"{}.embed_positions._float_tensor".format(name) |
|
] = torch.FloatTensor(1) |
|
for i in range(self.num_layers): |
|
|
|
self.layers[i].upgrade_state_dict_named( |
|
state_dict, "{}.layers.{}".format(name, i) |
|
) |
|
|
|
prefix = name + "." if name != "" else "" |
|
for param_name, param_tensor in self.state_dict().items(): |
|
if (prefix + param_name) not in state_dict: |
|
state_dict[prefix + param_name] = self.state_dict()[param_name] |
|
|
|
if len(state_dict["encoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]): |
|
num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["encoder.embed_image_positions.weight"]) |
|
embed_dim = state_dict["encoder.embed_image_positions.weight"].size(1) |
|
new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim) |
|
nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5) |
|
new_pos_embed_to_add = new_pos_embed_to_add.to( |
|
dtype=state_dict["encoder.embed_image_positions.weight"].dtype, |
|
) |
|
state_dict["encoder.embed_image_positions.weight"] = torch.cat( |
|
[state_dict["encoder.embed_image_positions.weight"], new_pos_embed_to_add] |
|
) |
|
return state_dict |
|
|
|
|
|
class TransformerDecoder(FairseqIncrementalDecoder): |
|
""" |
|
Transformer decoder consisting of *args.decoder_layers* layers. Each layer |
|
is a :class:`TransformerDecoderLayer`. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
dictionary (~fairseq.data.Dictionary): decoding dictionary |
|
embed_tokens (torch.nn.Embedding): output embedding |
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs |
|
(default: False). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
args, |
|
dictionary, |
|
embed_tokens, |
|
no_encoder_attn=False, |
|
output_projection=None, |
|
): |
|
self.args = args |
|
super().__init__(dictionary) |
|
self.register_buffer("version", torch.Tensor([3])) |
|
self._future_mask = torch.empty(0) |
|
self.with_cls = getattr(args, "with_cls", False) |
|
|
|
if getattr(args, "decoder_prompt", False): |
|
self.decoder_prompt_encoder = PromptEncoder( |
|
type=args.decoder_prompt_type, |
|
length=args.decoder_prompt_length, |
|
projection=args.decoder_prompt_projection, |
|
embed_dim=args.decoder_embed_dim, |
|
proj_dim=args.decoder_prompt_dim, |
|
layers=args.decoder_layers, |
|
vocab_size=args.vocab_size) |
|
self.decoder_dropout = nn.Dropout(p=0.2) |
|
|
|
self.dropout_module = FairseqDropout( |
|
args.dropout, module_name=self.__class__.__name__ |
|
) |
|
self.decoder_layerdrop = args.decoder_layerdrop |
|
self.share_input_output_embed = args.share_decoder_input_output_embed |
|
self.num_attention_heads = args.decoder_attention_heads |
|
|
|
input_embed_dim = embed_tokens.embedding_dim |
|
embed_dim = args.decoder_embed_dim |
|
self.embed_dim = embed_dim |
|
self.output_embed_dim = args.decoder_output_dim |
|
|
|
self.padding_idx = embed_tokens.padding_idx |
|
self.max_target_positions = args.max_target_positions |
|
|
|
self.embed_tokens = embed_tokens |
|
|
|
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) |
|
|
|
if not args.adaptive_input and args.quant_noise_pq > 0: |
|
self.quant_noise = apply_quant_noise_( |
|
nn.Linear(embed_dim, embed_dim, bias=False), |
|
args.quant_noise_pq, |
|
args.quant_noise_pq_block_size, |
|
) |
|
else: |
|
self.quant_noise = None |
|
|
|
self.project_in_dim = ( |
|
Linear(input_embed_dim, embed_dim, bias=False) |
|
if embed_dim != input_embed_dim |
|
else None |
|
) |
|
|
|
if getattr(args, "layernorm_embedding", False): |
|
self.layernorm_embedding = LayerNorm(embed_dim) |
|
else: |
|
self.layernorm_embedding = None |
|
|
|
self.window_size = args.code_image_size // 8 |
|
|
|
self.embed_positions = Embedding(args.max_target_positions + 2, embed_dim) |
|
self.embed_image_positions = Embedding(args.image_bucket_size ** 2 + 1, embed_dim) |
|
self.pos_ln = LayerNorm(embed_dim) |
|
self.image_pos_ln = LayerNorm(embed_dim) |
|
self.pos_scaling = float(embed_dim / self.num_attention_heads * args.attn_scale_factor) ** -0.5 |
|
self.self_pos_q_linear = nn.Linear(embed_dim, embed_dim) |
|
self.self_pos_k_linear = nn.Linear(embed_dim, embed_dim) |
|
self.cross_pos_q_linear = nn.Linear(embed_dim, embed_dim) |
|
self.cross_pos_k_linear = nn.Linear(embed_dim, embed_dim) |
|
|
|
if getattr(args, "code_layernorm_embedding", False): |
|
self.code_layernorm_embedding = LayerNorm(embed_dim) |
|
else: |
|
self.code_layernorm_embedding = None |
|
|
|
self.cross_self_attention = getattr(args, "cross_self_attention", False) |
|
|
|
if self.decoder_layerdrop > 0.0: |
|
self.layers = LayerDropModuleList(p=self.decoder_layerdrop) |
|
else: |
|
self.layers = nn.ModuleList([]) |
|
|
|
dpr = [x.item() for x in torch.linspace(0, args.decoder_drop_path_rate, args.decoder_layers)] |
|
self.layers.extend( |
|
[ |
|
self.build_decoder_layer(args, no_encoder_attn, drop_path_rate=dpr[i]) |
|
for i in range(args.decoder_layers) |
|
] |
|
) |
|
self.num_layers = len(self.layers) |
|
|
|
if args.decoder_normalize_before: |
|
self.layer_norm = LayerNorm(embed_dim) |
|
else: |
|
self.layer_norm = None |
|
|
|
self.project_out_dim = ( |
|
Linear(embed_dim, self.output_embed_dim, bias=False) |
|
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights |
|
else None |
|
) |
|
|
|
self.adaptive_softmax = None |
|
self.output_projection = output_projection |
|
if self.output_projection is None: |
|
self.build_output_projection(args, dictionary, embed_tokens) |
|
|
|
token_bucket_size = args.token_bucket_size |
|
token_num_rel_dis = 2 * token_bucket_size - 1 |
|
token_rp_bucket = make_token_bucket_position(token_bucket_size) |
|
self.token_rel_pos_table_list = nn.ModuleList( |
|
[Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)] |
|
) |
|
|
|
image_bucket_size = args.image_bucket_size |
|
image_num_rel_dis = (2 * image_bucket_size - 1) * (2 * image_bucket_size - 1) + 3 |
|
image_rp_bucket = make_image_bucket_position(image_bucket_size, image_num_rel_dis) |
|
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ |
|
torch.arange(self.window_size).unsqueeze(1) * image_bucket_size + 1 |
|
image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)]) |
|
image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 769)]) |
|
self.image_rel_pos_table_list = nn.ModuleList( |
|
[Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in range(args.decoder_layers)] |
|
) |
|
|
|
self.register_buffer("token_rp_bucket", token_rp_bucket) |
|
self.register_buffer("image_rp_bucket", image_rp_bucket) |
|
self.register_buffer("image_position_idx", image_position_idx) |
|
self.entangle_position_embedding = args.entangle_position_embedding |
|
|
|
def get_decoder_prompt(self, prompt_tokens): |
|
past_key_values = self.decoder_prompt_encoder(prompt_tokens) |
|
bsz, seqlen, _ = past_key_values.shape |
|
past_key_values = past_key_values.view( |
|
bsz, |
|
seqlen, |
|
self.args.decoder_layers * 2, |
|
self.args.decoder_attention_heads, |
|
self.args.decoder_embed_dim // self.args.decoder_attention_heads, |
|
) |
|
past_key_values = self.decoder_dropout(past_key_values) |
|
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) |
|
return past_key_values |
|
|
|
def build_output_projection(self, args, dictionary, embed_tokens): |
|
if args.adaptive_softmax_cutoff is not None: |
|
self.adaptive_softmax = AdaptiveSoftmax( |
|
len(dictionary), |
|
self.output_embed_dim, |
|
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), |
|
dropout=args.adaptive_softmax_dropout, |
|
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, |
|
factor=args.adaptive_softmax_factor, |
|
tie_proj=args.tie_adaptive_proj, |
|
) |
|
elif self.share_input_output_embed: |
|
self.output_projection = nn.Linear( |
|
self.embed_tokens.weight.shape[1], |
|
self.embed_tokens.weight.shape[0], |
|
bias=False, |
|
) |
|
self.output_projection.weight = self.embed_tokens.weight |
|
else: |
|
self.output_projection = nn.Linear( |
|
self.output_embed_dim, len(dictionary), bias=False |
|
) |
|
nn.init.normal_( |
|
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 |
|
) |
|
num_base_layers = getattr(args, "base_layers", 0) |
|
for i in range(num_base_layers): |
|
self.layers.insert(((i+1) * args.decoder_layers) // (num_base_layers + 1), BaseLayer(args)) |
|
|
|
def build_decoder_layer(self, args, no_encoder_attn=False, drop_path_rate=0.0): |
|
layer = TransformerDecoderLayer(args, no_encoder_attn, drop_path_rate= \ |
|
drop_path_rate, use_adapter=getattr(args, "adapter", False), adapter_dim=getattr(args, "adapter_dim", 200)) |
|
checkpoint = getattr(args, "checkpoint_activations", False) |
|
if checkpoint: |
|
offload_to_cpu = getattr(args, "offload_activations", False) |
|
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) |
|
|
|
|
|
min_params_to_wrap = ( |
|
getattr(args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP) |
|
if not checkpoint else 0 |
|
) |
|
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) |
|
return layer |
|
|
|
def get_rel_pos_bias(self, x, idx): |
|
seq_len = x.size(1) |
|
rp_bucket = self.token_rp_bucket[:seq_len, :seq_len] |
|
values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight) |
|
values = values.permute([2, 0, 1]) |
|
return values.contiguous() |
|
|
|
def get_image_rel_pos_bias(self, x, idx): |
|
seq_len = x.size(1) |
|
image_position_idx = self.image_position_idx[:seq_len] |
|
rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx] |
|
values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight) |
|
values = values.permute(2, 0, 1) |
|
return values |
|
|
|
def get_pos_info(self, tokens, tgt_pos_embed, src_pos_embed=None, use_image=False): |
|
batch_size = tokens.size(0) |
|
tgt_len = tokens.size(1) |
|
tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) |
|
if src_pos_embed is not None: |
|
src_len = src_pos_embed.size(1) |
|
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( |
|
batch_size, tgt_len, self.num_attention_heads, -1 |
|
).transpose(1, 2) * self.pos_scaling |
|
pos_k = self.cross_pos_k_linear(src_pos_embed).view( |
|
batch_size, src_len, self.num_attention_heads, -1 |
|
).transpose(1, 2) |
|
else: |
|
src_len = tgt_pos_embed.size(1) |
|
pos_q = self.self_pos_q_linear(tgt_pos_embed).view( |
|
batch_size, tgt_len, self.num_attention_heads, -1 |
|
).transpose(1, 2) * self.pos_scaling |
|
pos_k = self.self_pos_k_linear(tgt_pos_embed).view( |
|
batch_size, src_len, self.num_attention_heads, -1 |
|
).transpose(1, 2) |
|
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) |
|
return abs_pos_bias |
|
|
|
def forward( |
|
self, |
|
prev_output_tokens, |
|
code_masks: Optional[torch.Tensor] = None, |
|
encoder_out: Optional[Dict[str, List[Tensor]]] = None, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
features_only: bool = False, |
|
full_context_alignment: bool = False, |
|
alignment_layer: Optional[int] = None, |
|
alignment_heads: Optional[int] = None, |
|
src_lengths: Optional[Any] = None, |
|
return_all_hiddens: bool = False, |
|
): |
|
""" |
|
Args: |
|
prev_output_tokens (LongTensor): previous decoder outputs of shape |
|
`(batch, tgt_len)`, for teacher forcing |
|
encoder_out (optional): output from the encoder, used for |
|
encoder-side attention, should be of size T x B x C |
|
incremental_state (dict): dictionary used for storing state during |
|
:ref:`Incremental decoding` |
|
features_only (bool, optional): only return features without |
|
applying output layer (default: False). |
|
full_context_alignment (bool, optional): don't apply |
|
auto-regressive mask to self-attention (default: False). |
|
|
|
Returns: |
|
tuple: |
|
- the decoder's output of shape `(batch, tgt_len, vocab)` |
|
- a dictionary with any model-specific outputs |
|
""" |
|
|
|
x, extra = self.extract_features( |
|
prev_output_tokens, |
|
code_masks=code_masks, |
|
encoder_out=encoder_out, |
|
incremental_state=incremental_state, |
|
full_context_alignment=full_context_alignment, |
|
alignment_layer=alignment_layer, |
|
alignment_heads=alignment_heads, |
|
) |
|
|
|
if not features_only: |
|
x = self.output_layer(x) |
|
return x, extra |
|
|
|
def extract_features( |
|
self, |
|
prev_output_tokens, |
|
code_masks: Optional[torch.Tensor], |
|
encoder_out: Optional[Dict[str, List[Tensor]]], |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
full_context_alignment: bool = False, |
|
alignment_layer: Optional[int] = None, |
|
alignment_heads: Optional[int] = None, |
|
): |
|
return self.extract_features_scriptable( |
|
prev_output_tokens, |
|
code_masks, |
|
encoder_out, |
|
incremental_state, |
|
full_context_alignment, |
|
alignment_layer, |
|
alignment_heads, |
|
) |
|
|
|
""" |
|
A scriptable subclass of this class has an extract_features method and calls |
|
super().extract_features, but super() is not supported in torchscript. A copy of |
|
this function is made to be used in the subclass instead. |
|
""" |
|
|
|
def extract_features_scriptable( |
|
self, |
|
prev_output_tokens, |
|
code_masks: Optional[torch.Tensor], |
|
encoder_out: Optional[Dict[str, List[Tensor]]], |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
full_context_alignment: bool = False, |
|
alignment_layer: Optional[int] = None, |
|
alignment_heads: Optional[int] = None, |
|
): |
|
""" |
|
Similar to *forward* but only return features. |
|
|
|
Includes several features from "Jointly Learning to Align and |
|
Translate with Transformer Models" (Garg et al., EMNLP 2019). |
|
|
|
Args: |
|
full_context_alignment (bool, optional): don't apply |
|
auto-regressive mask to self-attention (default: False). |
|
alignment_layer (int, optional): return mean alignment over |
|
heads at this layer (default: last layer). |
|
alignment_heads (int, optional): only average alignment over |
|
this many heads (default: all heads). |
|
|
|
Returns: |
|
tuple: |
|
- the decoder's features of shape `(batch, tgt_len, embed_dim)` |
|
- a dictionary with any model-specific outputs |
|
""" |
|
prompt_tokens = None |
|
prompt_padding_mask = None |
|
prompt_kv_list = None |
|
if self.args.decoder_prompt: |
|
bsz, seq_len = prev_output_tokens.shape[0], prev_output_tokens.shape[1] |
|
if self.args.decoder_prompt_type in ("prefix"): |
|
prompt_tokens = torch.arange( |
|
0, self.args.decoder_prompt_length).to( |
|
prev_output_tokens.device) |
|
prompt_tokens = prompt_tokens.unsqueeze(0).expand(bsz, -1) |
|
prompt_padding_mask = torch.zeros_like(prompt_tokens).to(prompt_tokens.device) |
|
prompt_kv_list = self.get_decoder_prompt(prompt_tokens) |
|
bs, slen = prev_output_tokens.size() |
|
if alignment_layer is None: |
|
alignment_layer = self.num_layers - 1 |
|
|
|
enc: Optional[Tensor] = None |
|
padding_mask: Optional[Tensor] = None |
|
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: |
|
enc = encoder_out["encoder_out"][0] |
|
assert ( |
|
enc.size()[1] == bs |
|
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}" |
|
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: |
|
padding_mask = encoder_out["encoder_padding_mask"][0] |
|
|
|
bsz, tgt_len = prev_output_tokens.shape |
|
token_position_idx = utils.new_arange(prev_output_tokens) |
|
tgt_pos_embed = self.embed_positions(token_position_idx) |
|
if code_masks is not None and torch.any(code_masks): |
|
image_position_idx = self.image_position_idx[:prev_output_tokens.size(1)].unsqueeze(0).expand(bsz, tgt_len) |
|
tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks] |
|
|
|
|
|
self_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=False) |
|
if code_masks is not None and torch.any(code_masks): |
|
self_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, use_image=True) |
|
self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks] |
|
|
|
src_pos_embed = encoder_out['position_embeddings'][0] |
|
cross_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed) |
|
|
|
|
|
|
|
if code_masks is not None and torch.any(code_masks): |
|
cross_image_abs_pos_bias = self.get_pos_info(prev_output_tokens, tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True) |
|
cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks] |
|
|
|
if self.with_cls: |
|
cross_abs_pos_bias = F.pad(cross_abs_pos_bias, (enc.shape[0] - cross_abs_pos_bias.shape[-1], 0, prev_output_tokens.shape[1] - cross_abs_pos_bias.shape[2], 0), "constant", 0) |
|
|
|
|
|
cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:]) |
|
|
|
all_prev_output_tokens = prev_output_tokens.clone() |
|
|
|
if incremental_state is not None: |
|
prev_output_tokens = prev_output_tokens[:, -1:] |
|
cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :] |
|
tgt_pos_embed = tgt_pos_embed[:, -1:, :] |
|
|
|
|
|
x = self.embed_scale * self.embed_tokens(prev_output_tokens) |
|
|
|
if self.quant_noise is not None: |
|
x = self.quant_noise(x) |
|
|
|
if self.project_in_dim is not None: |
|
x = self.project_in_dim(x) |
|
|
|
if self.entangle_position_embedding is not None and not self.args.disable_entangle: |
|
x += tgt_pos_embed |
|
|
|
if self.layernorm_embedding is not None: |
|
if code_masks is None or not code_masks.any() or not getattr(self, "code_layernorm_embedding", False): |
|
x = self.layernorm_embedding(x) |
|
elif code_masks is not None and code_masks.all(): |
|
x = self.code_layernorm_embedding(x) |
|
else: |
|
x[~code_masks] = self.layernorm_embedding(x[~code_masks]) |
|
x[code_masks] = self.code_layernorm_embedding(x[code_masks]) |
|
|
|
x = self.dropout_module(x) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
self_attn_padding_mask: Optional[Tensor] = None |
|
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): |
|
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) |
|
if prompt_padding_mask is not None: |
|
self_attn_padding_mask = torch.cat([prompt_padding_mask, self_attn_padding_mask], dim=1) |
|
|
|
|
|
attn: Optional[Tensor] = None |
|
inner_states: List[Optional[Tensor]] = [x] |
|
for idx, layer in enumerate(self.layers): |
|
if incremental_state is None and not full_context_alignment: |
|
self_attn_mask = self.buffered_future_mask(x) |
|
if self.args.decoder_prompt: |
|
seq_len, prompt_len = x.size(0), prompt_tokens.size(1) |
|
prompt_mask = torch.zeros([seq_len, prompt_len]).to(x.device) |
|
self_attn_mask = torch.cat([prompt_mask, self_attn_mask], dim=1) |
|
else: |
|
self_attn_mask = None |
|
|
|
self_attn_bias = self_abs_pos_bias.clone() |
|
|
|
if code_masks is None or not code_masks.any(): |
|
|
|
self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) |
|
|
|
|
|
|
|
elif code_masks is not None and code_masks.all(): |
|
self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) |
|
else: |
|
self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) |
|
self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0) |
|
self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:]) |
|
if incremental_state is not None: |
|
self_attn_bias = self_attn_bias[:, -1:, :] |
|
|
|
if self.args.decoder_prompt: |
|
if self.args.decoder_prompt_type != "prompt": |
|
prompt_kv = prompt_kv_list[idx] |
|
else: |
|
if idx == 0: |
|
prompt_kv = prompt_kv_list[idx] |
|
else: |
|
prompt_kv = None |
|
else: |
|
prompt_kv = None |
|
|
|
x, layer_attn, _ = layer( |
|
x, |
|
enc, |
|
padding_mask, |
|
incremental_state, |
|
self_attn_mask=self_attn_mask, |
|
self_attn_padding_mask=self_attn_padding_mask, |
|
need_attn=bool((idx == alignment_layer)), |
|
need_head_weights=bool((idx == alignment_layer)), |
|
self_attn_bias=self_attn_bias, |
|
cross_attn_bias=cross_abs_pos_bias, |
|
prompt_kv=prompt_kv |
|
) |
|
inner_states.append(x) |
|
if layer_attn is not None and idx == alignment_layer: |
|
attn = layer_attn.float().to(x) |
|
|
|
if attn is not None: |
|
if alignment_heads is not None: |
|
attn = attn[:alignment_heads] |
|
|
|
|
|
attn = attn.mean(dim=0) |
|
|
|
if self.layer_norm is not None: |
|
x = self.layer_norm(x) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
if self.project_out_dim is not None: |
|
x = self.project_out_dim(x) |
|
|
|
return x, {"attn": [attn], "inner_states": inner_states} |
|
|
|
def output_layer(self, features): |
|
"""Project features to the vocabulary size.""" |
|
if self.adaptive_softmax is None: |
|
|
|
return self.output_projection(features) |
|
else: |
|
return features |
|
|
|
def max_positions(self): |
|
"""Maximum output length supported by the decoder.""" |
|
if self.embed_positions is None: |
|
return self.max_target_positions |
|
return self.max_target_positions |
|
|
|
def buffered_future_mask(self, tensor): |
|
dim = tensor.size(0) |
|
if ( |
|
self._future_mask.size(0) == 0 |
|
or (not self._future_mask.device == tensor.device) |
|
or self._future_mask.size(0) < dim |
|
): |
|
self._future_mask = torch.triu( |
|
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 |
|
) |
|
self._future_mask = self._future_mask.to(tensor) |
|
return self._future_mask[:dim, :dim] |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" |
|
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): |
|
weights_key = "{}.embed_positions.weights".format(name) |
|
if weights_key in state_dict: |
|
del state_dict[weights_key] |
|
state_dict[ |
|
"{}.embed_positions._float_tensor".format(name) |
|
] = torch.FloatTensor(1) |
|
|
|
if f"{name}.output_projection.weight" not in state_dict: |
|
if self.share_input_output_embed: |
|
embed_out_key = f"{name}.embed_tokens.weight" |
|
else: |
|
embed_out_key = f"{name}.embed_out" |
|
if embed_out_key in state_dict: |
|
state_dict[f"{name}.output_projection.weight"] = state_dict[ |
|
embed_out_key |
|
] |
|
if not self.share_input_output_embed: |
|
del state_dict[embed_out_key] |
|
|
|
for i in range(self.num_layers): |
|
|
|
self.layers[i].upgrade_state_dict_named( |
|
state_dict, "{}.layers.{}".format(name, i) |
|
) |
|
|
|
|
|
|
|
prefix = name + "." if name != "" else "" |
|
image_params = ["image_position_idx"] |
|
for image_param in image_params: |
|
state_dict[prefix + image_param] = self.state_dict()[image_param] |
|
for param_name, param_tensor in self.state_dict().items(): |
|
if (prefix + param_name) not in state_dict: |
|
state_dict[prefix + param_name] = self.state_dict()[param_name] |
|
|
|
if len(state_dict["decoder.embed_image_positions.weight"]) < len(self.state_dict()["embed_image_positions.weight"]): |
|
num_posids_to_add = len(self.state_dict()["embed_image_positions.weight"]) - len(state_dict["decoder.embed_image_positions.weight"]) |
|
embed_dim = state_dict["decoder.embed_image_positions.weight"].size(1) |
|
new_pos_embed_to_add = torch.zeros(num_posids_to_add, embed_dim) |
|
nn.init.normal_(new_pos_embed_to_add, mean=0, std=embed_dim ** -0.5) |
|
new_pos_embed_to_add = new_pos_embed_to_add.to( |
|
dtype=state_dict["decoder.embed_image_positions.weight"].dtype, |
|
) |
|
state_dict["decoder.embed_image_positions.weight"] = torch.cat( |
|
[state_dict["decoder.embed_image_positions.weight"], new_pos_embed_to_add] |
|
) |
|
return state_dict |
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False): |
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) |
|
if padding_idx is not None: |
|
nn.init.constant_(m.weight[padding_idx], 0) |
|
if zero_init: |
|
nn.init.constant_(m.weight, 0) |
|
return m |
|
|
|
|
|
def Linear(in_features, out_features, bias=True): |
|
m = nn.Linear(in_features, out_features, bias) |
|
nn.init.xavier_uniform_(m.weight) |
|
if bias: |
|
nn.init.constant_(m.bias, 0.0) |
|
return m |
|
|
|
|
|
@register_model_architecture("unify_transformer", "unify_transformer") |
|
def base_architecture(args): |
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) |
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) |
|
args.encoder_layers = getattr(args, "encoder_layers", 6) |
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) |
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) |
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) |
|
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) |
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) |
|
args.decoder_ffn_embed_dim = getattr( |
|
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim |
|
) |
|
args.decoder_layers = getattr(args, "decoder_layers", 6) |
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) |
|
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) |
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) |
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0) |
|
args.activation_dropout = getattr(args, "activation_dropout", 0.0) |
|
args.activation_fn = getattr(args, "activation_fn", "relu") |
|
args.dropout = getattr(args, "dropout", 0.1) |
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) |
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) |
|
args.share_decoder_input_output_embed = getattr( |
|
args, "share_decoder_input_output_embed", False |
|
) |
|
args.share_all_embeddings = getattr(args, "share_all_embeddings", False) |
|
args.no_token_positional_embeddings = getattr( |
|
args, "no_token_positional_embeddings", False |
|
) |
|
args.adaptive_input = getattr(args, "adaptive_input", False) |
|
args.no_cross_attention = getattr(args, "no_cross_attention", False) |
|
args.cross_self_attention = getattr(args, "cross_self_attention", False) |
|
|
|
args.decoder_output_dim = getattr( |
|
args, "decoder_output_dim", args.decoder_embed_dim |
|
) |
|
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) |
|
|
|
args.encoder_prompt = getattr(args, "encoder_prompt", False) |
|
args.encoder_prompt_length = getattr(args, "encoder_prompt_length", 100) |
|
args.encoder_prompt_type = getattr(args, "encoder_prompt_type", "prefix") |
|
args.encoder_prompt_projection = getattr(args, "encoder_prompt_projection", False) |
|
args.encoder_prompt_dim = getattr(args, "encoder_prompt_dim", 2 * args.encoder_embed_dim) |
|
|
|
args.decoder_prompt = getattr(args, "decoder_prompt", False) |
|
args.decoder_prompt_length = getattr(args, "decoder_prompt_length", 100) |
|
args.decoder_prompt_type = getattr(args, "decoder_prompt_type", "prefix") |
|
args.decoder_prompt_projection = getattr(args, "decoder_prompt_projection", False) |
|
args.decoder_prompt_dim = getattr(args, "decoder_prompt_dim", 2 * args.encoder_embed_dim) |
|
|
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) |
|
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) |
|
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) |
|
args.checkpoint_activations = getattr(args, "checkpoint_activations", False) |
|
args.offload_activations = getattr(args, "offload_activations", False) |
|
if args.offload_activations: |
|
args.checkpoint_activations = True |
|
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) |
|
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) |
|
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) |
|
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) |
|
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) |
|
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) |
|
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) |
|
|