|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock |
|
from monai.networks.layers import Conv |
|
from monai.utils import ensure_tuple_rep |
|
|
|
from typing import Sequence, Union |
|
import torch |
|
import torch.nn as nn |
|
|
|
from ..nn.blocks import TransformerBlock |
|
from icecream import ic |
|
ic.disable() |
|
|
|
__all__ = ["ViTAutoEnc"] |
|
|
|
|
|
class ViTAutoEnc(nn.Module): |
|
""" |
|
Vision Transformer (ViT), based on: "Dosovitskiy et al., |
|
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" |
|
|
|
Modified to also give same dimension outputs as the input size of the image |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
img_size: Union[Sequence[int], int], |
|
patch_size: Union[Sequence[int], int], |
|
out_channels: int = 1, |
|
deconv_chns: int = 16, |
|
hidden_size: int = 768, |
|
mlp_dim: int = 3072, |
|
num_layers: int = 12, |
|
num_heads: int = 12, |
|
pos_embed: str = "conv", |
|
dropout_rate: float = 0.0, |
|
spatial_dims: int = 3, |
|
) -> None: |
|
""" |
|
Args: |
|
in_channels: dimension of input channels or the number of channels for input |
|
img_size: dimension of input image. |
|
patch_size: dimension of patch size. |
|
hidden_size: dimension of hidden layer. |
|
out_channels: number of output channels. |
|
deconv_chns: number of channels for the deconvolution layers. |
|
mlp_dim: dimension of feedforward layer. |
|
num_layers: number of transformer blocks. |
|
num_heads: number of attention heads. |
|
pos_embed: position embedding layer type. |
|
dropout_rate: faction of the input units to drop. |
|
spatial_dims: number of spatial dimensions. |
|
|
|
Examples:: |
|
|
|
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone |
|
# It will provide an output of same size as that of the input |
|
>>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), pos_embed='conv') |
|
|
|
# for 3-channel with image size of (128,128,128), output will be same size as of input |
|
>>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), pos_embed='conv') |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) |
|
self.spatial_dims = spatial_dims |
|
self.hidden_size = hidden_size |
|
|
|
self.patch_embedding = PatchEmbeddingBlock( |
|
in_channels=in_channels, |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
hidden_size=hidden_size, |
|
num_heads=num_heads, |
|
pos_embed=pos_embed, |
|
dropout_rate=dropout_rate, |
|
spatial_dims=self.spatial_dims, |
|
) |
|
self.blocks = nn.ModuleList( |
|
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] |
|
) |
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
|
new_patch_size = [4] * self.spatial_dims |
|
conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims] |
|
|
|
self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=new_patch_size, stride=new_patch_size) |
|
self.conv3d_transpose_1 = conv_trans( |
|
in_channels=deconv_chns, out_channels=out_channels, kernel_size=new_patch_size, stride=new_patch_size |
|
) |
|
|
|
def forward(self, x, return_emb=False, return_hiddens=False): |
|
""" |
|
Args: |
|
x: input tensor must have isotropic spatial dimensions, |
|
such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. |
|
""" |
|
spatial_size = x.shape[2:] |
|
x = self.patch_embedding(x) |
|
hidden_states_out = [] |
|
for blk in self.blocks: |
|
x = blk(x) |
|
hidden_states_out.append(x) |
|
x = self.norm(x) |
|
x = x.transpose(1, 2) |
|
if return_emb: |
|
return x |
|
d = [s // p for s, p in zip(spatial_size, self.patch_size)] |
|
x = torch.reshape(x, [x.shape[0], x.shape[1], *d]) |
|
x = self.conv3d_transpose(x) |
|
x = self.conv3d_transpose_1(x) |
|
if return_hiddens: |
|
return x, hidden_states_out |
|
return x |
|
|
|
def get_last_selfattention(self, x): |
|
""" |
|
Args: |
|
x: input tensor must have isotropic spatial dimensions, |
|
such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``. |
|
""" |
|
x = self.patch_embedding(x) |
|
ic(x.size()) |
|
for i, blk in enumerate(self.blocks): |
|
if i < len(self.blocks) - 1: |
|
x = blk(x) |
|
x.size() |
|
else: |
|
return blk(x, return_attention=True) |
|
|
|
def load(self, ckpt_path, map_location='cpu', checkpoint_key='state_dict'): |
|
""" |
|
Args: |
|
ckpt_path: path to the pretrained weights |
|
map_location: device to load the checkpoint on |
|
""" |
|
state_dict = torch.load(ckpt_path, map_location=map_location) |
|
ic(state_dict['epoch'], state_dict['train_loss']) |
|
if checkpoint_key in state_dict: |
|
print(f"Take key {checkpoint_key} in provided checkpoint dict") |
|
state_dict = state_dict[checkpoint_key] |
|
|
|
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
|
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} |
|
msg = self.load_state_dict(state_dict, strict=False) |
|
print('Pretrained weights found at {} and loaded with msg: {}'.format(ckpt_path, msg)) |
|
|
|
|
|
|