# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # The code in this file is adapted from the BeiT implementation which can be found here: # https://github.com/microsoft/unilm/tree/master/beit import logging from dataclasses import dataclass from functools import partial from timm.models.vision_transformer import PatchEmbed, Block import torch import torch.nn as nn import numpy as np from fairseq.dataclass import FairseqDataclass from fairseq.models import BaseFairseqModel, register_model from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer try: from apex.normalization import FusedLayerNorm except: FusedLayerNorm = nn.LayerNorm import torch.nn.functional as F logger = logging.getLogger(__name__) @dataclass class MaeConfig(FairseqDataclass): input_size: int = 224 in_chans: int = 3 patch_size: int = 16 embed_dim: int = 768 depth: int = 12 num_heads: int = 12 decoder_embed_dim: int = 512 decoder_depth: int = 8 decoder_num_heads: int = 16 mlp_ratio: int = 4 norm_eps: float = 1e-6 drop_path_rate: float = 0.0 mask_ratio: float = 0.75 norm_pix_loss: bool = True w2v_block: bool = False alt_block: bool = False alt_block2: bool = False alt_attention: bool = False block_dropout: float = 0 attention_dropout: float = 0 activation_dropout: float = 0 layer_norm_first: bool = False fused_ln: bool = True end_of_block_targets: bool = True no_decoder_embed: bool = False no_decoder_pos_embed: bool = False mask_noise_std: float = 0 single_qkv: bool = False use_rel_pos_bias: bool = False no_cls: bool = False def modify_relative_position_bias(orig_bias, bsz, mask): if mask is None: return orig_bias.unsqueeze(0).repeat( bsz, 1, 1, 1 ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token mask_for_rel_pos_bias = torch.cat( (torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1 ).bool() # bsz x seqlen (add CLS token) unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat( 1, heads, 1 ) # bsz x seq_len => bsz x heads x seq_len b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat( bsz, 1, 1, 1 ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select( unmasked_for_rel_pos_bias.unsqueeze(-1) ) b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len) new_len = b_t_t_rel_pos_bias.size(-2) b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select( unmasked_for_rel_pos_bias.unsqueeze(-2) ) b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len) return b_t_t_rel_pos_bias class AltBlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_norm_first=True, ffn_targets=False, use_rel_pos_bias=False, window_size=None, alt_attention=False, ): super().__init__() self.layer_norm_first = layer_norm_first self.ffn_targets = ffn_targets from timm.models.vision_transformer import Attention, DropPath, Mlp self.norm1 = norm_layer(dim) self.use_rel_pos_bias = use_rel_pos_bias if use_rel_pos_bias: self.attn = AltAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, ) else: if alt_attention: from .multi.modules import AltAttention as AltAttention2 self.attn = AltAttention2( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) else: self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) def forward(self, x, rel_pos_bias=None, pos_mask=None): if self.layer_norm_first: if self.use_rel_pos_bias: x = x + self.drop_path( self.attn( self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask ) ) else: x = x + self.drop_path(self.attn(self.norm1(x))) t = self.mlp(self.norm2(x)) x = x + self.drop_path(t) if not self.ffn_targets: t = x return x, t else: if self.use_rel_pos_bias: x = x + self.drop_path( self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask) ) else: x = x + self.drop_path(self.attn(x)) r = x = self.norm1(x) x = self.mlp(x) t = x x = self.norm2(r + self.drop_path(x)) if not self.ffn_targets: t = x return x, t class AltAttention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, window_size=None, attn_head_dim=None, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * ( 2 * window_size[1] - 1 ) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads) ) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = ( coords_flatten[:, :, None] - coords_flatten[:, None, :] ) # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute( 1, 2, 0 ).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = torch.zeros( size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype, ) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rel_pos_bias=None, pos_mask=None): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat( ( self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias, ) ) # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = q @ k.transpose(-2, -1) if self.relative_position_bias_table is not None: relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1, ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + modify_relative_position_bias( relative_position_bias, x.size(0), pos_mask ) if rel_pos_bias is not None: attn = attn + rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class RelativePositionBias(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * ( 2 * window_size[1] - 1 ) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads) ) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = ( coords_flatten[:, :, None] - coords_flatten[:, None, :] ) # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute( 1, 2, 0 ).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = torch.zeros( size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype ) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) def forward(self): relative_position_bias = self.relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1, ) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float) omega /= embed_dim / 2.0 omega = 1.0 / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb def interpolate_pos_embed(model, checkpoint_model): if "pos_embed" in checkpoint_model: pos_embed_checkpoint = checkpoint_model["pos_embed"] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print( "Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size) ) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape( -1, orig_size, orig_size, embedding_size ).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False, ) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model["pos_embed"] = new_pos_embed @register_model("mae", dataclass=MaeConfig) class MaeModel(BaseFairseqModel): def __init__(self, cfg: MaeConfig): super().__init__() self.cfg = cfg self.mask_ratio = cfg.mask_ratio # -------------------------------------------------------------------------- # MAE encoder specifics self.patch_embed = PatchEmbed( cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False ) # fixed sin-cos embedding norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps) dpr = [ x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth) ] # stochastic depth decay rule def make_block(drop_path): if cfg.w2v_block: return TransformerSentenceEncoderLayer( embedding_dim=cfg.embed_dim, ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio, num_attention_heads=cfg.num_heads, dropout=cfg.block_dropout, attention_dropout=cfg.attention_dropout, activation_dropout=cfg.activation_dropout, activation_fn="gelu", layer_norm_first=cfg.layer_norm_first, drop_path=drop_path, norm_eps=1e-6, single_qkv=cfg.single_qkv, fused_ln=cfg.fused_ln, ) elif cfg.alt_block: window_size = ( cfg.input_size // self.patch_embed.patch_size[0], cfg.input_size // self.patch_embed.patch_size[1], ) return AltBlock( cfg.embed_dim, cfg.num_heads, cfg.mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, drop_path=drop_path, layer_norm_first=cfg.layer_norm_first, ffn_targets=not cfg.end_of_block_targets, use_rel_pos_bias=cfg.use_rel_pos_bias, window_size=window_size if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias) else None, alt_attention=cfg.alt_attention, ) elif cfg.alt_block2: from .multi.modules import AltBlock as AltBlock2 return AltBlock2( cfg.embed_dim, cfg.num_heads, cfg.mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, drop_path=drop_path, layer_norm_first=cfg.layer_norm_first, ffn_targets=not cfg.end_of_block_targets, ) else: return Block( cfg.embed_dim, cfg.num_heads, cfg.mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, drop_path=drop_path, ) self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) self.norm = norm_layer(cfg.embed_dim) # -------------------------------------------------------------------------- # -------------------------------------------------------------------------- # MAE decoder specifics self.decoder_embed = ( nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True) if not cfg.no_decoder_embed else None ) self.mask_token = ( nn.Parameter( torch.zeros( 1, 1, cfg.decoder_embed_dim if not cfg.no_decoder_embed else cfg.embed_dim, ) ) if cfg.mask_noise_std <= 0 else None ) self.decoder_pos_embed = ( nn.Parameter( torch.zeros( 1, num_patches + 1, cfg.decoder_embed_dim if not cfg.no_decoder_embed else cfg.embed_dim, ), requires_grad=False, ) if not cfg.no_decoder_pos_embed else None ) self.decoder_blocks = nn.ModuleList( [ Block( cfg.decoder_embed_dim, cfg.decoder_num_heads, cfg.mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, ) for _ in range(cfg.decoder_depth) ] ) self.decoder_norm = norm_layer(cfg.decoder_embed_dim) self.decoder_pred = nn.Linear( cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True ) # decoder to patch # -------------------------------------------------------------------------- self.norm_pix_loss = cfg.norm_pix_loss self.initialize_weights() for pn, p in self.named_parameters(): if len(p.shape) == 1 or pn.endswith(".bias"): p.param_group = "no_decay" else: p.param_group = "with_decay" def initialize_weights(self): # initialization # initialize (and freeze) pos_embed by sin-cos embedding pos_embed = get_2d_sincos_pos_embed( self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** 0.5), cls_token=not self.cfg.no_cls, ) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) if self.decoder_pos_embed is not None: decoder_pos_embed = get_2d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches ** 0.5), cls_token=not self.cfg.no_cls, ) self.decoder_pos_embed.data.copy_( torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) ) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) if self.cls_token is not None: torch.nn.init.normal_(self.cls_token, std=0.02) if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=0.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def patchify(self, imgs): """ imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) """ p = self.patch_embed.patch_size[0] assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum("nchpwq->nhwpqc", x) x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) return x def unpatchify(self, x): """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) """ p = self.patch_embed.patch_size[0] h = w = int(x.shape[1] ** 0.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum("nhwpqc->nchpwq", x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1 ) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore # x_masked is actually unmasked x @classmethod def build_model(cls, cfg: MaeConfig, task=None): """Build a new model instance.""" return cls(cfg) def forward_encoder(self, x, mask_ratio): # embed patches x = self.patch_embed(x) # add pos embed w/o cls token # if self.cls_token is not None: # x = x + self.pos_embed # else: x = x + self.pos_embed[:, 1:, :] # masking: length -> length * mask_ratio if mask_ratio > 0: x, mask, ids_restore = self.random_masking(x, mask_ratio) else: mask = ids_restore = None # append cls token if self.cls_token is not None: cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # apply Transformer blocks for blk in self.blocks: x = blk(x) if self.norm is not None: x = self.norm(x) return x, mask, ids_restore def forward_decoder(self, x, ids_restore): # embed tokens x = self.decoder_embed(x) # append mask tokens to sequence mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 ) if self.cls_token is not None: x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token else: x_ = torch.cat([x, mask_tokens], dim=1) # no cls token x_ = torch.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) ) # unshuffle if self.cls_token is not None: x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x = self.decoder_pred(x) if self.cls_token is not None: # remove cls token x = x[:, 1:, :] return x def forward_loss(self, imgs, pred, mask): """ imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, """ target = self.patchify(imgs) if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.0e-6) ** 0.5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() return loss, mask.sum() def forward(self, imgs, predictions_only=False): latent, mask, ids_restore = self.forward_encoder( imgs, self.mask_ratio if not predictions_only else 0 ) if predictions_only: return latent pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] loss, sample_size = self.forward_loss(imgs, pred, mask) result = { "losses": {"regression": loss}, "sample_size": sample_size, } return result def remove_pretraining_modules(self): self.decoder_embed = None self.decoder_blocks = None self.decoder_norm = None self.decoder_pos_embed = None self.decoder_pred = None self.mask_token = None if self.cfg.layer_norm_first: self.norm = None