# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py from functools import partial import math from typing import Sequence, Tuple, Union, Callable import torch import torch.nn as nn import torch.utils.checkpoint from mmseg.models.builder import BACKBONES from mmengine.model import BaseModule import torch.nn.functional as F from .dino_layers import ( Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block, ) def named_apply( fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False ) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) if depth_first and include_root: fn(module=module, name=name) return module class BlockChunk(nn.ModuleList): def forward(self, x): for b in self: x = b(x) return x @BACKBONES.register_module() class DinoVisionTransformer(BaseModule): def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, ffn_bias=True, proj_bias=True, drop_path_rate=0.0, drop_path_uniform=False, init_values=None, # for layerscale: None or 0 => no layerscale embed_layer=PatchEmbed, act_layer=nn.GELU, block_fn=partial(Block, attn_class=MemEffAttention), ffn_layer="mlp", block_chunks=1, out_indices=[7, 11, 15, 23], init_cfg=None, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True proj_bias (bool): enable bias for proj in attn if True ffn_bias (bool): enable bias for ffn if True drop_path_rate (float): stochastic depth rate drop_path_uniform (bool): apply uniform drop rate across blocks weight_init (str): weight init scheme init_values (float): layer-scale init values embed_layer (nn.Module): patch embedding layer act_layer (nn.Module): MLP activation layer block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap """ super().__init__(init_cfg) norm_layer = partial(nn.LayerNorm, eps=1e-6) self.out_indices = out_indices self.drop_path_rate = drop_path_rate self.num_features = ( self.embed_dim ) = embed_dim # num_features for consistency with other models self.num_tokens = 1 self.n_blocks = depth self.num_heads = num_heads self.norm_layer = norm_layer self.patch_size = patch_size self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + self.num_tokens, embed_dim) ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth else: dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if ffn_layer == "mlp": ffn_layer = Mlp elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": ffn_layer = SwiGLUFFNFused elif ffn_layer == "identity": def f(*args, **kwargs): return nn.Identity() ffn_layer = f else: raise NotImplementedError blocks_list = [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ffn_layer=ffn_layer, init_values=init_values, ) for i in range(depth) ] if block_chunks > 0: self.chunked_blocks = True chunked_blocks = [] chunksize = depth // block_chunks for i in range(0, depth, chunksize): # this is to keep the block index consistent if we chunk the block list chunked_blocks.append( [nn.Identity()] * i + blocks_list[i : i + chunksize] ) self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) else: self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = norm_layer(embed_dim) self.head = nn.Identity() self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) def interpolate_pos_encoding(self, x, w, h): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed pos_embed = self.pos_embed.float() class_pos_embed = pos_embed[:, 0] patch_pos_embed = pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape( 1, int(math.sqrt(N)), int(math.sqrt(N)), dim ).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode="bicubic", ) assert ( int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( previous_dtype ) def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape x = self.patch_embed(x) if masks is not None: x = torch.where( masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x ) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) return x def forward_features_list(self, x_list, masks_list): x = [ self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list) ] for blk in self.blocks: x = blk(x) all_x = x output = [] for x, masks in zip(all_x, masks_list): x_norm = self.norm(x) output.append( { "x_norm_clstoken": x_norm[:, 0], "x_norm_patchtokens": x_norm[:, 1:], "x_prenorm": x, "masks": masks, } ) return output def forward_features(self, x, masks=None): B, _, h, w = x.shape if isinstance(x, list): return self.forward_features_list(x, masks) x = self.prepare_tokens_with_masks(x, masks) outs = [] for idx, blk in enumerate(self.blocks): x = blk(x) if idx in self.out_indices: outs.append( x[:, 1:, :] .permute(0, 2, 1) .reshape(B, -1, h // self.patch_size, w // self.patch_size) .contiguous() ) return outs def _get_intermediate_layers_not_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) # If n is an int, take the n last blocks. If it's a list, take them output, total_block_len = [], len(self.blocks) blocks_to_take = ( range(total_block_len - n, total_block_len) if isinstance(n, int) else n ) for i, blk in enumerate(self.blocks): x = blk(x) if i in blocks_to_take: output.append(x) assert len(output) == len( blocks_to_take ), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def _get_intermediate_layers_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) output, i, total_block_len = [], 0, len(self.blocks[-1]) # If n is an int, take the n last blocks. If it's a list, take them blocks_to_take = ( range(total_block_len - n, total_block_len) if isinstance(n, int) else n ) for block_chunk in self.blocks: for blk in block_chunk[i:]: # Passing the nn.Identity() x = blk(x) if i in blocks_to_take: output.append(x) i += 1 assert len(output) == len( blocks_to_take ), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def get_intermediate_layers( self, x: torch.Tensor, n: Union[int, Sequence] = 1, # Layers or n last layers to take reshape: bool = False, return_class_token: bool = False, norm=True, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: if self.chunked_blocks: outputs = self._get_intermediate_layers_chunked(x, n) else: outputs = self._get_intermediate_layers_not_chunked(x, n) if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] outputs = [out[:, 1:] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [ out.reshape(B, w // self.patch_size, h // self.patch_size, -1) .permute(0, 3, 1, 2) .contiguous() for out in outputs ] if return_class_token: return tuple(zip(outputs, class_tokens)) return tuple(outputs) def forward(self, *args, **kwargs): ret = self.forward_features(*args, **kwargs) if isinstance(ret[0], torch.Tensor): ret[0] = F.interpolate( ret[0], scale_factor=4, mode="bilinear", align_corners=False ) ret[1] = F.interpolate( ret[1], scale_factor=2, mode="bilinear", align_corners=False ) ret[3] = F.interpolate( ret[3], scale_factor=0.5, mode="bilinear", align_corners=False ) else: ret[0][0] = F.interpolate( ret[0][0], scale_factor=4, mode="bilinear", align_corners=False ) ret[0][1] = F.interpolate( ret[0][1], scale_factor=2, mode="bilinear", align_corners=False ) ret[0][3] = F.interpolate( ret[0][3], scale_factor=0.5, mode="bilinear", align_corners=False ) return ret