diff --git "a/modeling_internvideo2.py" "b/modeling_internvideo2.py" new file mode 100644--- /dev/null +++ "b/modeling_internvideo2.py" @@ -0,0 +1,3346 @@ +import numpy as np +import cv2 +import os +import math + +import torch +from torch import nn + +import torch.nn.functional as F +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +import torch.utils.checkpoint as checkpoint +from functools import partial +from einops import rearrange + +try: + from flash_attn.modules.mlp import FusedMLP +except: + print(f'FusedMLP of flash_attn is not installed!!!') + +try: + from flash_attn.ops.rms_norm import DropoutAddRMSNorm +except: + print(f'DropoutAddRMSNorm of flash_attn is not installed!!!') + +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +class FlashAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): + super().__init__() + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, + max_s=None, need_weights=False): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None + if unpadded: (nnz, 3, h, d) + key_padding_mask: a bool tensor of shape (B, S) + """ + assert not need_weights + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + + if cu_seqlens is None: + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=qkv.device) + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, batch_size, seqlen), + 'b s (h d) -> b s h d', h=nheads) + else: + assert max_s is not None + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + + return output, None + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate( + [np.zeros([1, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False): + """ + t_size: int of the temporal size + return: + pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token) + """ + grid_t = np.arange(t_size, dtype=np.float32) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t) + if cls_token: + pos_embed = np.concatenate( + [np.zeros([1, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0] + ) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1] + ) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'): + if pos_name in checkpoint_model: + pos_embed_checkpoint = checkpoint_model[pos_name] + embedding_size = pos_embed_checkpoint.shape[-1] # channel dim + num_patches = model.patch_embed.num_patches # + num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1 + + # we use 4 frames for pretraining + new_t_size = model.T + # height (== width) for the checkpoint position embedding + orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5) + # height (== width) for the new position embedding + new_size = int((num_patches // (new_t_size))** 0.5) + + # class_token and dist_token are kept unchanged + if orig_t_size != new_t_size: + print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})") + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1) + pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear') + pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size) + pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model[pos_name] = new_pos_embed + pos_embed_checkpoint = new_pos_embed + + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})") + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + # B, L, C -> BT, H, W, C -> BT, C, H, W + pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size) + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C + pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) + pos_tokens = pos_tokens.flatten(1, 3) # B, L, C + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model[pos_name] = new_pos_embed + + +def interpolate_pos_embed_internvideo2(checkpoint_model, model, orig_t_size = 8): + # interpolate position embedding + for pos_name in ['pos_embed', 'clip_pos_embed']: + if pos_name in checkpoint_model: + pos_embed_checkpoint = checkpoint_model[pos_name] + embedding_size = pos_embed_checkpoint.shape[-1] # channel dim + num_patches = model.patch_embed.num_patches # + num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1 + + # we use 8 frames for pretraining + # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size + new_t_size = model.num_frames // model.tubelet_size + # height (== width) for the checkpoint position embedding + orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5) + # height (== width) for the new position embedding + new_size = int((num_patches // (new_t_size))** 0.5) + + # class_token and dist_token are kept unchanged + if orig_t_size != new_t_size: + print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})") + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1) + pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear') + pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size) + pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model[pos_name] = new_pos_embed + pos_embed_checkpoint = new_pos_embed + + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})") + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + # B, L, C -> BT, H, W, C -> BT, C, H, W + pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size) + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C + pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) + pos_tokens = pos_tokens.flatten(1, 3) # B, L, C + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model[pos_name] = new_pos_embed + + if 'pos_embed_spatial' in checkpoint_model or 'pos_embed_temporal' in checkpoint_model: + raise NotImplementedError + + +def interpolate_pos_embed_internvideo2_new(checkpoint_model, model, orig_t_size = 8): + pos_names = [] + for k in checkpoint_model.keys(): + if ('pos_embed' in k or 'clip_pos_embed' in k) and 'img_pos_embed' not in k: + pos_names.append(k) + + print(f"pos names list for interpolating: {pos_names}") + + assert len(pos_names) > 0, checkpoint_model.keys() + + if 'pos_embed_spatial' in checkpoint_model.keys() or 'pos_embed_temporal' in checkpoint_model.keys(): + raise NotImplementedError + + # interpolate position embedding + for pos_name in pos_names: + + pos_embed_checkpoint = checkpoint_model[pos_name] + embedding_size = pos_embed_checkpoint.shape[-1] # channel dim + num_patches = model.patch_embed.num_patches # + num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1 + + # we use 8 frames for pretraining + # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size + new_t_size = model.num_frames // model.tubelet_size + # height (== width) for the checkpoint position embedding + orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5) + # height (== width) for the new position embedding + new_size = int((num_patches // (new_t_size))** 0.5) + + # class_token and dist_token are kept unchanged + if orig_t_size != new_t_size: + print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})") + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1) + pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear') + pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size) + pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model[pos_name] = new_pos_embed + pos_embed_checkpoint = new_pos_embed + + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})") + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + # B, L, C -> BT, H, W, C -> BT, C, H, W + pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size) + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C + pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) + pos_tokens = pos_tokens.flatten(1, 3) # B, L, C + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model[pos_name] = new_pos_embed + + +def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False): + """ + grid_size: int of the grid height and width + t_size: int of the temporal size + return: + pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + assert embed_dim % 4 == 0 + embed_dim_spatial = embed_dim // 4 * 3 + embed_dim_temporal = embed_dim // 4 + + # spatial + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid( + embed_dim_spatial, grid + ) + + # temporal + grid_t = np.arange(t_size, dtype=np.float32) + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid( + embed_dim_temporal, grid_t + ) + + # concate: [T, H, W] order + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat( + pos_embed_temporal, grid_size**2, axis=1 + ) # [T, H*W, D // 4] + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat( + pos_embed_spatial, t_size, axis=0 + ) # [T, H*W, D // 4 * 3] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) + pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] + + if cls_token: + pos_embed = np.concatenate( + [np.zeros([1, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class PatchEmbed(nn.Module): + """ 3D Image to Patch Embedding + """ + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, + num_frames=8, tubelet_size=1, norm_layer=None + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = ( + num_frames // tubelet_size, + img_size[0] // patch_size[0], + img_size[1] // patch_size[1] + ) # (T, H, W) + self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] + self.num_img_patches = self.grid_size[1] * self.grid_size[2] + + self.proj = nn.Conv3d( + in_channels=in_chans, out_channels=embed_dim, + kernel_size=(tubelet_size, patch_size[0], patch_size[1]), + stride=(tubelet_size, patch_size[0], patch_size[1]) + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C + x = self.norm(x) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None, out_dim=None): + super().__init__() + if out_dim is None: + out_dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + assert all_head_dim == dim + + self.q = nn.Linear(dim, all_head_dim, bias=False) + self.k = nn.Linear(dim, all_head_dim, bias=False) + self.v = nn.Linear(dim, all_head_dim, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, out_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, k=None, v=None): + B, N, C = x.shape + N_k = k.shape[1] + N_v = v.shape[1] + + q_bias, k_bias, v_bias = None, None, None + if self.q_bias is not None: + q_bias = self.q_bias + k_bias = self.k_bias + v_bias = self.v_bias + + q = F.linear(input=x, weight=self.q.weight, bias=q_bias) + q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim) + + k = F.linear(input=k, weight=self.k.weight, bias=k_bias) + k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) + + v = F.linear(input=v, weight=self.v.weight, bias=v_bias) + v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class AttentiveBlock(nn.Module): + + def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None): + super().__init__() + + self.norm1_q = norm_layer(dim) + self.norm1_k = norm_layer(dim) + self.norm1_v = norm_layer(dim) + self.cross_attn = CrossAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, + proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None): + x_q = self.norm1_q(x_q + pos_q) + x_k = self.norm1_k(x_kv + pos_k) + x_v = self.norm1_v(x_kv) + x = self.cross_attn(x_q, k=x_k, v=x_v) + + return x + + +class AttentionPoolingBlock(AttentiveBlock): + + def forward(self, x): + x_q = x.mean(1, keepdim=True) + x_kv, pos_q, pos_k = x, 0, 0 + x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None) + x = x.squeeze(1) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + self.force_fp32 = force_fp32 + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, x): + if self.force_fp32: + output_type = x.dtype + out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float() + return out.to(dtype=output_type) + else: + out = x.mul_(self.gamma) if self.inplace else x * self.gamma + return out + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False, + causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.use_flash_attn = use_flash_attn + if use_flash_attn: + self.causal = causal + self.inner_attn = FlashAttention(attention_dropout=attn_drop) + + self.qk_normalization = qk_normalization + self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity() + self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity() + self.use_fused_rmsnorm = use_fused_rmsnorm + + def _naive_attn(self, x): + B, N, C = x.shape + # print(x.shape, torch.cuda.memory_allocated(), torch.cuda.memory_allocated()) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + if self.qk_normalization: + B_, H_, N_, D_ = q.shape + q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + + attn = ((q * self.scale) @ k.transpose(-2, -1)) + # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16 + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + # print(torch.cuda.memory_allocated(), torch.cuda.memory_allocated()) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _flash_attn(self, x, key_padding_mask=None, need_weights=False): + + qkv = self.qkv(x) + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads) + + if self.qk_normalization: + q, k, v = qkv.unbind(2) + if self.use_fused_rmsnorm: + q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape) + k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape) + else: + q = self.q_norm(q.flatten(-2, -1)).view(q.shape) + k = self.k_norm(k.flatten(-2, -1)).view(k.shape) + qkv = torch.stack([q, k, v], dim=2) + + context, _ = self.inner_attn( + qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal + ) + outs = self.proj(rearrange(context, "b s h d -> b s (h d)")) + outs = self.proj_drop(outs) + return outs + + def forward(self, x): + x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x) + return x + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + bias=True, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Block(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False, + fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False, + use_fused_rmsnorm=False): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer, + qk_normalization=qk_normalization, + use_fused_rmsnorm=use_fused_rmsnorm) + self.ls1 = LayerScale(dim, init_values=init_values, + force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + if use_fused_mlp: + self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic) + else: + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values, + force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.with_cp = with_cp + self.use_fused_rmsnorm = use_fused_rmsnorm + + def forward(self, x, residual=None): + + def _inner_forward(x, residual=None): + if self.use_fused_rmsnorm: + x, residual = self.norm1(x, residual) + x = self.drop_path1(self.ls1(self.attn(x))) + x, residual = self.norm2(x, residual) + x = self.drop_path2(self.ls2(self.mlp(x))) + return x, residual + else: + assert residual is None + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + if self.with_cp: + return checkpoint.checkpoint(_inner_forward, x, residual) + else: + return _inner_forward(x, residual=residual) + + +class Linear_Decoder(nn.Module): + def __init__(self, in_channels=1408, out_channels=3200, + norm_layer=nn.LayerNorm, clip_norm_type='l2'): + super().__init__() + self.clip_norm_type = clip_norm_type + + self.head = nn.Linear(in_channels, out_channels) + self.norm = norm_layer(out_channels) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.norm(self.head(x)) + + if self.clip_norm_type == 'l2': + x = x / x.norm(dim=-1, keepdim=True) + elif self.clip_norm_type == 'none': + pass + else: + raise NotImplementedError + + return x + + +class PretrainInternVideo2(nn.Module): + def __init__( + self, + in_chans: int = 3, + patch_size: int = 14, + img_size: int = 224, + qkv_bias: bool = False, + drop_path_rate: float = 0.25, + embed_dim: int = 1408, + num_heads: int = 16, + mlp_ratio: float = 48/11, + init_values: float = 1e-5, + qk_normalization: bool = True, + depth: int = 40, + use_flash_attn: bool = True, + use_fused_rmsnorm: bool = True, + use_fused_mlp: bool = True, + fused_mlp_heuristic: int = 1, + attn_pool_num_heads: int = 16, + clip_embed_dim: int = 768, + layerscale_no_force_fp32: bool = False, + num_frames: int = 8, + tubelet_size: int = 1, + sep_pos_embed: bool = False, + sep_image_video_pos_embed: bool = False, + use_checkpoint: bool = False, + checkpoint_num: int = 0, + # for unmasked teacher + clip_teacher_embed_dim: int = 3200, + clip_teacher_final_dim: int = 768, # if 0, not distill final features + clip_norm_type: str = 'l2', + clip_return_layer: int = 1, + clip_student_return_interval: int = 1, + ): + super().__init__() + + self.num_frames = num_frames + self.tubelet_size = tubelet_size + assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent' + + self.use_flash_attn = use_flash_attn + self.embed_dim = embed_dim + + self.depth = depth + self.clip_norm_type = clip_norm_type + self.return_index = [] + for i in range(clip_return_layer): + self.return_index.append(depth - int(i * clip_student_return_interval) - 1) + + if use_fused_rmsnorm: + norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True) + else: + norm_layer_for_blocks = partial(RMSNorm, eps=1e-6) + self.norm_layer_for_blocks = norm_layer_for_blocks + self.patch_embed = PatchEmbed( + img_size, patch_size, in_chans, embed_dim, + num_frames=num_frames, tubelet_size=tubelet_size, + ) + num_patches = self.patch_embed.num_patches + num_img_patches = self.patch_embed.num_img_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17 + self.sep_pos_embed = sep_pos_embed + self.sep_image_video_pos_embed = sep_image_video_pos_embed + if sep_pos_embed: + raise NotImplementedError + else: + if sep_image_video_pos_embed: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim)) + # for CLIP decoder + self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.clip_img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim)) + else: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + # choose which layer to use checkpoint + with_cp_list = [False] * depth + if use_checkpoint: + for idx in range(depth): + if idx < checkpoint_num: + with_cp_list[idx] = True + + self.blocks = nn.ModuleList([ + Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias, + norm_layer=norm_layer_for_blocks, + drop_path=dpr[i], init_values=init_values, attn_drop=0., + use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp, + fused_mlp_heuristic=fused_mlp_heuristic, + with_cp=with_cp_list[i], + qk_normalization=qk_normalization, + layerscale_no_force_fp32=layerscale_no_force_fp32, + use_fused_rmsnorm=use_fused_rmsnorm) + for i in range(depth)]) + self.clip_projector = AttentionPoolingBlock( + dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None, + drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim) + + # CLIP decoder + self.clip_decoder = nn.ModuleList([ + Linear_Decoder( + in_channels=embed_dim, + out_channels=clip_teacher_embed_dim, + norm_layer=partial(nn.LayerNorm, eps=1e-5), + clip_norm_type=clip_norm_type + ) for _ in range(clip_return_layer) + ]) + self.final_clip_decoder = nn.Identity() + if clip_teacher_final_dim > 0: + self.final_clip_decoder = Linear_Decoder( + in_channels=clip_embed_dim, + out_channels=clip_teacher_final_dim, + norm_layer=partial(nn.LayerNorm, eps=1e-5), + clip_norm_type=clip_norm_type + ) + + self.init_pos_embed() + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() + + def init_pos_embed(self): + if self.sep_pos_embed: + raise NotImplementedError + else: + # trunc_normal_(self.pos_embed, std=.02) + # trunc_normal_(self.clip_pos_embed, std=.02) + pos_embed = get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], + self.patch_embed.grid_size[1], # height & weight + self.patch_embed.grid_size[0], # t_size + cls_token=True + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + self.clip_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + if self.sep_image_video_pos_embed: + img_pos_embed = get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], + self.patch_embed.grid_size[1], # height & weight + 1, + cls_token=True + ) + self.img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0)) + self.clip_img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + @property + def dtype(self): + return self.patch_embed.proj.weight.dtype + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return { + 'pos_embed', + 'pos_embed_spatial', + 'pos_embed_temporal', + 'pos_embed_cls', + 'img_pos_embed', + 'cls_token', + 'clip_pos_embed', + 'clip_pos_embed_spatial', + 'clip_pos_embed_temporal', + 'clip_pos_embed_cls', + 'clip_img_pos_embed' + } + + # @torch.cuda.amp.autocast(enabled=False) + def forward(self, x, mask=None, use_image=False, x_vis_return_idx=-1, x_vis_only=False): + x = self.patch_embed(x.type(self.dtype)) + # print(f"x.shape: {x.shape} x.dtype: {x.dtype}, model.dtype: {self.dtype}") + B, T, L, C = x.shape # T: temporal; L: spatial + x = x.view([B, T * L, C]) + + # append cls token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add pos_embed + if self.sep_pos_embed: + raise NotImplementedError + else: + if use_image: + if self.sep_image_video_pos_embed: + pos_embed = self.img_pos_embed + else: + # (1, num_img_patches + 1, embed_dim) + # print('origin pos_embed.shape:', self.pos_embed.shape) + cls_pos_embed = self.pos_embed[:, 0:1, :] + # print('cls_pos_embed.shape:', cls_pos_embed.shape) + + img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1) + # print('img_pos_embed.shape:', img_pos_embed.shape) + + pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1) + # print('final img_pos_embed.shape:', pos_embed.shape) + else: + pos_embed = self.pos_embed + x = x + pos_embed + + # mask tokens, ~mask means visible + if mask is not None: + x = x[~mask].reshape(B, -1, C) + else: + x = x.reshape(B, -1, C) + + residual = None + x_clip = [] + for idx, blk in enumerate(self.blocks): + if isinstance(x, tuple) and len(x) == 2: + x, residual = x + # print(f"\033[31m这是{idx}, {x.shape}\033[0m") + x = blk(x, residual=residual) + # return intermediate features + if idx in self.return_index: + if isinstance(x, tuple) and len(x) == 2: + tmp_x, tmp_residual = x + if residual is not None: + x_clip.append(tmp_x + tmp_residual) + else: + x_clip.append(x) + if idx == (self.depth + x_vis_return_idx): + # print(f'idx = {idx} len(self.blocks)={len(self.blocks)}') + break + + if isinstance(x, tuple) and len(x) == 2: + x, residual = x + if residual is not None: + x = x + residual + + x_vis = x + if x_vis_only: + return x_vis + + x_pool_vis = self.clip_projector(x_vis) + x_align = self.final_clip_decoder(x_pool_vis) + + # align CLIP + x_clip = torch.stack(x_clip) + K, B, _, C_CLIP = x_clip.shape + # add pos_embed + if self.sep_pos_embed: + raise NotImplementedError + else: + if use_image: + if self.sep_image_video_pos_embed: + clip_pos_embed = self.clip_img_pos_embed + else: + # (1, num_img_patches + 1, embed_dim) + # print('origin pos_embed.shape:', self.pos_embed.shape) + clip_cls_pos_embed = self.clip_pos_embed[:, 0:1, :] + # print('cls_pos_embed.shape:', cls_pos_embed.shape) + + clip_img_pos_embed = self.clip_pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1) + # print('img_pos_embed.shape:', img_pos_embed.shape) + + clip_pos_embed = torch.cat([clip_cls_pos_embed, clip_img_pos_embed], dim=1) + # print('final img_pos_embed.shape:', pos_embed.shape) + + else: + clip_pos_embed = self.clip_pos_embed + + clip_pos_embed = clip_pos_embed.repeat(B, 1, 1) + if mask is not None: + x_clip = x_clip + clip_pos_embed[~mask].view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1) + else: + x_clip = x_clip + clip_pos_embed.view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1) + + # CLIP decoder + x_clip_align = [] + for idx, clip_decoder in enumerate(self.clip_decoder): + x_clip_align.append(clip_decoder(x_clip[idx])) + x_clip_align = torch.stack(x_clip_align) + + return x_vis, x_pool_vis, x_clip_align, x_align + + +def pretrain_internvideo2_1b_patch14_224(config): + model = PretrainInternVideo2( + in_chans=3, img_size=224, patch_size=14, + embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + clip_embed_dim=config.vision_encoder.clip_embed_dim, + attn_pool_num_heads=16, qkv_bias=False, + drop_path_rate=0.25, + init_values=0.00001, + qk_normalization=True, + use_flash_attn=config.vision_encoder.use_flash_attn, + use_fused_rmsnorm=config.vision_encoder.use_fused_rmsnorm, + use_fused_mlp=config.vision_encoder.use_fused_mlp, + fused_mlp_heuristic=1, + layerscale_no_force_fp32=False, + num_frames=config.vision_encoder.num_frames, + tubelet_size=config.vision_encoder.tubelet_size, + sep_pos_embed=False, + sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed, + use_checkpoint=config.vision_encoder.use_checkpoint, + checkpoint_num=config.vision_encoder.checkpoint_num, + clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim, + clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim, + clip_norm_type=config.vision_encoder.clip_norm_type, + clip_return_layer=config.vision_encoder.clip_return_layer, + clip_student_return_interval=config.vision_encoder.clip_student_return_interval, + ) + + if config.vision_encoder.pretrained is not None: + state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu') + interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8) + message = model.load_state_dict(state_dict, strict=False) + + return model + + +def pretrain_internvideo2_6b_patch14_224(config): + model = PretrainInternVideo2( + in_chans=3, img_size=224, patch_size=14, + embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4, + clip_embed_dim=config.vision_encoder.clip_embed_dim, + attn_pool_num_heads=16, qkv_bias=False, + drop_path_rate=0.3, + init_values=0.00001, + qk_normalization=True, + use_flash_attn=config.vision_encoder.use_flash_attn, + use_fused_rmsnorm=config.vision_encoder.use_fused_rmsnorm, + use_fused_mlp=config.vision_encoder.use_fused_mlp, + fused_mlp_heuristic=1, + layerscale_no_force_fp32=False, + num_frames=config.vision_encoder.num_frames, + tubelet_size=config.vision_encoder.tubelet_size, + sep_pos_embed=False, + sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed, + use_checkpoint=config.vision_encoder.use_checkpoint, + checkpoint_num=config.vision_encoder.checkpoint_num, + clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim, + clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim, + clip_norm_type=config.vision_encoder.clip_norm_type, + clip_return_layer=config.vision_encoder.clip_return_layer, + clip_student_return_interval=config.vision_encoder.clip_student_return_interval, + ) + + if config.vision_encoder.pretrained is not None: + + state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu') + interpolate_pos_embed_internvideo2(state_dict, model, orig_t_size=8) + msg = model.load_state_dict(state_dict, strict=False) + + return model + + +from dataclasses import dataclass +from typing import Tuple, Optional, List +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import (PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + MaskedLMOutput, + ) +from torch import Tensor, device +from torch.nn import CrossEntropyLoss + + +class BertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to + instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the BERT + [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import BertModel, BertConfig + + >>> # Initializing a BERT bert-base-uncased style configuration + >>> configuration = BertConfig() + + >>> # Initializing a model from the bert-base-uncased style configuration + >>> model = BertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + cross_module="ca", + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + self.cross_module = cross_module + + +def load_tf_weights_in_bert(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + print( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + print("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] + for n in name + ): + print("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + print("Skipping {}".format("/".join(name))) + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ] + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device + ) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + # added `attention_scores` to return tuple + outputs = ( + (context_layer, attention_probs, attention_scores) + if output_attentions + else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + + self.self = BertSelfAttention(config, is_cross_attention) + + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] + return outputs # (context_layer, attention_probs, attention_scores, past_key_value,) + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + + self.has_cross_attention = layer_num >= config.fusion_layer + if self.has_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=True) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) # (context_layer, attention_probs, attention_scores, past_key_value,) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + if type(encoder_hidden_states) == list: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states[ + (self.layer_num - self.config.fusion_layer) + % len(encoder_hidden_states) + ], + encoder_attention_mask[ + (self.layer_num - self.config.fusion_layer) + % len(encoder_hidden_states) + ], + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + + else: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) # (context_layer, attention_probs, attention_scores, past_key_value,) + attention_output = cross_attention_outputs[0] + # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode="multi_modal", + normalize_attention=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + # all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_cross_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + if ( + mode == "text" or mode == "temporal" + ): # temporal is added and used for temporal att module. + start_layer = 0 + output_layer = self.config.fusion_layer + + elif mode == "fusion": + start_layer = self.config.fusion_layer + output_layer = self.config.num_hidden_layers + + elif mode == "multi_modal": + start_layer = 0 + output_layer = self.config.num_hidden_layers + + for i in range(start_layer, output_layer): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + print( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + use_reentrant=False, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) # (context_layer, attention_probs, attention_scores, past_key_value,) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + # whether to output normalized attention, + # note for unnormalized attention, there is a mask added + offset = int(normalize_attention) + # all_self_attentions = all_self_attentions + (layer_outputs[1], ) + all_self_attentions = all_self_attentions + (layer_outputs[2 - offset],) + if hasattr(layer_module, "crossattention"): + # all_cross_attentions = all_cross_attentions + (layer_outputs[3], ) + all_cross_attentions = all_cross_attentions + (layer_outputs[4 - offset],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + load_tf_weights = load_tf_weights_in_bert + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode="multi_modal", + normalize_attention=True, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds or encoder_embeds" + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + normalize_attention=normalize_attention, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@dataclass +class MaskedLMOutputWithDistill(MaskedLMOutput): + loss_aux: Optional[torch.FloatTensor] = None + loss_distill: Optional[torch.FloatTensor] = None + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def tie_aux_decoder_weights(self, module, aux_modules): + """Tie decoder weights of all `aux_modules` to `module`, (not bias)""" + for m in aux_modules: + m.predictions.decoder.weight = module.predictions.decoder.weight + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode="multi_modal", + normalize_attention=True, + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_embeds=encoder_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + normalize_attention=normalize_attention, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + masked_lm_loss_aux = 0.0 + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1 + ) + loss_distill = loss_distill[labels != -100].mean() + masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + # changed from MaskedLMOutput to MaskedLMOutputWithDistill + return MaskedLMOutputWithDistill( + loss=masked_lm_loss, + loss_aux=masked_lm_loss_aux, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert ( + self.config.pad_token_id is not None + ), "The PAD token should be defined for generation" + attention_mask = torch.cat( + [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1 + ) + dummy_token = torch.full( + (effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +def build_bert(model_config, pretrain, checkpoint, encoder_width=None): + """build text encoder. + + Args: + model_config (dict): model config. + pretrain (bool): Whether to do pretrain or finetuning. + checkpoint (bool): whether to do gradient_checkpointing. + + Returns: TODO + + """ + bert_config = BertConfig.from_json_file(model_config.text_encoder.config) + if encoder_width is None: + bert_config.encoder_width = model_config.vision_encoder.d_model + else: + bert_config.encoder_width = encoder_width + + bert_config.gradient_checkpointing = checkpoint + bert_config.fusion_layer = model_config.text_encoder.fusion_layer + + if not model_config.multimodal.enable: + bert_config.fusion_layer = bert_config.num_hidden_layers + + if pretrain: + try: + text_encoder, loading_info = BertForMaskedLM.from_pretrained( + model_config.text_encoder.pretrained, + config=bert_config, + output_loading_info=True, + local_files_only=True + ) + except: + text_encoder, loading_info = BertForMaskedLM.from_pretrained( + model_config.text_encoder.pretrained, + config=bert_config, + output_loading_info=True, + local_files_only=False + ) + else: + try: + text_encoder, loading_info = BertModel.from_pretrained( + model_config.text_encoder.pretrained, + config=bert_config, + add_pooling_layer=False, + output_loading_info=True, + local_files_only=True + ) + except: + text_encoder, loading_info = BertModel.from_pretrained( + model_config.text_encoder.pretrained, + config=bert_config, + add_pooling_layer=False, + output_loading_info=True, + local_files_only=False + ) + + return text_encoder + + +def get_sim( + vision_proj: torch.Tensor, + text_proj: torch.Tensor, + temp=1.0, + agg_method="mean", +): + """calculate pair-wise video-text similarity. + + Args: + vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. + text_proj (torch.Tensor): The text representation. Shape: [B,C]. + temp (torch.Tensor): The temperature. Shape: []. + + Returns: The similarity between video and text. Shape: [B,B]. + + """ + vision_proj = F.normalize(vision_proj, dim=-1) + text_proj = F.normalize(text_proj, dim=-1) + if vision_proj.ndim == 3: + sim_v2t = torch.einsum("mld,nd->mln", vision_proj, text_proj) / temp # [B, L, B] + sim_t2v = torch.einsum("nd,mld->nlm", text_proj, vision_proj) / temp # [B, L, B] + if agg_method == "mean": + sim_v2t = sim_v2t.mean(1) + sim_t2v = sim_t2v.mean(1) + elif agg_method == "max": + sim_v2t = sim_v2t.max(1)[0] + sim_t2v = sim_t2v.max(1)[0] + elif text_proj.ndim == 3: + sim_v2t = torch.einsum("nd,mld->nlm", vision_proj, text_proj) / temp # [B, L, B] + sim_t2v = torch.einsum("nld,md->nlm", text_proj, vision_proj) / temp # [B, L, B] + if agg_method == "mean": + sim_v2t = sim_v2t.mean(1) + sim_t2v = sim_t2v.mean(1) + elif agg_method == "max": + sim_v2t = sim_v2t.max(1)[0] + sim_t2v = sim_t2v.max(1)[0] + else: + sim_v2t = vision_proj @ text_proj.T / temp + sim_t2v = sim_v2t.T + + return sim_v2t, sim_t2v + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", + "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", + "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt", + "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", + "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt", + "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt", + "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt", + "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt", + "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt", + "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt", + "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", + "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", + "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt", + "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt", + "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt", + "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt", + "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt", + "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "bert-base-uncased": 512, + "bert-large-uncased": 512, + "bert-base-cased": 512, + "bert-large-cased": 512, + "bert-base-multilingual-uncased": 512, + "bert-base-multilingual-cased": 512, + "bert-base-chinese": 512, + "bert-base-german-cased": 512, + "bert-large-uncased-whole-word-masking": 512, + "bert-large-cased-whole-word-masking": 512, + "bert-large-uncased-whole-word-masking-finetuned-squad": 512, + "bert-large-cased-whole-word-masking-finetuned-squad": 512, + "bert-base-cased-finetuned-mrpc": 512, + "bert-base-german-dbmdz-cased": 512, + "bert-base-german-dbmdz-uncased": 512, + "TurkuNLP/bert-base-finnish-cased-v1": 512, + "TurkuNLP/bert-base-finnish-uncased-v1": 512, + "wietsedv/bert-base-dutch-cased": 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + "bert-base-uncased": {"do_lower_case": True}, + "bert-large-uncased": {"do_lower_case": True}, + "bert-base-cased": {"do_lower_case": False}, + "bert-large-cased": {"do_lower_case": False}, + "bert-base-multilingual-uncased": {"do_lower_case": True}, + "bert-base-multilingual-cased": {"do_lower_case": False}, + "bert-base-chinese": {"do_lower_case": False}, + "bert-base-german-cased": {"do_lower_case": False}, + "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, + "bert-large-cased-whole-word-masking": {"do_lower_case": False}, + "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, + "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, + "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, + "bert-base-german-dbmdz-cased": {"do_lower_case": False}, + "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, + "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, + "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, + "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, +} + + +import collections +import unicodedata +from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BasicTokenizer(object): + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + Args: + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to lowercase the input when tokenizing. + never_split (:obj:`Iterable`, `optional`): + Collection of tokens which will never be split during tokenization. Only has an effect when + :obj:`do_basic_tokenize=True` + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to tokenize Chinese characters. + This should likely be deactivated for Japanese (see this `issue + `__). + strip_accents: (:obj:`bool`, `optional`): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for :obj:`lowercase` (as in the original BERT). + """ + + def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see + WordPieceTokenizer. + Args: + **never_split**: (`optional`) list of str + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + :func:`PreTrainedTokenizer.tokenize`) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union( + set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if never_split is not None and text in never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +class BertTokenizer(PreTrainedTokenizer): + r""" + Construct a BERT tokenizer. Based on WordPiece. + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. + Users should refer to this superclass for more information regarding those methods. + Args: + vocab_file (:obj:`str`): + File containing the vocabulary. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to do basic tokenization before WordPiece. + never_split (:obj:`Iterable`, `optional`): + Collection of tokens which will never be split during tokenization. Only has an effect when + :obj:`do_basic_tokenize=True` + unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to tokenize Chinese characters. + This should likely be deactivated for Japanese (see this `issue + `__). + strip_accents: (:obj:`bool`, `optional`): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for :obj:`lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs + ): + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + vocab_file) + ) + self.vocab = load_vocab(vocab_file) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer( + vocab=self.vocab, unk_token=self.unk_token) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + - single sequence: ``[CLS] X `` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + :: + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + print( + "Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format( + vocab_file) + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +from huggingface_hub import PyTorchModelHubMixin + + +def _frame_from_video(video): + while video.isOpened(): + success, frame = video.read() + if success: + yield frame + else: + break + +v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3) +v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3) +def normalize(data): + return (data/255.0-v_mean)/v_std + + +def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')): + assert(len(vid_list) >= fnum) + step = len(vid_list) // fnum + vid_list = vid_list[::step][:fnum] + vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list] + vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list] + vid_tube = np.concatenate(vid_tube, axis=1) + vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3)) + vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float() + return vid_tube + +def vid2tensor(path: str, fnum: int=8, target_size: tuple=(224, 224), device=torch.device('cuda')): + video = cv2.VideoCapture(path) + frames = [x for x in _frame_from_video(video)] + return frames2tensor(frames, fnum, target_size, device) + +def get_text_feat_dict(texts, clip, text_feat_d={}): + for t in texts: + feat = clip.get_txt_feat(t) + text_feat_d[t] = feat + return text_feat_d + +def get_vid_feat(frames, vlm): + return vlm.get_vid_features(frames) + + +def retrieve_text(frames, + texts, + model, + topk:int=5, + device=torch.device('cuda')): + + vlm = model.to(device) + config = vlm.config + + fn = config.num_frames + size_t = config.size_t + frames_tensor = frames2tensor(frames, fnum=fn, target_size=(size_t, size_t), device=device) + vid_feat = vlm.get_vid_feat(frames_tensor) + + text_feat_d = {} + text_feat_d = get_text_feat_dict(texts, vlm, text_feat_d) + text_feats = [text_feat_d[t] for t in texts] + text_feats_tensor = torch.cat(text_feats, 0) + + probs, idxs = vlm.predict_label(vid_feat, text_feats_tensor, top=topk) + + ret_texts = [texts[i] for i in idxs.long().numpy()[0].tolist()] + return ret_texts, probs.float().numpy()[0] + + +def setup_internvideo2(config): + + model = InternVideo2_Stage2(config=config, is_pretrain=True) + + torch.set_float32_matmul_precision('high') + model = torch.compile(model) + + model = model.to(torch.device(config.device)) + model_without_ddp = model + + if (config.pretrained_path.strip() and (os.path.isfile(config.pretrained_path)) or "s3://" in config.pretrained_path): + checkpoint = torch.load(config.pretrained_path, map_location="cpu") + try: + if "model" in checkpoint.keys(): + state_dict = checkpoint["model"] + else: + state_dict = checkpoint["module"] # This is a deepspeed stage 1 model + except: + state_dict = checkpoint + + # if config.get('origin_num_frames', None) is not None: + a = len(state_dict) + interpolate_pos_embed_internvideo2_new(state_dict, model_without_ddp.vision_encoder, orig_t_size=config.origin_num_frames) + assert a == len(state_dict), state_dict.keys() + + msg = model_without_ddp.load_state_dict(state_dict, strict=False) + + model_without_ddp = model_without_ddp.to(torch.float32) + + return model_without_ddp.eval() + + +class DictToClass: + def __init__(self, data): + for key, value in data.items(): + key = str(key) + if isinstance(value, dict): + setattr(self, key, DictToClass(value)) + elif isinstance(value, list): + setattr(self, key, [ + DictToClass(item) if isinstance(item, dict) else item + for item in value + ]) + else: + setattr(self, key, value) + + def __repr__(self): + """方便调试的对象表示""" + attrs = ', '.join(f"{k}={v!r}" for k, v in self.__dict__.items()) + return f"{self.__class__.__name__}({attrs})" + + +def instance2dict(obj): + """将类实例及其嵌套属性转换为字典""" + if isinstance(obj, (str, int, float, bool, type(None))): + # 基本类型直接返回 + return obj + elif isinstance(obj, dict): + # 字典类型递归处理值 + return {k: instance2dict(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple, set)): + # 可迭代类型递归处理元素 + return type(obj)(instance2dict(item) for item in obj) + elif hasattr(obj, '__dict__'): + # 类实例处理 + result = {} + for key, value in obj.__dict__.items(): + # 过滤私有属性(可选) + if not key.startswith('_'): + result[key] = instance2dict(value) + return result + else: + # 其他不可序列化类型直接返回 + return str(obj) # 或者根据需求抛出异常 + + +class InternVideo2_Stage2_Config(PretrainedConfig): + _auto_class='AutoConfig' + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class InternVideo2_Stage2( + PreTrainedModel, + ): + """docstring for InternVideo2_Stage2""" + + _auto_class="AutoModel" + config_class=InternVideo2_Stage2_Config + + def __init__(self, + config: InternVideo2_Stage2_Config, + # tokenizer, + is_pretrain: bool=True): + + super(InternVideo2_Stage2, self).__init__(config) + + config = config.to_dict() + self._config = DictToClass(config) if isinstance(config, dict) else config + + self.tokenizer = BertTokenizer.from_pretrained(self._config.model.text_encoder.pretrained, local_files_only=True, use_safetensors=True) + + self.is_pretrain = is_pretrain + self.vision_width = self._config.model.vision_encoder.clip_embed_dim + self.text_width = self._config.model.text_encoder.d_model + self.embed_dim = self._config.model.embed_dim + + # create modules. + self.vision_encoder = self.build_vision_encoder() + self.freeze_vision() + + self.text_encoder = self.build_text_encoder() + self.freeze_text() + + self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) + self.text_proj = nn.Linear(self.text_width, self.embed_dim) + + def freeze_vision(self): + """freeze vision encoder""" + for p in self.vision_encoder.parameters(): + p.requires_grad = False + + def freeze_text(self): + """freeze text encoder""" + for p in self.text_encoder.parameters(): + p.requires_grad = False + + @property + def dtype(self): + return self.vision_encoder.patch_embed.proj.weight.dtype + + def encode_vision(self, + image: torch.Tensor, + test: bool=False): + """encode image / videos as features. + + Args: + image (torch.Tensor): The input images. + test (bool): Whether testing. + + Returns: tuple. + - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. + - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. + - student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C]. + - clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C]. + + """ + + T = image.shape[1] + use_image = True if T == 1 else False + image = image.permute(0, 2, 1, 3, 4).to(self.dtype) # [B,T,C,H,W] -> [B,C,T,H,W] + # whether save temporal dimension + # keep_temporal=self._config.model.vision_encoder.keep_temporal + if test: + vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder( + image, None, use_image) + return vision_embeds, pooled_vision_embeds + else: + mask, targets_clip_middle_vis, targets_clip_final_vis = self.encode_teacher(image) + # if mask is not None and (self.video_mask_type != 'tube' or self.image_mask_type != 'tube'): + # keep_temporal = False + # print(f"\033[31mmask is {type(mask)}\033[0m") + vision_embeds, pooled_vision_embeds, student_output, student_output_final = self.vision_encoder( + image, mask, use_image) + return vision_embeds, pooled_vision_embeds, student_output, student_output_final, targets_clip_middle_vis, targets_clip_final_vis + + def encode_text(self, + text: dict): + """encode text. + Args: + text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: + - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. + - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. + - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". + Returns: tuple. + - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C]. + - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C]. + + """ + text_output = self.get_text_encoder()( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + mode="text", + ) + text_embeds = text_output.last_hidden_state + pooled_text_embeds = text_embeds[:, 0] + return text_embeds, pooled_text_embeds + + def build_vision_encoder(self): + """build vision encoder + Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`. + + """ + encoder_name = self._config.model.vision_encoder.name + + if encoder_name == 'pretrain_internvideo2_1b_patch14_224': + vision_encoder = pretrain_internvideo2_1b_patch14_224(self._config.model) + elif encoder_name == 'pretrain_internvideo2_6b_patch14_224': + vision_encoder = pretrain_internvideo2_6b_patch14_224(self._config.model) + else: + raise ValueError(f"Not implemented: {encoder_name}") + + # parameters for mask + img_size = self._config.model.vision_encoder.img_size + num_frames = self._config.model.vision_encoder.num_frames + tublet_size = self._config.model.vision_encoder.tubelet_size + patch_size = self._config.model.vision_encoder.patch_size + self.clip_img_size = self._config.model.vision_encoder.clip_input_resolution + self.video_mask_type = self._config.model.vision_encoder.video_mask_type + self.video_window_size = (num_frames // tublet_size, img_size // patch_size, img_size // patch_size) + self.video_mask_ratio = self._config.model.vision_encoder.video_mask_ratio + self.image_mask_type = self._config.model.vision_encoder.image_mask_type + self.image_window_size = (1, img_size // patch_size, img_size // patch_size) + self.image_mask_ratio = self._config.model.vision_encoder.image_mask_ratio + + return vision_encoder + + def build_text_encoder(self): + """build text_encoder and possiblly video-to-text multimodal fusion encoder. + Returns: nn.Module. The text encoder + + """ + encoder_name = self._config.model.text_encoder.name + + if "bert" in encoder_name: + text_encoder = build_bert( + self._config.model, + self.is_pretrain, + self._config.gradient_checkpointing, + ) + else: + raise ValueError(f"Not implemented: {encoder_name}") + + return text_encoder + + def get_text_encoder(self): + """get text encoder, used for text and cross-modal encoding""" + encoder = self.text_encoder + return encoder.bert if hasattr(encoder, "bert") else encoder + + def get_vid_feat(self, + frames: torch.Tensor): + """get the video features for the given frames. + + Args: + frames (torch.Tensor): The input frames. Shape: [B,T,C,H,W]. + + Returns: tuple. + - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. + - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. + + """ + with torch.no_grad(): + _, vfeat = self.encode_vision(frames, test=True) + vfeat = self.vision_proj(vfeat) + vfeat /= vfeat.norm(dim=-1, keepdim=True) + return vfeat + + def get_txt_feat(self, + text: str): + """get the text features for the given text.""" + with torch.no_grad(): + text = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self._config.max_txt_l, + return_tensors="pt",).to(self._config.device) + _, tfeat = self.encode_text(text) + tfeat = self.text_proj(tfeat) + tfeat /= tfeat.norm(dim=-1, keepdim=True) + return tfeat + + def predict_label(self, + vid_feat: torch.Tensor, + txt_feat: torch.Tensor, + top: int=5): + label_probs = (100.0 * vid_feat @ txt_feat.T).softmax(dim=-1) + top_probs, top_labels = label_probs.float().cpu().topk(top, dim=-1) + return top_probs, top_labels