Tianyinus's picture
init submit
edcf5ee verified
"""
SAE Model Script ver: Oct 28th 2023 15:30
SAE stands for shuffled autoencoder, designed for PuzzleTuning
# References:
Based on MAE code.
https://github.com/facebookresearch/mae
"""
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
from SSL_structures.pos_embed import get_2d_sincos_pos_embed
from Backbone.VPT_structure import VPT_ViT
class ShuffledAutoEncoderViT(VPT_ViT):
"""
Shuffled Autoencoder with VisionTransformer backbone
prompt_mode: "Deep" / "Shallow" by default None
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, group_shuffle_size=-1,
prompt_mode=None, Prompt_Token_num=20, basic_state_dict=None, decoder=None, decoder_rep_dim=None):
if prompt_mode is None:
super().__init__()
# SAE encoder specifics (this part just the same as ViT)
# --------------------------------------------------------------------------
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # BCHW -> BNC
num_patches = self.patch_embed.num_patches
# learnable cls token is still used but on cls head need
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# set and freeze encoder_pos_embed, use the fixed sin-cos embedding for tokens + mask_token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
# Encoder blocks
self.blocks = nn.ModuleList([ # qk_scale=None fixme related to timm version
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.prompt_mode = prompt_mode
# --------------------------------------------------------------------------
else:
super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
norm_layer=norm_layer, Prompt_Token_num=Prompt_Token_num, VPT_type=prompt_mode,
basic_state_dict=None) # Firstly, set then Encoder state_dict to none here.
num_patches = self.patch_embed.num_patches # set patch_embed of VPT
# set and freeze encoder_pos_embed, use the fixed sin-cos embedding for tokens + mask_token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)
self.prompt_mode = prompt_mode
# Freeze Encoder parameters except of the Prompt Tokens
self.Freeze()
# SAE decoder specifics todo as a low-level backbone, the explore for future segmentation is need
# --------------------------------------------------------------------------
# if the feature dimension of encoder and decoder are different, use decoder_embed to align them
if embed_dim != decoder_embed_dim:
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
else:
self.decoder_embed = nn.Identity()
# set decoder
if decoder is not None:
self.decoder = decoder
# Decoder use a FC to reconstruct image, unlike the Encoder which use a CNN to split patch
self.decoder_pred = nn.Linear(decoder_rep_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
else:
self.decoder = None
# set and freeze decoder_pos_embed, use the fixed sin-cos embedding for tokens + mask_token
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
requires_grad=False)
self.decoder_blocks = nn.ModuleList([ # qk_scale=None fixme related to timm version
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
# Decoder use a FC to reconstruct image, unlike the Encoder which use a CNN to split patch
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------
# this controls the puzzle group
self.group_shuffle_size = group_shuffle_size
# wether or not to use norm_pix_loss
self.norm_pix_loss = norm_pix_loss
# parameter initialization
self.initialize_weights()
# load basic state_dict of backbone for Transfer-learning-based tuning
if basic_state_dict is not None:
self.load_state_dict(basic_state_dict, False)
def initialize_weights(self):
# initialization
# initialize a 2d positional encoding of (embed_dim, grid) by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
int(self.patch_embed.num_patches ** .5),
cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
if self.decoder is None:
# initialize a 2d positional encoding of (embed_dim, grid) by sin-cos embedding
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
int(self.patch_embed.num_patches ** .5),
cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # xavier_uniform,让输入输出的方差相同,包括前后向传播
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=.02)
# torch.nn.init.normal_(self.prompt_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
# initialize nn.Linear and nn.LayerNorm
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs, patch_size=None):
"""
Break image to patch tokens
input:
imgs: (B, 3, H, W)
output:
x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim]
"""
# patch_size
patch_size = self.patch_embed.patch_size[0] if patch_size is None else patch_size
# assert H == W and image shape is dividedable by patch
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0
# patch num in rol or column
h = w = imgs.shape[2] // patch_size
# use reshape to split patch [B, C, H, W] -> [B, C, h_p, patch_size, w_p, patch_size]
x = imgs.reshape(shape=(imgs.shape[0], 3, h, patch_size, w, patch_size))
# ReArrange dimensions [B, C, h_p, patch_size, w_p, patch_size] -> [B, h_p, w_p, patch_size, patch_size, C]
x = torch.einsum('nchpwq->nhwpqc', x)
# ReArrange dimensions [B, h_p, w_p, patch_size, patch_size, C] -> [B, num_patches, flatten_dim]
x = x.reshape(shape=(imgs.shape[0], h * w, patch_size ** 2 * 3))
return x
def patchify_decoder(self, imgs, patch_size=None):
"""
Break image to patch tokens
fixme,注意,这里patch_size应该是按照decoder的网络设置来作为default更合理
input:
imgs: (B, CLS, H, W)
output:
x: (B, num_patches, -1) AKA [B, num_patches, -1]
"""
# patch_size
patch_size = self.patch_embed.patch_size[0] if patch_size is None else patch_size
# assert H == W and image shape is divided-able by patch
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0
# patch num in rol or column
h = w = imgs.shape[2] // patch_size
# use reshape to split patch [B, CLS, H, W] -> [B, CLS, h_p, patch_size, w_p, patch_size]
x = imgs.reshape(shape=(imgs.shape[0], -1, h, patch_size, w, patch_size))
# ReArrange dimensions [B, CLS, h_p, patch_size, w_p, patch_size] -> [B, h_p, w_p, patch_size, patch_size, CLS]
x = torch.einsum('nchpwq->nhwpqc', x)
# ReArrange dimensions [B, h_p, w_p, patch_size, patch_size, C] -> [B, num_patches, flatten_dim]
x = x.reshape(shape=(imgs.shape[0], h * w, -1))
return x
def unpatchify(self, x, patch_size=None):
"""
Decoding encoded patch tokens
input:
x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim]
output:
imgs: (B, 3, H, W)
"""
# patch_size
p = self.patch_embed.patch_size[0] if patch_size is None else patch_size
# squre root of num_patches (without CLS token is required)
h = w = int(x.shape[1] ** .5)
# assert num_patches is with out CLS token
assert h * w == x.shape[1]
# ReArrange dimensions [B, num_patches, flatten_dim] -> [B, h_p, w_p, patch_size, patch_size, C]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
# ReArrange dimensions [B, h_p, w_p, patch_size, patch_size, C] -> [B, C, h_p, patch_size, w_p, patch_size]
x = torch.einsum('nhwpqc->nchpwq', x)
# use reshape to compose patch [B, C, h_p, patch_size, w_p, patch_size] -> [B, C, H, W]
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def fix_position_shuffling(self, x, fix_position_ratio, puzzle_patch_size):
"""
Fix-position shuffling
Randomly assign patches by per-sample shuffling.
After it, the fixed patches are reserved as Positional Tokens
the rest patches are batch wise randomly shuffled among the batch since they serve as Relation Tokens.
Per-sample shuffling is done by argsort random noise.
batch wise shuffle operation is done by shuffle all idxes
input:
x: [B, 3, H, W], input image tensor
fix_position_ratio float
puzzle_patch_size int
output: x_puzzled, mask
x_puzzled: [B, 3, H, W]
mask: [B, 3, H, W], binary mask indicating pix position with 0
"""
# Break img into puzzle patches with the size of puzzle_patch_size [B, num_puzzle_patches, D_puzzle]
x = self.patchify(x, puzzle_patch_size)
# output: x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim]
B, num_puzzle_patches, D = x.shape
# num of fix_position puzzle patches
len_fix_position = int(num_puzzle_patches * fix_position_ratio)
num_shuffled_patches = num_puzzle_patches - len_fix_position
# create a noise tensor to prepare shuffle idx of puzzle patches
noise = torch.rand(B, num_puzzle_patches, device=x.device) # [B,num_puzzle_patches] noise in [0, 1]
# 在Batch里面每个序列上获得noise tensor经过升序排列后原本位置的idx矩阵,(各自不同)
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
# 再对idx矩阵继续升序排列可获得:原始noise tensor的每个位置的排序顺位
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset 前面的是fix的,后面的是puzzle的
ids_fix = ids_shuffle[:, :len_fix_position] # [B,num_puzzle_patches] -> [B,fix_patches]
# fix_patches=num_puzzle_patches * fix_position_ratio len_fix_position
ids_puzzle = ids_shuffle[:, len_fix_position:] # [B,num_puzzle_patches] -> [B,puzzle_patches]
# puzzle_patches=num_puzzle_patches*(1-fix_position_ratio) num_shuffled_patches
# set puzzle patch
# ids_?.unsqueeze(-1).repeat(1, 1, D)
# [B,?_patches] -> [B,?_patches,1] (at each place with the idx of ori patch) -> [B,?_patches,D]
# torch.gather to select patche groups x_fixed of [B,fix_patches,D] and x_puzzle of [B,puzzle_patches,D]
# 要保持的,batch中每个sample不一样
x_fixed = torch.gather(x, dim=1, index=ids_fix.unsqueeze(-1).repeat(1, 1, D))
# 要shuffle的,batch中每个sample不一样
x_puzzle = torch.gather(x, dim=1, index=ids_puzzle.unsqueeze(-1).repeat(1, 1, D))
# batch&patch-wise shuffle is needed else the restore will restore all puzzles
if self.group_shuffle_size == -1 or self.group_shuffle_size == B:
puzzle_shuffle_indices = torch.randperm(B * num_shuffled_patches, device=x.device, requires_grad=False)
else:
assert B > self.group_shuffle_size > 0 and B % self.group_shuffle_size == 0
# build [B//self.group_shuffle_size, num_puzzle_patches] noise in [0, 1]
group_noise = torch.rand(B // self.group_shuffle_size, num_shuffled_patches * self.group_shuffle_size, device=x.device)
# get shuffled index in each (num_shuffled_patches*group_shuffle)
group_ids_shuffle = torch.argsort(group_noise, dim=1)
# break the dim and add the group idx(in list), stack back to tensor
group_ids_shuffle = torch.stack([group_ids_shuffle[i] +
num_shuffled_patches * self.group_shuffle_size * i
for i in range(B // self.group_shuffle_size)])
# flattern to be idx for all (B * num_shuffled_patches)
puzzle_shuffle_indices = group_ids_shuffle.view(-1)
# 将0~B * num_shuffled_patches-1(包括0和B * num_shuffled_patches-1)随机打乱后获得的数字序列
x_puzzle = x_puzzle.view(B * num_shuffled_patches, D)[puzzle_shuffle_indices].view(B, num_shuffled_patches, D)
# 利用randperm获得的乱序序列对应batch内所有需要shuffle的部分进行打乱顺序,之后将其恢复为原本的划分batch
# pack up all puzzle patches
x = torch.cat([x_fixed, x_puzzle], dim=1)
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([B, num_puzzle_patches, D], device=x.device, requires_grad=False) # no grad
mask[:, :len_fix_position, :] = 0 # set the first len_fix of tokens to 0,rest to 1
# unshuffle to restore the fixed positions
x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D))
# torch.gather to generate restored binary mask
mask = torch.gather(mask, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D))
# unpatchify to obtain puzzle images and their mask
x = self.unpatchify(x, puzzle_patch_size)
mask = self.unpatchify(mask, puzzle_patch_size)
return x, mask # x_puzzled and mask
def forward_puzzle(self, imgs, fix_position_ratio=0.25, puzzle_patch_size=32):
"""
Transform the input images to puzzle images
input:
x: [B, 3, H, W], input image tensor
fix_position_ratio float
puzzle_patch_size int
output: x_puzzled, mask
x_puzzled: [B, 3, H, W]
mask: [B, 3, H, W], binary mask indicating pix position with 0
"""
x_puzzled, mask = self.fix_position_shuffling(imgs, fix_position_ratio, puzzle_patch_size)
return x_puzzled, mask
def forward_encoder(self, imgs):
"""
:param imgs: [B, C, H, W], sequence of imgs
:return: Encoder output: encoded tokens, mask position, restore idxs
x: [B, num_patches, D], sequence of Tokens (including the cls token)
CLS_token: [B, 1, D]
"""
if self.prompt_mode is None: # ViT
# embed patches
x = self.patch_embed(imgs)
# add pos embed before concatenate the cls token
x = x + self.pos_embed[:, 1:, :]
# detatch puzzle for embed_puzzle output
embed_puzzle = x.data.detach()
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1) # batch fix
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
else: # VPT
x = self.patch_embed(imgs)
# add pos embed before concatenate the cls token
x = x + self.pos_embed[:, 1:, :]
# detatch puzzle for embed_puzzle output
embed_puzzle = x.data.detach() # copy the embed original puzzle (for illustration)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1) # batch fix
x = torch.cat((cls_tokens, x), dim=1)
if self.VPT_type == "Deep":
Prompt_Token_num = self.Prompt_Tokens.shape[1]
for i in range(len(self.blocks)):
# concatenate Prompt_Tokens
Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0)
# firstly concatenate
x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1)
num_tokens = x.shape[1]
# lastly remove, a good trick
x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num]
else: # self.VPT_type == "Shallow"
Prompt_Token_num = self.Prompt_Tokens.shape[1]
# concatenate Prompt_Tokens
Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1)
x = torch.cat((x, Prompt_Tokens), dim=1)
num_tokens = x.shape[1]
# A whole sequential process
x = self.blocks(x)[:, :num_tokens - Prompt_Token_num]
# last norm of Transformer
x = self.norm(x)
CLS_token = x[:, :1, :]
x = x[:, 1:, :]
# Encoder output: encoded tokens, mask position, embed original puzzle (for illustration)
return x, CLS_token, embed_puzzle
def forward_decoder(self, x):
"""
Decoder to reconstruct the puzzle image
[B, 1 + num_patches, D_Encoder] -> [B, 1 + num_patches, D_Decoder] -> [B, num_patches, p*p*3]
:param x: [B, 1 + num_patches, D_Encoder], sequence of Tokens (including the cls token)
:return: Decoder output: reconstracted tokens
x: [B, num_patches, patch_size ** 2 * in_chans], sequence of Patch Tokens
"""
if self.decoder is None:
# embed tokens: [B, num_encoded_tokens, D_Encoder] -> [B, num_encoded_tokens, D_Decoder]
x = self.decoder_embed(x)
# print(x.shape)
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# Reconstruction projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
# print("x shape: ", x.shape) # [B, N, p*p*3]
else:
# remove cls token
x = x[:, 1:, :]
# embed tokens: [B, num_encoded_tokens, D_Encoder] -> [B, num_encoded_tokens, D_Decoder]
x = self.decoder_embed(x)
# unpatchify to make image form [B, H, W, C]
x = self.unpatchify(x) # restore image by Encoder
# apply decoder module to segment the output of encoder
x = self.decoder(x) # one-hot seg decoder [B, CLS, H, W]
# the output of segmentation is transformed to [B, N, Dec]
x = self.patchify_decoder(x) # TODO 做一个有意义的设计
# Convert the number of channels to match image for loss function
x = self.decoder_pred(x) # [B, N, Dec] -> [B, N, p*p*3]
# print(x.shape)
return x
def forward_loss(self, imgs, pred, mask):
"""
MSE loss for all patches towards the ori image
Input:
imgs: [B, 3, H, W], Encoder input image
pred: [B, num_patches, p*p*3], Decoder reconstructed image
mask: [B, num_patches, p*p*3], 0 is keep, 1 is puzzled
"""
# print("pred shape: ", pred.shape) # [64, 196, 768]
# target imgs: [B, 3, H, W] -> [B, num_patches, p*p*3]
target = self.patchify(imgs)
# print("target shape: ", target.shape) # [64, 196, 768]
# use mask as a patch indicator [B, num_patches, D] -> [B, num_patches]
mask = mask[:, :, 0] # Binary mask, 1 for removed patches, 0 for reserved patches:
if self.norm_pix_loss: # Normalize the target image patches
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6) ** .5
# MSE loss
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [B, num_patches], mean loss on each patch pixel
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches [B], scalar
return loss
def forward(self, imgs, fix_position_ratio=0.25, puzzle_patch_size=32, combined_pred_illustration=False):
# STEP 1: Puzzle making
# create puzzle images: [B, 3, H, W]
imgs_puzzled, mask = self.forward_puzzle(imgs, fix_position_ratio, puzzle_patch_size)
# Visualization of imgs_puzzled_patches sequence: [B, num_patches, p*p*3]
imgs_puzzled_patches = self.patchify(imgs_puzzled)
# here, latent crop size is automatically based on encoder embedding
# STEP 2: Puzzle understanding
# Encoder to obtain latent tokens and embed_puzzle: [B, num_patches, D]
latent_puzzle, CLS_token, embed_puzzle = self.forward_encoder(imgs_puzzled)
# VPT output size of more tokens ? currently use firstly-cat-lastly-remove so its fine
# STEP 3: Puzzle restoring
# step 3.(a) prepare decoder input indcator mask at the encoder output stage:
mask_patches_pp3 = self.patchify(mask) # mark relation tokens with 1 [B, num_patches, p*p*3]
# here, latent crop size is automatically based on encoder embedding
# Reassign mask indicator shape to the encoder output dim
if mask_patches_pp3.shape[-1] != latent_puzzle.shape[-1]:
# [B, num_patches, p*p*3] -> [B, num_patches, 1] -> [B, num_patches, D]
mask_patches = mask_patches_pp3[:, :, :1].expand(-1, -1, latent_puzzle.shape[-1])
else:
mask_patches = mask_patches_pp3
# anti_mask: [B, num_patches, D], binary mask indicating fix position with 1 instead of 0
anti_mask = mask_patches * -1 + 1 # great trick to process positional operation with less calculation
# Position hint
# in mask, 0 is Position Tokens, therefore take only Relation Tokens
latent_tokens = latent_puzzle * mask_patches # take out relation tokens(latent_tokens here)
# in anti_mask, 0 is Relation Tokens, therefore take only Position Tokens
hint_tokens = embed_puzzle * anti_mask # anti_mask to take hint_tokens (position tokens)
# group decoder tokens: [B, num_patches, D]
latent = latent_tokens + hint_tokens
# append back the cls token at the first -> [B, 1+num_patches, D]
x = torch.cat([CLS_token, latent], dim=1)
# step 3.(b) Decoder to obtain Reconstructed image patches:
# [B, 1+num_patches,D] -> [B, 1+num_patches, D_Decoder] -> [B, num_patches, p*p*3]
pred = self.forward_decoder(x)
# combined pred
anti_mask_patches_pp3 = mask_patches_pp3 * -1 + 1 # fix position with 1, relation patches with 0
hint_img_patches = imgs_puzzled_patches * anti_mask_patches_pp3
pred_img_patches = pred * mask_patches_pp3 # mark relation tokens with 1, fix position with 0
pred_with_hint_imgs = hint_img_patches + pred_img_patches
# MSE loss for all patches towards the ori image
loss = self.forward_loss(imgs, pred, mask_patches)
# print(loss) # check whether the loss is working
if combined_pred_illustration:
return loss, pred_with_hint_imgs, imgs_puzzled_patches
else:
return loss, pred, imgs_puzzled_patches
def sae_vit_base_patch16_dec512d8b(dec_idx=None, **kwargs):
print("Decoder:", dec_idx)
model = ShuffledAutoEncoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def sae_vit_large_patch16_dec512d8b(dec_idx=None, **kwargs):
print("Decoder:", dec_idx)
model = ShuffledAutoEncoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def sae_vit_huge_patch14_dec512d8b(dec_idx=None, **kwargs):
print("Decoder:", dec_idx)
model = ShuffledAutoEncoderViT(
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
# decoder
def sae_vit_base_patch16_dec(dec_idx=None, num_classes=3, img_size=224, **kwargs):
# num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练
if dec_idx == 'swin_unet':
decoder_embed_dim = 768
decoder_rep_dim = 16 * 16 * 3
from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg
decoder = ViT_seg(num_classes=num_classes, img_size=img_size, patch_size=16)
elif dec_idx == 'transunet':
decoder_embed_dim = 768
decoder_rep_dim = 16 * 16 * 3
transunet_name = 'R50-ViT-B_16'
transunet_patches_size = 16
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg
config_vit = CONFIGS_Transunet_seg[transunet_name]
config_vit.n_classes = num_classes
config_vit.n_skip = 3
if transunet_name.find('R50') != -1:
config_vit.patches.grid = (
int(img_size / transunet_patches_size), int(img_size / transunet_patches_size))
decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes)
elif dec_idx == 'UTNetV2':
decoder_embed_dim = 768
decoder_rep_dim = 16 * 16 * 3
from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg
decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes)
else:
print('no effective decoder!')
return -1
print('dec_idx: ', dec_idx)
model = ShuffledAutoEncoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder,
**kwargs)
return model
def sae_vit_large_patch16_dec(dec_idx=None, num_classes=3, img_size=224, **kwargs):
# num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练
if dec_idx == 'swin_unet':
decoder_embed_dim = 768
decoder_rep_dim = 16 * 16 * 3
from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg
decoder = ViT_seg(num_classes=num_classes, img_size=img_size, patch_size=16)
elif dec_idx == 'transunet':
decoder_embed_dim = 768
decoder_rep_dim = 16 * 16 * 3
transunet_name = 'R50-ViT-B_16'
transunet_patches_size = 16
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg
config_vit = CONFIGS_Transunet_seg[transunet_name]
config_vit.n_classes = num_classes
config_vit.n_skip = 3
if transunet_name.find('R50') != -1:
config_vit.patches.grid = (
int(img_size / transunet_patches_size), int(img_size / transunet_patches_size))
decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes)
elif dec_idx == 'UTNetV2':
decoder_embed_dim = 768
decoder_rep_dim = 16 * 16 * 3
from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg
decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes)
else:
print('no effective decoder!')
return -1
print('dec_idx: ', dec_idx)
model = ShuffledAutoEncoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder,
**kwargs)
return model
def sae_vit_huge_patch14_dec(dec_idx=None, num_classes=3, img_size=224, **kwargs):
# num_classes做的是one-hot seg但是不是做还原,我们得设计一下如何去做这个还原才能实现预训练
if dec_idx == 'swin_unet':
decoder_embed_dim = 14 * 14 * 3
decoder_rep_dim = 14 * 14 * 3
from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg
decoder = ViT_seg(num_classes=num_classes, img_size=img_size, patch_size=16)
elif dec_idx == 'transunet':
decoder_embed_dim = 14 * 14 * 3
decoder_rep_dim = 14 * 14 * 3
transunet_name = 'R50-ViT-B_16'
transunet_patches_size = 16
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg
config_vit = CONFIGS_Transunet_seg[transunet_name]
config_vit.n_classes = num_classes
config_vit.n_skip = 3
if transunet_name.find('R50') != -1:
config_vit.patches.grid = (
int(img_size / transunet_patches_size), int(img_size / transunet_patches_size))
decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes)
elif dec_idx == 'UTNetV2':
decoder_embed_dim = 14 * 14 * 3
decoder_rep_dim = 14 * 14 * 3
from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg
decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes)
else:
print('no effective decoder!')
return -1
print('dec_idx: ', dec_idx)
model = ShuffledAutoEncoderViT(
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder,
**kwargs)
return model
# set recommended archs following MAE
sae_vit_base_patch16 = sae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
sae_vit_large_patch16 = sae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
sae_vit_huge_patch14 = sae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
# Equiped with decoders
sae_vit_base_patch16_decoder = sae_vit_base_patch16_dec # decoder: 768 dim, HYF decoders
sae_vit_large_patch16_decoder = sae_vit_large_patch16_dec # decoder: 768 dim, HYF decoders
sae_vit_huge_patch14_decoder = sae_vit_huge_patch14_dec # decoder: 768 dim, HYF decoders
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_size = 224
'''
num_classes = 3 # set to 3 for 3 channel
x = torch.rand(2, 3, img_size, img_size, device=device)
'''
image_tensor_path = './temp-tensors/color.pt'
x = torch.load(image_tensor_path)
x.to(device)
# model = sae_vit_base_patch16(img_size=img_size, decoder=None)
# model = sae_vit_huge_patch14(img_size=img_size, decoder=None)
# model = sae_vit_base_patch16_decoder(prompt_mode="Deep", dec_idx='swin_unet', img_size=img_size)
model = sae_vit_base_patch16(img_size=img_size, decoder=None, group_shuffle_size=2)
'''
# ViT_Prompt
from pprint import pprint
model_names = timm.list_models('*vit*')
pprint(model_names)
basic_model = timm.create_model('vit_base_patch' + str(16) + '_' + str(edge_size), pretrained=True)
basic_state_dict = basic_model.state_dict()
model = sae_vit_base_patch16(img_size=384, prompt_mode='Deep', Prompt_Token_num=20, basic_state_dict=basic_state_dict)
prompt_state_dict = model.obtain_prompt()
VPT = VPT_ViT(img_size=384, VPT_type='Deep', Prompt_Token_num=20, basic_state_dict=basic_state_dict)
VPT.load_prompt(prompt_state_dict)
VPT.to(device)
pred = VPT(x)
print(pred)
'''
model.to(device)
loss, pred, imgs_puzzled_patches = model(x, fix_position_ratio=0.25, puzzle_patch_size=32,
combined_pred_illustration=True)
# combined_pred_illustration = True to add hint tokens at the pred, False to know more info
# 可视化看看效果
from utils.visual_usage import *
imgs_puzzled_batch = unpatchify(imgs_puzzled_patches, patch_size=16)
for img_idx in range(len(imgs_puzzled_batch)):
puzzled_img = imgs_puzzled_batch.cpu()[img_idx]
puzzled_img = ToPILImage()(puzzled_img)
puzzled_img.save(os.path.join('./temp-figs/', 'puzzled_sample_'+str(img_idx)+'.jpg'))
recons_img_batch = unpatchify(pred, patch_size=16)
recons_img = recons_img_batch.cpu()[img_idx]
recons_img = ToPILImage()(recons_img)
recons_img.save(os.path.join('./temp-figs/', 'recons_sample_'+str(img_idx)+'.jpg'))
'''
print(loss, '\n')
print(loss.shape, '\n')
print(pred.shape, '\n')
print(imgs_puzzled_patches.shape, '\n')
'''