""" MAE Model Script ver: Oct 23rd 15:00 # References: Based on MAE code. https://github.com/facebookresearch/mae timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm DeiT: https://github.com/facebookresearch/deit July 16th Add patchify_decoder to form B,N,D Add a parameter for MAE to import segmentation network """ from functools import partial import torch import torch.nn as nn from timm.models.vision_transformer import PatchEmbed, Block from Backbone.VPT_structure import VPT_ViT from SSL_structures.pos_embed import get_2d_sincos_pos_embed class MaskedAutoencoderViT(VPT_ViT): """ Masked Autoencoder with VisionTransformer backbone """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, prompt_mode=None, Prompt_Token_num=20, basic_state_dict=None, decoder=None, decoder_rep_dim=None): # model = MaskedAutoencoderViT( # patch_size=16, embed_dim=768, depth=12, num_heads=12, # decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, # mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) if prompt_mode is None: super().__init__() # MAE encoder specifics (this part just the same as ViT) # -------------------------------------------------------------------------- self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # BCHW -> BNC num_patches = self.patch_embed.num_patches # learnable cls token is still used but on cls head need self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # set and freeze encoder_pos_embed, use the fixed sin-cos embedding for tokens + mask_token self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # Encoder blocks self.blocks = nn.ModuleList([ # qk_scale=None fixme related to timm version Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) self.prompt_mode = prompt_mode # -------------------------------------------------------------------------- else: super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, Prompt_Token_num=Prompt_Token_num, VPT_type=prompt_mode, basic_state_dict=None) # Firstly, set then Encoder state_dict to none here. num_patches = self.patch_embed.num_patches # set patch_embed of VPT # set and freeze encoder_pos_embed, use the fixed sin-cos embedding for tokens + mask_token self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) self.prompt_mode = prompt_mode # Freeze Encoder parameters except of the Prompt Tokens self.Freeze() # MAE decoder specifics # -------------------------------------------------------------------------- # if the feature dimension of encoder and decoder are different, use decoder_embed to align them if embed_dim != decoder_embed_dim: self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) else: self.decoder_embed = nn.Identity() if decoder is not None: self.decoder = decoder # set mask_token (learnable mask token for reconstruction) self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # Decoder use a FC to reconstruct image, unlike the Encoder which use a CNN to split patch self.decoder_pred = nn.Linear(decoder_rep_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch else: self.decoder = None # 未传入decoder则与encoder流程一致,但是更改了通道数量,构建block(原版MAE) # set mask_token (learnable mask token for reconstruction) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # set and freeze decoder_pos_embed, use the fixed sin-cos embedding for tokens + mask_token self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) self.decoder_blocks = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(decoder_depth)]) # qk_scale=None fixme related to timm version self.decoder_norm = norm_layer(decoder_embed_dim) # Decoder use a FC to reconstruct image, unlike the Encoder which use a CNN to split patch self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch # -------------------------------------------------------------------------- # wether or not to use norm_pix_loss self.norm_pix_loss = norm_pix_loss # parameter initialization self.initialize_weights() # load basic state_dict of backbone for Transfer-learning-based tuning if basic_state_dict is not None: self.load_state_dict(basic_state_dict, False) def initialize_weights(self): # initialization # initialize a 2d positional encoding of (embed_dim, grid) by sin-cos embedding pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), cls_token=True) # return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) if self.decoder is None: # initialize a 2d positional encoding of (embed_dim, grid) by sin-cos embedding decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), cls_token=True) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # xavier_uniform,让输入输出的方差相同,包括前后向传播 # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=.02) torch.nn.init.normal_(self.mask_token, std=.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): # initialize nn.Linear and nn.LayerNorm if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def patchify(self, imgs): """ Encode image to patch tokens input: imgs: (B, 3, H, W) output: x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim] """ # patch_size p = self.patch_embed.patch_size[0] # assert H == W and image shape is dividedable by patch assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 # patch num in rol or column h = w = imgs.shape[2] // p # use reshape to split patch [B, C, H, W] -> [B, C, h_p, p, w_p, p] x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) # ReArrange dimensions [B, C, h_p, p, w_p, p] -> [B, h_p, w_p, p, p, C] x = torch.einsum('nchpwq->nhwpqc', x) # ReArrange dimensions [B, h_p, w_p, p, p, C] -> [B, num_patches, flatten_dim] x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) return x def patchify_decoder(self, imgs, patch_size=None): # TODO 这里目的很大,需要实现预训练! """ Break image to patch tokens fixme,注意,这里patch_size应该是按照decoder的网络设置来作为default input: imgs: (B, CLS, H, W) output: x: (B, num_patches, -1) AKA [B, num_patches, -1] """ # patch_size patch_size = self.patch_embed.patch_size[0] if patch_size is None else patch_size # assert H == W and image shape is divided-able by patch assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0 # patch num in rol or column h = w = imgs.shape[2] // patch_size # use reshape to split patch [B, C, H, W] -> [B, C, h_p, patch_size, w_p, patch_size] x = imgs.reshape(shape=(imgs.shape[0], -1, h, patch_size, w, patch_size)) # ReArrange dimensions [B, C, h_p, patch_size, w_p, patch_size] -> [B, h_p, w_p, patch_size, patch_size, C] x = torch.einsum('nchpwq->nhwpqc', x) # ReArrange dimensions [B, h_p, w_p, patch_size, patch_size, C] -> [B, num_patches, flatten_dim] x = x.reshape(shape=(imgs.shape[0], h * w, -1)) return x def unpatchify(self, x, patch_size=None): """ Decoding encoded patch tokens input: x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim] output: imgs: (B, 3, H, W) """ # patch_size p = self.patch_embed.patch_size[0] if patch_size is None else patch_size # squre root of num_patches(without CLS token required) h = w = int(x.shape[1] ** .5) # assert num_patches is without CLS token assert h * w == x.shape[1] # ReArrange dimensions [B, num_patches, flatten_dim] -> [B, h_p, w_p, p, p, C] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) # ReArrange dimensions [B, h_p, w_p, p, p, C] -> [B, C, h_p, p, w_p, p] x = torch.einsum('nhwpqc->nchpwq', x) # use reshape to compose patch [B, C, h_p, p, w_p, p] -> [B, C, H, W] imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. 注意torch.argsort返回的是: 在每个指定dim,按原tensor每个位置数值大小升序排列后,的原本位置的idx组成的矩阵 input: x: [B, num_patches, D], sequence of Tokens output: x_remained, mask, ids_restore x_remained: [B, num_patches * (1-mask_ratio), D], sequence of Tokens mask: [B, num_patches], binary mask ids_restore: [B, num_patches], idx of restoring all position """ B, num_patches, D = x.shape # batch, length, dim # 计算需要保留的位置的个数 len_keep = int(num_patches * (1 - mask_ratio)) # 做一个随机序列[B,num_patches],用于做位置标号 noise = torch.rand(B, num_patches, device=x.device) # noise in [0, 1] # 在Batch里面每个序列上获得noise tensor经过升序排列后原本位置的idx矩阵 在batch内进行升序排列 ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove # 再对idx矩阵继续升序排列可获得:原始noise tensor的每个位置的排序顺位 ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] # 设置需要的patch的索引 # ids_keep.unsqueeze(-1).repeat(1, 1, D): # [B,num_patches] -> [B,keep_patches] -> [B,keep_patches,1] 每个位置数字为idx of ori patch -> [B,keep_patches,D] # torch.gather 按照索引取值构建新tensor: x_remained [B,keep_patches,D] 表示被标记需要保留的位置, 原文是x_masked x_remained = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([B, num_patches], device=x.device) mask[:, :len_keep] = 0 # 设置mask矩阵,前len_keep个为0,后面为1 # 按照noise tensor每个位置的大小顺序,来设置mask符号为0的位置,获得mask矩阵 mask = torch.gather(mask, dim=1, index=ids_restore) return x_remained, mask, ids_restore # x_remained原文是x_masked def forward_encoder(self, imgs, mask_ratio): """ :param imgs: [B, C, H, W], sequence of imgs :param mask_ratio: mask_ratio :return: Encoder output: encoded tokens, mask position, restore idxs x: [B, 1 + num_patches * (1-mask_ratio), D], sequence of Tokens (including the cls token) mask: [B, num_patches], binary mask ids_restore: [B, num_patches], idx of restoring all position """ if self.prompt_mode is None: # ViT # embed patches x = self.patch_embed(imgs) # BCHW -> BNC # add pos embed w/o cls token x = x + self.pos_embed[:, 1:, :] # add pos embed before concatenate the cls token # masking: length -> length * (1-mask_ratio) # x_remained: [B, num_patches * (1-mask_ratio), D], sequence of Tokens x, mask, ids_restore = self.random_masking(x, mask_ratio) # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) # batch fix 调整batch x = torch.cat((cls_tokens, x), dim=1) # apply Transformer Encoders for blk in self.blocks: x = blk(x) else: # VPT x = self.patch_embed(imgs) # add pos embed before concatenate the cls token x = x + self.pos_embed[:, 1:, :] # masking: length -> length * (1-mask_ratio) # x_remained: [B, num_patches * (1-mask_ratio), D], sequence of Tokens x, mask, ids_restore = self.random_masking(x, mask_ratio) # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) # batch fix 调整batch x = torch.cat((cls_tokens, x), dim=1) if self.VPT_type == "Deep": Prompt_Token_num = self.Prompt_Tokens.shape[1] for i in range(len(self.blocks)): # concatenate Prompt_Tokens Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0) # firstly concatenate x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1) num_tokens = x.shape[1] # lastly remove, a good trick x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num] else: # self.VPT_type == "Shallow" Prompt_Token_num = self.Prompt_Tokens.shape[1] # concatenate Prompt_Tokens Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1) x = torch.cat((x, Prompt_Tokens), dim=1) num_tokens = x.shape[1] # A whole sequential process x = self.blocks(x)[:, :num_tokens - Prompt_Token_num] # last norm of Transformer x = self.norm(x) # Encoder output: encoded tokens, mask position, restore idxs return x, mask, ids_restore def forward_decoder(self, x, ids_restore): """ :param x: [B, 1 + num_patches * (1-mask_ratio), D], sequence of Tokens (including the cls token) :param ids_restore: restore idxs for torch.gather(mask, dim=1, index=ids_restore) :return: Decoder output: reconstracted tokens x: [B, num_patches * (1-mask_ratio), D], sequence of Tokens """ if self.decoder is None: # embed tokens: [B, num_encoded_tokens, embed_dim] -> [B, num_encoded_tokens, D_Decoder] x = self.decoder_embed(x) # 更改适合的通道数 # append mask tokens to sequence as place holder: [B, num_patches + 1 - num_encoded_tokens, D_Decoder] # number of mask token need is the requirement to fill the num_patches mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) # 这里ids_restore.shape[1] + 1 - x.shape[1] 其实意思是ids_restore.shape[1] - (x.shape[1]-1), 因为不要CLS token # -> [B, num_patches, D_Decoder] x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # stripe the cls token in Decoder for restore position # unshuffle to restore the position of tokens x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # torch.gather 按照索引取值构建新tensor: x_ [B,num_patches,D_Decoder] 表示位置还原之后的图,此时数值还不对 # append back the cls token at the first -> [B,1+num_patches,D_Decoder] x = torch.cat([x[:, :1, :], x_], dim=1) # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # Reconstruction projection [B, num_patches, D_Decoder] -> [B, num_patches, p*p*3] x = self.decoder_pred(x) # remove cls token x = x[:, 1:, :] else: # append mask tokens to sequence as place holder: [B, num_patches + 1 - num_encoded_tokens, D] # number of mask token need is the requirement to fill the num_patches mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) # 这里ids_restore.shape[1] + 1 - x.shape[1] 其实意思是ids_restore.shape[1] - (x.shape[1]-1), 因为不要CLS token # -> [B, num_patches, D] x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # stripe the cls token in Decoder for restore position # unshuffle to restore the position of tokens x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # torch.gather 按照索引取值构建新tensor: x_ [B,num_patches,D] 表示位置还原之后的图,此时数值还不对 # embed tokens: [B, num_encoded_tokens, D_Encoder] -> [B, num_encoded_tokens, D_Decoder] x_ = self.decoder_embed(x_) # unpatchify to make image form [B, N, Enc] to [B,H,W,C] x = self.unpatchify(x_) # restore image by Encoder # apply decoder module to segment the output of encoder x = self.decoder(x) # [B, CLS, H, W] # the output of segmentation is transformed to [B, N, Dec] x = self.patchify_decoder(x) # TODO 做一个有意义的设计 # Convert the number of channels to match image for loss function x = self.decoder_pred(x) # [B, N, Dec] -> [B, N, p*p*3] return x def forward_loss(self, imgs, pred, mask): # 通过把loss放到model里面,把model变成了一个训练框架 """ MSE loss for all patches towards the ori image Input: imgs: [B, 3, H, W], Encoder input image pred: [B, num_patches, p*p*3], Decoder reconstructed image mask: [B, num_patches], 0 is keep, 1 is remove, """ target = self.patchify(imgs) if self.norm_pix_loss: # 把target image patches 标准化 mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6) ** .5 # MSE loss loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch # binary mask, 1 for removed patches loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss def forward(self, imgs, mask_ratio=0.75): # Encoder to obtain latent tokens latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) # Decoder to obtain Reconstructed image patches pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] # MSE loss for all patches towards the ori image loss = self.forward_loss(imgs, pred, mask) # print(loss) # todo 这里原文是为了关注loss爆炸, 可能有坑 return loss, pred, mask def mae_vit_base_patch16_dec512d8b(dec_idx=None, **kwargs): print("Decoder:", dec_idx) model = MaskedAutoencoderViT( patch_size=16, embed_dim=768, depth=12, num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def mae_vit_large_patch16_dec512d8b(dec_idx=None, **kwargs): print("Decoder:", dec_idx) model = MaskedAutoencoderViT( patch_size=16, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def mae_vit_huge_patch14_dec512d8b(dec_idx=None, **kwargs): print("Decoder:", dec_idx) model = MaskedAutoencoderViT( patch_size=14, embed_dim=1280, depth=32, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def mae_vit_base_patch16_decoder(dec_idx=None, num_classes=3, img_size=224, **kwargs): # num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练 if dec_idx == 'swin_unet': decoder_embed_dim = 768 decoder_rep_dim = 16 * 16 * 3 from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg decoder = ViT_seg(num_classes=num_classes, **kwargs) elif dec_idx == 'transunet': decoder_embed_dim = 768 decoder_rep_dim = 16 * 16 * 3 transunet_name = 'R50-ViT-B_16' transunet_patches_size = 16 from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg config_vit = CONFIGS_Transunet_seg[transunet_name] config_vit.n_classes = num_classes config_vit.n_skip = 3 if transunet_name.find('R50') != -1: config_vit.patches.grid = ( int(img_size / transunet_patches_size), int(img_size / transunet_patches_size)) decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes) elif dec_idx == 'UTNetV2': decoder_embed_dim = 768 decoder_rep_dim = 16 * 16 * 3 from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes) else: print('no effective decoder!') return -1 print('dec_idx: ', dec_idx) model = MaskedAutoencoderViT( patch_size=16, embed_dim=768, depth=12, num_heads=12, decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder, **kwargs) return model def mae_vit_large_patch16_decoder(dec_idx=None, num_classes=3, img_size=224, **kwargs): # num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练 if dec_idx == 'swin_unet': decoder_embed_dim = 768 decoder_rep_dim = 16 * 16 * 3 from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg decoder = ViT_seg(num_classes=num_classes, **kwargs) elif dec_idx == 'transunet': decoder_embed_dim = 768 decoder_rep_dim = 16 * 16 * 3 transunet_name = 'R50-ViT-B_16' transunet_patches_size = 16 from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg config_vit = CONFIGS_Transunet_seg[transunet_name] config_vit.n_classes = num_classes config_vit.n_skip = 3 if transunet_name.find('R50') != -1: config_vit.patches.grid = ( int(img_size / transunet_patches_size), int(img_size / transunet_patches_size)) decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes) elif dec_idx == 'UTNetV2': decoder_embed_dim = 768 decoder_rep_dim = 16 * 16 * 3 from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes) else: print('no effective decoder!') return -1 print('dec_idx: ', dec_idx) model = MaskedAutoencoderViT( patch_size=16, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder, **kwargs) return model def mae_vit_huge_patch14_decoder(dec_idx=None, num_classes=3, img_size=224, **kwargs): # num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练 if dec_idx == 'swin_unet': decoder_embed_dim = 588 # 1280 14*14*3 decoder_rep_dim = 14 * 14 * 3 from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg decoder = ViT_seg(num_classes=num_classes, **kwargs) elif dec_idx == 'transunet': decoder_embed_dim = 768 decoder_rep_dim = 16 * 16 * 3 transunet_name = 'R50-ViT-B_16' transunet_patches_size = 16 from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg config_vit = CONFIGS_Transunet_seg[transunet_name] config_vit.n_classes = num_classes config_vit.n_skip = 3 if transunet_name.find('R50') != -1: config_vit.patches.grid = ( int(img_size / transunet_patches_size), int(img_size / transunet_patches_size)) decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes) elif dec_idx == 'UTNetV2': decoder_embed_dim = 768 decoder_rep_dim = 14 * 14 * 3 from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes) else: print('no effective decoder!') return -1 print('dec_idx: ', dec_idx) model = MaskedAutoencoderViT( patch_size=14, embed_dim=1280, depth=32, num_heads=16, decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder, **kwargs) return model # set recommended archs mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks # Equiped with decoders mae_vit_base_patch16_decoder = mae_vit_base_patch16_decoder # decoder: 768 dim, HYF mae_vit_large_patch16_decoder = mae_vit_large_patch16_decoder # decoder: 768 dim, HYF mae_vit_huge_patch14_decoder = mae_vit_huge_patch14_decoder # decoder: 768 dim, HYF if __name__ == '__main__': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") img_size = 224 num_classes = 3 x = torch.rand(8, 3, img_size, img_size, device=device) # model = mae_vit_base_patch16(img_size=224, decoder=None) # decoder_embed_dim=512 model = mae_vit_base_patch16_decoder(prompt_mode='Deep', Prompt_Token_num=20, basic_state_dict=None, dec_idx='UTNetV2', img_size=img_size) model.to(device) loss, pred, mask_patch_indicators = model(x) print(loss, '\n') print(loss.shape, '\n') print(pred.shape, '\n') print(mask_patch_indicators.shape, '\n')