UniCell / models /unicell_modules.py
junma's picture
add app
56afa1a
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import torch
import torch.nn as nn
from functools import partial
import math
from itertools import repeat
import collections.abc
from typing import Tuple, Union
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock, UnetrPrUpBlock
from monai.networks.blocks.dynunet_block import get_conv_layer
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
#%%
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
#%%
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x) # [2, 3, 224, 224]-> [2, 64, 56, 56]
# print(f"{x.shape=}")
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2) # [2, 64, 56, 56]-> [2, 3136, 64]
# print(f"{x.shape=}")
x = self.norm(x) # [2, 3136, 64]-> [2, 3136, 64]
# print(f"{x.shape=}")
return x, H, W
# embed_dims=[64, 128, 256, 512]
# patch_embed1 = OverlapPatchEmbed(img_size=224,patch_size=7,stride=4,in_chans=in_chans, embed_dim=64)
# x1, H, W = patch_embed1(input_img)
# x1 = x1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
# embed_dim=embed_dims[1])
# x2, H, W = patch_embed2(x1)
# x2 = x2.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
# embed_dim=embed_dims[2])
# x3, H, W = patch_embed3(x2)
# x3 = x3.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],embed_dim=embed_dims[3])
# x4, H, W = patch_embed4(x3)
# x4 = x4.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
#%%
class MixVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
super().__init__()
# self.num_classes = num_classes
self.depths = depths
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
self.norm1 = norm_layer(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
self.norm2 = norm_layer(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
self.norm3 = norm_layer(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm4 = norm_layer(embed_dims[3])
# classification head
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
# logger = get_root_logger()
# load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
# load_checkpoint(self, pretrained, map_location='cpu', strict=False)
torch.load(pretrained, map_location='cpu')
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
def get_classifier(self):
return self.head
# def reset_classifier(self, num_classes, global_pool=''):
# self.num_classes = num_classes
# self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
outs = []
# stage 1
x, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 2
x, H, W = self.patch_embed2(x)
for i, blk in enumerate(self.block2):
x = blk(x, H, W)
x = self.norm2(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 3
x, H, W = self.patch_embed3(x)
for i, blk in enumerate(self.block3):
x = blk(x, H, W)
x = self.norm3(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 4
x, H, W = self.patch_embed4(x)
for i, blk in enumerate(self.block4):
x = blk(x, H, W)
x = self.norm4(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
def forward(self, x):
x = self.forward_features(x)
# x = self.head(x)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class mit_b0(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b0, self).__init__(
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b1(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b1, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b2(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b2, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b3(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b3, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b4(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b4, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b5(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b5, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
#%% B2
class MiT_B2_UNet_MultiHead(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
regress_class: int = 1,
img_size: Tuple[int, int] = (256,256),
feature_size: int = 16,
spatial_dims: int = 2,
# hidden_size: int = 768,
# mlp_dim: int = 3072,
num_heads = [1, 2, 4, 8],
# pos_embed: str = "perceptron",
norm_name: Union[Tuple, str] = "instance",
conv_block: bool = False,
res_block: bool = True,
dropout_rate: float = 0.0,
debug: bool = False
):
super().__init__()
self.debug = debug
self.mit_b3 = MixVisionTransformer(img_size=img_size, patch_size=4, embed_dims=[feature_size*2, feature_size*4, feature_size*8, feature_size*16],
num_heads=num_heads, mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
self.encoder1 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder2 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=2 * feature_size,
out_channels=2 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder3 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=4 * feature_size,
out_channels=4 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder4 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=8 * feature_size,
out_channels=8 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder5 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=16 * feature_size,
out_channels=16 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 16,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.transp_conv = get_conv_layer(
spatial_dims=2,
in_channels=feature_size*2,
out_channels=feature_size*2,
kernel_size=3,
stride=2,
conv_only=True,
is_transposed=True,
)
self.decoder1 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.out_interior = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore
self.out_dist = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=1) # type: ignore
def forward(self, x_in):
hidden_states_out = self.mit_b3(x_in) # x: (B, 256,768), hidden_states_out: list, 12 elements, (B,256,768)
enc1 = self.encoder1(x_in) # (B, 16, 256, 256)
x1 = hidden_states_out[0] # (B, 32, 64, 64)
enc2 = self.encoder2(x1) # (B, 64, 32, 32)
x2 = hidden_states_out[1] # (B, 64, 32, 32)
enc3 = self.encoder3(x2) # (B, 128, 16, 16)
x3 = hidden_states_out[2] # (B, 128, 16,16)
enc4 = self.encoder4(x3) # (B, 256, 8, 8)
x4 = hidden_states_out[3] # (B, 256, 8, 8)
enc5 = self.encoder5(x4) # (B, 256, 8, 8)
# print(f"{enc1.shape=}, {enc2.shape=}, {enc3.shape=}, {enc4.shape=}, {enc5.shape=}")
dec4 = self.decoder4(enc5, enc4) # (B, 128, 16, 16); up -> cat -> ResConv; (B, 128, 16, 16)
dec3 = self.decoder3(dec4, enc3) # (B, 64, 32, 32)
dec2 = self.decoder2(dec3, enc2) # (B, 32, 64, 64)
dec2_up = self.transp_conv(dec2) # [B, 32, 128, 128]
dec1 = self.decoder1(dec2_up, enc1) # (B, 16, 256, 256)
logits = self.out_interior(dec1)
dist = self.out_dist(dec1)
if self.debug:
return hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits
else:
return logits, dist
# print(f"{dec1.shape=}, {dec2.shape=}, {dec3.shape=}, {dec4.shape=}, {logits.shape=}")
img_size = 256
in_chans = 3
B = 2
input_img = torch.randn((B,in_chans,img_size,img_size))
b2 = MiT_B2_UNet_MultiHead(3, 3, img_size=img_size)
logits, dist = b2(input_img)
#%% B3
class MiT_B3_UNet_MultiHead(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
regress_class: int = 1,
img_size: Tuple[int, int] = (256,256),
feature_size: int = 16,
spatial_dims: int = 2,
# hidden_size: int = 768,
# mlp_dim: int = 3072,
num_heads = [1, 2, 4, 8],
# pos_embed: str = "perceptron",
norm_name: Union[Tuple, str] = "instance",
conv_block: bool = False,
res_block: bool = True,
dropout_rate: float = 0.0,
debug: bool = False
):
super().__init__()
self.debug = debug
self.mit_b3 = MixVisionTransformer(img_size=img_size, patch_size=4, embed_dims=[feature_size*2, feature_size*4, feature_size*8, feature_size*16],
num_heads=num_heads, mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
self.encoder1 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder2 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=2 * feature_size,
out_channels=2 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder3 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=4 * feature_size,
out_channels=4 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder4 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=8 * feature_size,
out_channels=8 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder5 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=16 * feature_size,
out_channels=16 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 16,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.transp_conv = get_conv_layer(
spatial_dims=2,
in_channels=feature_size*2,
out_channels=feature_size*2,
kernel_size=3,
stride=2,
conv_only=True,
is_transposed=True,
)
self.decoder1 = UnetrUpBlock(
spatial_dims=2,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.out_interior = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore
self.out_dist = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=1) # type: ignore
def forward(self, x_in):
hidden_states_out = self.mit_b3(x_in) # x: (B, 256,768), hidden_states_out: list, 12 elements, (B,256,768)
enc1 = self.encoder1(x_in) # (B, 16, 256, 256)
x1 = hidden_states_out[0] # (B, 32, 64, 64)
enc2 = self.encoder2(x1) # (B, 64, 32, 32)
x2 = hidden_states_out[1] # (B, 64, 32, 32)
enc3 = self.encoder3(x2) # (B, 128, 16, 16)
x3 = hidden_states_out[2] # (B, 128, 16,16)
enc4 = self.encoder4(x3) # (B, 256, 8, 8)
x4 = hidden_states_out[3] # (B, 256, 8, 8)
enc5 = self.encoder5(x4) # (B, 256, 8, 8)
# print(f"{enc1.shape=}, {enc2.shape=}, {enc3.shape=}, {enc4.shape=}, {enc5.shape=}")
dec4 = self.decoder4(enc5, enc4) # (B, 128, 16, 16); up -> cat -> ResConv; (B, 128, 16, 16)
dec3 = self.decoder3(dec4, enc3) # (B, 64, 32, 32)
dec2 = self.decoder2(dec3, enc2) # (B, 32, 64, 64)
dec2_up = self.transp_conv(dec2) # [B, 32, 128, 128]
dec1 = self.decoder1(dec2_up, enc1) # (B, 16, 256, 256)
logits = self.out_interior(dec1)
dist = self.out_dist(dec1)
if self.debug:
return hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits
else:
return logits, dist
# print(f"{dec1.shape=}, {dec2.shape=}, {dec3.shape=}, {dec4.shape=}, {logits.shape=}")
#%% head
class MLPEmbedding(nn.Module):
"""
Linear Embedding
used in head
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class All_MLP_Head(nn.Module):
"""
All MLP head in segformer
Simple and Efficient Design for Semantic Segmentation with Transformers
"""
def __init__(self, in_channels=[64,128,320,512], # channel number of multi-scale features
in_index=[0,1,2,3],
feature_strides=[4,8,16,32],
dropout_ratio=0.1,
num_classes=3,
embedding_dim=768,
output_hidden_states=False):
super().__init__()
self.in_channels = in_channels
assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0]
self.in_index = in_index
self.feature_strides = feature_strides
self.dropout_ratio = dropout_ratio
self.num_classes = num_classes
self.output_hidden_states = output_hidden_states
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
# unify channel number to 768
self.linear_c4 = MLPEmbedding(input_dim=c4_in_channels, embed_dim=embedding_dim)
self.linear_c3 = MLPEmbedding(input_dim=c3_in_channels, embed_dim=embedding_dim)
self.linear_c2 = MLPEmbedding(input_dim=c2_in_channels, embed_dim=embedding_dim)
self.linear_c1 = MLPEmbedding(input_dim=c1_in_channels, embed_dim=embedding_dim)
self.linear_fuse = nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1,bias=False)
self.batch_norm = nn.BatchNorm2d(embedding_dim) # 4: number of blocks
self.activation = nn.ReLU()
if dropout_ratio>0:
self.dropout = nn.Dropout2d(self.dropout_ratio)
self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
def forward(self, inputs):
# x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
c1, c2, c3, c4 = inputs
############## MLP decoder on C1-C4 ###########
n, _, h, w = c4.shape
# normalize channel number and resample to 1/4 HxW
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = nn.functional.interpolate(_c4, size=c1.size()[2:], mode='bilinear',align_corners=False)
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = nn.functional.interpolate(_c3, size=c1.size()[2:], mode='bilinear',align_corners=False)
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = nn.functional.interpolate(_c2, size=c1.size()[2:], mode='bilinear',align_corners=False)
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
# concatenate features
hidden_states = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.dropout(hidden_states)
# predict results
x = self.linear_pred(hidden_states)
if self.output_hidden_states:
return x, hidden_states
else:
return x
#%% test different networks
# img_size = 256
# in_chans = 3
# B = 2
# input_img = torch.randn((B,in_chans,img_size,img_size))
# b3 = mit_b3_demo(img_size=img_size)
# b3_out = b3(input_img)
# for feature in b3_out:
# print(f"{feature.shape=}")
# head = All_MLP_Head()
# outputs = head(b3_out)
# print(f"{outputs.shape = }")