Polaris / polaris /model /polarisnet.py
rr-ss's picture
Upload folder using huggingface_hub
3290550 verified
import torch
import torch.nn as nn
from operator import itemgetter
from typing import Type, Callable, Tuple, Optional, Set, List, Union
from timm.models.layers import drop_path, trunc_normal_, Mlp, DropPath
from timm.models.efficientnet_blocks import SqueezeExcite, DepthwiseSeparableConv
def exists(val):
return val is not None
def map_el_ind(arr, ind):
return list(map(itemgetter(ind), arr))
def sort_and_return_indices(arr):
indices = [ind for ind in range(len(arr))]
arr = zip(arr, indices)
arr = sorted(arr)
return map_el_ind(arr, 0), map_el_ind(arr, 1)
def calculate_permutations(num_dimensions, emb_dim):
total_dimensions = num_dimensions + 2
axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]
permutations = []
for axial_dim in axial_dims:
last_two_dims = [axial_dim, emb_dim]
dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
permutation = [*dims_rest, *last_two_dims]
permutations.append(permutation)
return permutations
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class PermuteToFrom(nn.Module):
def __init__(self, permutation, fn):
super().__init__()
self.fn = fn
_, inv_permutation = sort_and_return_indices(permutation)
self.permutation = permutation
self.inv_permutation = inv_permutation
def forward(self, x, **kwargs):
axial = x.permute(*self.permutation).contiguous()
shape = axial.shape
*_, t, d = shape
axial = axial.reshape(-1, t, d)
axial = self.fn(axial, **kwargs)
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
return axial
class AxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape, emb_dim_index = 1):
super().__init__()
parameters = []
total_dimensions = len(shape) + 2
ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
self.num_axials = len(shape)
for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
shape = [1] * total_dimensions
shape[emb_dim_index] = dim
shape[axial_dim_index] = axial_dim
parameter = nn.Parameter(torch.randn(*shape))
setattr(self, f'param_{i}', parameter)
def forward(self, x):
for i in range(self.num_axials):
x = x + getattr(self, f'param_{i}')
return x
class SelfAttention(nn.Module):
def __init__(self, dim, heads, dim_heads=None, drop=0):
super().__init__()
self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
dim_hidden = self.dim_heads * heads
self.drop_rate = drop
self.heads = heads
self.to_q = nn.Linear(dim, dim_hidden, bias = False)
self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim)
self.proj_drop = DropPath(drop)
def forward(self, x, kv = None):
kv = x if kv is None else kv
q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))
b, t, d, h, e = *q.shape, self.heads, self.dim_heads
merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
q, k, v = map(merge_heads, (q, k, v))
dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)
dots = dots.softmax(dim=-1)
out = torch.einsum('bij,bje->bie', dots, v)
out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
out = self.to_out(out)
out = self.proj_drop(out)
return out
class AxialTransformerBlock(nn.Module):
def __init__(self,
dim,
axial_pos_emb_shape,
pos_embed,
heads = 8,
dim_heads = None,
drop = 0.,
drop_path_rate=0.,
):
super().__init__()
dim_index = 1
permutations = calculate_permutations(2, dim_index)
self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if pos_embed else nn.Identity()
self.height_attn, self.width_attn = nn.ModuleList([PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads, drop=drop))) for permutation in permutations])
self.FFN = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, dim * 4, 3, padding = 1),
nn.GELU(),
DropPath(drop),
nn.Conv2d(dim * 4, dim, 3, padding = 1),
DropPath(drop),
ChanLayerNorm(dim),
nn.Conv2d(dim, dim * 4, 3, padding = 1),
nn.GELU(),
DropPath(drop),
nn.Conv2d(dim * 4, dim, 3, padding = 1),
DropPath(drop),
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
x = self.pos_emb(x)
x = x + self.drop_path(self.height_attn(x))
x = x + self.drop_path(self.width_attn(x))
x = x + self.drop_path(self.FFN(x))
return x
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def _gelu_ignore_parameters(*args, **kwargs) -> nn.Module:
activation = nn.GELU()
return activation
class DoubleConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
downscale: bool = False,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
drop_path: float = 0.,
) -> None:
super(DoubleConv, self).__init__()
self.drop_path_rate: float = drop_path
if act_layer == nn.GELU:
act_layer = _gelu_ignore_parameters
self.main_path = nn.Sequential(
norm_layer(in_channels),
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1)),
DepthwiseSeparableConv(in_chs=in_channels, out_chs=out_channels, stride=2 if downscale else 1,
act_layer=act_layer, norm_layer=norm_layer, drop_path_rate=drop_path),
SqueezeExcite(in_chs=out_channels, rd_ratio=0.25),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, 1))
)
if downscale:
self.skip_path = nn.Sequential(
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1))
)
else:
self.skip_path = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.main_path(x)
if self.drop_path_rate > 0.:
output = drop_path(output, self.drop_path_rate, self.training)
x = output + self.skip_path(x)
return x
class DeconvModule(nn.Module):
def __init__(self,
in_channels,
out_channels,
norm_layer=nn.BatchNorm2d,
act_layer=nn.Mish,
kernel_size=4,
scale_factor=2):
super(DeconvModule, self).__init__()
assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\
f'kernel_size should be greater than or equal to scale_factor '\
f'and (kernel_size - scale_factor) should be even numbers, '\
f'while the kernel size is {kernel_size} and scale_factor is '\
f'{scale_factor}.'
stride = scale_factor
padding = (kernel_size - scale_factor) // 2
deconv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
norm = norm_layer(out_channels)
activate = act_layer()
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
def forward(self, x):
out = self.deconv_upsamping(x)
return out
class Stage(nn.Module):
def __init__(self,
image_size: int,
depth: int,
in_channels: int,
out_channels: int,
type_name: str,
pos_embed: bool,
num_heads: int = 32,
drop: float = 0.,
drop_path: Union[List[float], float] = 0.,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
):
super().__init__()
self.type_name = type_name
if self.type_name == "encoder":
self.conv = DoubleConv(
in_channels=in_channels,
out_channels=out_channels,
downscale=True,
act_layer=act_layer,
norm_layer=norm_layer,
drop_path=drop_path[0],
)
self.blocks = nn.Sequential(*[
AxialTransformerBlock(
dim=out_channels,
axial_pos_emb_shape=pair(image_size),
heads = num_heads,
drop = drop,
drop_path_rate=drop_path[index],
dim_heads = None,
pos_embed=pos_embed
)
for index in range(depth)
])
elif self.type_name == "decoder":
self.upsample = DeconvModule(
in_channels=in_channels,
out_channels=out_channels,
norm_layer=norm_layer,
act_layer=act_layer
)
self.conv = DoubleConv(
in_channels=in_channels,
out_channels=out_channels,
downscale=False,
act_layer=act_layer,
norm_layer=norm_layer,
drop_path=drop_path[0],
)
self.blocks = nn.Sequential(*[
AxialTransformerBlock(
dim=out_channels,
axial_pos_emb_shape=pair(image_size),
heads = num_heads,
drop = drop,
drop_path_rate=drop_path[index],
dim_heads = None,
pos_embed=pos_embed
)
for index in range(depth)
])
def forward(self, x, skip=None):
if self.type_name == "encoder":
x = self.conv(x)
x = self.blocks(x)
elif self.type_name == "decoder":
x = self.upsample(x)
x = torch.cat([skip, x], dim=1)
x = self.conv(x)
x = self.blocks(x)
return x
class FinalExpand(nn.Module):
def __init__(
self,
in_channels,
embed_dim,
out_channels,
norm_layer,
act_layer,
):
super().__init__()
self.upsample = DeconvModule(
in_channels=in_channels,
out_channels=embed_dim,
norm_layer=norm_layer,
act_layer=act_layer
)
self.conv = nn.Sequential(
nn.Conv2d(in_channels=embed_dim*2, out_channels=embed_dim, kernel_size=3, stride=1, padding=1),
act_layer(),
nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=3, stride=1, padding=1),
act_layer(),
)
def forward(self, skip, x):
x = self.upsample(x)
x = torch.cat([skip, x], dim=1)
x = self.conv(x)
return x
class polarisnet(nn.Module):
def __init__(
self,
image_size=224,
in_channels=1,
out_channels=1,
embed_dim=64,
depths=[2,2,2,2],
channels=[64,128,256,512],
num_heads = 16,
drop=0.,
drop_path=0.1,
act_layer=nn.GELU,
norm_layer=nn.BatchNorm2d,
pos_embed=False
):
super(polarisnet, self).__init__()
self.num_stages = len(depths)
self.num_features = channels[-1]
self.embed_dim = channels[0]
self.conv_first = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
act_layer(),
nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
act_layer(),
)
drop_path = torch.linspace(0.0, drop_path, sum(depths)).tolist()
encoder_stages = []
for index in range(self.num_stages):
encoder_stages.append(
Stage(
image_size=image_size//(pow(2,1+index)),
depth=depths[index],
in_channels=embed_dim if index == 0 else channels[index - 1],
out_channels=channels[index],
num_heads=num_heads,
drop=drop,
drop_path=drop_path[sum(depths[:index]):sum(depths[:index + 1])],
act_layer=act_layer,
norm_layer=norm_layer,
type_name = "encoder",
pos_embed=pos_embed
)
)
self.encoder_stages = nn.ModuleList(encoder_stages)
decoder_stages = []
for index in range(self.num_stages-1):
decoder_stages.append(
Stage(
image_size=image_size//(pow(2,self.num_stages-index-1)),
depth=depths[self.num_stages - index - 2],
in_channels=channels[self.num_stages - index - 1],
out_channels=channels[self.num_stages - index - 2],
num_heads=num_heads,
drop=drop,
drop_path=drop_path[sum(depths[:(self.num_stages-2-index)]):sum(depths[:(self.num_stages-2-index) + 1])],
act_layer=act_layer,
norm_layer=norm_layer,
type_name = "decoder",
pos_embed=pos_embed
)
)
self.decoder_stages = nn.ModuleList(decoder_stages)
self.norm = norm_layer(self.num_features)
self.norm_up= norm_layer(self.embed_dim)
self.up = FinalExpand(
in_channels=channels[0],
embed_dim=embed_dim,
out_channels=embed_dim,
norm_layer=norm_layer,
act_layer=act_layer
)
self.output = nn.Conv2d(embed_dim, out_channels, kernel_size=3, padding=1)
def encoder_forward(self, x: torch.Tensor) -> torch.Tensor:
outs = []
x = self.conv_first(x)
for stage in self.encoder_stages:
outs.append(x)
x = stage(x)
x = self.norm(x)
return x, outs
def decoder_forward(self, x: torch.Tensor, x_downsample: list) -> torch.Tensor:
for inx, stage in enumerate(self.decoder_stages):
x = stage(x, x_downsample[len(x_downsample)-1-inx])
x = self.norm_up(x)
return x
def up_x4(self, x: torch.Tensor, x_downsample: list):
x = self.up(x_downsample[0],x)
x = self.output(x)
return x
def forward(self, x):
x, x_downsample = self.encoder_forward(x)
x = self.decoder_forward(x,x_downsample)
x = self.up_x4(x,x_downsample)
return x
if __name__ == '__main__':
net = polarisnet(in_channels=1, embed_dim=64, pos_embed=True).cuda()
X = torch.randn(5, 1, 224, 224).cuda()
y = net(X)
print(y.shape)