from typing import Optional, Sequence, Tuple, Union, Type import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from monai.networks.blocks.dynunet_block import get_padding, get_output_padding from monai.networks.layers import Pool, Conv from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_dropout_layer from monai.utils.misc import ensure_tuple_rep from medical_diffusion.models.utils.attention_blocks import Attention, zero_module def save_add(*args): args = [arg for arg in args if arg is not None] return sum(args) if len(args)>0 else None class SequentialEmb(nn.Sequential): def forward(self, input, emb): for module in self: input = module(input, emb) return input class BasicDown(nn.Module): def __init__( self, spatial_dims, in_channels, out_channels, kernel_size=3, stride=2, learnable_interpolation=True, use_res=False ) -> None: super().__init__() if learnable_interpolation: Convolution = Conv[Conv.CONV, spatial_dims] self.down_op = Convolution( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=get_padding(kernel_size, stride), dilation=1, groups=1, bias=True, ) if use_res: self.down_skip = nn.PixelUnshuffle(2) # WARNING: Only supports 2D, , out_channels == 4*in_channels else: Pooling = Pool['avg', spatial_dims] self.down_op = Pooling( kernel_size=kernel_size, stride=stride, padding=get_padding(kernel_size, stride) ) def forward(self, x, emb=None): y = self.down_op(x) if hasattr(self, 'down_skip'): y = y+self.down_skip(x) return y class BasicUp(nn.Module): def __init__( self, spatial_dims, in_channels, out_channels, kernel_size=2, stride=2, learnable_interpolation=True, use_res=False, ) -> None: super().__init__() self.learnable_interpolation = learnable_interpolation if learnable_interpolation: # TransConvolution = Conv[Conv.CONVTRANS, spatial_dims] # padding = get_padding(kernel_size, stride) # output_padding = get_output_padding(kernel_size, stride, padding) # self.up_op = TransConvolution( # in_channels, # out_channels, # kernel_size=kernel_size, # stride=stride, # padding=padding, # output_padding=output_padding, # groups=1, # bias=True, # dilation=1 # ) self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size) -2*np.atleast_1d(get_padding(kernel_size, stride))) Convolution = Conv[Conv.CONV, spatial_dims] self.up_op = Convolution( in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, ) if use_res: self.up_skip = nn.PixelShuffle(2) # WARNING: Only supports 2D, out_channels == in_channels/4 else: self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size) -2*np.atleast_1d(get_padding(kernel_size, stride))) def forward(self, x, emb=None): if self.learnable_interpolation: new_size = self.calc_shape(x.shape[2:]) x_res = F.interpolate(x, size=new_size, mode='nearest-exact') y = self.up_op(x_res) if hasattr(self, 'up_skip'): y = y+self.up_skip(x) return y else: new_size = self.calc_shape(x.shape[2:]) return F.interpolate(x, size=new_size, mode='nearest-exact') class BasicBlock(nn.Module): """ A block that consists of Conv-Norm-Drop-Act, similar to blocks.Convolution. Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. act_name: activation layer type and arguments. dropout: dropout probability. zero_conv: zero out the parameters of the convolution. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int]=1, norm_name: Union[Tuple, str, None]=None, act_name: Union[Tuple, str, None] = None, dropout: Optional[Union[Tuple, str, float]] = None, zero_conv: bool = False, ): super().__init__() Convolution = Conv[Conv.CONV, spatial_dims] conv = Convolution( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=get_padding(kernel_size, stride), dilation=1, groups=1, bias=True, ) self.conv = zero_module(conv) if zero_conv else conv if norm_name is not None: self.norm = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) if dropout is not None: self.drop = get_dropout_layer(name=dropout, dropout_dim=spatial_dims) if act_name is not None: self.act = get_act_layer(name=act_name) def forward(self, inp): out = self.conv(inp) if hasattr(self, "norm"): out = self.norm(out) if hasattr(self, 'drop'): out = self.drop(out) if hasattr(self, "act"): out = self.act(out) return out class BasicResBlock(nn.Module): """ A block that consists of Conv-Act-Norm + skip. Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. act_name: activation layer type and arguments. dropout: dropout probability. zero_conv: zero out the parameters of the convolution. """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int]=1, norm_name: Union[Tuple, str, None]=None, act_name: Union[Tuple, str, None] = None, dropout: Optional[Union[Tuple, str, float]] = None, zero_conv: bool = False ): super().__init__() self.basic_block = BasicBlock(spatial_dims, in_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, zero_conv) Convolution = Conv[Conv.CONV, spatial_dims] self.conv_res = Convolution( in_channels, out_channels, kernel_size=1, stride=stride, padding=get_padding(1, stride), dilation=1, groups=1, bias=True, ) if in_channels != out_channels else nn.Identity() def forward(self, inp): out = self.basic_block(inp) residual = self.conv_res(inp) out = out+residual return out class UnetBasicBlock(nn.Module): """ A modified version of monai.networks.blocks.UnetBasicBlock with additional embedding Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. act_name: activation layer type and arguments. dropout: dropout probability. emb_channels: Number of embedding channels """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int]=1, norm_name: Union[Tuple, str]=None, act_name: Union[Tuple, str]=None, dropout: Optional[Union[Tuple, str, float]] = None, emb_channels: int = None, blocks = 2 ): super().__init__() self.block_seq = nn.ModuleList([ BasicBlock(spatial_dims, in_channels if i==0 else out_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, i==blocks-1) for i in range(blocks) ]) if emb_channels is not None: self.local_embedder = nn.Sequential( get_act_layer(name=act_name), nn.Linear(emb_channels, out_channels), ) def forward(self, x, emb=None): # ------------ Embedding ---------- if emb is not None: emb = self.local_embedder(emb) b,c, *_ = emb.shape sp_dim = x.ndim-2 emb = emb.reshape(b, c, *((1,)*sp_dim) ) # scale, shift = emb.chunk(2, dim = 1) # x = x * (scale + 1) + shift # x = x+emb # ----------- Convolution --------- n_blocks = len(self.block_seq) for i, block in enumerate(self.block_seq): x = block(x) if (emb is not None) and i