import torch import torch.nn as nn import torch.nn.functional as F from timm.models.vision_transformer import _load_weights from timm.models.layers import trunc_normal_ from typing import List from src.models.vit.utils import init_weights, resize_pos_embed from src.models.vit.blocks import Block from src.models.vit.decoder import DecoderLinear class PatchEmbedding(nn.Module): def __init__(self, image_size, patch_size, embed_dim, channels): super().__init__() self.image_size = image_size if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0: raise ValueError("image dimensions must be divisible by the patch size") self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size self.num_patches = self.grid_size[0] * self.grid_size[1] self.patch_size = patch_size self.proj = nn.Conv2d(channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, im): B, C, H, W = im.shape x = self.proj(im).flatten(2).transpose(1, 2) return x class VisionTransformer(nn.Module): def __init__( self, image_size, patch_size, n_layers, d_model, d_ff, n_heads, n_cls, dropout=0.1, drop_path_rate=0.0, distilled=False, channels=3, ): super().__init__() self.patch_embed = PatchEmbedding( image_size, patch_size, d_model, channels, ) self.patch_size = patch_size self.n_layers = n_layers self.d_model = d_model self.d_ff = d_ff self.n_heads = n_heads self.dropout = nn.Dropout(dropout) self.n_cls = n_cls # cls and pos tokens self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) self.distilled = distilled if self.distilled: self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 2, d_model)) self.head_dist = nn.Linear(d_model, n_cls) else: self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, d_model)) # transformer blocks dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]) # output head self.norm = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, n_cls) trunc_normal_(self.pos_embed, std=0.02) trunc_normal_(self.cls_token, std=0.02) if self.distilled: trunc_normal_(self.dist_token, std=0.02) self.pre_logits = nn.Identity() self.apply(init_weights) @torch.jit.ignore def no_weight_decay(self): return {"pos_embed", "cls_token", "dist_token"} @torch.jit.ignore() def load_pretrained(self, checkpoint_path, prefix=""): _load_weights(self, checkpoint_path, prefix) def forward(self, im, head_out_idx: List[int], n_dim_output=3, return_features=False): B, _, H, W = im.shape PS = self.patch_size assert n_dim_output == 3 or n_dim_output == 4, "n_dim_output must be 3 or 4" x = self.patch_embed(im) cls_tokens = self.cls_token.expand(B, -1, -1) if self.distilled: dist_tokens = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_tokens, x), dim=1) else: x = torch.cat((cls_tokens, x), dim=1) pos_embed = self.pos_embed num_extra_tokens = 1 + self.distilled if x.shape[1] != pos_embed.shape[1]: pos_embed = resize_pos_embed( pos_embed, self.patch_embed.grid_size, (H // PS, W // PS), num_extra_tokens, ) x = x + pos_embed x = self.dropout(x) device = x.device if n_dim_output == 3: heads_out = torch.zeros(size=(len(head_out_idx), B, (H // PS) ** 2 + 1, self.d_model)).to(device) else: heads_out = torch.zeros(size=(len(head_out_idx), B, self.d_model, H // PS, H // PS)).to(device) self.register_buffer("heads_out", heads_out) head_idx = 0 for idx_layer, blk in enumerate(self.blocks): x = blk(x) if idx_layer in head_out_idx: if n_dim_output == 3: heads_out[head_idx] = x else: heads_out[head_idx] = x[:, 1:, :].reshape((-1, 24, 24, self.d_model)).permute(0, 3, 1, 2) head_idx += 1 x = self.norm(x) if return_features: return heads_out if self.distilled: x, x_dist = x[:, 0], x[:, 1] x = self.head(x) x_dist = self.head_dist(x_dist) x = (x + x_dist) / 2 else: x = x[:, 0] x = self.head(x) return x def get_attention_map(self, im, layer_id): if layer_id >= self.n_layers or layer_id < 0: raise ValueError(f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}.") B, _, H, W = im.shape PS = self.patch_size x = self.patch_embed(im) cls_tokens = self.cls_token.expand(B, -1, -1) if self.distilled: dist_tokens = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_tokens, x), dim=1) else: x = torch.cat((cls_tokens, x), dim=1) pos_embed = self.pos_embed num_extra_tokens = 1 + self.distilled if x.shape[1] != pos_embed.shape[1]: pos_embed = resize_pos_embed( pos_embed, self.patch_embed.grid_size, (H // PS, W // PS), num_extra_tokens, ) x = x + pos_embed for i, blk in enumerate(self.blocks): if i < layer_id: x = blk(x) else: return blk(x, return_attention=True) class FeatureTransform(nn.Module): def __init__(self, img_size, d_encoder, nls_list=[128, 256, 512, 512], scale_factor_list=[8, 4, 2, 1]): super(FeatureTransform, self).__init__() self.img_size = img_size self.decoder_0 = DecoderLinear(n_cls=nls_list[0], d_encoder=d_encoder, scale_factor=scale_factor_list[0]) self.decoder_1 = DecoderLinear(n_cls=nls_list[1], d_encoder=d_encoder, scale_factor=scale_factor_list[1]) self.decoder_2 = DecoderLinear(n_cls=nls_list[2], d_encoder=d_encoder, scale_factor=scale_factor_list[2]) self.decoder_3 = DecoderLinear(n_cls=nls_list[3], d_encoder=d_encoder, scale_factor=scale_factor_list[3]) def forward(self, x_list): feat_3 = self.decoder_3(x_list[3][:, 1:, :], self.img_size) # (2, 512, 24, 24) feat_2 = self.decoder_2(x_list[2][:, 1:, :], self.img_size) # (2, 512, 48, 48) feat_1 = self.decoder_1(x_list[1][:, 1:, :], self.img_size) # (2, 256, 96, 96) feat_0 = self.decoder_0(x_list[0][:, 1:, :], self.img_size) # (2, 128, 192, 192) return feat_0, feat_1, feat_2, feat_3