|
import torch |
|
import torch.nn as nn |
|
from operator import itemgetter |
|
|
|
from typing import Type, Callable, Tuple, Optional, Set, List, Union |
|
from timm.models.layers import drop_path, trunc_normal_, Mlp, DropPath |
|
from timm.models.efficientnet_blocks import SqueezeExcite, DepthwiseSeparableConv |
|
|
|
def exists(val): |
|
|
|
return val is not None |
|
|
|
def map_el_ind(arr, ind): |
|
|
|
return list(map(itemgetter(ind), arr)) |
|
|
|
def sort_and_return_indices(arr): |
|
|
|
indices = [ind for ind in range(len(arr))] |
|
arr = zip(arr, indices) |
|
arr = sorted(arr) |
|
|
|
return map_el_ind(arr, 0), map_el_ind(arr, 1) |
|
|
|
def calculate_permutations(num_dimensions, emb_dim): |
|
total_dimensions = num_dimensions + 2 |
|
axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim] |
|
|
|
permutations = [] |
|
|
|
for axial_dim in axial_dims: |
|
last_two_dims = [axial_dim, emb_dim] |
|
dims_rest = set(range(0, total_dimensions)) - set(last_two_dims) |
|
permutation = [*dims_rest, *last_two_dims] |
|
permutations.append(permutation) |
|
|
|
return permutations |
|
|
|
class ChanLayerNorm(nn.Module): |
|
def __init__(self, dim, eps = 1e-5): |
|
super().__init__() |
|
self.eps = eps |
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) |
|
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) |
|
|
|
def forward(self, x): |
|
|
|
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt() |
|
mean = torch.mean(x, dim = 1, keepdim = True) |
|
return (x - mean) / (std + self.eps) * self.g + self.b |
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn): |
|
super().__init__() |
|
self.fn = fn |
|
self.norm = nn.LayerNorm(dim) |
|
|
|
def forward(self, x): |
|
|
|
x = self.norm(x) |
|
|
|
return self.fn(x) |
|
|
|
class PermuteToFrom(nn.Module): |
|
|
|
def __init__(self, permutation, fn): |
|
super().__init__() |
|
|
|
self.fn = fn |
|
_, inv_permutation = sort_and_return_indices(permutation) |
|
self.permutation = permutation |
|
self.inv_permutation = inv_permutation |
|
|
|
def forward(self, x, **kwargs): |
|
|
|
axial = x.permute(*self.permutation).contiguous() |
|
shape = axial.shape |
|
*_, t, d = shape |
|
axial = axial.reshape(-1, t, d) |
|
axial = self.fn(axial, **kwargs) |
|
axial = axial.reshape(*shape) |
|
axial = axial.permute(*self.inv_permutation).contiguous() |
|
|
|
return axial |
|
|
|
class AxialPositionalEmbedding(nn.Module): |
|
def __init__(self, dim, shape, emb_dim_index = 1): |
|
super().__init__() |
|
parameters = [] |
|
total_dimensions = len(shape) + 2 |
|
ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index] |
|
|
|
self.num_axials = len(shape) |
|
|
|
for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)): |
|
shape = [1] * total_dimensions |
|
shape[emb_dim_index] = dim |
|
shape[axial_dim_index] = axial_dim |
|
parameter = nn.Parameter(torch.randn(*shape)) |
|
setattr(self, f'param_{i}', parameter) |
|
|
|
def forward(self, x): |
|
|
|
for i in range(self.num_axials): |
|
x = x + getattr(self, f'param_{i}') |
|
|
|
return x |
|
|
|
class SelfAttention(nn.Module): |
|
def __init__(self, dim, heads, dim_heads=None, drop=0): |
|
super().__init__() |
|
self.dim_heads = (dim // heads) if dim_heads is None else dim_heads |
|
dim_hidden = self.dim_heads * heads |
|
self.drop_rate = drop |
|
self.heads = heads |
|
self.to_q = nn.Linear(dim, dim_hidden, bias = False) |
|
self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias = False) |
|
self.to_out = nn.Linear(dim_hidden, dim) |
|
self.proj_drop = DropPath(drop) |
|
|
|
def forward(self, x, kv = None): |
|
kv = x if kv is None else kv |
|
q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1)) |
|
b, t, d, h, e = *q.shape, self.heads, self.dim_heads |
|
merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e) |
|
|
|
q, k, v = map(merge_heads, (q, k, v)) |
|
dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5) |
|
dots = dots.softmax(dim=-1) |
|
|
|
out = torch.einsum('bij,bje->bie', dots, v) |
|
out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d) |
|
out = self.to_out(out) |
|
out = self.proj_drop(out) |
|
|
|
return out |
|
|
|
class AxialTransformerBlock(nn.Module): |
|
def __init__(self, |
|
dim, |
|
axial_pos_emb_shape, |
|
pos_embed, |
|
heads = 8, |
|
dim_heads = None, |
|
drop = 0., |
|
drop_path_rate=0., |
|
): |
|
super().__init__() |
|
|
|
dim_index = 1 |
|
|
|
permutations = calculate_permutations(2, dim_index) |
|
|
|
self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if pos_embed else nn.Identity() |
|
|
|
self.height_attn, self.width_attn = nn.ModuleList([PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads, drop=drop))) for permutation in permutations]) |
|
|
|
self.FFN = nn.Sequential( |
|
ChanLayerNorm(dim), |
|
nn.Conv2d(dim, dim * 4, 3, padding = 1), |
|
nn.GELU(), |
|
DropPath(drop), |
|
nn.Conv2d(dim * 4, dim, 3, padding = 1), |
|
DropPath(drop), |
|
|
|
ChanLayerNorm(dim), |
|
nn.Conv2d(dim, dim * 4, 3, padding = 1), |
|
nn.GELU(), |
|
DropPath(drop), |
|
nn.Conv2d(dim * 4, dim, 3, padding = 1), |
|
DropPath(drop), |
|
) |
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
|
def forward(self, x): |
|
|
|
x = self.pos_emb(x) |
|
x = x + self.drop_path(self.height_attn(x)) |
|
x = x + self.drop_path(self.width_attn(x)) |
|
x = x + self.drop_path(self.FFN(x)) |
|
|
|
return x |
|
|
|
def pair(t): |
|
|
|
return t if isinstance(t, tuple) else (t, t) |
|
|
|
def _gelu_ignore_parameters(*args, **kwargs) -> nn.Module: |
|
|
|
activation = nn.GELU() |
|
|
|
return activation |
|
|
|
class DoubleConv(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
downscale: bool = False, |
|
act_layer: Type[nn.Module] = nn.GELU, |
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d, |
|
drop_path: float = 0., |
|
) -> None: |
|
|
|
super(DoubleConv, self).__init__() |
|
|
|
self.drop_path_rate: float = drop_path |
|
|
|
if act_layer == nn.GELU: |
|
act_layer = _gelu_ignore_parameters |
|
|
|
self.main_path = nn.Sequential( |
|
norm_layer(in_channels), |
|
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1)), |
|
DepthwiseSeparableConv(in_chs=in_channels, out_chs=out_channels, stride=2 if downscale else 1, |
|
act_layer=act_layer, norm_layer=norm_layer, drop_path_rate=drop_path), |
|
SqueezeExcite(in_chs=out_channels, rd_ratio=0.25), |
|
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, 1)) |
|
) |
|
|
|
if downscale: |
|
self.skip_path = nn.Sequential( |
|
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), |
|
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1)) |
|
) |
|
else: |
|
self.skip_path = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1)) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
output = self.main_path(x) |
|
|
|
if self.drop_path_rate > 0.: |
|
output = drop_path(output, self.drop_path_rate, self.training) |
|
|
|
x = output + self.skip_path(x) |
|
|
|
return x |
|
|
|
|
|
class DeconvModule(nn.Module): |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
norm_layer=nn.BatchNorm2d, |
|
act_layer=nn.Mish, |
|
kernel_size=4, |
|
scale_factor=2): |
|
super(DeconvModule, self).__init__() |
|
|
|
assert (kernel_size - scale_factor >= 0) and\ |
|
(kernel_size - scale_factor) % 2 == 0,\ |
|
f'kernel_size should be greater than or equal to scale_factor '\ |
|
f'and (kernel_size - scale_factor) should be even numbers, '\ |
|
f'while the kernel size is {kernel_size} and scale_factor is '\ |
|
f'{scale_factor}.' |
|
|
|
stride = scale_factor |
|
padding = (kernel_size - scale_factor) // 2 |
|
deconv = nn.ConvTranspose2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding) |
|
|
|
norm = norm_layer(out_channels) |
|
activate = act_layer() |
|
self.deconv_upsamping = nn.Sequential(deconv, norm, activate) |
|
|
|
def forward(self, x): |
|
|
|
out = self.deconv_upsamping(x) |
|
|
|
return out |
|
|
|
class Stage(nn.Module): |
|
|
|
def __init__(self, |
|
image_size: int, |
|
depth: int, |
|
in_channels: int, |
|
out_channels: int, |
|
type_name: str, |
|
pos_embed: bool, |
|
num_heads: int = 32, |
|
drop: float = 0., |
|
drop_path: Union[List[float], float] = 0., |
|
act_layer: Type[nn.Module] = nn.GELU, |
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d, |
|
): |
|
super().__init__() |
|
self.type_name = type_name |
|
|
|
if self.type_name == "encoder": |
|
|
|
self.conv = DoubleConv( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
downscale=True, |
|
act_layer=act_layer, |
|
norm_layer=norm_layer, |
|
drop_path=drop_path[0], |
|
) |
|
|
|
self.blocks = nn.Sequential(*[ |
|
AxialTransformerBlock( |
|
dim=out_channels, |
|
axial_pos_emb_shape=pair(image_size), |
|
heads = num_heads, |
|
drop = drop, |
|
drop_path_rate=drop_path[index], |
|
dim_heads = None, |
|
pos_embed=pos_embed |
|
) |
|
for index in range(depth) |
|
]) |
|
|
|
elif self.type_name == "decoder": |
|
|
|
self.upsample = DeconvModule( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
norm_layer=norm_layer, |
|
act_layer=act_layer |
|
) |
|
|
|
self.conv = DoubleConv( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
downscale=False, |
|
act_layer=act_layer, |
|
norm_layer=norm_layer, |
|
drop_path=drop_path[0], |
|
) |
|
|
|
self.blocks = nn.Sequential(*[ |
|
AxialTransformerBlock( |
|
dim=out_channels, |
|
axial_pos_emb_shape=pair(image_size), |
|
heads = num_heads, |
|
drop = drop, |
|
drop_path_rate=drop_path[index], |
|
dim_heads = None, |
|
pos_embed=pos_embed |
|
) |
|
for index in range(depth) |
|
]) |
|
|
|
def forward(self, x, skip=None): |
|
|
|
if self.type_name == "encoder": |
|
x = self.conv(x) |
|
x = self.blocks(x) |
|
|
|
elif self.type_name == "decoder": |
|
x = self.upsample(x) |
|
x = torch.cat([skip, x], dim=1) |
|
x = self.conv(x) |
|
x = self.blocks(x) |
|
|
|
return x |
|
|
|
class FinalExpand(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
embed_dim, |
|
out_channels, |
|
norm_layer, |
|
act_layer, |
|
): |
|
super().__init__() |
|
self.upsample = DeconvModule( |
|
in_channels=in_channels, |
|
out_channels=embed_dim, |
|
norm_layer=norm_layer, |
|
act_layer=act_layer |
|
) |
|
|
|
self.conv = nn.Sequential( |
|
nn.Conv2d(in_channels=embed_dim*2, out_channels=embed_dim, kernel_size=3, stride=1, padding=1), |
|
act_layer(), |
|
nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=3, stride=1, padding=1), |
|
act_layer(), |
|
) |
|
|
|
def forward(self, skip, x): |
|
x = self.upsample(x) |
|
x = torch.cat([skip, x], dim=1) |
|
x = self.conv(x) |
|
|
|
return x |
|
|
|
class polarisnet(nn.Module): |
|
def __init__( |
|
self, |
|
image_size=224, |
|
in_channels=1, |
|
out_channels=1, |
|
embed_dim=64, |
|
depths=[2,2,2,2], |
|
channels=[64,128,256,512], |
|
num_heads = 16, |
|
drop=0., |
|
drop_path=0.1, |
|
act_layer=nn.GELU, |
|
norm_layer=nn.BatchNorm2d, |
|
pos_embed=False |
|
): |
|
|
|
super(polarisnet, self).__init__() |
|
self.num_stages = len(depths) |
|
self.num_features = channels[-1] |
|
self.embed_dim = channels[0] |
|
|
|
self.conv_first = nn.Sequential( |
|
nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), |
|
act_layer(), |
|
nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), |
|
act_layer(), |
|
) |
|
|
|
drop_path = torch.linspace(0.0, drop_path, sum(depths)).tolist() |
|
encoder_stages = [] |
|
|
|
for index in range(self.num_stages): |
|
|
|
encoder_stages.append( |
|
Stage( |
|
image_size=image_size//(pow(2,1+index)), |
|
depth=depths[index], |
|
in_channels=embed_dim if index == 0 else channels[index - 1], |
|
out_channels=channels[index], |
|
num_heads=num_heads, |
|
drop=drop, |
|
drop_path=drop_path[sum(depths[:index]):sum(depths[:index + 1])], |
|
act_layer=act_layer, |
|
norm_layer=norm_layer, |
|
type_name = "encoder", |
|
pos_embed=pos_embed |
|
) |
|
) |
|
|
|
self.encoder_stages = nn.ModuleList(encoder_stages) |
|
|
|
decoder_stages = [] |
|
|
|
for index in range(self.num_stages-1): |
|
|
|
decoder_stages.append( |
|
Stage( |
|
image_size=image_size//(pow(2,self.num_stages-index-1)), |
|
depth=depths[self.num_stages - index - 2], |
|
in_channels=channels[self.num_stages - index - 1], |
|
out_channels=channels[self.num_stages - index - 2], |
|
num_heads=num_heads, |
|
drop=drop, |
|
drop_path=drop_path[sum(depths[:(self.num_stages-2-index)]):sum(depths[:(self.num_stages-2-index) + 1])], |
|
act_layer=act_layer, |
|
norm_layer=norm_layer, |
|
type_name = "decoder", |
|
pos_embed=pos_embed |
|
) |
|
) |
|
|
|
self.decoder_stages = nn.ModuleList(decoder_stages) |
|
|
|
self.norm = norm_layer(self.num_features) |
|
self.norm_up= norm_layer(self.embed_dim) |
|
|
|
self.up = FinalExpand( |
|
in_channels=channels[0], |
|
embed_dim=embed_dim, |
|
out_channels=embed_dim, |
|
norm_layer=norm_layer, |
|
act_layer=act_layer |
|
) |
|
|
|
self.output = nn.Conv2d(embed_dim, out_channels, kernel_size=3, padding=1) |
|
|
|
def encoder_forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
outs = [] |
|
x = self.conv_first(x) |
|
|
|
for stage in self.encoder_stages: |
|
outs.append(x) |
|
x = stage(x) |
|
|
|
x = self.norm(x) |
|
|
|
return x, outs |
|
|
|
def decoder_forward(self, x: torch.Tensor, x_downsample: list) -> torch.Tensor: |
|
|
|
for inx, stage in enumerate(self.decoder_stages): |
|
x = stage(x, x_downsample[len(x_downsample)-1-inx]) |
|
|
|
x = self.norm_up(x) |
|
|
|
return x |
|
|
|
def up_x4(self, x: torch.Tensor, x_downsample: list): |
|
x = self.up(x_downsample[0],x) |
|
x = self.output(x) |
|
|
|
return x |
|
|
|
def forward(self, x): |
|
x, x_downsample = self.encoder_forward(x) |
|
x = self.decoder_forward(x,x_downsample) |
|
x = self.up_x4(x,x_downsample) |
|
|
|
return x |
|
|
|
if __name__ == '__main__': |
|
net = polarisnet(in_channels=1, embed_dim=64, pos_embed=True).cuda() |
|
|
|
X = torch.randn(5, 1, 224, 224).cuda() |
|
y = net(X) |
|
print(y.shape) |
|
|
|
|