|
from safetensors import safe_open |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from timm.models.layers import to_2tuple |
|
from timm.models.vision_transformer import Block |
|
|
|
|
|
|
|
def _convTranspose2dOutput( |
|
input_size: int, |
|
stride: int, |
|
padding: int, |
|
dilation: int, |
|
kernel_size: int, |
|
output_padding: int, |
|
): |
|
""" |
|
Calculate the output size of a ConvTranspose2d. |
|
Taken from: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html |
|
""" |
|
return ( |
|
(input_size - 1) * stride |
|
- 2 * padding |
|
+ dilation * (kernel_size - 1) |
|
+ output_padding |
|
+ 1 |
|
) |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float32) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
def get_3d_sincos_pos_embed(embed_dim: int, grid_size: tuple, cls_token: bool = False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
grid_size: 3d tuple of grid size: t, h, w |
|
return: |
|
pos_embed: L, D |
|
""" |
|
|
|
assert embed_dim % 16 == 0 |
|
|
|
t_size, h_size, w_size = grid_size |
|
|
|
w_embed_dim = embed_dim // 16 * 6 |
|
h_embed_dim = embed_dim // 16 * 6 |
|
t_embed_dim = embed_dim // 16 * 4 |
|
|
|
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) |
|
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) |
|
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) |
|
|
|
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) |
|
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) |
|
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) |
|
|
|
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) |
|
|
|
if cls_token: |
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
class Norm2d(nn.Module): |
|
def __init__(self, embed_dim: int): |
|
super().__init__() |
|
self.ln = nn.LayerNorm(embed_dim, eps=1e-6) |
|
|
|
def forward(self, x): |
|
x = x.permute(0, 2, 3, 1) |
|
x = self.ln(x) |
|
x = x.permute(0, 3, 1, 2).contiguous() |
|
return x |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
"""Frames of 2D Images to Patch Embedding |
|
The 3D version of timm.models.vision_transformer.PatchEmbed |
|
""" |
|
|
|
def __init__( |
|
self, |
|
img_size: int = 224, |
|
patch_size: int = 16, |
|
num_frames: int = 3, |
|
tubelet_size: int = 1, |
|
in_chans: int = 3, |
|
embed_dim: int = 768, |
|
norm_layer: nn.Module = None, |
|
flatten: bool = True, |
|
bias: bool = True, |
|
): |
|
super().__init__() |
|
img_size = to_2tuple(img_size) |
|
patch_size = to_2tuple(patch_size) |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.num_frames = num_frames |
|
self.tubelet_size = tubelet_size |
|
self.grid_size = ( |
|
num_frames // tubelet_size, |
|
img_size[0] // patch_size[0], |
|
img_size[1] // patch_size[1], |
|
) |
|
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] |
|
self.flatten = flatten |
|
|
|
self.proj = nn.Conv3d( |
|
in_chans, |
|
embed_dim, |
|
kernel_size=(tubelet_size, patch_size[0], patch_size[1]), |
|
stride=(tubelet_size, patch_size[0], patch_size[1]), |
|
bias=bias, |
|
) |
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
def forward(self, x): |
|
B, C, T, H, W = x.shape |
|
assert ( |
|
H == self.img_size[0] |
|
), f"Input image height ({H}) doesn't match model ({self.img_size[0]})." |
|
assert ( |
|
W == self.img_size[1] |
|
), f"Input image width ({W}) doesn't match model ({self.img_size[1]})." |
|
x = self.proj(x) |
|
Hp, Wp = x.shape[3], x.shape[4] |
|
if self.flatten: |
|
x = x.flatten(2).transpose(1, 2) |
|
x = self.norm(x) |
|
return x, Hp, Wp |
|
|
|
|
|
class ConvTransformerTokensToEmbeddingNeck(nn.Module): |
|
""" |
|
Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers. |
|
Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
output_embed_dim: int, |
|
|
|
Hp: int = 14, |
|
Wp: int = 14, |
|
drop_cls_token: bool = True, |
|
): |
|
""" |
|
|
|
Args: |
|
embed_dim (int): Input embedding dimension |
|
output_embed_dim (int): Output embedding dimension |
|
Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14. |
|
Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14. |
|
drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True. |
|
""" |
|
super().__init__() |
|
self.drop_cls_token = drop_cls_token |
|
self.Hp = Hp |
|
self.Wp = Wp |
|
self.H_out = Hp |
|
self.W_out = Wp |
|
|
|
|
|
kernel_size = 2 |
|
stride = 2 |
|
dilation = 1 |
|
padding = 0 |
|
output_padding = 0 |
|
for _ in range(4): |
|
self.H_out = _convTranspose2dOutput( |
|
self.H_out, stride, padding, dilation, kernel_size, output_padding |
|
) |
|
self.W_out = _convTranspose2dOutput( |
|
self.W_out, stride, padding, dilation, kernel_size, output_padding |
|
) |
|
|
|
self.embed_dim = embed_dim |
|
self.output_embed_dim = output_embed_dim |
|
self.fpn1 = nn.Sequential( |
|
nn.ConvTranspose2d( |
|
self.embed_dim, |
|
self.output_embed_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding, |
|
output_padding=output_padding, |
|
), |
|
Norm2d(self.output_embed_dim), |
|
nn.GELU(), |
|
nn.ConvTranspose2d( |
|
self.output_embed_dim, |
|
self.output_embed_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding, |
|
output_padding=output_padding, |
|
), |
|
) |
|
self.fpn2 = nn.Sequential( |
|
nn.ConvTranspose2d( |
|
self.output_embed_dim, |
|
self.output_embed_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding, |
|
output_padding=output_padding, |
|
), |
|
Norm2d(self.output_embed_dim), |
|
nn.GELU(), |
|
nn.ConvTranspose2d( |
|
self.output_embed_dim, |
|
self.output_embed_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding, |
|
output_padding=output_padding, |
|
), |
|
) |
|
|
|
def forward(self, x): |
|
x = x[0] |
|
if self.drop_cls_token: |
|
x = x[:, 1:, :] |
|
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp) |
|
|
|
x = self.fpn1(x) |
|
x = self.fpn2(x) |
|
|
|
x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out)) |
|
out = tuple([x]) |
|
return out |
|
|
|
class ConvTransformerTokensToEmbeddingBottleneckNeck(nn.Module): |
|
""" |
|
Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers. |
|
Performs ConvTranspose2d operations with bottleneck layers to reduce channels. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
output_embed_dim: int, |
|
Hp: int = 14, |
|
Wp: int = 14, |
|
drop_cls_token: bool = True, |
|
bottleneck_reduction_factor: int = 4, |
|
): |
|
""" |
|
Args: |
|
embed_dim (int): Input embedding dimension |
|
output_embed_dim (int): Output embedding dimension |
|
Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14. |
|
Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14. |
|
drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. Defaults to True. |
|
bottleneck_ratio (int, optional): Ratio to reduce channels in bottleneck layers. Defaults to 4. |
|
""" |
|
super().__init__() |
|
self.drop_cls_token = drop_cls_token |
|
self.Hp = Hp |
|
self.Wp = Wp |
|
self.H_out = Hp |
|
self.W_out = Wp |
|
|
|
kernel_size = 2 |
|
stride = 2 |
|
dilation = 1 |
|
padding = 0 |
|
output_padding = 0 |
|
for _ in range(4): |
|
self.H_out = _convTranspose2dOutput( |
|
self.H_out, stride, padding, dilation, kernel_size, output_padding |
|
) |
|
self.W_out = _convTranspose2dOutput( |
|
self.W_out, stride, padding, dilation, kernel_size, output_padding |
|
) |
|
|
|
self.embed_dim = embed_dim |
|
self.output_embed_dim = output_embed_dim |
|
bottleneck_dim = self.embed_dim // bottleneck_reduction_factor |
|
|
|
self.fpn1 = nn.Sequential( |
|
nn.Conv2d( |
|
self.embed_dim, |
|
bottleneck_dim, |
|
kernel_size=1 |
|
), |
|
Norm2d(bottleneck_dim), |
|
nn.GELU(), |
|
nn.ConvTranspose2d( |
|
bottleneck_dim, |
|
bottleneck_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding |
|
), |
|
Norm2d(bottleneck_dim), |
|
nn.GELU(), |
|
nn.ConvTranspose2d( |
|
bottleneck_dim, |
|
bottleneck_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding |
|
), |
|
Norm2d(bottleneck_dim), |
|
nn.GELU(), |
|
nn.Conv2d( |
|
bottleneck_dim, |
|
self.output_embed_dim, |
|
kernel_size=1 |
|
), |
|
Norm2d(self.output_embed_dim), |
|
nn.GELU(), |
|
) |
|
|
|
self.fpn2 = nn.Sequential( |
|
nn.Conv2d( |
|
self.output_embed_dim, |
|
bottleneck_dim, |
|
kernel_size=1 |
|
), |
|
Norm2d(bottleneck_dim), |
|
nn.GELU(), |
|
nn.ConvTranspose2d( |
|
bottleneck_dim, |
|
bottleneck_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding |
|
), |
|
Norm2d(bottleneck_dim), |
|
nn.GELU(), |
|
nn.ConvTranspose2d( |
|
bottleneck_dim, |
|
bottleneck_dim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding |
|
), |
|
Norm2d(bottleneck_dim), |
|
nn.GELU(), |
|
nn.Conv2d( |
|
bottleneck_dim, |
|
self.output_embed_dim, |
|
kernel_size=1 |
|
), |
|
Norm2d(self.output_embed_dim), |
|
nn.GELU(), |
|
) |
|
|
|
def forward(self, x): |
|
x = x[0] |
|
if self.drop_cls_token: |
|
x = x[:, 1:, :] |
|
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp) |
|
|
|
x = self.fpn1(x) |
|
x = self.fpn2(x) |
|
|
|
x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out)) |
|
out = tuple([x]) |
|
return out |
|
|
|
class TemporalViTEncoder(nn.Module): |
|
"""Encoder from an ViT with capability to take in temporal input. |
|
|
|
This class defines an encoder taken from a ViT architecture. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
img_size: int = 224, |
|
patch_size: int = 16, |
|
num_frames: int = 1, |
|
tubelet_size: int = 1, |
|
in_chans: int = 3, |
|
embed_dim: int = 1024, |
|
depth: int = 24, |
|
num_heads: int = 16, |
|
mlp_ratio: float = 4.0, |
|
norm_layer: nn.Module = nn.LayerNorm, |
|
norm_pix_loss: bool = False, |
|
pretrained: str = None, |
|
debug=False |
|
): |
|
""" |
|
|
|
Args: |
|
img_size (int, optional): Input image size. Defaults to 224. |
|
patch_size (int, optional): Patch size to be used by the transformer. Defaults to 16. |
|
num_frames (int, optional): Number of frames (temporal dimension) to be input to the encoder. Defaults to 1. |
|
tubelet_size (int, optional): Tubelet size used in patch embedding. Defaults to 1. |
|
in_chans (int, optional): Number of input channels. Defaults to 3. |
|
embed_dim (int, optional): Embedding dimension. Defaults to 1024. |
|
depth (int, optional): Encoder depth. Defaults to 24. |
|
num_heads (int, optional): Number of heads used in the encoder blocks. Defaults to 16. |
|
mlp_ratio (float, optional): Ratio to be used for the size of the MLP in encoder blocks. Defaults to 4.0. |
|
norm_layer (nn.Module, optional): Norm layer to be used. Defaults to nn.LayerNorm. |
|
norm_pix_loss (bool, optional): Whether to use Norm Pix Loss. Defaults to False. |
|
pretrained (str, optional): Path to pretrained encoder weights. Defaults to None. |
|
""" |
|
super().__init__() |
|
|
|
|
|
|
|
self.embed_dim = embed_dim |
|
self.patch_embed = PatchEmbed( |
|
img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim |
|
) |
|
num_patches = self.patch_embed.num_patches |
|
self.num_frames = num_frames |
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
self.pos_embed = nn.Parameter( |
|
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False |
|
) |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
Block( |
|
embed_dim, |
|
num_heads, |
|
mlp_ratio, |
|
qkv_bias=True, |
|
norm_layer=norm_layer, |
|
) |
|
for _ in range(depth) |
|
] |
|
) |
|
self.norm = norm_layer(embed_dim) |
|
|
|
self.norm_pix_loss = norm_pix_loss |
|
self.pretrained = pretrained |
|
self.debug = debug |
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
pos_embed = get_3d_sincos_pos_embed( |
|
self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True |
|
) |
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.patch_embed.proj.weight.data |
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_safetensors_weights(self): |
|
with safe_open(self.pretrained, framework='pt', device='cpu') as f: |
|
checkpoint_state_dict = {k: torch.tensor(v) for k, v in f.items()} |
|
missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False) |
|
if missing_keys: |
|
print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys) |
|
if unexpected_keys: |
|
print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys) |
|
print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (safetensors).") |
|
|
|
def _load_pt_weights(self): |
|
checkpoint = torch.load(self.pretrained, map_location='cpu') |
|
checkpoint_state_dict = checkpoint.get('state_dict', checkpoint) |
|
missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False) |
|
if missing_keys: |
|
print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys) |
|
if unexpected_keys: |
|
print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys) |
|
print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (pt file).") |
|
|
|
def _init_weights(self, m): |
|
print("TemporalViTEncoder | Newly Initializing weights...") |
|
if isinstance(m, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def forward(self, x): |
|
if self.debug: |
|
print('TemporalViTEncoder IN:', x.shape) |
|
|
|
|
|
x, _, _ = self.patch_embed(x) |
|
|
|
if self.debug: |
|
print('TemporalViTEncoder EMBED:', x.shape) |
|
|
|
|
|
x = x + self.pos_embed[:, 1:, :] |
|
|
|
|
|
cls_token = self.cls_token + self.pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
|
|
x = self.norm(x) |
|
|
|
if self.debug: |
|
print('TemporalViTEncoder OUT:', x.shape) |
|
|
|
return tuple([x]) |