Polaris / polaris /model /polarisnet.py
rr-ss's picture
Upload folder using huggingface_hub
3290550 verified
raw
history blame
16.9 kB
import torch
import torch.nn as nn
from operator import itemgetter
from typing import Type, Callable, Tuple, Optional, Set, List, Union
from timm.models.layers import drop_path, trunc_normal_, Mlp, DropPath
from timm.models.efficientnet_blocks import SqueezeExcite, DepthwiseSeparableConv
def exists(val):
return val is not None
def map_el_ind(arr, ind):
return list(map(itemgetter(ind), arr))
def sort_and_return_indices(arr):
indices = [ind for ind in range(len(arr))]
arr = zip(arr, indices)
arr = sorted(arr)
return map_el_ind(arr, 0), map_el_ind(arr, 1)
def calculate_permutations(num_dimensions, emb_dim):
total_dimensions = num_dimensions + 2
axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]
permutations = []
for axial_dim in axial_dims:
last_two_dims = [axial_dim, emb_dim]
dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
permutation = [*dims_rest, *last_two_dims]
permutations.append(permutation)
return permutations
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class PermuteToFrom(nn.Module):
def __init__(self, permutation, fn):
super().__init__()
self.fn = fn
_, inv_permutation = sort_and_return_indices(permutation)
self.permutation = permutation
self.inv_permutation = inv_permutation
def forward(self, x, **kwargs):
axial = x.permute(*self.permutation).contiguous()
shape = axial.shape
*_, t, d = shape
axial = axial.reshape(-1, t, d)
axial = self.fn(axial, **kwargs)
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
return axial
class AxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape, emb_dim_index = 1):
super().__init__()
parameters = []
total_dimensions = len(shape) + 2
ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
self.num_axials = len(shape)
for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
shape = [1] * total_dimensions
shape[emb_dim_index] = dim
shape[axial_dim_index] = axial_dim
parameter = nn.Parameter(torch.randn(*shape))
setattr(self, f'param_{i}', parameter)
def forward(self, x):
for i in range(self.num_axials):
x = x + getattr(self, f'param_{i}')
return x
class SelfAttention(nn.Module):
def __init__(self, dim, heads, dim_heads=None, drop=0):
super().__init__()
self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
dim_hidden = self.dim_heads * heads
self.drop_rate = drop
self.heads = heads
self.to_q = nn.Linear(dim, dim_hidden, bias = False)
self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim)
self.proj_drop = DropPath(drop)
def forward(self, x, kv = None):
kv = x if kv is None else kv
q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))
b, t, d, h, e = *q.shape, self.heads, self.dim_heads
merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)
q, k, v = map(merge_heads, (q, k, v))
dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)
dots = dots.softmax(dim=-1)
out = torch.einsum('bij,bje->bie', dots, v)
out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)
out = self.to_out(out)
out = self.proj_drop(out)
return out
class AxialTransformerBlock(nn.Module):
def __init__(self,
dim,
axial_pos_emb_shape,
pos_embed,
heads = 8,
dim_heads = None,
drop = 0.,
drop_path_rate=0.,
):
super().__init__()
dim_index = 1
permutations = calculate_permutations(2, dim_index)
self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if pos_embed else nn.Identity()
self.height_attn, self.width_attn = nn.ModuleList([PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads, drop=drop))) for permutation in permutations])
self.FFN = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, dim * 4, 3, padding = 1),
nn.GELU(),
DropPath(drop),
nn.Conv2d(dim * 4, dim, 3, padding = 1),
DropPath(drop),
ChanLayerNorm(dim),
nn.Conv2d(dim, dim * 4, 3, padding = 1),
nn.GELU(),
DropPath(drop),
nn.Conv2d(dim * 4, dim, 3, padding = 1),
DropPath(drop),
)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x):
x = self.pos_emb(x)
x = x + self.drop_path(self.height_attn(x))
x = x + self.drop_path(self.width_attn(x))
x = x + self.drop_path(self.FFN(x))
return x
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def _gelu_ignore_parameters(*args, **kwargs) -> nn.Module:
activation = nn.GELU()
return activation
class DoubleConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
downscale: bool = False,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
drop_path: float = 0.,
) -> None:
super(DoubleConv, self).__init__()
self.drop_path_rate: float = drop_path
if act_layer == nn.GELU:
act_layer = _gelu_ignore_parameters
self.main_path = nn.Sequential(
norm_layer(in_channels),
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1)),
DepthwiseSeparableConv(in_chs=in_channels, out_chs=out_channels, stride=2 if downscale else 1,
act_layer=act_layer, norm_layer=norm_layer, drop_path_rate=drop_path),
SqueezeExcite(in_chs=out_channels, rd_ratio=0.25),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, 1))
)
if downscale:
self.skip_path = nn.Sequential(
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1))
)
else:
self.skip_path = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.main_path(x)
if self.drop_path_rate > 0.:
output = drop_path(output, self.drop_path_rate, self.training)
x = output + self.skip_path(x)
return x
class DeconvModule(nn.Module):
def __init__(self,
in_channels,
out_channels,
norm_layer=nn.BatchNorm2d,
act_layer=nn.Mish,
kernel_size=4,
scale_factor=2):
super(DeconvModule, self).__init__()
assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\
f'kernel_size should be greater than or equal to scale_factor '\
f'and (kernel_size - scale_factor) should be even numbers, '\
f'while the kernel size is {kernel_size} and scale_factor is '\
f'{scale_factor}.'
stride = scale_factor
padding = (kernel_size - scale_factor) // 2
deconv = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
norm = norm_layer(out_channels)
activate = act_layer()
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
def forward(self, x):
out = self.deconv_upsamping(x)
return out
class Stage(nn.Module):
def __init__(self,
image_size: int,
depth: int,
in_channels: int,
out_channels: int,
type_name: str,
pos_embed: bool,
num_heads: int = 32,
drop: float = 0.,
drop_path: Union[List[float], float] = 0.,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
):
super().__init__()
self.type_name = type_name
if self.type_name == "encoder":
self.conv = DoubleConv(
in_channels=in_channels,
out_channels=out_channels,
downscale=True,
act_layer=act_layer,
norm_layer=norm_layer,
drop_path=drop_path[0],
)
self.blocks = nn.Sequential(*[
AxialTransformerBlock(
dim=out_channels,
axial_pos_emb_shape=pair(image_size),
heads = num_heads,
drop = drop,
drop_path_rate=drop_path[index],
dim_heads = None,
pos_embed=pos_embed
)
for index in range(depth)
])
elif self.type_name == "decoder":
self.upsample = DeconvModule(
in_channels=in_channels,
out_channels=out_channels,
norm_layer=norm_layer,
act_layer=act_layer
)
self.conv = DoubleConv(
in_channels=in_channels,
out_channels=out_channels,
downscale=False,
act_layer=act_layer,
norm_layer=norm_layer,
drop_path=drop_path[0],
)
self.blocks = nn.Sequential(*[
AxialTransformerBlock(
dim=out_channels,
axial_pos_emb_shape=pair(image_size),
heads = num_heads,
drop = drop,
drop_path_rate=drop_path[index],
dim_heads = None,
pos_embed=pos_embed
)
for index in range(depth)
])
def forward(self, x, skip=None):
if self.type_name == "encoder":
x = self.conv(x)
x = self.blocks(x)
elif self.type_name == "decoder":
x = self.upsample(x)
x = torch.cat([skip, x], dim=1)
x = self.conv(x)
x = self.blocks(x)
return x
class FinalExpand(nn.Module):
def __init__(
self,
in_channels,
embed_dim,
out_channels,
norm_layer,
act_layer,
):
super().__init__()
self.upsample = DeconvModule(
in_channels=in_channels,
out_channels=embed_dim,
norm_layer=norm_layer,
act_layer=act_layer
)
self.conv = nn.Sequential(
nn.Conv2d(in_channels=embed_dim*2, out_channels=embed_dim, kernel_size=3, stride=1, padding=1),
act_layer(),
nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=3, stride=1, padding=1),
act_layer(),
)
def forward(self, skip, x):
x = self.upsample(x)
x = torch.cat([skip, x], dim=1)
x = self.conv(x)
return x
class polarisnet(nn.Module):
def __init__(
self,
image_size=224,
in_channels=1,
out_channels=1,
embed_dim=64,
depths=[2,2,2,2],
channels=[64,128,256,512],
num_heads = 16,
drop=0.,
drop_path=0.1,
act_layer=nn.GELU,
norm_layer=nn.BatchNorm2d,
pos_embed=False
):
super(polarisnet, self).__init__()
self.num_stages = len(depths)
self.num_features = channels[-1]
self.embed_dim = channels[0]
self.conv_first = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
act_layer(),
nn.Conv2d(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
act_layer(),
)
drop_path = torch.linspace(0.0, drop_path, sum(depths)).tolist()
encoder_stages = []
for index in range(self.num_stages):
encoder_stages.append(
Stage(
image_size=image_size//(pow(2,1+index)),
depth=depths[index],
in_channels=embed_dim if index == 0 else channels[index - 1],
out_channels=channels[index],
num_heads=num_heads,
drop=drop,
drop_path=drop_path[sum(depths[:index]):sum(depths[:index + 1])],
act_layer=act_layer,
norm_layer=norm_layer,
type_name = "encoder",
pos_embed=pos_embed
)
)
self.encoder_stages = nn.ModuleList(encoder_stages)
decoder_stages = []
for index in range(self.num_stages-1):
decoder_stages.append(
Stage(
image_size=image_size//(pow(2,self.num_stages-index-1)),
depth=depths[self.num_stages - index - 2],
in_channels=channels[self.num_stages - index - 1],
out_channels=channels[self.num_stages - index - 2],
num_heads=num_heads,
drop=drop,
drop_path=drop_path[sum(depths[:(self.num_stages-2-index)]):sum(depths[:(self.num_stages-2-index) + 1])],
act_layer=act_layer,
norm_layer=norm_layer,
type_name = "decoder",
pos_embed=pos_embed
)
)
self.decoder_stages = nn.ModuleList(decoder_stages)
self.norm = norm_layer(self.num_features)
self.norm_up= norm_layer(self.embed_dim)
self.up = FinalExpand(
in_channels=channels[0],
embed_dim=embed_dim,
out_channels=embed_dim,
norm_layer=norm_layer,
act_layer=act_layer
)
self.output = nn.Conv2d(embed_dim, out_channels, kernel_size=3, padding=1)
def encoder_forward(self, x: torch.Tensor) -> torch.Tensor:
outs = []
x = self.conv_first(x)
for stage in self.encoder_stages:
outs.append(x)
x = stage(x)
x = self.norm(x)
return x, outs
def decoder_forward(self, x: torch.Tensor, x_downsample: list) -> torch.Tensor:
for inx, stage in enumerate(self.decoder_stages):
x = stage(x, x_downsample[len(x_downsample)-1-inx])
x = self.norm_up(x)
return x
def up_x4(self, x: torch.Tensor, x_downsample: list):
x = self.up(x_downsample[0],x)
x = self.output(x)
return x
def forward(self, x):
x, x_downsample = self.encoder_forward(x)
x = self.decoder_forward(x,x_downsample)
x = self.up_x4(x,x_downsample)
return x
if __name__ == '__main__':
net = polarisnet(in_channels=1, embed_dim=64, pos_embed=True).cuda()
X = torch.randn(5, 1, 224, 224).cuda()
y = net(X)
print(y.shape)