mueller-franzes's picture
init
f85e212
from typing import Optional, Sequence, Tuple, Union, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from monai.networks.blocks.dynunet_block import get_padding, get_output_padding
from monai.networks.layers import Pool, Conv
from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_dropout_layer
from monai.utils.misc import ensure_tuple_rep
from medical_diffusion.models.utils.attention_blocks import Attention, zero_module
def save_add(*args):
args = [arg for arg in args if arg is not None]
return sum(args) if len(args)>0 else None
class SequentialEmb(nn.Sequential):
def forward(self, input, emb):
for module in self:
input = module(input, emb)
return input
class BasicDown(nn.Module):
def __init__(
self,
spatial_dims,
in_channels,
out_channels,
kernel_size=3,
stride=2,
learnable_interpolation=True,
use_res=False
) -> None:
super().__init__()
if learnable_interpolation:
Convolution = Conv[Conv.CONV, spatial_dims]
self.down_op = Convolution(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=get_padding(kernel_size, stride),
dilation=1,
groups=1,
bias=True,
)
if use_res:
self.down_skip = nn.PixelUnshuffle(2) # WARNING: Only supports 2D, , out_channels == 4*in_channels
else:
Pooling = Pool['avg', spatial_dims]
self.down_op = Pooling(
kernel_size=kernel_size,
stride=stride,
padding=get_padding(kernel_size, stride)
)
def forward(self, x, emb=None):
y = self.down_op(x)
if hasattr(self, 'down_skip'):
y = y+self.down_skip(x)
return y
class BasicUp(nn.Module):
def __init__(
self,
spatial_dims,
in_channels,
out_channels,
kernel_size=2,
stride=2,
learnable_interpolation=True,
use_res=False,
) -> None:
super().__init__()
self.learnable_interpolation = learnable_interpolation
if learnable_interpolation:
# TransConvolution = Conv[Conv.CONVTRANS, spatial_dims]
# padding = get_padding(kernel_size, stride)
# output_padding = get_output_padding(kernel_size, stride, padding)
# self.up_op = TransConvolution(
# in_channels,
# out_channels,
# kernel_size=kernel_size,
# stride=stride,
# padding=padding,
# output_padding=output_padding,
# groups=1,
# bias=True,
# dilation=1
# )
self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size)
-2*np.atleast_1d(get_padding(kernel_size, stride)))
Convolution = Conv[Conv.CONV, spatial_dims]
self.up_op = Convolution(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
bias=True,
)
if use_res:
self.up_skip = nn.PixelShuffle(2) # WARNING: Only supports 2D, out_channels == in_channels/4
else:
self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size)
-2*np.atleast_1d(get_padding(kernel_size, stride)))
def forward(self, x, emb=None):
if self.learnable_interpolation:
new_size = self.calc_shape(x.shape[2:])
x_res = F.interpolate(x, size=new_size, mode='nearest-exact')
y = self.up_op(x_res)
if hasattr(self, 'up_skip'):
y = y+self.up_skip(x)
return y
else:
new_size = self.calc_shape(x.shape[2:])
return F.interpolate(x, size=new_size, mode='nearest-exact')
class BasicBlock(nn.Module):
"""
A block that consists of Conv-Norm-Drop-Act, similar to blocks.Convolution.
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
act_name: activation layer type and arguments.
dropout: dropout probability.
zero_conv: zero out the parameters of the convolution.
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int]=1,
norm_name: Union[Tuple, str, None]=None,
act_name: Union[Tuple, str, None] = None,
dropout: Optional[Union[Tuple, str, float]] = None,
zero_conv: bool = False,
):
super().__init__()
Convolution = Conv[Conv.CONV, spatial_dims]
conv = Convolution(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=get_padding(kernel_size, stride),
dilation=1,
groups=1,
bias=True,
)
self.conv = zero_module(conv) if zero_conv else conv
if norm_name is not None:
self.norm = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)
if dropout is not None:
self.drop = get_dropout_layer(name=dropout, dropout_dim=spatial_dims)
if act_name is not None:
self.act = get_act_layer(name=act_name)
def forward(self, inp):
out = self.conv(inp)
if hasattr(self, "norm"):
out = self.norm(out)
if hasattr(self, 'drop'):
out = self.drop(out)
if hasattr(self, "act"):
out = self.act(out)
return out
class BasicResBlock(nn.Module):
"""
A block that consists of Conv-Act-Norm + skip.
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
act_name: activation layer type and arguments.
dropout: dropout probability.
zero_conv: zero out the parameters of the convolution.
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int]=1,
norm_name: Union[Tuple, str, None]=None,
act_name: Union[Tuple, str, None] = None,
dropout: Optional[Union[Tuple, str, float]] = None,
zero_conv: bool = False
):
super().__init__()
self.basic_block = BasicBlock(spatial_dims, in_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, zero_conv)
Convolution = Conv[Conv.CONV, spatial_dims]
self.conv_res = Convolution(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
padding=get_padding(1, stride),
dilation=1,
groups=1,
bias=True,
) if in_channels != out_channels else nn.Identity()
def forward(self, inp):
out = self.basic_block(inp)
residual = self.conv_res(inp)
out = out+residual
return out
class UnetBasicBlock(nn.Module):
"""
A modified version of monai.networks.blocks.UnetBasicBlock with additional embedding
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
act_name: activation layer type and arguments.
dropout: dropout probability.
emb_channels: Number of embedding channels
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int]=1,
norm_name: Union[Tuple, str]=None,
act_name: Union[Tuple, str]=None,
dropout: Optional[Union[Tuple, str, float]] = None,
emb_channels: int = None,
blocks = 2
):
super().__init__()
self.block_seq = nn.ModuleList([
BasicBlock(spatial_dims, in_channels if i==0 else out_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, i==blocks-1)
for i in range(blocks)
])
if emb_channels is not None:
self.local_embedder = nn.Sequential(
get_act_layer(name=act_name),
nn.Linear(emb_channels, out_channels),
)
def forward(self, x, emb=None):
# ------------ Embedding ----------
if emb is not None:
emb = self.local_embedder(emb)
b,c, *_ = emb.shape
sp_dim = x.ndim-2
emb = emb.reshape(b, c, *((1,)*sp_dim) )
# scale, shift = emb.chunk(2, dim = 1)
# x = x * (scale + 1) + shift
# x = x+emb
# ----------- Convolution ---------
n_blocks = len(self.block_seq)
for i, block in enumerate(self.block_seq):
x = block(x)
if (emb is not None) and i<n_blocks:
x += emb
return x
class UnetResBlock(nn.Module):
"""
A modified version of monai.networks.blocks.UnetResBlock with additional skip connection and embedding
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
act_name: activation layer type and arguments.
dropout: dropout probability.
emb_channels: Number of embedding channels
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int]=1,
norm_name: Union[Tuple, str]=None,
act_name: Union[Tuple, str]=None,
dropout: Optional[Union[Tuple, str, float]] = None,
emb_channels: int = None,
blocks = 2
):
super().__init__()
self.block_seq = nn.ModuleList([
BasicResBlock(spatial_dims, in_channels if i==0 else out_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, i==blocks-1)
for i in range(blocks)
])
if emb_channels is not None:
self.local_embedder = nn.Sequential(
get_act_layer(name=act_name),
nn.Linear(emb_channels, out_channels),
)
def forward(self, x, emb=None):
# ------------ Embedding ----------
if emb is not None:
emb = self.local_embedder(emb)
b,c, *_ = emb.shape
sp_dim = x.ndim-2
emb = emb.reshape(b, c, *((1,)*sp_dim) )
# scale, shift = emb.chunk(2, dim = 1)
# x = x * (scale + 1) + shift
# x = x+emb
# ----------- Convolution ---------
n_blocks = len(self.block_seq)
for i, block in enumerate(self.block_seq):
x = block(x)
if (emb is not None) and i<n_blocks-1:
x += emb
return x
class DownBlock(nn.Module):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
downsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
act_name: Union[Tuple, str],
dropout: Optional[Union[Tuple, str, float]] = None,
use_res_block: bool = False,
learnable_interpolation: bool = True,
use_attention: str = 'none',
emb_channels: int = None
):
super(DownBlock, self).__init__()
enable_down = ensure_tuple_rep(stride, spatial_dims) != ensure_tuple_rep(1, spatial_dims)
down_out_channels = out_channels if learnable_interpolation and enable_down else in_channels
# -------------- Down ----------------------
self.down_op = BasicDown(
spatial_dims,
in_channels,
out_channels,
kernel_size=downsample_kernel_size,
stride=stride,
learnable_interpolation=learnable_interpolation,
use_res=False
) if enable_down else nn.Identity()
# ---------------- Attention -------------
self.attention = Attention(
spatial_dims=spatial_dims,
in_channels=down_out_channels,
out_channels=down_out_channels,
num_heads=8,
ch_per_head=down_out_channels//8,
depth=1,
norm_name=norm_name,
dropout=dropout,
emb_dim=emb_channels,
attention_type=use_attention
)
# -------------- Convolution ----------------------
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
self.conv_block = ConvBlock(
spatial_dims,
down_out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
dropout=dropout,
norm_name=norm_name,
act_name=act_name,
emb_channels=emb_channels
)
def forward(self, x, emb=None):
# ----------- Down ---------
x = self.down_op(x)
# ----------- Attention -------------
if self.attention is not None:
x = self.attention(x, emb)
# ------------- Convolution --------------
x = self.conv_block(x, emb)
return x
class UpBlock(nn.Module):
def __init__(
self,
spatial_dims,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
act_name: Union[Tuple, str],
dropout: Optional[Union[Tuple, str, float]] = None,
use_res_block: bool = False,
learnable_interpolation: bool = True,
use_attention: str = 'none',
emb_channels: int = None,
skip_channels: int = 0
):
super(UpBlock, self).__init__()
enable_up = ensure_tuple_rep(stride, spatial_dims) != ensure_tuple_rep(1, spatial_dims)
skip_out_channels = out_channels if learnable_interpolation and enable_up else in_channels+skip_channels
self.learnable_interpolation = learnable_interpolation
# -------------- Up ----------------------
self.up_op = BasicUp(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=upsample_kernel_size,
stride=stride,
learnable_interpolation=learnable_interpolation,
use_res=False
) if enable_up else nn.Identity()
# ---------------- Attention -------------
self.attention = Attention(
spatial_dims=spatial_dims,
in_channels=skip_out_channels,
out_channels=skip_out_channels,
num_heads=8,
ch_per_head=skip_out_channels//8,
depth=1,
norm_name=norm_name,
dropout=dropout,
emb_dim=emb_channels,
attention_type=use_attention
)
# -------------- Convolution ----------------------
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
self.conv_block = ConvBlock(
spatial_dims,
skip_out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
dropout=dropout,
norm_name=norm_name,
act_name=act_name,
emb_channels=emb_channels
)
def forward(self, x_enc, x_skip=None, emb=None):
# ----------- Up -------------
x = self.up_op(x_enc)
# ----------- Skip Connection ------------
if x_skip is not None:
if self.learnable_interpolation: # Channel of x_enc and x_skip are equal and summation is possible
x = x+x_skip
else:
x = torch.cat((x, x_skip), dim=1)
# ----------- Attention -------------
if self.attention is not None:
x = self.attention(x, emb)
# ----------- Convolution ------------
x = self.conv_block(x, emb)
return x