Spaces:
Sleeping
Sleeping
import math | |
import re | |
from collections import OrderedDict | |
from functools import partial | |
from typing import Dict, Iterable, List, Optional, Tuple, Union | |
import torch | |
import torch.nn.functional as F | |
from torchvision.ops import MLP | |
from einops import rearrange, repeat | |
from torch import Tensor, nn | |
from definition import PRETRAINED_BACKBONE | |
from ..configs.base_config import base_cfg | |
from ..utils import count_parameters | |
from .components import ( | |
build_2d_sincos_posemb, | |
drop_path, | |
pair, | |
trunc_normal_, | |
) | |
class PatchedInputAdapter(nn.Module): | |
"""Adapter for spatial inputs, like images or feature maps. | |
Creates tokens from patches over the image. | |
:param num_channels: Number of input channels of the image/feature map | |
:param stride_level: Stride level compared to the full-sized image. | |
E.g. 4 for 1/4th the size of the image. | |
:param patch_size_full: Int or tuple of the patch size over the full image size. | |
Patch size for smaller inputs will be computed accordingly. | |
:param dim_tokens: Dimension of output tokens. Can be set using init method. | |
:param sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings | |
:param learnable_pos_emb: Set to True to learn positional embeddings instead | |
:param image_size: Default image size. Used to initialize size of positional embeddings. | |
""" | |
def __init__( | |
self, | |
num_channels: int, | |
stride_level: int, | |
patch_size_full: Union[int, Tuple[int, int]], | |
dim_tokens: Optional[int] = None, | |
sincos_pos_emb: bool = True, | |
learnable_pos_emb: bool = False, | |
image_size: Union[int, Tuple[int]] = 224, | |
): | |
super().__init__() | |
self.num_channels = num_channels | |
self.stride_level = stride_level | |
self.patch_size_full = pair(patch_size_full) | |
self.dim_tokens = dim_tokens | |
self.sincos_pos_emb = sincos_pos_emb | |
self.learnable_pos_emb = learnable_pos_emb | |
self.image_size = pair(image_size) | |
self.num_patches = (self.image_size[0] // patch_size_full) * ( | |
self.image_size[1] // patch_size_full | |
) | |
# Actual patch height and width, taking into account stride of input | |
self.P_H = max(1, self.patch_size_full[0] // stride_level) | |
self.P_W = max(1, self.patch_size_full[1] // stride_level) | |
if self.dim_tokens is not None: | |
self.init(dim_tokens=dim_tokens) | |
def init(self, dim_tokens: int = 768): | |
""" | |
Initialize parts of encoder that are dependent on dimension of tokens. | |
Should be called when setting up MultiMAE. | |
:param dim_tokens: Dimension of tokens | |
""" | |
self.dim_tokens = dim_tokens | |
# Task embedding identifying from which task a given token comes from | |
# Fixed-size positional embeddings. Can be interpolated to different input sizes | |
h_posemb = self.image_size[0] // (self.stride_level * self.P_H) | |
w_posemb = self.image_size[1] // (self.stride_level * self.P_W) | |
if self.sincos_pos_emb: | |
self.pos_emb = build_2d_sincos_posemb( | |
h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens | |
) | |
self.pos_emb = nn.Parameter( | |
self.pos_emb, requires_grad=self.learnable_pos_emb | |
) | |
else: | |
self.pos_emb = nn.Parameter( | |
torch.zeros(1, self.dim_tokens, h_posemb, w_posemb) | |
) | |
trunc_normal_(self.pos_emb, std=0.02) | |
# Image -> tokens projection | |
self.proj = nn.Conv2d( | |
in_channels=self.num_channels, | |
out_channels=self.dim_tokens, | |
kernel_size=(self.P_H, self.P_W), | |
stride=(self.P_H, self.P_W), | |
) | |
def no_weight_decay(self): | |
return {"pos_emb"} | |
def forward(self, x: Tensor) -> Tensor: | |
""" | |
Forward pass through input adapter, transforming image to sequence of tokens. | |
Adds task and positional encodings. | |
:param x: Input image tensor | |
""" | |
B, C, H, W = x.shape | |
assert ( | |
self.dim_tokens is not None | |
), "Need to call init(dim_tokens) function first" | |
assert (H % self.P_H == 0) and ( | |
W % self.P_W == 0 | |
), f"Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}" | |
N_H, N_W = H // self.P_H, W // self.P_W # Number of patches in height and width | |
# Create patches [B, C, H, W] -> [B, (H*W), C] | |
projected_x = self.proj(x) | |
x_patch = rearrange(projected_x, "b d nh nw -> b (nh nw) d") | |
# Create positional embedding | |
x_pos_emb = F.interpolate( | |
self.pos_emb, size=(N_H, N_W), mode="bicubic", align_corners=False | |
) | |
x_pos_emb = rearrange(x_pos_emb, "b d nh nw -> b (nh nw) d") | |
# Add patches and positional embeddings | |
x = x_patch + x_pos_emb | |
return x | |
class DropPath(nn.Module): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" | |
def __init__(self, drop_prob=None): | |
super(DropPath, self).__init__() | |
self.drop_prob = drop_prob | |
def forward(self, x: Tensor) -> Tensor: | |
return drop_path(x, self.drop_prob, self.training) | |
def extra_repr(self) -> str: | |
return "p={}".format(self.drop_prob) | |
class ConvNeXtBlock(nn.Module): | |
r"""ConvNeXt Block. There are two equivalent implementations: | |
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) | |
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back | |
We use (2) as we find it slightly faster in PyTorch | |
Args: | |
dim (int): Number of input channels. | |
drop_path: Stochastic depth rate. Default: 0.0 | |
layer_scale_init_value (float): Init value for Layer Scale. Default: 0 (disabled for isotropic ConvNeXt). | |
Code from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py | |
""" | |
def __init__(self, dim, drop_path=0.0, layer_scale_init_value=0.0): | |
super().__init__() | |
self.dwconv = nn.Conv2d( | |
dim, dim, kernel_size=7, padding=3, groups=dim | |
) # depthwise conv | |
self.norm = nn.LayerNorm(dim, eps=1e-6) | |
self.pwconv1 = nn.Linear( | |
dim, 4 * dim | |
) # pointwise/1x1 convs, implemented with linear layers | |
self.act = nn.GELU() | |
self.pwconv2 = nn.Linear(4 * dim, dim) | |
self.gamma = ( | |
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) | |
if layer_scale_init_value > 0 | |
else None | |
) | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
def forward(self, x: Tensor) -> Tensor: | |
input = x | |
x = self.dwconv(x) | |
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) | |
x = self.norm(x) | |
x = self.pwconv1(x) | |
x = self.act(x) | |
x = self.pwconv2(x) | |
if self.gamma is not None: | |
x = self.gamma * x | |
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) | |
x = input + self.drop_path(x) | |
return x | |
class ConvNeXtAdapter(nn.Module): | |
"""Output adapter with ConvNext blocks for semantic segmentation | |
:param num_classes: Number of classes | |
:param num_heads: Number of attention heads | |
:param embed_dim: Token dimension after projection, and before reshaping operation. | |
:param preds_per_patch: Increases size of feature map by reshaping each patch Each patch gets reshaped | |
from embed_dim x 1 x 1 to (embed_dim / preds_per_patch) x (preds_per_patch ** 0.5) x (preds_per_patch ** 0.5) | |
:param main_tasks: Tasks to use for the adapter. Only tokens coming from these tasks are kept. | |
:param patch_size: Size of patches | |
:param depth: Number of ConvNeXt blocks | |
:interpolate_mode: Interpolation mode for final upsampling | |
""" | |
def __init__( | |
self, | |
image_size: int, | |
num_classes: int, | |
embed_dim: int = 6144, | |
preds_per_patch: int = 16, | |
main_tasks: Iterable[str] = ("rgb",), | |
patch_size: int = 16, | |
depth: int = 4, | |
interpolate_mode: str = "bilinear", | |
act_fn: nn.Module = nn.GELU, | |
dec_kernel: int = 1, | |
): | |
super().__init__() | |
self.main_tasks = main_tasks | |
self.patch_size = patch_size | |
self.embed_dim = embed_dim | |
self.preds_per_patch = preds_per_patch | |
self.class_dim = embed_dim // preds_per_patch | |
self.num_classes = num_classes | |
self.interpolate_mode = interpolate_mode | |
self.image_size = image_size | |
self.blocks = nn.Sequential( | |
*[ConvNeXtBlock(dim=self.class_dim) for _ in range(depth)] | |
) | |
if dec_kernel == 1: | |
self.final_layer_1 = nn.Sequential( | |
nn.Conv2d(self.class_dim, self.class_dim // 4, 1), | |
nn.BatchNorm2d(self.class_dim // 4), | |
act_fn(), | |
nn.Upsample(scale_factor=2, mode=self.interpolate_mode), | |
) | |
self.final_layer_2 = nn.Sequential( | |
nn.Conv2d(self.class_dim // 4, self.class_dim // 16, 1), | |
nn.BatchNorm2d(self.class_dim // 16), | |
act_fn(), | |
nn.Upsample(size=image_size, mode=self.interpolate_mode), | |
) | |
self.final_layer = nn.Conv2d(self.class_dim // 16, self.num_classes, 1) | |
elif dec_kernel == 3: | |
self.final_layer_1 = nn.Sequential( | |
nn.Conv2d( | |
self.class_dim, | |
self.class_dim // 4, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), | |
nn.BatchNorm2d(self.class_dim // 4), | |
act_fn(), | |
nn.Upsample(scale_factor=2, mode=self.interpolate_mode), | |
) | |
self.final_layer_2 = nn.Sequential( | |
nn.Conv2d( | |
self.class_dim // 4, | |
self.class_dim // 16, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
), | |
nn.BatchNorm2d(self.class_dim // 16), | |
act_fn(), | |
nn.Upsample(size=image_size, mode=self.interpolate_mode), | |
) | |
self.final_layer = nn.Conv2d( | |
self.class_dim // 16, | |
self.num_classes, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
) | |
else: | |
raise Exception(f"Unsupported dec_kernel {dec_kernel}") | |
self.apply(self._init_weights) | |
def init(self, dim_tokens_enc: int = 768): | |
""" | |
Initialize parts of decoder that are dependent on dimension of encoder tokens. | |
Should be called when setting up MultiMAE. | |
:param dim_tokens_enc: Dimension of tokens coming from encoder | |
""" | |
self.in_channels = dim_tokens_enc * len(self.main_tasks) | |
# Projection of encoder tokens to the patch dimension | |
self.proj_dec = nn.Linear(self.in_channels, self.embed_dim) | |
self._init_weights(self.proj_dec) | |
def _init_weights(self, m: nn.Module): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def adapt_tokens(self, encoder_tokens: Tensor, input_info: Dict): | |
# Adapt tokens | |
x = [] | |
for task in self.main_tasks: | |
start_idx = input_info["tasks"][task]["start_idx"] | |
end_idx = input_info["tasks"][task]["end_idx"] | |
x.append(encoder_tokens[:, start_idx:end_idx]) | |
x = torch.cat(x, dim=-1) | |
return x | |
def forward(self, encoder_tokens: Tensor, input_info: Dict) -> Tensor: | |
H, W = input_info["image_size"] | |
N_H, N_W = H // self.patch_size, W // self.patch_size | |
x = self.adapt_tokens(encoder_tokens, input_info) | |
x = self.proj_dec(x) | |
x = rearrange( | |
x, | |
"b n (p c) -> b (n p) c", | |
n=N_H * N_W, | |
p=self.preds_per_patch, | |
c=self.class_dim, | |
) | |
x = rearrange( | |
x, | |
"b (nh nw ph pw) c -> b c (nh ph) (nw pw)", | |
nh=N_H, | |
nw=N_W, | |
ph=int(self.preds_per_patch**0.5), | |
pw=int(self.preds_per_patch**0.5), | |
) | |
x = self.blocks(x) | |
# for block in self.blocks: | |
# x = block(x) | |
# print(x.shape) | |
# print(x.shape) | |
x = self.final_layer_1(x) | |
# print(x.shape) | |
x = self.final_layer_2(x) | |
# print(x.shape) | |
x = self.final_layer(x) | |
# print(x.shape) | |
# Interpolate to sod res | |
# x = F.interpolate(x, size=(H, W), mode=self.interpolate_mode) | |
return x | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
num_heads=8, | |
qkv_bias=False, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = head_dim**-0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x: Tensor) -> Tensor: | |
B, N, C = x.shape | |
qkv = ( | |
self.qkv(x) | |
.reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
.permute(2, 0, 3, 1, 4) | |
) | |
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class Mlp(nn.Module): | |
def __init__( | |
self, | |
in_features: int, | |
hidden_features: Optional[int] = None, | |
out_features: Optional[int] = None, | |
act_layer: nn.Module = nn.GELU, | |
drop: float = 0.0, | |
): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.fc1(x) | |
x = self.act(x) | |
# x = self.drop(x) | |
# commit this for the orignal BERT implement | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class Block(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
num_heads: int, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=0.0, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = Attention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
) | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp( | |
in_features=dim, | |
hidden_features=mlp_hidden_dim, | |
act_layer=act_layer, | |
drop=drop, | |
) | |
def forward(self, x: Tensor) -> Tensor: | |
x = x + self.drop_path(self.attn(self.norm1(x))) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |
class MultiMAE(nn.Module): | |
"""MultiMAE: Multi-task Multi-modal Masked Autoencoder | |
This module performs masking in its forward pass. | |
The MultiViT module defined below inherits from this module and performs a regular forward pass, | |
and should be used instead for downstream tasks | |
:param input_adapters: Dictionary of task -> input adapters | |
:param output_adapters: Optional dictionary of task -> output adapters | |
:param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1 | |
:param dim_tokens: Dimension of encoder tokens | |
:param depth: Depth of encoder | |
:param num_heads: Number of attention heads | |
:param mlp_ratio: MLP hidden dim ratio | |
:param qkv_bias: Set to False to disable bias | |
:param drop_rate: Dropout after MLPs and Attention | |
:param attn_drop_rate: Attention matrix drop rate | |
:param drop_path_rate: DropPath drop rate | |
:param norm_layer: Type of normalization layer | |
""" | |
def __init__( | |
self, | |
input_adapters: Dict[str, PatchedInputAdapter], | |
output_adapters: Dict[str, ConvNeXtAdapter], | |
num_global_tokens: int = 1, | |
dim_tokens: int = 768, | |
depth: int = 12, | |
num_heads: int = 12, | |
mlp_ratio: float = 4.0, | |
qkv_bias: bool = True, | |
drop_rate: float = 0.0, | |
attn_drop_rate: float = 0.0, | |
drop_path_rate: float = 0.0, | |
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), | |
freeze_encoder: bool = False, | |
num_additional_gt_tokens: int = 0, # @deprecated | |
actual_num_additional_gt_tokens: int = 0, # @deprecated | |
learnable_additional_gt_tokens: bool = False, | |
additional_gt_tokens_mlp_channels: List[int] = [], | |
ground_truth_version: int = -1, | |
A: float = 0.5, | |
): | |
super().__init__() | |
self.dim_tokens = dim_tokens | |
self.ground_truth_version = ground_truth_version | |
# Initialize input and output adapters | |
for adapter in input_adapters.values(): | |
adapter.init(dim_tokens=dim_tokens) | |
self.input_adapters = nn.ModuleDict(input_adapters) | |
for adapter in output_adapters.values(): | |
adapter.init(dim_tokens_enc=dim_tokens) | |
self.output_adapters = nn.ModuleDict(output_adapters) | |
# Additional learnable tokens that can be used by encoder to process/store global information | |
self.num_global_tokens = num_global_tokens | |
self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, dim_tokens)) | |
trunc_normal_(self.global_tokens, std=0.02) | |
self.num_additional_gt_tokens = num_additional_gt_tokens # @deprecated | |
self.actual_num_additional_gt_tokens = ( | |
actual_num_additional_gt_tokens # @deprecated | |
) | |
self.A = A | |
self.additional_gt_tokens_mlp_channels = additional_gt_tokens_mlp_channels | |
self.learnable_additional_gt_tokens = learnable_additional_gt_tokens | |
self.init_gt_tokens() | |
# Transformer encoder | |
dpr = [ | |
x.item() for x in torch.linspace(0, drop_path_rate, depth) | |
] # stochastic depth decay rule | |
self.encoder = nn.Sequential( | |
*[ | |
Block( | |
dim=dim_tokens, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
drop=drop_rate, | |
attn_drop=attn_drop_rate, | |
drop_path=dpr[i], | |
norm_layer=norm_layer, | |
) | |
for i in range(depth) | |
] | |
) | |
print(f"Encoder {count_parameters(self.encoder)}") | |
if freeze_encoder: | |
print("Freeze encoder") | |
for param in self.encoder.parameters(): | |
param.requires_grad = False | |
self.apply(self._init_weights) | |
for name, m in self.named_modules(): | |
if isinstance(m, nn.Linear): | |
if "qkv" in name: | |
# treat the weights of Q, K, V separately | |
val = math.sqrt( | |
6.0 / float(m.weight.shape[0] // 3 + m.weight.shape[1]) | |
) | |
nn.init.uniform_(m.weight, -val, val) | |
elif "kv" in name: | |
# treat the weights of K, V separately | |
val = math.sqrt( | |
6.0 / float(m.weight.shape[0] // 2 + m.weight.shape[1]) | |
) | |
nn.init.uniform_(m.weight, -val, val) | |
if isinstance(m, nn.Conv2d): | |
if ".proj" in name: | |
# From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) | |
w = m.weight.data | |
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
print(f"Total params: {count_parameters(self)}") | |
def init_gt_tokens(self): | |
"""Just prepare beforehand to save time in training | |
In inference, there is no need""" | |
addtional_gt_tokens: List[Tensor] = [] | |
if self.num_additional_gt_tokens == 0: | |
self.token_mlp = nn.Identity() | |
return | |
if len(self.additional_gt_tokens_mlp_channels) > 0: | |
self.token_mlp = MLP( | |
self.dim_tokens, | |
self.additional_gt_tokens_mlp_channels + [self.dim_tokens], | |
) | |
else: | |
self.token_mlp = nn.Identity() | |
if self.ground_truth_version != 6: | |
T = 1 / (self.num_additional_gt_tokens * 4) | |
for i in range(self.actual_num_additional_gt_tokens): | |
t = [ | |
2 * math.pi * (offset / self.dim_tokens - i * T) | |
for offset in range(self.dim_tokens) | |
] | |
addtional_gt_tokens.append( | |
nn.Parameter( | |
self.A * torch.cos(Tensor(t).unsqueeze(0).unsqueeze(0)), | |
requires_grad=self.learnable_additional_gt_tokens, | |
) | |
) | |
self.addtional_gt_tokens = nn.ParameterList(addtional_gt_tokens) | |
def _init_weights(self, m: nn.Module) -> None: | |
if isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def no_weight_decay(self): | |
no_wd_set = {"global_tokens"} | |
for task, adapter in self.input_adapters.items(): | |
if hasattr(adapter, "no_weight_decay"): | |
to_skip = adapter.no_weight_decay() | |
to_skip = set([f"input_adapters.{task}.{name}" for name in to_skip]) | |
no_wd_set = no_wd_set | to_skip | |
for task, adapter in self.output_adapters.items(): | |
if hasattr(adapter, "no_weight_decay"): | |
to_skip = adapter.no_weight_decay() | |
to_skip = set([f"output_adapters.{task}.{name}" for name in to_skip]) | |
no_wd_set = no_wd_set | to_skip | |
return no_wd_set | |
def generate_input_info( | |
self, input_task_tokens: Dict[str, Tensor], image_size: Tuple[int, int] | |
) -> Dict[str, Tensor]: | |
input_info = OrderedDict() | |
i = 0 | |
input_info["tasks"] = {} | |
for domain, tensor in input_task_tokens.items(): | |
num_tokens: Union[int, Tensor] = tensor.shape[1] | |
if type(num_tokens) == Tensor: | |
num_tokens = num_tokens.item() | |
d = { | |
"num_tokens": num_tokens, | |
"has_2d_posemb": True, | |
"start_idx": i, | |
"end_idx": i + num_tokens, | |
} | |
i += num_tokens | |
input_info["tasks"][domain] = d | |
input_info["image_size"] = image_size | |
input_info["num_task_tokens"] = i | |
input_info["num_global_tokens"] = self.num_global_tokens | |
return input_info | |
class MultiViT(MultiMAE): | |
def extract_B_H_W(self, x: Dict[str, Tensor]) -> Tuple[int, int, int]: | |
# If input x is a Tensor, assume it's RGB | |
# x = {'rgb': x} if isinstance(x, Tensor) else x | |
# Need image size for tokens->image reconstruction | |
if "rgb" in x: | |
B, _, H, W = x["rgb"].shape | |
elif "sod" in x: | |
B, H, W = x["sod"].shape | |
H *= self.input_adapters["sod"].stride_level | |
W *= self.input_adapters["sod"].stride_level | |
else: | |
B, _, H, W = list(x.values())[0].shape | |
return B, H, W | |
def process_input( | |
self, | |
x: Dict[str, Tensor], | |
gt_index_lst: List[int], | |
num_gts_lst: List[int], | |
) -> Tuple[Tensor, Dict[str, Tensor]]: | |
""" | |
len(gt_i) must equal to x.shape[0] when self.num_additional_gt_tokens > 0 | |
""" | |
B, H, W = self.extract_B_H_W(x) | |
# Encode selected inputs to tokens | |
input_task_tokens: Dict[str, Tensor] = { | |
domain: self.input_adapters[domain](tensor) | |
for domain, tensor in x.items() | |
if domain in self.input_adapters | |
} | |
input_info = self.generate_input_info( | |
input_task_tokens=input_task_tokens, image_size=(H, W) | |
) | |
input_tokens = torch.cat( | |
[task_tokens for task_tokens in input_task_tokens.values()], dim=1 | |
) | |
# Add global tokens to input tokens | |
global_tokens = repeat(self.global_tokens, "() n d -> b n d", b=B) | |
if self.ground_truth_version == 6: | |
# We need two inputs: gt_index, num_gts | |
assert len(gt_index_lst) == len(num_gts_lst) | |
additional_gt_tokens = [] | |
for gt_index, num_gts in zip(gt_index_lst, num_gts_lst): | |
T = 1 / num_gts | |
i = gt_index | |
t = [ | |
2 * math.pi * (offset / self.dim_tokens - i * T) | |
for offset in range(self.dim_tokens) | |
] | |
additional_gt_token = self.A * torch.cos( | |
Tensor(t).unsqueeze(0).unsqueeze(0) | |
) | |
additional_gt_tokens.append(additional_gt_token) | |
additional_gt_tokens = torch.cat(additional_gt_tokens, dim=0).to( | |
input_tokens.device | |
) | |
additional_gt_tokens = self.token_mlp(additional_gt_tokens) | |
input_tokens = torch.cat( | |
[input_tokens, global_tokens, additional_gt_tokens], dim=1 | |
) | |
else: | |
if self.num_additional_gt_tokens > 0: | |
assert gt_index_lst is not None and len(gt_index_lst) == B | |
additional_gt_tokens: Tensor = torch.cat( | |
[self.addtional_gt_tokens[gt_i] for gt_i in gt_index_lst], dim=0 | |
) | |
additional_gt_tokens = self.token_mlp(additional_gt_tokens) | |
input_tokens = torch.cat( | |
[input_tokens, global_tokens, additional_gt_tokens], dim=1 | |
) | |
else: | |
input_tokens = torch.cat([input_tokens, global_tokens], dim=1) | |
return input_tokens, input_info | |
def forward( | |
self, | |
x: Dict[str, Tensor], | |
gt_index_lst: Optional[List[int]] = None, | |
max_gts_lst: Optional[List[int]] = None, | |
) -> Dict[str, Tensor]: | |
""" | |
Forward pass through input adapters, transformer encoder and output adapters. | |
:param x: Dictionary of tensors | |
:param outputs: List of outputs. For ex: outputs=['sod', 'depth']. Make sure 'sod' placed first! | |
""" | |
input_tokens, input_info = self.process_input(x, gt_index_lst, max_gts_lst) | |
# Pass tokens through Transformer | |
encoder_tokens = self.encoder(input_tokens) | |
# Decode tokens for each task using task-specific output adapters | |
preds = { | |
domain: self.output_adapters[domain]( | |
encoder_tokens=encoder_tokens, | |
input_info=input_info, | |
) | |
for domain in self.output_adapters | |
} | |
return preds | |
def interpolate_pos_embed_multimae( | |
model: MultiViT, | |
checkpoint_model: Dict[str, Tensor], | |
) -> None: | |
pattern = "input_adapters\.(.*)\.pos_emb" | |
matched_keys = [k for k in checkpoint_model if bool(re.match(pattern, k))] | |
for key in matched_keys: | |
domain = re.match(pattern, key).group(1) # group(0) is entire matched regex | |
if getattr(model.input_adapters, domain, None) is not None: | |
pos_embed_checkpoint = checkpoint_model[key] | |
_, _, orig_H, orig_W = pos_embed_checkpoint.shape | |
_, _, new_H, new_W = getattr(model.input_adapters, domain).pos_emb.shape | |
if (orig_H != new_H) or (orig_W != new_W): | |
print( | |
f"Key {key}: Position interpolate from {orig_H}x{orig_W} to {new_H}x{new_W}" | |
) | |
pos_embed_checkpoint = torch.nn.functional.interpolate( | |
pos_embed_checkpoint, | |
size=(new_H, new_W), | |
mode="bicubic", | |
align_corners=False, | |
) | |
checkpoint_model[key] = pos_embed_checkpoint | |
def construct_adapters(cfg: base_cfg): | |
INPUT_ADAPTERS = { | |
"rgb": PatchedInputAdapter( | |
num_channels=3, | |
stride_level=1, | |
patch_size_full=cfg.input_patch_size, | |
image_size=cfg.image_size, | |
learnable_pos_emb=cfg.learnable_pos_emb, | |
), | |
"depth": PatchedInputAdapter( | |
num_channels=1, | |
stride_level=1, | |
patch_size_full=cfg.input_patch_size, | |
image_size=cfg.image_size, | |
learnable_pos_emb=cfg.learnable_pos_emb, | |
), | |
} | |
num_classes = cfg.num_classes | |
if cfg.ground_truth_version in [5, 6]: | |
num_classes = 1 | |
OUTPUT_ADAPTERS = { | |
"sod": partial( | |
ConvNeXtAdapter, | |
num_classes=num_classes, | |
image_size=cfg.image_size, | |
embed_dim=cfg.embed_dim, | |
patch_size=cfg.input_patch_size, | |
preds_per_patch=cfg.output_patch_size, | |
depth=cfg.decoder_depth, | |
interpolate_mode=cfg.decoder_interpolate_mode, | |
main_tasks=cfg.decoder_main_tasks, | |
act_fn=cfg.act_fn, | |
dec_kernel=cfg.dec_kernel, | |
), | |
"rgb": partial( | |
ConvNeXtAdapter, | |
num_classes=3, | |
image_size=cfg.image_size, | |
embed_dim=cfg.embed_dim, | |
patch_size=cfg.input_patch_size, | |
preds_per_patch=cfg.output_patch_size, | |
depth=cfg.decoder_depth, | |
interpolate_mode=cfg.decoder_interpolate_mode, | |
main_tasks=cfg.decoder_main_tasks, | |
act_fn=cfg.act_fn, | |
dec_kernel=cfg.dec_kernel, | |
), | |
"depth": partial( | |
ConvNeXtAdapter, | |
num_classes=1, | |
image_size=cfg.image_size, | |
embed_dim=cfg.embed_dim, | |
patch_size=cfg.input_patch_size, | |
preds_per_patch=cfg.output_patch_size, | |
depth=cfg.decoder_depth, | |
interpolate_mode=cfg.decoder_interpolate_mode, | |
main_tasks=cfg.decoder_main_tasks, | |
act_fn=cfg.act_fn, | |
dec_kernel=cfg.dec_kernel, | |
), | |
} | |
if cfg.ground_truth_version == 3: | |
for i in range(cfg.num_classes): | |
OUTPUT_ADAPTERS[f"sod{i}"] = partial( | |
ConvNeXtAdapter, | |
num_classes=1, | |
image_size=cfg.image_size, | |
embed_dim=cfg.embed_dim, | |
patch_size=cfg.input_patch_size, | |
preds_per_patch=cfg.output_patch_size, | |
depth=cfg.decoder_depth, | |
interpolate_mode=cfg.decoder_interpolate_mode, | |
main_tasks=cfg.decoder_main_tasks, | |
act_fn=cfg.act_fn, | |
dec_kernel=cfg.dec_kernel, | |
) | |
return INPUT_ADAPTERS, OUTPUT_ADAPTERS | |
def generate_smultimae_model(cfg: base_cfg) -> Tuple[MultiViT, List[Dict]]: | |
"""MULTIMAE""" | |
assert len(cfg.decoder_main_tasks) == len( | |
cfg.outputs | |
), "Length of decoder main tasks must match length of outputs" | |
INPUT_ADAPTERS, OUTPUT_ADAPTERS = construct_adapters(cfg) | |
input_adapters = dict() | |
for input_key in cfg.inputs: | |
input_adapters[input_key] = INPUT_ADAPTERS[input_key] | |
output_adapters = dict() | |
for output_key, decoder_main_tasks_per_output in zip( | |
cfg.outputs, cfg.decoder_main_tasks | |
): | |
output_adapters[output_key] = OUTPUT_ADAPTERS[output_key]( | |
main_tasks=decoder_main_tasks_per_output | |
) | |
num_additional_gt_tokens = 0 # @deprecated | |
actual_num_additional_gt_tokens = 0 # @deprecated | |
if cfg.ground_truth_version in [5, 6]: # @deprecated | |
num_additional_gt_tokens = cfg.num_classes # @deprecated | |
actual_num_additional_gt_tokens = cfg.actual_num_classes # @deprecated | |
model = MultiViT( | |
input_adapters=input_adapters, | |
output_adapters=output_adapters, | |
freeze_encoder=cfg.freeze_encoder, | |
drop_path_rate=0.1, | |
dim_tokens=cfg.dim_tokens, | |
depth=cfg.encoder_depth, | |
num_heads=cfg.num_heads, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
num_additional_gt_tokens=num_additional_gt_tokens, # @deprecated | |
actual_num_additional_gt_tokens=actual_num_additional_gt_tokens, # @deprecated | |
ground_truth_version=cfg.ground_truth_version, | |
) | |
# return load_pretrained_backbone(cfg, model) | |
return model, [] | |