|
""" |
|
MAE Model Script ver: Oct 23rd 15:00 |
|
|
|
# References: |
|
Based on MAE code. |
|
https://github.com/facebookresearch/mae |
|
|
|
timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm |
|
DeiT: https://github.com/facebookresearch/deit |
|
|
|
|
|
July 16th |
|
Add patchify_decoder to form B,N,D |
|
Add a parameter for MAE to import segmentation network |
|
""" |
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from timm.models.vision_transformer import PatchEmbed, Block |
|
from Backbone.VPT_structure import VPT_ViT |
|
from SSL_structures.pos_embed import get_2d_sincos_pos_embed |
|
|
|
|
|
class MaskedAutoencoderViT(VPT_ViT): |
|
""" |
|
Masked Autoencoder with VisionTransformer backbone |
|
""" |
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, |
|
embed_dim=1024, depth=24, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, |
|
prompt_mode=None, Prompt_Token_num=20, basic_state_dict=None, decoder=None, decoder_rep_dim=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if prompt_mode is None: |
|
super().__init__() |
|
|
|
|
|
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) |
|
num_patches = self.patch_embed.num_patches |
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) |
|
|
|
self.blocks = nn.ModuleList([ |
|
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
for i in range(depth)]) |
|
self.norm = norm_layer(embed_dim) |
|
|
|
self.prompt_mode = prompt_mode |
|
|
|
|
|
else: |
|
super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, |
|
embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, Prompt_Token_num=Prompt_Token_num, VPT_type=prompt_mode, |
|
basic_state_dict=None) |
|
num_patches = self.patch_embed.num_patches |
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) |
|
|
|
self.prompt_mode = prompt_mode |
|
|
|
self.Freeze() |
|
|
|
|
|
|
|
|
|
if embed_dim != decoder_embed_dim: |
|
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) |
|
else: |
|
self.decoder_embed = nn.Identity() |
|
|
|
if decoder is not None: |
|
self.decoder = decoder |
|
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
|
self.decoder_pred = nn.Linear(decoder_rep_dim, patch_size ** 2 * in_chans, bias=True) |
|
|
|
else: |
|
self.decoder = None |
|
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
|
|
|
|
|
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), |
|
requires_grad=False) |
|
self.decoder_blocks = nn.ModuleList([Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, |
|
qkv_bias=True, norm_layer=norm_layer) |
|
for i in range(decoder_depth)]) |
|
|
|
self.decoder_norm = norm_layer(decoder_embed_dim) |
|
|
|
|
|
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) |
|
|
|
|
|
|
|
self.norm_pix_loss = norm_pix_loss |
|
|
|
self.initialize_weights() |
|
|
|
|
|
if basic_state_dict is not None: |
|
self.load_state_dict(basic_state_dict, False) |
|
|
|
def initialize_weights(self): |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], |
|
int(self.patch_embed.num_patches ** .5), |
|
cls_token=True) |
|
|
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
if self.decoder is None: |
|
|
|
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], |
|
int(self.patch_embed.num_patches ** .5), |
|
cls_token=True) |
|
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.patch_embed.proj.weight.data |
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
torch.nn.init.normal_(self.cls_token, std=.02) |
|
torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
|
|
if isinstance(m, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def patchify(self, imgs): |
|
""" |
|
Encode image to patch tokens |
|
|
|
input: |
|
imgs: (B, 3, H, W) |
|
|
|
output: |
|
x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim] |
|
""" |
|
|
|
p = self.patch_embed.patch_size[0] |
|
|
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
|
|
|
h = w = imgs.shape[2] // p |
|
|
|
|
|
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
|
|
|
x = torch.einsum('nchpwq->nhwpqc', x) |
|
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) |
|
return x |
|
|
|
def patchify_decoder(self, imgs, patch_size=None): |
|
""" |
|
Break image to patch tokens |
|
|
|
fixme,注意,这里patch_size应该是按照decoder的网络设置来作为default |
|
|
|
input: |
|
imgs: (B, CLS, H, W) |
|
|
|
output: |
|
x: (B, num_patches, -1) AKA [B, num_patches, -1] |
|
""" |
|
|
|
patch_size = self.patch_embed.patch_size[0] if patch_size is None else patch_size |
|
|
|
|
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0 |
|
|
|
h = w = imgs.shape[2] // patch_size |
|
|
|
|
|
x = imgs.reshape(shape=(imgs.shape[0], -1, h, patch_size, w, patch_size)) |
|
|
|
|
|
x = torch.einsum('nchpwq->nhwpqc', x) |
|
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, -1)) |
|
return x |
|
|
|
def unpatchify(self, x, patch_size=None): |
|
""" |
|
Decoding encoded patch tokens |
|
|
|
input: |
|
x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim] |
|
|
|
output: |
|
imgs: (B, 3, H, W) |
|
""" |
|
|
|
p = self.patch_embed.patch_size[0] if patch_size is None else patch_size |
|
|
|
|
|
h = w = int(x.shape[1] ** .5) |
|
|
|
assert h * w == x.shape[1] |
|
|
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
|
|
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
|
|
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
|
return imgs |
|
|
|
def random_masking(self, x, mask_ratio): |
|
""" |
|
Perform per-sample random masking by per-sample shuffling. |
|
Per-sample shuffling is done by argsort random noise. |
|
|
|
注意torch.argsort返回的是: |
|
在每个指定dim,按原tensor每个位置数值大小升序排列后,的原本位置的idx组成的矩阵 |
|
|
|
input: |
|
x: [B, num_patches, D], sequence of Tokens |
|
|
|
output: x_remained, mask, ids_restore |
|
x_remained: [B, num_patches * (1-mask_ratio), D], sequence of Tokens |
|
mask: [B, num_patches], binary mask |
|
ids_restore: [B, num_patches], idx of restoring all position |
|
""" |
|
B, num_patches, D = x.shape |
|
|
|
len_keep = int(num_patches * (1 - mask_ratio)) |
|
|
|
noise = torch.rand(B, num_patches, device=x.device) |
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1) |
|
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_remained = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
mask = torch.ones([B, num_patches], device=x.device) |
|
mask[:, :len_keep] = 0 |
|
|
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
return x_remained, mask, ids_restore |
|
|
|
def forward_encoder(self, imgs, mask_ratio): |
|
""" |
|
:param imgs: [B, C, H, W], sequence of imgs |
|
:param mask_ratio: mask_ratio |
|
|
|
:return: Encoder output: encoded tokens, mask position, restore idxs |
|
x: [B, 1 + num_patches * (1-mask_ratio), D], sequence of Tokens (including the cls token) |
|
mask: [B, num_patches], binary mask |
|
ids_restore: [B, num_patches], idx of restoring all position |
|
""" |
|
if self.prompt_mode is None: |
|
|
|
x = self.patch_embed(imgs) |
|
|
|
|
|
x = x + self.pos_embed[:, 1:, :] |
|
|
|
|
|
|
|
x, mask, ids_restore = self.random_masking(x, mask_ratio) |
|
|
|
|
|
cls_token = self.cls_token + self.pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
|
|
else: |
|
x = self.patch_embed(imgs) |
|
|
|
x = x + self.pos_embed[:, 1:, :] |
|
|
|
|
|
x, mask, ids_restore = self.random_masking(x, mask_ratio) |
|
|
|
|
|
cls_token = self.cls_token + self.pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
if self.VPT_type == "Deep": |
|
Prompt_Token_num = self.Prompt_Tokens.shape[1] |
|
for i in range(len(self.blocks)): |
|
|
|
Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0) |
|
|
|
x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1) |
|
num_tokens = x.shape[1] |
|
|
|
x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num] |
|
|
|
else: |
|
Prompt_Token_num = self.Prompt_Tokens.shape[1] |
|
|
|
Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1) |
|
x = torch.cat((x, Prompt_Tokens), dim=1) |
|
num_tokens = x.shape[1] |
|
|
|
x = self.blocks(x)[:, :num_tokens - Prompt_Token_num] |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
return x, mask, ids_restore |
|
|
|
def forward_decoder(self, x, ids_restore): |
|
""" |
|
:param x: [B, 1 + num_patches * (1-mask_ratio), D], sequence of Tokens (including the cls token) |
|
:param ids_restore: restore idxs for torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
:return: Decoder output: reconstracted tokens |
|
x: [B, num_patches * (1-mask_ratio), D], sequence of Tokens |
|
""" |
|
if self.decoder is None: |
|
|
|
x = self.decoder_embed(x) |
|
|
|
|
|
|
|
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
|
|
|
|
|
|
|
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
|
|
|
|
|
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
|
|
|
|
|
|
x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
|
|
for blk in self.decoder_blocks: |
|
x = blk(x) |
|
x = self.decoder_norm(x) |
|
|
|
|
|
x = self.decoder_pred(x) |
|
|
|
|
|
x = x[:, 1:, :] |
|
|
|
else: |
|
|
|
|
|
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
|
|
|
|
|
|
|
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
|
|
|
|
|
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
|
|
|
|
|
|
x_ = self.decoder_embed(x_) |
|
|
|
|
|
x = self.unpatchify(x_) |
|
|
|
|
|
x = self.decoder(x) |
|
|
|
x = self.patchify_decoder(x) |
|
|
|
|
|
x = self.decoder_pred(x) |
|
|
|
return x |
|
|
|
def forward_loss(self, imgs, pred, mask): |
|
""" |
|
MSE loss for all patches towards the ori image |
|
|
|
Input: |
|
imgs: [B, 3, H, W], Encoder input image |
|
pred: [B, num_patches, p*p*3], Decoder reconstructed image |
|
mask: [B, num_patches], 0 is keep, 1 is remove, |
|
|
|
""" |
|
target = self.patchify(imgs) |
|
|
|
if self.norm_pix_loss: |
|
mean = target.mean(dim=-1, keepdim=True) |
|
var = target.var(dim=-1, keepdim=True) |
|
target = (target - mean) / (var + 1.e-6) ** .5 |
|
|
|
|
|
loss = (pred - target) ** 2 |
|
loss = loss.mean(dim=-1) |
|
|
|
|
|
loss = (loss * mask).sum() / mask.sum() |
|
return loss |
|
|
|
def forward(self, imgs, mask_ratio=0.75): |
|
|
|
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) |
|
|
|
pred = self.forward_decoder(latent, ids_restore) |
|
|
|
loss = self.forward_loss(imgs, pred, mask) |
|
|
|
return loss, pred, mask |
|
|
|
|
|
def mae_vit_base_patch16_dec512d8b(dec_idx=None, **kwargs): |
|
print("Decoder:", dec_idx) |
|
|
|
model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
|
|
def mae_vit_large_patch16_dec512d8b(dec_idx=None, **kwargs): |
|
print("Decoder:", dec_idx) |
|
|
|
model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
|
|
def mae_vit_huge_patch14_dec512d8b(dec_idx=None, **kwargs): |
|
print("Decoder:", dec_idx) |
|
|
|
model = MaskedAutoencoderViT( |
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
|
|
def mae_vit_base_patch16_decoder(dec_idx=None, num_classes=3, img_size=224, **kwargs): |
|
|
|
|
|
if dec_idx == 'swin_unet': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg |
|
decoder = ViT_seg(num_classes=num_classes, **kwargs) |
|
|
|
elif dec_idx == 'transunet': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
transunet_name = 'R50-ViT-B_16' |
|
transunet_patches_size = 16 |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg |
|
|
|
config_vit = CONFIGS_Transunet_seg[transunet_name] |
|
config_vit.n_classes = num_classes |
|
config_vit.n_skip = 3 |
|
|
|
if transunet_name.find('R50') != -1: |
|
config_vit.patches.grid = ( |
|
int(img_size / transunet_patches_size), int(img_size / transunet_patches_size)) |
|
decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes) |
|
|
|
elif dec_idx == 'UTNetV2': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg |
|
decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes) |
|
|
|
else: |
|
print('no effective decoder!') |
|
return -1 |
|
|
|
print('dec_idx: ', dec_idx) |
|
|
|
model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|
decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder, |
|
**kwargs) |
|
return model |
|
|
|
|
|
def mae_vit_large_patch16_decoder(dec_idx=None, num_classes=3, img_size=224, **kwargs): |
|
|
|
|
|
if dec_idx == 'swin_unet': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg |
|
decoder = ViT_seg(num_classes=num_classes, **kwargs) |
|
|
|
elif dec_idx == 'transunet': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
transunet_name = 'R50-ViT-B_16' |
|
transunet_patches_size = 16 |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg |
|
|
|
config_vit = CONFIGS_Transunet_seg[transunet_name] |
|
config_vit.n_classes = num_classes |
|
config_vit.n_skip = 3 |
|
|
|
if transunet_name.find('R50') != -1: |
|
config_vit.patches.grid = ( |
|
int(img_size / transunet_patches_size), int(img_size / transunet_patches_size)) |
|
decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes) |
|
|
|
elif dec_idx == 'UTNetV2': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg |
|
decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes) |
|
|
|
else: |
|
print('no effective decoder!') |
|
return -1 |
|
|
|
print('dec_idx: ', dec_idx) |
|
|
|
model = MaskedAutoencoderViT( |
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, |
|
decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder, |
|
**kwargs) |
|
return model |
|
|
|
|
|
def mae_vit_huge_patch14_decoder(dec_idx=None, num_classes=3, img_size=224, **kwargs): |
|
|
|
|
|
if dec_idx == 'swin_unet': |
|
decoder_embed_dim = 588 |
|
decoder_rep_dim = 14 * 14 * 3 |
|
|
|
from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg |
|
decoder = ViT_seg(num_classes=num_classes, **kwargs) |
|
|
|
elif dec_idx == 'transunet': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
transunet_name = 'R50-ViT-B_16' |
|
transunet_patches_size = 16 |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg |
|
|
|
config_vit = CONFIGS_Transunet_seg[transunet_name] |
|
config_vit.n_classes = num_classes |
|
config_vit.n_skip = 3 |
|
|
|
if transunet_name.find('R50') != -1: |
|
config_vit.patches.grid = ( |
|
int(img_size / transunet_patches_size), int(img_size / transunet_patches_size)) |
|
decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes) |
|
|
|
elif dec_idx == 'UTNetV2': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 14 * 14 * 3 |
|
|
|
from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg |
|
decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes) |
|
|
|
else: |
|
print('no effective decoder!') |
|
return -1 |
|
|
|
print('dec_idx: ', dec_idx) |
|
|
|
model = MaskedAutoencoderViT( |
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, |
|
decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder, |
|
**kwargs) |
|
return model |
|
|
|
|
|
|
|
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b |
|
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b |
|
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b |
|
|
|
|
|
mae_vit_base_patch16_decoder = mae_vit_base_patch16_decoder |
|
mae_vit_large_patch16_decoder = mae_vit_large_patch16_decoder |
|
mae_vit_huge_patch14_decoder = mae_vit_huge_patch14_decoder |
|
|
|
|
|
if __name__ == '__main__': |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
img_size = 224 |
|
num_classes = 3 |
|
x = torch.rand(8, 3, img_size, img_size, device=device) |
|
|
|
|
|
model = mae_vit_base_patch16_decoder(prompt_mode='Deep', Prompt_Token_num=20, basic_state_dict=None, |
|
dec_idx='UTNetV2', img_size=img_size) |
|
|
|
model.to(device) |
|
|
|
loss, pred, mask_patch_indicators = model(x) |
|
|
|
print(loss, '\n') |
|
|
|
print(loss.shape, '\n') |
|
|
|
print(pred.shape, '\n') |
|
|
|
print(mask_patch_indicators.shape, '\n') |
|
|