InternVideo2-Stage2_6B / modeling_internvideo2.py
shepnerd's picture
Update modeling_internvideo2.py
9e1bb51 verified
import numpy as np
import cv2
import os
import math
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import torch.utils.checkpoint as checkpoint
from functools import partial
from einops import rearrange
try:
from flash_attn.modules.mlp import FusedMLP
except:
print(f'FusedMLP of flash_attn is not installed!!!')
try:
from flash_attn.ops.rms_norm import DropoutAddRMSNorm
except:
print(f'DropoutAddRMSNorm of flash_attn is not installed!!!')
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
return output, None
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate(
[np.zeros([1, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
"""
t_size: int of the temporal size
return:
pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token)
"""
grid_t = np.arange(t_size, dtype=np.float32)
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
if cls_token:
pos_embed = np.concatenate(
[np.zeros([1, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0]
) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1]
) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'):
if pos_name in checkpoint_model:
pos_embed_checkpoint = checkpoint_model[pos_name]
embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
num_patches = model.patch_embed.num_patches #
num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
# we use 4 frames for pretraining
new_t_size = model.T
# height (== width) for the checkpoint position embedding
orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
# height (== width) for the new position embedding
new_size = int((num_patches // (new_t_size))** 0.5)
# class_token and dist_token are kept unchanged
if orig_t_size != new_t_size:
print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
pos_embed_checkpoint = new_pos_embed
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> BT, H, W, C -> BT, C, H, W
pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
def interpolate_pos_embed_internvideo2(checkpoint_model, model, orig_t_size = 8):
# interpolate position embedding
for pos_name in ['pos_embed', 'clip_pos_embed']:
if pos_name in checkpoint_model:
pos_embed_checkpoint = checkpoint_model[pos_name]
embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
num_patches = model.patch_embed.num_patches #
num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
# we use 8 frames for pretraining
# new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
new_t_size = model.num_frames // model.tubelet_size
# height (== width) for the checkpoint position embedding
orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
# height (== width) for the new position embedding
new_size = int((num_patches // (new_t_size))** 0.5)
# class_token and dist_token are kept unchanged
if orig_t_size != new_t_size:
print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
pos_embed_checkpoint = new_pos_embed
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> BT, H, W, C -> BT, C, H, W
pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
if 'pos_embed_spatial' in checkpoint_model or 'pos_embed_temporal' in checkpoint_model:
raise NotImplementedError
def interpolate_pos_embed_internvideo2_new(checkpoint_model, model, orig_t_size = 8):
pos_names = []
for k in checkpoint_model.keys():
if ('pos_embed' in k or 'clip_pos_embed' in k) and 'img_pos_embed' not in k:
pos_names.append(k)
print(f"pos names list for interpolating: {pos_names}")
assert len(pos_names) > 0, checkpoint_model.keys()
if 'pos_embed_spatial' in checkpoint_model.keys() or 'pos_embed_temporal' in checkpoint_model.keys():
raise NotImplementedError
# interpolate position embedding
for pos_name in pos_names:
pos_embed_checkpoint = checkpoint_model[pos_name]
embedding_size = pos_embed_checkpoint.shape[-1] # channel dim
num_patches = model.patch_embed.num_patches #
num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1
# we use 8 frames for pretraining
# new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size
new_t_size = model.num_frames // model.tubelet_size
# height (== width) for the checkpoint position embedding
orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5)
# height (== width) for the new position embedding
new_size = int((num_patches // (new_t_size))** 0.5)
# class_token and dist_token are kept unchanged
if orig_t_size != new_t_size:
print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> B, T, HW, C -> BHW, C, T (B = 1)
pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear')
pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size)
pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
pos_embed_checkpoint = new_pos_embed
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})")
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# B, L, C -> BT, H, W, C -> BT, C, H, W
pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size)
pos_tokens = pos_tokens.flatten(1, 3) # B, L, C
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model[pos_name] = new_pos_embed
def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False):
"""
grid_size: int of the grid height and width
t_size: int of the temporal size
return:
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
assert embed_dim % 4 == 0
embed_dim_spatial = embed_dim // 4 * 3
embed_dim_temporal = embed_dim // 4
# spatial
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(
embed_dim_spatial, grid
)
# temporal
grid_t = np.arange(t_size, dtype=np.float32)
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(
embed_dim_temporal, grid_t
)
# concate: [T, H, W] order
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(
pos_embed_temporal, grid_size**2, axis=1
) # [T, H*W, D // 4]
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(
pos_embed_spatial, t_size, axis=0
) # [T, H*W, D // 4 * 3]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
if cls_token:
pos_embed = np.concatenate(
[np.zeros([1, embed_dim]), pos_embed], axis=0
)
return pos_embed
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class PatchEmbed(nn.Module):
""" 3D Image to Patch Embedding
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
num_frames=8, tubelet_size=1, norm_layer=None
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (
num_frames // tubelet_size,
img_size[0] // patch_size[0],
img_size[1] // patch_size[1]
) # (T, H, W)
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
self.num_img_patches = self.grid_size[1] * self.grid_size[2]
self.proj = nn.Conv3d(
in_channels=in_chans, out_channels=embed_dim,
kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
stride=(tubelet_size, patch_size[0], patch_size[1])
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C
x = self.norm(x)
return x
class CrossAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None, out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
super().__init__()
self.norm1_q = norm_layer(dim)
self.norm1_k = norm_layer(dim)
self.norm1_v = norm_layer(dim)
self.cross_attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_attn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv, pos_q, pos_k = x, 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
x = x.squeeze(1)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
self.force_fp32 = force_fp32
@torch.cuda.amp.autocast(enabled=False)
def forward(self, x):
if self.force_fp32:
output_type = x.dtype
out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
return out.to(dtype=output_type)
else:
out = x.mul_(self.gamma) if self.inplace else x * self.gamma
return out
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.use_flash_attn = use_flash_attn
if use_flash_attn:
self.causal = causal
self.inner_attn = FlashAttention(attention_dropout=attn_drop)
self.qk_normalization = qk_normalization
self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
self.use_fused_rmsnorm = use_fused_rmsnorm
def _naive_attn(self, x):
B, N, C = x.shape
# print(x.shape, torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
attn = ((q * self.scale) @ k.transpose(-2, -1))
# attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# print(torch.cuda.memory_allocated(), torch.cuda.memory_allocated())
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
qkv = self.qkv(x)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
if self.qk_normalization:
q, k, v = qkv.unbind(2)
if self.use_fused_rmsnorm:
q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape)
k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape)
else:
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
qkv = torch.stack([q, k, v], dim=2)
context, _ = self.inner_attn(
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
)
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
outs = self.proj_drop(outs)
return outs
def forward(self, x):
x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
bias=True, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Block(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False,
fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False,
use_fused_rmsnorm=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
qk_normalization=qk_normalization,
use_fused_rmsnorm=use_fused_rmsnorm)
self.ls1 = LayerScale(dim, init_values=init_values,
force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if use_fused_mlp:
self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic)
else:
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values,
force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.with_cp = with_cp
self.use_fused_rmsnorm = use_fused_rmsnorm
def forward(self, x, residual=None):
def _inner_forward(x, residual=None):
if self.use_fused_rmsnorm:
x, residual = self.norm1(x, residual)
x = self.drop_path1(self.ls1(self.attn(x)))
x, residual = self.norm2(x, residual)
x = self.drop_path2(self.ls2(self.mlp(x)))
return x, residual
else:
assert residual is None
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
if self.with_cp:
return checkpoint.checkpoint(_inner_forward, x, residual)
else:
return _inner_forward(x, residual=residual)
class Linear_Decoder(nn.Module):
def __init__(self, in_channels=1408, out_channels=3200,
norm_layer=nn.LayerNorm, clip_norm_type='l2'):
super().__init__()
self.clip_norm_type = clip_norm_type
self.head = nn.Linear(in_channels, out_channels)
self.norm = norm_layer(out_channels)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
x = self.norm(self.head(x))
if self.clip_norm_type == 'l2':
x = x / x.norm(dim=-1, keepdim=True)
elif self.clip_norm_type == 'none':
pass
else:
raise NotImplementedError
return x
class PretrainInternVideo2(nn.Module):
def __init__(
self,
in_chans: int = 3,
patch_size: int = 14,
img_size: int = 224,
qkv_bias: bool = False,
drop_path_rate: float = 0.25,
embed_dim: int = 1408,
num_heads: int = 16,
mlp_ratio: float = 48/11,
init_values: float = 1e-5,
qk_normalization: bool = True,
depth: int = 40,
use_flash_attn: bool = True,
use_fused_rmsnorm: bool = True,
use_fused_mlp: bool = True,
fused_mlp_heuristic: int = 1,
attn_pool_num_heads: int = 16,
clip_embed_dim: int = 768,
layerscale_no_force_fp32: bool = False,
num_frames: int = 8,
tubelet_size: int = 1,
sep_pos_embed: bool = False,
sep_image_video_pos_embed: bool = False,
use_checkpoint: bool = False,
checkpoint_num: int = 0,
# for unmasked teacher
clip_teacher_embed_dim: int = 3200,
clip_teacher_final_dim: int = 768, # if 0, not distill final features
clip_norm_type: str = 'l2',
clip_return_layer: int = 1,
clip_student_return_interval: int = 1,
):
super().__init__()
self.num_frames = num_frames
self.tubelet_size = tubelet_size
assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent'
self.use_flash_attn = use_flash_attn
self.embed_dim = embed_dim
self.depth = depth
self.clip_norm_type = clip_norm_type
self.return_index = []
for i in range(clip_return_layer):
self.return_index.append(depth - int(i * clip_student_return_interval) - 1)
if use_fused_rmsnorm:
norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True)
else:
norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
self.norm_layer_for_blocks = norm_layer_for_blocks
self.patch_embed = PatchEmbed(
img_size, patch_size, in_chans, embed_dim,
num_frames=num_frames, tubelet_size=tubelet_size,
)
num_patches = self.patch_embed.num_patches
num_img_patches = self.patch_embed.num_img_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17
self.sep_pos_embed = sep_pos_embed
self.sep_image_video_pos_embed = sep_image_video_pos_embed
if sep_pos_embed:
raise NotImplementedError
else:
if sep_image_video_pos_embed:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
# for CLIP decoder
self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.clip_img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim))
else:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# choose which layer to use checkpoint
with_cp_list = [False] * depth
if use_checkpoint:
for idx in range(depth):
if idx < checkpoint_num:
with_cp_list[idx] = True
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
norm_layer=norm_layer_for_blocks,
drop_path=dpr[i], init_values=init_values, attn_drop=0.,
use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp,
fused_mlp_heuristic=fused_mlp_heuristic,
with_cp=with_cp_list[i],
qk_normalization=qk_normalization,
layerscale_no_force_fp32=layerscale_no_force_fp32,
use_fused_rmsnorm=use_fused_rmsnorm)
for i in range(depth)])
self.clip_projector = AttentionPoolingBlock(
dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
# CLIP decoder
self.clip_decoder = nn.ModuleList([
Linear_Decoder(
in_channels=embed_dim,
out_channels=clip_teacher_embed_dim,
norm_layer=partial(nn.LayerNorm, eps=1e-5),
clip_norm_type=clip_norm_type
) for _ in range(clip_return_layer)
])
self.final_clip_decoder = nn.Identity()
if clip_teacher_final_dim > 0:
self.final_clip_decoder = Linear_Decoder(
in_channels=clip_embed_dim,
out_channels=clip_teacher_final_dim,
norm_layer=partial(nn.LayerNorm, eps=1e-5),
clip_norm_type=clip_norm_type
)
self.init_pos_embed()
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
def init_pos_embed(self):
if self.sep_pos_embed:
raise NotImplementedError
else:
# trunc_normal_(self.pos_embed, std=.02)
# trunc_normal_(self.clip_pos_embed, std=.02)
pos_embed = get_3d_sincos_pos_embed(
self.pos_embed.shape[-1],
self.patch_embed.grid_size[1], # height & weight
self.patch_embed.grid_size[0], # t_size
cls_token=True
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
self.clip_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
if self.sep_image_video_pos_embed:
img_pos_embed = get_3d_sincos_pos_embed(
self.pos_embed.shape[-1],
self.patch_embed.grid_size[1], # height & weight
1,
cls_token=True
)
self.img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
self.clip_img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0))
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
@property
def dtype(self):
return self.patch_embed.proj.weight.dtype
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {
'pos_embed',
'pos_embed_spatial',
'pos_embed_temporal',
'pos_embed_cls',
'img_pos_embed',
'cls_token',
'clip_pos_embed',
'clip_pos_embed_spatial',
'clip_pos_embed_temporal',
'clip_pos_embed_cls',
'clip_img_pos_embed'
}
# @torch.cuda.amp.autocast(enabled=False)
def forward(self, x, mask=None, use_image=False, x_vis_return_idx=-1, x_vis_only=False):
x = self.patch_embed(x.type(self.dtype))
# print(f"x.shape: {x.shape} x.dtype: {x.dtype}, model.dtype: {self.dtype}")
B, T, L, C = x.shape # T: temporal; L: spatial
x = x.view([B, T * L, C])
# append cls token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# add pos_embed
if self.sep_pos_embed:
raise NotImplementedError
else:
if use_image:
if self.sep_image_video_pos_embed:
pos_embed = self.img_pos_embed
else:
# (1, num_img_patches + 1, embed_dim)
# print('origin pos_embed.shape:', self.pos_embed.shape)
cls_pos_embed = self.pos_embed[:, 0:1, :]
# print('cls_pos_embed.shape:', cls_pos_embed.shape)
img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
# print('img_pos_embed.shape:', img_pos_embed.shape)
pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1)
# print('final img_pos_embed.shape:', pos_embed.shape)
else:
pos_embed = self.pos_embed
x = x + pos_embed
# mask tokens, ~mask means visible
if mask is not None:
x = x[~mask].reshape(B, -1, C)
else:
x = x.reshape(B, -1, C)
residual = None
x_clip = []
for idx, blk in enumerate(self.blocks):
if isinstance(x, tuple) and len(x) == 2:
x, residual = x
# print(f"\033[31m这是{idx}, {x.shape}\033[0m")
x = blk(x, residual=residual)
# return intermediate features
if idx in self.return_index:
if isinstance(x, tuple) and len(x) == 2:
tmp_x, tmp_residual = x
if residual is not None:
x_clip.append(tmp_x + tmp_residual)
else:
x_clip.append(x)
if idx == (self.depth + x_vis_return_idx):
# print(f'idx = {idx} len(self.blocks)={len(self.blocks)}')
break
if isinstance(x, tuple) and len(x) == 2:
x, residual = x
if residual is not None:
x = x + residual
x_vis = x
if x_vis_only:
return x_vis
x_pool_vis = self.clip_projector(x_vis)
x_align = self.final_clip_decoder(x_pool_vis)
# align CLIP
x_clip = torch.stack(x_clip)
K, B, _, C_CLIP = x_clip.shape
# add pos_embed
if self.sep_pos_embed:
raise NotImplementedError
else:
if use_image:
if self.sep_image_video_pos_embed:
clip_pos_embed = self.clip_img_pos_embed
else:
# (1, num_img_patches + 1, embed_dim)
# print('origin pos_embed.shape:', self.pos_embed.shape)
clip_cls_pos_embed = self.clip_pos_embed[:, 0:1, :]
# print('cls_pos_embed.shape:', cls_pos_embed.shape)
clip_img_pos_embed = self.clip_pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1)
# print('img_pos_embed.shape:', img_pos_embed.shape)
clip_pos_embed = torch.cat([clip_cls_pos_embed, clip_img_pos_embed], dim=1)
# print('final img_pos_embed.shape:', pos_embed.shape)
else:
clip_pos_embed = self.clip_pos_embed
clip_pos_embed = clip_pos_embed.repeat(B, 1, 1)
if mask is not None:
x_clip = x_clip + clip_pos_embed[~mask].view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1)
else:
x_clip = x_clip + clip_pos_embed.view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1)
# CLIP decoder
x_clip_align = []
for idx, clip_decoder in enumerate(self.clip_decoder):
x_clip_align.append(clip_decoder(x_clip[idx]))
x_clip_align = torch.stack(x_clip_align)
return x_vis, x_pool_vis, x_clip_align, x_align
def pretrain_internvideo2_1b_patch14_224(config):
model = PretrainInternVideo2(
in_chans=3, img_size=224, patch_size=14,
embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
clip_embed_dim=config.vision_encoder.clip_embed_dim,
attn_pool_num_heads=16, qkv_bias=False,
drop_path_rate=0.25,
init_values=0.00001,
qk_normalization=True,
use_flash_attn=config.vision_encoder.use_flash_attn,
use_fused_rmsnorm=config.vision_encoder.use_fused_rmsnorm,
use_fused_mlp=config.vision_encoder.use_fused_mlp,
fused_mlp_heuristic=1,
layerscale_no_force_fp32=False,
num_frames=config.vision_encoder.num_frames,
tubelet_size=config.vision_encoder.tubelet_size,
sep_pos_embed=False,
sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
use_checkpoint=config.vision_encoder.use_checkpoint,
checkpoint_num=config.vision_encoder.checkpoint_num,
clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
clip_norm_type=config.vision_encoder.clip_norm_type,
clip_return_layer=config.vision_encoder.clip_return_layer,
clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
)
return model
def pretrain_internvideo2_6b_patch14_224(config):
model = PretrainInternVideo2(
in_chans=3, img_size=224, patch_size=14,
embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4,
clip_embed_dim=config.vision_encoder.clip_embed_dim,
attn_pool_num_heads=16, qkv_bias=False,
drop_path_rate=0.3,
init_values=0.00001,
qk_normalization=True,
use_flash_attn=config.vision_encoder.use_flash_attn,
use_fused_rmsnorm=config.vision_encoder.use_fused_rmsnorm,
use_fused_mlp=config.vision_encoder.use_fused_mlp,
fused_mlp_heuristic=1,
layerscale_no_force_fp32=False,
num_frames=config.vision_encoder.num_frames,
tubelet_size=config.vision_encoder.tubelet_size,
sep_pos_embed=False,
sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed,
use_checkpoint=config.vision_encoder.use_checkpoint,
checkpoint_num=config.vision_encoder.checkpoint_num,
clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim,
clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim,
clip_norm_type=config.vision_encoder.clip_norm_type,
clip_return_layer=config.vision_encoder.clip_return_layer,
clip_student_return_interval=config.vision_encoder.clip_student_return_interval,
)
return model
from dataclasses import dataclass
from typing import Tuple, Optional, List
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import (PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer)
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
)
from torch import Tensor, device
from torch.nn import CrossEntropyLoss
class BertConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the BERT
[bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 30522):
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (`int`, *optional*, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (`int`, *optional*, defaults to 2):
The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
classifier_dropout (`float`, *optional*):
The dropout ratio for the classification head.
Examples:
```python
>>> from transformers import BertModel, BertConfig
>>> # Initializing a BERT bert-base-uncased style configuration
>>> configuration = BertConfig()
>>> # Initializing a model from the bert-base-uncased style configuration
>>> model = BertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "bert"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
cross_module="ca",
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
self.cross_module = cross_module
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
print(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
n
in [
"adam_v",
"adam_m",
"AdamWeightDecayOptimizer",
"AdamWeightDecayOptimizer_1",
"global_step",
]
for n in name
):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, scope_names[0])
except AttributeError:
print("Skipping {}".format("/".join(name)))
continue
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self,
input_ids=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
past_key_values_length=0,
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[
:, past_key_values_length : seq_length + past_key_values_length
]
if token_type_ids is None:
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=self.position_ids.device
)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
config, "embedding_size"
):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * config.max_position_embeddings - 1, self.attention_head_size
)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if (
self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"
):
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(-1, 1)
position_ids_r = torch.arange(
seq_length, dtype=torch.long, device=hidden_states.device
).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings - 1
)
positional_embedding = positional_embedding.to(
dtype=query_layer.dtype
) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
relative_position_scores_key = torch.einsum(
"bhrd,lrd->bhlr", key_layer, positional_embedding
)
attention_scores = (
attention_scores
+ relative_position_scores_query
+ relative_position_scores_key
)
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# added `attention_scores` to return tuple
outputs = (
(context_layer, attention_probs, attention_scores)
if output_attentions
else (context_layer,)
)
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads,
self.self.num_attention_heads,
self.self.attention_head_size,
self.pruned_heads,
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
# add attentions if we output them
outputs = (attention_output,) + self_outputs[1:]
return outputs # (context_layer, attention_probs, attention_scores, past_key_value,)
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.has_cross_attention = layer_num >= config.fusion_layer
if self.has_cross_attention:
self.crossattention = BertAttention(config, is_cross_attention=True)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
) # (context_layer, attention_probs, attention_scores, past_key_value,)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if self.has_cross_attention:
assert (
encoder_hidden_states is not None
), "encoder_hidden_states must be given for cross-attention layers"
if type(encoder_hidden_states) == list:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states[
(self.layer_num - self.config.fusion_layer)
% len(encoder_hidden_states)
],
encoder_attention_mask[
(self.layer_num - self.config.fusion_layer)
% len(encoder_hidden_states)
],
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
else:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
) # (context_layer, attention_probs, attention_scores, past_key_value,)
attention_output = cross_attention_outputs[0]
# add cross attentions if we output attention weights
outputs = outputs + cross_attention_outputs[1:-1]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[BertLayer(config, i) for i in range(config.num_hidden_layers)]
)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
mode="multi_modal",
normalize_attention=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
# all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None
if (
mode == "text" or mode == "temporal"
): # temporal is added and used for temporal att module.
start_layer = 0
output_layer = self.config.fusion_layer
elif mode == "fusion":
start_layer = self.config.fusion_layer
output_layer = self.config.num_hidden_layers
elif mode == "multi_modal":
start_layer = 0
output_layer = self.config.num_hidden_layers
for i in range(start_layer, output_layer):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
print(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
use_reentrant=False,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
) # (context_layer, attention_probs, attention_scores, past_key_value,)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
# whether to output normalized attention,
# note for unnormalized attention, there is a mask added
offset = int(normalize_attention)
# all_self_attentions = all_self_attentions + (layer_outputs[1], )
all_self_attentions = all_self_attentions + (layer_outputs[2 - offset],)
if hasattr(layer_module, "crossattention"):
# all_cross_attentions = all_cross_attentions + (layer_outputs[3], )
all_cross_attentions = all_cross_attentions + (layer_outputs[4 - offset],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = (
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
<= seq_ids[None, :, None]
)
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones(
(batch_size, seq_length, prefix_seq_len),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=-1,
)
extended_attention_mask = (
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=self.dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
mode="multi_modal",
normalize_attention=True,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = inputs_embeds.device
elif encoder_embeds is not None:
input_shape = encoder_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = encoder_embeds.device
else:
raise ValueError(
"You have to specify either input_ids or inputs_embeds or encoder_embeds"
)
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
if attention_mask is None:
attention_mask = torch.ones(
((batch_size, seq_length + past_key_values_length)), device=device
)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device, is_decoder
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
0
].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [
self.invert_attention_mask(mask) for mask in encoder_attention_mask
]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask
)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if encoder_embeds is None:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
else:
embedding_output = encoder_embeds
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mode=mode,
normalize_attention=normalize_attention,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@dataclass
class MaskedLMOutputWithDistill(MaskedLMOutput):
loss_aux: Optional[torch.FloatTensor] = None
loss_distill: Optional[torch.FloatTensor] = None
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def tie_aux_decoder_weights(self, module, aux_modules):
"""Tie decoder weights of all `aux_modules` to `module`, (not bias)"""
for m in aux_modules:
m.predictions.decoder.weight = module.predictions.decoder.weight
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
mode="multi_modal",
normalize_attention=True,
soft_labels=None,
alpha=0,
return_logits=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_embeds=encoder_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
mode=mode,
normalize_attention=normalize_attention,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores
masked_lm_loss = None
masked_lm_loss_aux = 0.0
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
)
if soft_labels is not None:
loss_distill = -torch.sum(
F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1
)
loss_distill = loss_distill[labels != -100].mean()
masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
# changed from MaskedLMOutput to MaskedLMOutputWithDistill
return MaskedLMOutputWithDistill(
loss=masked_lm_loss,
loss_aux=masked_lm_loss_aux,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
# add a dummy token
assert (
self.config.pad_token_id is not None
), "The PAD token should be defined for generation"
attention_mask = torch.cat(
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
)
dummy_token = torch.full(
(effective_batch_size, 1),
self.config.pad_token_id,
dtype=torch.long,
device=input_ids.device,
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask}
def build_bert(model_config, pretrain, checkpoint, encoder_width=None):
"""build text encoder.
Args:
model_config (dict): model config.
pretrain (bool): Whether to do pretrain or finetuning.
checkpoint (bool): whether to do gradient_checkpointing.
Returns: TODO
"""
bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
if encoder_width is None:
bert_config.encoder_width = model_config.vision_encoder.d_model
else:
bert_config.encoder_width = encoder_width
bert_config.gradient_checkpointing = checkpoint
bert_config.fusion_layer = model_config.text_encoder.fusion_layer
if not model_config.multimodal.enable:
bert_config.fusion_layer = bert_config.num_hidden_layers
if pretrain:
try:
text_encoder, loading_info = BertForMaskedLM.from_pretrained(
model_config.text_encoder.pretrained,
config=bert_config,
output_loading_info=True,
local_files_only=True
)
except:
text_encoder, loading_info = BertForMaskedLM.from_pretrained(
model_config.text_encoder.pretrained,
config=bert_config,
output_loading_info=True,
local_files_only=False
)
else:
try:
text_encoder, loading_info = BertModel.from_pretrained(
model_config.text_encoder.pretrained,
config=bert_config,
add_pooling_layer=False,
output_loading_info=True,
local_files_only=True
)
except:
text_encoder, loading_info = BertModel.from_pretrained(
model_config.text_encoder.pretrained,
config=bert_config,
add_pooling_layer=False,
output_loading_info=True,
local_files_only=False
)
return text_encoder
def get_sim(
vision_proj: torch.Tensor,
text_proj: torch.Tensor,
temp=1.0,
agg_method="mean",
):
"""calculate pair-wise video-text similarity.
Args:
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
text_proj (torch.Tensor): The text representation. Shape: [B,C].
temp (torch.Tensor): The temperature. Shape: [].
Returns: The similarity between video and text. Shape: [B,B].
"""
vision_proj = F.normalize(vision_proj, dim=-1)
text_proj = F.normalize(text_proj, dim=-1)
if vision_proj.ndim == 3:
sim_v2t = torch.einsum("mld,nd->mln", vision_proj, text_proj) / temp # [B, L, B]
sim_t2v = torch.einsum("nd,mld->nlm", text_proj, vision_proj) / temp # [B, L, B]
if agg_method == "mean":
sim_v2t = sim_v2t.mean(1)
sim_t2v = sim_t2v.mean(1)
elif agg_method == "max":
sim_v2t = sim_v2t.max(1)[0]
sim_t2v = sim_t2v.max(1)[0]
elif text_proj.ndim == 3:
sim_v2t = torch.einsum("nd,mld->nlm", vision_proj, text_proj) / temp # [B, L, B]
sim_t2v = torch.einsum("nld,md->nlm", text_proj, vision_proj) / temp # [B, L, B]
if agg_method == "mean":
sim_v2t = sim_v2t.mean(1)
sim_t2v = sim_t2v.mean(1)
elif agg_method == "max":
sim_v2t = sim_v2t.max(1)[0]
sim_t2v = sim_t2v.max(1)[0]
else:
sim_v2t = vision_proj @ text_proj.T / temp
sim_t2v = sim_v2t.T
return sim_v2t, sim_t2v
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-uncased": 512,
"bert-large-uncased": 512,
"bert-base-cased": 512,
"bert-large-cased": 512,
"bert-base-multilingual-uncased": 512,
"bert-base-multilingual-cased": 512,
"bert-base-chinese": 512,
"bert-base-german-cased": 512,
"bert-large-uncased-whole-word-masking": 512,
"bert-large-cased-whole-word-masking": 512,
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
"bert-base-cased-finetuned-mrpc": 512,
"bert-base-german-dbmdz-cased": 512,
"bert-base-german-dbmdz-uncased": 512,
"TurkuNLP/bert-base-finnish-cased-v1": 512,
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
"wietsedv/bert-base-dutch-cased": 512,
}
PRETRAINED_INIT_CONFIGURATION = {
"bert-base-uncased": {"do_lower_case": True},
"bert-large-uncased": {"do_lower_case": True},
"bert-base-cased": {"do_lower_case": False},
"bert-large-cased": {"do_lower_case": False},
"bert-base-multilingual-uncased": {"do_lower_case": True},
"bert-base-multilingual-cased": {"do_lower_case": False},
"bert-base-chinese": {"do_lower_case": False},
"bert-base-german-cased": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
}
import collections
import unicodedata
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class BasicTokenizer(object):
"""
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
Args:
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to lowercase the input when tokenizing.
never_split (:obj:`Iterable`, `optional`):
Collection of tokens which will never be split during tokenization. Only has an effect when
:obj:`do_basic_tokenize=True`
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to tokenize Chinese characters.
This should likely be deactivated for Japanese (see this `issue
<https://github.com/huggingface/transformers/issues/328>`__).
strip_accents: (:obj:`bool`, `optional`):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for :obj:`lowercase` (as in the original BERT).
"""
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = set(never_split)
self.tokenize_chinese_chars = tokenize_chinese_chars
self.strip_accents = strip_accents
def tokenize(self, text, never_split=None):
"""
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
WordPieceTokenizer.
Args:
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
"""
# union() returns a new set by concatenating the two sets.
never_split = self.never_split.union(
set(never_split)) if never_split else self.never_split
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if token not in never_split:
if self.do_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
tokenization using the given vocabulary.
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
class BertTokenizer(PreTrainedTokenizer):
r"""
Construct a BERT tokenizer. Based on WordPiece.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.
Args:
vocab_file (:obj:`str`):
File containing the vocabulary.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to lowercase the input when tokenizing.
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to do basic tokenization before WordPiece.
never_split (:obj:`Iterable`, `optional`):
Collection of tokens which will never be split during tokenization. Only has an effect when
:obj:`do_basic_tokenize=True`
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
The token used for padding, for example when batching sequences of different lengths.
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to tokenize Chinese characters.
This should likely be deactivated for Japanese (see this `issue
<https://github.com/huggingface/transformers/issues/328>`__).
strip_accents: (:obj:`bool`, `optional`):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for :obj:`lowercase` (as in the original BERT).
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs
):
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
vocab_file)
)
self.vocab = load_vocab(vocab_file)
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(
vocab=self.vocab, unk_token=self.unk_token)
@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
else:
split_tokens += self.wordpiece_tokenizer.tokenize(token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: ``[CLS] X ``
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
pair mask has the following format:
::
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") +
VOCAB_FILES_NAMES["vocab_file"]
)
else:
vocab_file = (filename_prefix +
"-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
print(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(
vocab_file)
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)
from huggingface_hub import PyTorchModelHubMixin
def _frame_from_video(video):
while video.isOpened():
success, frame = video.read()
if success:
yield frame
else:
break
v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)
v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)
def normalize(data):
return (data/255.0-v_mean)/v_std
def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')):
assert(len(vid_list) >= fnum)
step = len(vid_list) // fnum
vid_list = vid_list[::step][:fnum]
vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list]
vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list]
vid_tube = np.concatenate(vid_tube, axis=1)
vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float()
return vid_tube
def vid2tensor(path: str, fnum: int=8, target_size: tuple=(224, 224), device=torch.device('cuda')):
video = cv2.VideoCapture(path)
frames = [x for x in _frame_from_video(video)]
return frames2tensor(frames, fnum, target_size, device)
def get_text_feat_dict(texts, clip, text_feat_d={}):
for t in texts:
feat = clip.get_txt_feat(t)
text_feat_d[t] = feat
return text_feat_d
def get_vid_feat(frames, vlm):
return vlm.get_vid_features(frames)
def retrieve_text(frames,
texts,
model,
topk:int=5,
device=torch.device('cuda')):
vlm = model.to(device)
config = vlm.config
fn = config.num_frames
size_t = config.size_t
frames_tensor = frames2tensor(frames, fnum=fn, target_size=(size_t, size_t), device=device)
vid_feat = vlm.get_vid_feat(frames_tensor)
text_feat_d = {}
text_feat_d = get_text_feat_dict(texts, vlm, text_feat_d)
text_feats = [text_feat_d[t] for t in texts]
text_feats_tensor = torch.cat(text_feats, 0)
probs, idxs = vlm.predict_label(vid_feat, text_feats_tensor, top=topk)
ret_texts = [texts[i] for i in idxs.long().numpy()[0].tolist()]
return ret_texts, probs.float().numpy()[0]
def setup_internvideo2(config):
model = InternVideo2_Stage2(config=config, is_pretrain=True)
torch.set_float32_matmul_precision('high')
model = torch.compile(model)
model = model.to(torch.device(config.device))
model_without_ddp = model
if (config.pretrained_path.strip() and (os.path.isfile(config.pretrained_path)) or "s3://" in config.pretrained_path):
checkpoint = torch.load(config.pretrained_path, map_location="cpu")
try:
if "model" in checkpoint.keys():
state_dict = checkpoint["model"]
else:
state_dict = checkpoint["module"] # This is a deepspeed stage 1 model
except:
state_dict = checkpoint
# if config.get('origin_num_frames', None) is not None:
a = len(state_dict)
interpolate_pos_embed_internvideo2_new(state_dict, model_without_ddp.vision_encoder, orig_t_size=config.origin_num_frames)
assert a == len(state_dict), state_dict.keys()
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
model_without_ddp = model_without_ddp.to(torch.float32)
return model_without_ddp.eval()
class DictToClass:
def __init__(self, data):
for key, value in data.items():
key = str(key)
if isinstance(value, dict):
setattr(self, key, DictToClass(value))
elif isinstance(value, list):
setattr(self, key, [
DictToClass(item) if isinstance(item, dict) else item
for item in value
])
else:
setattr(self, key, value)
def __repr__(self):
"""方便调试的对象表示"""
attrs = ', '.join(f"{k}={v!r}" for k, v in self.__dict__.items())
return f"{self.__class__.__name__}({attrs})"
def instance2dict(obj):
"""将类实例及其嵌套属性转换为字典"""
if isinstance(obj, (str, int, float, bool, type(None))):
# 基本类型直接返回
return obj
elif isinstance(obj, dict):
# 字典类型递归处理值
return {k: instance2dict(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple, set)):
# 可迭代类型递归处理元素
return type(obj)(instance2dict(item) for item in obj)
elif hasattr(obj, '__dict__'):
# 类实例处理
result = {}
for key, value in obj.__dict__.items():
# 过滤私有属性(可选)
if not key.startswith('_'):
result[key] = instance2dict(value)
return result
else:
# 其他不可序列化类型直接返回
return str(obj) # 或者根据需求抛出异常
class InternVideo2_Stage2_Config(PretrainedConfig):
_auto_class='AutoConfig'
def __init__(self, **kwargs):
super().__init__(**kwargs)
class InternVideo2_Stage2(
PreTrainedModel,
):
"""docstring for InternVideo2_Stage2"""
_auto_class="AutoModel"
config_class=InternVideo2_Stage2_Config
def __init__(self,
config: InternVideo2_Stage2_Config,
is_pretrain: bool=True):
super(InternVideo2_Stage2, self).__init__(config)
config = config.to_dict()
self._config = DictToClass(config) if isinstance(config, dict) else config
self.tokenizer = BertTokenizer.from_pretrained(self._config.model.text_encoder.pretrained, local_files_only=True, use_safetensors=True)
self.is_pretrain = is_pretrain
self.vision_width = self._config.model.vision_encoder.clip_embed_dim
self.text_width = self._config.model.text_encoder.d_model
self.embed_dim = self._config.model.embed_dim
# create modules.
self.vision_encoder = self.build_vision_encoder()
self.text_encoder = self.build_text_encoder()
self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
self.text_proj = nn.Linear(self.text_width, self.embed_dim)
def freeze_vision(self):
"""freeze vision encoder"""
for p in self.vision_encoder.parameters():
p.requires_grad = False
def freeze_text(self):
"""freeze text encoder"""
for p in self.text_encoder.parameters():
p.requires_grad = False
@property
def dtype(self):
return self.vision_encoder.patch_embed.proj.weight.dtype
def encode_vision(self,
image: torch.Tensor,
test: bool=False):
"""encode image / videos as features.
Args:
image (torch.Tensor): The input images.
test (bool): Whether testing.
Returns: tuple.
- vision_embeds (torch.Tensor): The output features. Shape: [B,N,C].
- pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C].
- student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C].
- clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C].
"""
T = image.shape[1]
use_image = True if T == 1 else False
image = image.permute(0, 2, 1, 3, 4).to(self.dtype) # [B,T,C,H,W] -> [B,C,T,H,W]
# whether save temporal dimension
# keep_temporal=self._config.model.vision_encoder.keep_temporal
if test:
vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder(
image, None, use_image)
return vision_embeds, pooled_vision_embeds
else:
mask, targets_clip_middle_vis, targets_clip_final_vis = self.encode_teacher(image)
# if mask is not None and (self.video_mask_type != 'tube' or self.image_mask_type != 'tube'):
# keep_temporal = False
# print(f"\033[31mmask is {type(mask)}\033[0m")
vision_embeds, pooled_vision_embeds, student_output, student_output_final = self.vision_encoder(
image, mask, use_image)
return vision_embeds, pooled_vision_embeds, student_output, student_output_final, targets_clip_middle_vis, targets_clip_final_vis
def encode_text(self,
text: dict):
"""encode text.
Args:
text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
- input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
- attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
- other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
Returns: tuple.
- text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].
- pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].
"""
text_output = self.get_text_encoder()(
text.input_ids,
attention_mask=text.attention_mask,
return_dict=True,
mode="text",
)
text_embeds = text_output.last_hidden_state
pooled_text_embeds = text_embeds[:, 0]
return text_embeds, pooled_text_embeds
def build_vision_encoder(self):
"""build vision encoder
Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`.
"""
encoder_name = self._config.model.vision_encoder.name
if encoder_name == 'pretrain_internvideo2_1b_patch14_224':
vision_encoder = pretrain_internvideo2_1b_patch14_224(self._config.model)
elif encoder_name == 'pretrain_internvideo2_6b_patch14_224':
vision_encoder = pretrain_internvideo2_6b_patch14_224(self._config.model)
else:
raise ValueError(f"Not implemented: {encoder_name}")
# parameters for mask
img_size = self._config.model.vision_encoder.img_size
num_frames = self._config.model.vision_encoder.num_frames
tublet_size = self._config.model.vision_encoder.tubelet_size
patch_size = self._config.model.vision_encoder.patch_size
self.clip_img_size = self._config.model.vision_encoder.clip_input_resolution
self.video_mask_type = self._config.model.vision_encoder.video_mask_type
self.video_window_size = (num_frames // tublet_size, img_size // patch_size, img_size // patch_size)
self.video_mask_ratio = self._config.model.vision_encoder.video_mask_ratio
self.image_mask_type = self._config.model.vision_encoder.image_mask_type
self.image_window_size = (1, img_size // patch_size, img_size // patch_size)
self.image_mask_ratio = self._config.model.vision_encoder.image_mask_ratio
return vision_encoder
def build_text_encoder(self):
"""build text_encoder and possiblly video-to-text multimodal fusion encoder.
Returns: nn.Module. The text encoder
"""
encoder_name = self._config.model.text_encoder.name
if "bert" in encoder_name:
text_encoder = build_bert(
self._config.model,
self.is_pretrain,
self._config.gradient_checkpointing,
)
else:
raise ValueError(f"Not implemented: {encoder_name}")
return text_encoder
def get_text_encoder(self):
"""get text encoder, used for text and cross-modal encoding"""
encoder = self.text_encoder
return encoder.bert if hasattr(encoder, "bert") else encoder
def get_vid_feat(self,
frames: torch.Tensor):
"""get the video features for the given frames.
Args:
frames (torch.Tensor): The input frames. Shape: [B,T,C,H,W].
Returns: tuple.
- vision_embeds (torch.Tensor): The output features. Shape: [B,N,C].
- pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C].
"""
with torch.no_grad():
_, vfeat = self.encode_vision(frames, test=True)
vfeat = self.vision_proj(vfeat)
vfeat /= vfeat.norm(dim=-1, keepdim=True)
return vfeat
def get_txt_feat(self,
text: str):
"""get the text features for the given text."""
with torch.no_grad():
text = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self._config.max_txt_l,
return_tensors="pt",).to(self._config.device)
_, tfeat = self.encode_text(text)
tfeat = self.text_proj(tfeat)
tfeat /= tfeat.norm(dim=-1, keepdim=True)
return tfeat
def predict_label(self,
vid_feat: torch.Tensor,
txt_feat: torch.Tensor,
top: int=5):
label_probs = (100.0 * vid_feat @ txt_feat.T).softmax(dim=-1)
top_probs, top_labels = label_probs.float().cpu().topk(top, dim=-1)
return top_probs, top_labels