|
""" |
|
SAE Model Script ver: Oct 28th 2023 15:30 |
|
SAE stands for shuffled autoencoder, designed for PuzzleTuning |
|
|
|
# References: |
|
Based on MAE code. |
|
https://github.com/facebookresearch/mae |
|
|
|
""" |
|
|
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from timm.models.vision_transformer import PatchEmbed, Block |
|
|
|
from SSL_structures.pos_embed import get_2d_sincos_pos_embed |
|
|
|
from Backbone.VPT_structure import VPT_ViT |
|
|
|
|
|
class ShuffledAutoEncoderViT(VPT_ViT): |
|
""" |
|
Shuffled Autoencoder with VisionTransformer backbone |
|
|
|
prompt_mode: "Deep" / "Shallow" by default None |
|
""" |
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, |
|
embed_dim=1024, depth=24, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, group_shuffle_size=-1, |
|
prompt_mode=None, Prompt_Token_num=20, basic_state_dict=None, decoder=None, decoder_rep_dim=None): |
|
|
|
if prompt_mode is None: |
|
super().__init__() |
|
|
|
|
|
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) |
|
num_patches = self.patch_embed.num_patches |
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) |
|
|
|
self.blocks = nn.ModuleList([ |
|
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
for i in range(depth)]) |
|
self.norm = norm_layer(embed_dim) |
|
|
|
self.prompt_mode = prompt_mode |
|
|
|
|
|
else: |
|
super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, |
|
embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, |
|
norm_layer=norm_layer, Prompt_Token_num=Prompt_Token_num, VPT_type=prompt_mode, |
|
basic_state_dict=None) |
|
num_patches = self.patch_embed.num_patches |
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) |
|
|
|
self.prompt_mode = prompt_mode |
|
|
|
self.Freeze() |
|
|
|
|
|
|
|
|
|
if embed_dim != decoder_embed_dim: |
|
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) |
|
else: |
|
self.decoder_embed = nn.Identity() |
|
|
|
|
|
if decoder is not None: |
|
self.decoder = decoder |
|
|
|
self.decoder_pred = nn.Linear(decoder_rep_dim, patch_size ** 2 * in_chans, bias=True) |
|
|
|
else: |
|
self.decoder = None |
|
|
|
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), |
|
requires_grad=False) |
|
self.decoder_blocks = nn.ModuleList([ |
|
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
for i in range(decoder_depth)]) |
|
self.decoder_norm = norm_layer(decoder_embed_dim) |
|
|
|
|
|
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) |
|
|
|
|
|
|
|
self.group_shuffle_size = group_shuffle_size |
|
|
|
|
|
self.norm_pix_loss = norm_pix_loss |
|
|
|
self.initialize_weights() |
|
|
|
|
|
if basic_state_dict is not None: |
|
self.load_state_dict(basic_state_dict, False) |
|
|
|
def initialize_weights(self): |
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], |
|
int(self.patch_embed.num_patches ** .5), |
|
cls_token=True) |
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
if self.decoder is None: |
|
|
|
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], |
|
int(self.patch_embed.num_patches ** .5), |
|
cls_token=True) |
|
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
w = self.patch_embed.proj.weight.data |
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
torch.nn.init.normal_(self.cls_token, std=.02) |
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
|
|
if isinstance(m, nn.Linear): |
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def patchify(self, imgs, patch_size=None): |
|
""" |
|
Break image to patch tokens |
|
|
|
input: |
|
imgs: (B, 3, H, W) |
|
|
|
output: |
|
x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim] |
|
""" |
|
|
|
patch_size = self.patch_embed.patch_size[0] if patch_size is None else patch_size |
|
|
|
|
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0 |
|
|
|
h = w = imgs.shape[2] // patch_size |
|
|
|
|
|
x = imgs.reshape(shape=(imgs.shape[0], 3, h, patch_size, w, patch_size)) |
|
|
|
|
|
x = torch.einsum('nchpwq->nhwpqc', x) |
|
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, patch_size ** 2 * 3)) |
|
return x |
|
|
|
def patchify_decoder(self, imgs, patch_size=None): |
|
""" |
|
Break image to patch tokens |
|
|
|
fixme,注意,这里patch_size应该是按照decoder的网络设置来作为default更合理 |
|
|
|
input: |
|
imgs: (B, CLS, H, W) |
|
|
|
output: |
|
x: (B, num_patches, -1) AKA [B, num_patches, -1] |
|
""" |
|
|
|
patch_size = self.patch_embed.patch_size[0] if patch_size is None else patch_size |
|
|
|
|
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0 |
|
|
|
h = w = imgs.shape[2] // patch_size |
|
|
|
|
|
x = imgs.reshape(shape=(imgs.shape[0], -1, h, patch_size, w, patch_size)) |
|
|
|
|
|
x = torch.einsum('nchpwq->nhwpqc', x) |
|
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, -1)) |
|
return x |
|
|
|
def unpatchify(self, x, patch_size=None): |
|
""" |
|
Decoding encoded patch tokens |
|
|
|
input: |
|
x: (B, num_patches, patch_size**2 *3) AKA [B, num_patches, flatten_dim] |
|
|
|
output: |
|
imgs: (B, 3, H, W) |
|
""" |
|
|
|
p = self.patch_embed.patch_size[0] if patch_size is None else patch_size |
|
|
|
|
|
h = w = int(x.shape[1] ** .5) |
|
|
|
assert h * w == x.shape[1] |
|
|
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
|
|
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
|
|
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
|
return imgs |
|
|
|
def fix_position_shuffling(self, x, fix_position_ratio, puzzle_patch_size): |
|
""" |
|
Fix-position shuffling |
|
|
|
Randomly assign patches by per-sample shuffling. |
|
After it, the fixed patches are reserved as Positional Tokens |
|
the rest patches are batch wise randomly shuffled among the batch since they serve as Relation Tokens. |
|
|
|
Per-sample shuffling is done by argsort random noise. |
|
batch wise shuffle operation is done by shuffle all idxes |
|
|
|
input: |
|
x: [B, 3, H, W], input image tensor |
|
fix_position_ratio float |
|
puzzle_patch_size int |
|
|
|
output: x_puzzled, mask |
|
x_puzzled: [B, 3, H, W] |
|
mask: [B, 3, H, W], binary mask indicating pix position with 0 |
|
""" |
|
|
|
x = self.patchify(x, puzzle_patch_size) |
|
|
|
B, num_puzzle_patches, D = x.shape |
|
|
|
|
|
len_fix_position = int(num_puzzle_patches * fix_position_ratio) |
|
num_shuffled_patches = num_puzzle_patches - len_fix_position |
|
|
|
noise = torch.rand(B, num_puzzle_patches, device=x.device) |
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1) |
|
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
ids_fix = ids_shuffle[:, :len_fix_position] |
|
|
|
ids_puzzle = ids_shuffle[:, len_fix_position:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_fixed = torch.gather(x, dim=1, index=ids_fix.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
x_puzzle = torch.gather(x, dim=1, index=ids_puzzle.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
if self.group_shuffle_size == -1 or self.group_shuffle_size == B: |
|
puzzle_shuffle_indices = torch.randperm(B * num_shuffled_patches, device=x.device, requires_grad=False) |
|
else: |
|
assert B > self.group_shuffle_size > 0 and B % self.group_shuffle_size == 0 |
|
|
|
group_noise = torch.rand(B // self.group_shuffle_size, num_shuffled_patches * self.group_shuffle_size, device=x.device) |
|
|
|
group_ids_shuffle = torch.argsort(group_noise, dim=1) |
|
|
|
group_ids_shuffle = torch.stack([group_ids_shuffle[i] + |
|
num_shuffled_patches * self.group_shuffle_size * i |
|
for i in range(B // self.group_shuffle_size)]) |
|
|
|
puzzle_shuffle_indices = group_ids_shuffle.view(-1) |
|
|
|
|
|
x_puzzle = x_puzzle.view(B * num_shuffled_patches, D)[puzzle_shuffle_indices].view(B, num_shuffled_patches, D) |
|
|
|
|
|
x = torch.cat([x_fixed, x_puzzle], dim=1) |
|
|
|
|
|
mask = torch.ones([B, num_puzzle_patches, D], device=x.device, requires_grad=False) |
|
mask[:, :len_fix_position, :] = 0 |
|
|
|
|
|
x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
x = self.unpatchify(x, puzzle_patch_size) |
|
mask = self.unpatchify(mask, puzzle_patch_size) |
|
|
|
return x, mask |
|
|
|
def forward_puzzle(self, imgs, fix_position_ratio=0.25, puzzle_patch_size=32): |
|
""" |
|
Transform the input images to puzzle images |
|
|
|
input: |
|
x: [B, 3, H, W], input image tensor |
|
fix_position_ratio float |
|
puzzle_patch_size int |
|
|
|
output: x_puzzled, mask |
|
x_puzzled: [B, 3, H, W] |
|
mask: [B, 3, H, W], binary mask indicating pix position with 0 |
|
""" |
|
x_puzzled, mask = self.fix_position_shuffling(imgs, fix_position_ratio, puzzle_patch_size) |
|
return x_puzzled, mask |
|
|
|
def forward_encoder(self, imgs): |
|
""" |
|
:param imgs: [B, C, H, W], sequence of imgs |
|
|
|
:return: Encoder output: encoded tokens, mask position, restore idxs |
|
x: [B, num_patches, D], sequence of Tokens (including the cls token) |
|
CLS_token: [B, 1, D] |
|
""" |
|
|
|
if self.prompt_mode is None: |
|
|
|
x = self.patch_embed(imgs) |
|
|
|
|
|
x = x + self.pos_embed[:, 1:, :] |
|
|
|
|
|
embed_puzzle = x.data.detach() |
|
|
|
|
|
cls_token = self.cls_token + self.pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
|
|
else: |
|
x = self.patch_embed(imgs) |
|
|
|
x = x + self.pos_embed[:, 1:, :] |
|
|
|
|
|
embed_puzzle = x.data.detach() |
|
|
|
|
|
cls_token = self.cls_token + self.pos_embed[:, :1, :] |
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
if self.VPT_type == "Deep": |
|
|
|
Prompt_Token_num = self.Prompt_Tokens.shape[1] |
|
|
|
for i in range(len(self.blocks)): |
|
|
|
Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0) |
|
|
|
x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1) |
|
num_tokens = x.shape[1] |
|
|
|
x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num] |
|
|
|
else: |
|
Prompt_Token_num = self.Prompt_Tokens.shape[1] |
|
|
|
|
|
Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1) |
|
x = torch.cat((x, Prompt_Tokens), dim=1) |
|
num_tokens = x.shape[1] |
|
|
|
x = self.blocks(x)[:, :num_tokens - Prompt_Token_num] |
|
|
|
|
|
x = self.norm(x) |
|
|
|
CLS_token = x[:, :1, :] |
|
x = x[:, 1:, :] |
|
|
|
|
|
return x, CLS_token, embed_puzzle |
|
|
|
def forward_decoder(self, x): |
|
""" |
|
Decoder to reconstruct the puzzle image |
|
[B, 1 + num_patches, D_Encoder] -> [B, 1 + num_patches, D_Decoder] -> [B, num_patches, p*p*3] |
|
|
|
:param x: [B, 1 + num_patches, D_Encoder], sequence of Tokens (including the cls token) |
|
|
|
:return: Decoder output: reconstracted tokens |
|
x: [B, num_patches, patch_size ** 2 * in_chans], sequence of Patch Tokens |
|
""" |
|
|
|
if self.decoder is None: |
|
|
|
x = self.decoder_embed(x) |
|
|
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
|
|
for blk in self.decoder_blocks: |
|
x = blk(x) |
|
x = self.decoder_norm(x) |
|
|
|
|
|
x = self.decoder_pred(x) |
|
|
|
x = x[:, 1:, :] |
|
|
|
|
|
else: |
|
|
|
x = x[:, 1:, :] |
|
|
|
x = self.decoder_embed(x) |
|
|
|
x = self.unpatchify(x) |
|
|
|
x = self.decoder(x) |
|
|
|
x = self.patchify_decoder(x) |
|
|
|
x = self.decoder_pred(x) |
|
|
|
|
|
return x |
|
|
|
def forward_loss(self, imgs, pred, mask): |
|
""" |
|
MSE loss for all patches towards the ori image |
|
|
|
Input: |
|
imgs: [B, 3, H, W], Encoder input image |
|
pred: [B, num_patches, p*p*3], Decoder reconstructed image |
|
mask: [B, num_patches, p*p*3], 0 is keep, 1 is puzzled |
|
|
|
""" |
|
|
|
|
|
target = self.patchify(imgs) |
|
|
|
|
|
mask = mask[:, :, 0] |
|
|
|
if self.norm_pix_loss: |
|
mean = target.mean(dim=-1, keepdim=True) |
|
var = target.var(dim=-1, keepdim=True) |
|
target = (target - mean) / (var + 1.e-6) ** .5 |
|
|
|
|
|
loss = (pred - target) ** 2 |
|
loss = loss.mean(dim=-1) |
|
|
|
loss = (loss * mask).sum() / mask.sum() |
|
|
|
return loss |
|
|
|
def forward(self, imgs, fix_position_ratio=0.25, puzzle_patch_size=32, combined_pred_illustration=False): |
|
|
|
|
|
imgs_puzzled, mask = self.forward_puzzle(imgs, fix_position_ratio, puzzle_patch_size) |
|
|
|
|
|
imgs_puzzled_patches = self.patchify(imgs_puzzled) |
|
|
|
|
|
|
|
|
|
latent_puzzle, CLS_token, embed_puzzle = self.forward_encoder(imgs_puzzled) |
|
|
|
|
|
|
|
|
|
|
|
mask_patches_pp3 = self.patchify(mask) |
|
|
|
|
|
|
|
if mask_patches_pp3.shape[-1] != latent_puzzle.shape[-1]: |
|
|
|
mask_patches = mask_patches_pp3[:, :, :1].expand(-1, -1, latent_puzzle.shape[-1]) |
|
else: |
|
mask_patches = mask_patches_pp3 |
|
|
|
|
|
anti_mask = mask_patches * -1 + 1 |
|
|
|
|
|
|
|
latent_tokens = latent_puzzle * mask_patches |
|
|
|
hint_tokens = embed_puzzle * anti_mask |
|
|
|
latent = latent_tokens + hint_tokens |
|
|
|
x = torch.cat([CLS_token, latent], dim=1) |
|
|
|
|
|
|
|
pred = self.forward_decoder(x) |
|
|
|
|
|
anti_mask_patches_pp3 = mask_patches_pp3 * -1 + 1 |
|
hint_img_patches = imgs_puzzled_patches * anti_mask_patches_pp3 |
|
pred_img_patches = pred * mask_patches_pp3 |
|
pred_with_hint_imgs = hint_img_patches + pred_img_patches |
|
|
|
|
|
loss = self.forward_loss(imgs, pred, mask_patches) |
|
|
|
|
|
if combined_pred_illustration: |
|
return loss, pred_with_hint_imgs, imgs_puzzled_patches |
|
else: |
|
return loss, pred, imgs_puzzled_patches |
|
|
|
|
|
def sae_vit_base_patch16_dec512d8b(dec_idx=None, **kwargs): |
|
print("Decoder:", dec_idx) |
|
|
|
model = ShuffledAutoEncoderViT( |
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
|
|
def sae_vit_large_patch16_dec512d8b(dec_idx=None, **kwargs): |
|
print("Decoder:", dec_idx) |
|
|
|
model = ShuffledAutoEncoderViT( |
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
|
|
def sae_vit_huge_patch14_dec512d8b(dec_idx=None, **kwargs): |
|
print("Decoder:", dec_idx) |
|
|
|
model = ShuffledAutoEncoderViT( |
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, |
|
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
return model |
|
|
|
|
|
|
|
def sae_vit_base_patch16_dec(dec_idx=None, num_classes=3, img_size=224, **kwargs): |
|
|
|
|
|
if dec_idx == 'swin_unet': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg |
|
decoder = ViT_seg(num_classes=num_classes, img_size=img_size, patch_size=16) |
|
|
|
elif dec_idx == 'transunet': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
transunet_name = 'R50-ViT-B_16' |
|
transunet_patches_size = 16 |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg |
|
|
|
config_vit = CONFIGS_Transunet_seg[transunet_name] |
|
config_vit.n_classes = num_classes |
|
config_vit.n_skip = 3 |
|
|
|
if transunet_name.find('R50') != -1: |
|
config_vit.patches.grid = ( |
|
int(img_size / transunet_patches_size), int(img_size / transunet_patches_size)) |
|
decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes) |
|
|
|
elif dec_idx == 'UTNetV2': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg |
|
decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes) |
|
|
|
else: |
|
print('no effective decoder!') |
|
return -1 |
|
|
|
print('dec_idx: ', dec_idx) |
|
|
|
model = ShuffledAutoEncoderViT( |
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|
decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder, |
|
**kwargs) |
|
return model |
|
|
|
|
|
def sae_vit_large_patch16_dec(dec_idx=None, num_classes=3, img_size=224, **kwargs): |
|
|
|
|
|
if dec_idx == 'swin_unet': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg |
|
decoder = ViT_seg(num_classes=num_classes, img_size=img_size, patch_size=16) |
|
|
|
elif dec_idx == 'transunet': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
transunet_name = 'R50-ViT-B_16' |
|
transunet_patches_size = 16 |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg |
|
|
|
config_vit = CONFIGS_Transunet_seg[transunet_name] |
|
config_vit.n_classes = num_classes |
|
config_vit.n_skip = 3 |
|
|
|
if transunet_name.find('R50') != -1: |
|
config_vit.patches.grid = ( |
|
int(img_size / transunet_patches_size), int(img_size / transunet_patches_size)) |
|
decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes) |
|
|
|
elif dec_idx == 'UTNetV2': |
|
decoder_embed_dim = 768 |
|
decoder_rep_dim = 16 * 16 * 3 |
|
|
|
from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg |
|
decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes) |
|
|
|
else: |
|
print('no effective decoder!') |
|
return -1 |
|
|
|
print('dec_idx: ', dec_idx) |
|
|
|
model = ShuffledAutoEncoderViT( |
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, |
|
decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder, |
|
**kwargs) |
|
return model |
|
|
|
|
|
def sae_vit_huge_patch14_dec(dec_idx=None, num_classes=3, img_size=224, **kwargs): |
|
|
|
|
|
if dec_idx == 'swin_unet': |
|
decoder_embed_dim = 14 * 14 * 3 |
|
decoder_rep_dim = 14 * 14 * 3 |
|
|
|
from SSL_structures.Swin_Unet_main.networks.vision_transformer import SwinUnet as ViT_seg |
|
decoder = ViT_seg(num_classes=num_classes, img_size=img_size, patch_size=16) |
|
|
|
elif dec_idx == 'transunet': |
|
decoder_embed_dim = 14 * 14 * 3 |
|
decoder_rep_dim = 14 * 14 * 3 |
|
|
|
transunet_name = 'R50-ViT-B_16' |
|
transunet_patches_size = 16 |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import CONFIGS as CONFIGS_Transunet_seg |
|
from SSL_structures.TransUNet_main.networks.vit_seg_modeling import VisionTransformer as Transunet_seg |
|
|
|
config_vit = CONFIGS_Transunet_seg[transunet_name] |
|
config_vit.n_classes = num_classes |
|
config_vit.n_skip = 3 |
|
|
|
if transunet_name.find('R50') != -1: |
|
config_vit.patches.grid = ( |
|
int(img_size / transunet_patches_size), int(img_size / transunet_patches_size)) |
|
decoder = Transunet_seg(config_vit, num_classes=config_vit.n_classes) |
|
|
|
elif dec_idx == 'UTNetV2': |
|
decoder_embed_dim = 14 * 14 * 3 |
|
decoder_rep_dim = 14 * 14 * 3 |
|
|
|
from SSL_structures.UtnetV2.utnetv2 import UTNetV2 as UTNetV2_seg |
|
decoder = UTNetV2_seg(in_chan=3, num_classes=num_classes) |
|
|
|
else: |
|
print('no effective decoder!') |
|
return -1 |
|
|
|
print('dec_idx: ', dec_idx) |
|
|
|
model = ShuffledAutoEncoderViT( |
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, |
|
decoder_embed_dim=decoder_embed_dim, decoder_depth=8, decoder_num_heads=16, |
|
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), decoder_rep_dim=decoder_rep_dim, decoder=decoder, |
|
**kwargs) |
|
return model |
|
|
|
|
|
|
|
sae_vit_base_patch16 = sae_vit_base_patch16_dec512d8b |
|
sae_vit_large_patch16 = sae_vit_large_patch16_dec512d8b |
|
sae_vit_huge_patch14 = sae_vit_huge_patch14_dec512d8b |
|
|
|
|
|
sae_vit_base_patch16_decoder = sae_vit_base_patch16_dec |
|
sae_vit_large_patch16_decoder = sae_vit_large_patch16_dec |
|
sae_vit_huge_patch14_decoder = sae_vit_huge_patch14_dec |
|
|
|
if __name__ == '__main__': |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
img_size = 224 |
|
|
|
''' |
|
num_classes = 3 # set to 3 for 3 channel |
|
x = torch.rand(2, 3, img_size, img_size, device=device) |
|
''' |
|
|
|
image_tensor_path = './temp-tensors/color.pt' |
|
x = torch.load(image_tensor_path) |
|
x.to(device) |
|
|
|
|
|
|
|
|
|
model = sae_vit_base_patch16(img_size=img_size, decoder=None, group_shuffle_size=2) |
|
|
|
''' |
|
# ViT_Prompt |
|
|
|
from pprint import pprint |
|
model_names = timm.list_models('*vit*') |
|
pprint(model_names) |
|
|
|
basic_model = timm.create_model('vit_base_patch' + str(16) + '_' + str(edge_size), pretrained=True) |
|
|
|
basic_state_dict = basic_model.state_dict() |
|
|
|
model = sae_vit_base_patch16(img_size=384, prompt_mode='Deep', Prompt_Token_num=20, basic_state_dict=basic_state_dict) |
|
|
|
prompt_state_dict = model.obtain_prompt() |
|
VPT = VPT_ViT(img_size=384, VPT_type='Deep', Prompt_Token_num=20, basic_state_dict=basic_state_dict) |
|
VPT.load_prompt(prompt_state_dict) |
|
VPT.to(device) |
|
pred = VPT(x) |
|
print(pred) |
|
''' |
|
|
|
model.to(device) |
|
|
|
loss, pred, imgs_puzzled_patches = model(x, fix_position_ratio=0.25, puzzle_patch_size=32, |
|
combined_pred_illustration=True) |
|
|
|
|
|
|
|
|
|
from utils.visual_usage import * |
|
|
|
imgs_puzzled_batch = unpatchify(imgs_puzzled_patches, patch_size=16) |
|
for img_idx in range(len(imgs_puzzled_batch)): |
|
puzzled_img = imgs_puzzled_batch.cpu()[img_idx] |
|
puzzled_img = ToPILImage()(puzzled_img) |
|
puzzled_img.save(os.path.join('./temp-figs/', 'puzzled_sample_'+str(img_idx)+'.jpg')) |
|
|
|
recons_img_batch = unpatchify(pred, patch_size=16) |
|
recons_img = recons_img_batch.cpu()[img_idx] |
|
recons_img = ToPILImage()(recons_img) |
|
recons_img.save(os.path.join('./temp-figs/', 'recons_sample_'+str(img_idx)+'.jpg')) |
|
''' |
|
|
|
print(loss, '\n') |
|
|
|
print(loss.shape, '\n') |
|
|
|
print(pred.shape, '\n') |
|
|
|
print(imgs_puzzled_patches.shape, '\n') |
|
''' |