mueller-franzes's picture
init
f85e212
raw
history blame
6.66 kB
import torch
import torch.nn as nn
from monai.networks.blocks import UnetOutBlock
from medical_diffusion.models.utils.conv_blocks import BasicBlock, UpBlock, DownBlock, UnetBasicBlock, UnetResBlock, save_add
from medical_diffusion.models.embedders import TimeEmbbeding
from medical_diffusion.models.utils.attention_blocks import SpatialTransformer, LinearTransformer
class UNet(nn.Module):
def __init__(self,
in_ch=1,
out_ch=1,
spatial_dims = 3,
hid_chs = [32, 64, 128, 256],
kernel_sizes=[ 1, 3, 3, 3],
strides = [ 1, 2, 2, 2],
downsample_kernel_sizes = None,
upsample_kernel_sizes = None,
act_name=("SWISH", {}),
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
time_embedder=TimeEmbbeding,
time_embedder_kwargs={},
cond_embedder=None,
cond_embedder_kwargs={},
deep_supervision=True, # True = all but last layer, 0/False=disable, 1=only first layer, ...
use_res_block=True,
estimate_variance=False ,
use_self_conditioning = False,
dropout=0.0,
learnable_interpolation=True,
use_attention='none',
):
super().__init__()
use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides)
self.use_self_conditioning = use_self_conditioning
self.use_res_block = use_res_block
self.depth = len(strides)
if downsample_kernel_sizes is None:
downsample_kernel_sizes = kernel_sizes
if upsample_kernel_sizes is None:
upsample_kernel_sizes = strides
# ------------- Time-Embedder-----------
if time_embedder is not None:
self.time_embedder=time_embedder(**time_embedder_kwargs)
time_emb_dim = self.time_embedder.emb_dim
else:
self.time_embedder = None
# ------------- Condition-Embedder-----------
if cond_embedder is not None:
self.cond_embedder=cond_embedder(**cond_embedder_kwargs)
else:
self.cond_embedder = None
# ----------- In-Convolution ------------
in_ch = in_ch*2 if self.use_self_conditioning else in_ch
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock
self.inc = ConvBlock(
spatial_dims = spatial_dims,
in_channels = in_ch,
out_channels = hid_chs[0],
kernel_size=kernel_sizes[0],
stride=strides[0],
act_name=act_name,
norm_name=norm_name,
emb_channels=time_emb_dim
)
# ----------- Encoder ----------------
self.encoders = nn.ModuleList([
DownBlock(
spatial_dims = spatial_dims,
in_channels = hid_chs[i-1],
out_channels = hid_chs[i],
kernel_size = kernel_sizes[i],
stride = strides[i],
downsample_kernel_size = downsample_kernel_sizes[i],
norm_name = norm_name,
act_name = act_name,
dropout = dropout,
use_res_block = use_res_block,
learnable_interpolation = learnable_interpolation,
use_attention = use_attention[i],
emb_channels = time_emb_dim
)
for i in range(1, self.depth)
])
# ------------ Decoder ----------
self.decoders = nn.ModuleList([
UpBlock(
spatial_dims = spatial_dims,
in_channels = hid_chs[i+1],
out_channels = hid_chs[i],
kernel_size=kernel_sizes[i+1],
stride=strides[i+1],
upsample_kernel_size=upsample_kernel_sizes[i+1],
norm_name=norm_name,
act_name=act_name,
dropout=dropout,
use_res_block=use_res_block,
learnable_interpolation=learnable_interpolation,
use_attention=use_attention[i],
emb_channels=time_emb_dim,
skip_channels=hid_chs[i]
)
for i in range(self.depth-1)
])
# --------------- Out-Convolution ----------------
out_ch_hor = out_ch*2 if estimate_variance else out_ch
self.outc = UnetOutBlock(spatial_dims, hid_chs[0], out_ch_hor, dropout=None)
if isinstance(deep_supervision, bool):
deep_supervision = self.depth-1 if deep_supervision else 0
self.outc_ver = nn.ModuleList([
UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None)
for i in range(1, deep_supervision+1)
])
def forward(self, x_t, t=None, condition=None, self_cond=None):
# x_t [B, C, *]
# t [B,]
# condition [B,]
# self_cond [B, C, *]
x = [ None for _ in range(len(self.encoders)+1) ]
# -------- Time Embedding (Global) -----------
if t is None:
time_emb = None
else:
time_emb = self.time_embedder(t) # [B, C]
# -------- Condition Embedding (Global) -----------
if (condition is None) or (self.cond_embedder is None):
cond_emb = None
else:
cond_emb = self.cond_embedder(condition) # [B, C]
# ----------- Embedding Summation --------
emb = save_add(time_emb, cond_emb)
# ---------- Self-conditioning-----------
if self.use_self_conditioning:
self_cond = torch.zeros_like(x_t) if self_cond is None else x_t
x_t = torch.cat([x_t, self_cond], dim=1)
# -------- In-Convolution --------------
x[0] = self.inc(x_t, emb)
# --------- Encoder --------------
for i in range(len(self.encoders)):
x[i+1] = self.encoders[i](x[i], emb)
# -------- Decoder -----------
for i in range(len(self.decoders), 0, -1):
x[i-1] = self.decoders[i-1](x[i], x[i-1], emb)
# ---------Out-Convolution ------------
y = self.outc(x[0])
y_ver = [outc_ver_i(x[i+1]) for i, outc_ver_i in enumerate(self.outc_ver)]
return y, y_ver
if __name__=='__main__':
model = UNet(in_ch=3, use_res_block=False, learnable_interpolation=False)
input = torch.randn((1,3,16,128,128))
time = torch.randn((1,))
out_hor, out_ver = model(input, time)
print(out_hor[0].shape)