Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from monai.networks.blocks import UnetOutBlock | |
from medical_diffusion.models.utils.conv_blocks import BasicBlock, UpBlock, DownBlock, UnetBasicBlock, UnetResBlock, save_add | |
from medical_diffusion.models.embedders import TimeEmbbeding | |
from medical_diffusion.models.utils.attention_blocks import SpatialTransformer, LinearTransformer | |
class UNet(nn.Module): | |
def __init__(self, | |
in_ch=1, | |
out_ch=1, | |
spatial_dims = 3, | |
hid_chs = [32, 64, 128, 256], | |
kernel_sizes=[ 1, 3, 3, 3], | |
strides = [ 1, 2, 2, 2], | |
downsample_kernel_sizes = None, | |
upsample_kernel_sizes = None, | |
act_name=("SWISH", {}), | |
norm_name = ("GROUP", {'num_groups':32, "affine": True}), | |
time_embedder=TimeEmbbeding, | |
time_embedder_kwargs={}, | |
cond_embedder=None, | |
cond_embedder_kwargs={}, | |
deep_supervision=True, # True = all but last layer, 0/False=disable, 1=only first layer, ... | |
use_res_block=True, | |
estimate_variance=False , | |
use_self_conditioning = False, | |
dropout=0.0, | |
learnable_interpolation=True, | |
use_attention='none', | |
): | |
super().__init__() | |
use_attention = use_attention if isinstance(use_attention, list) else [use_attention]*len(strides) | |
self.use_self_conditioning = use_self_conditioning | |
self.use_res_block = use_res_block | |
self.depth = len(strides) | |
if downsample_kernel_sizes is None: | |
downsample_kernel_sizes = kernel_sizes | |
if upsample_kernel_sizes is None: | |
upsample_kernel_sizes = strides | |
# ------------- Time-Embedder----------- | |
if time_embedder is not None: | |
self.time_embedder=time_embedder(**time_embedder_kwargs) | |
time_emb_dim = self.time_embedder.emb_dim | |
else: | |
self.time_embedder = None | |
# ------------- Condition-Embedder----------- | |
if cond_embedder is not None: | |
self.cond_embedder=cond_embedder(**cond_embedder_kwargs) | |
else: | |
self.cond_embedder = None | |
# ----------- In-Convolution ------------ | |
in_ch = in_ch*2 if self.use_self_conditioning else in_ch | |
ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock | |
self.inc = ConvBlock( | |
spatial_dims = spatial_dims, | |
in_channels = in_ch, | |
out_channels = hid_chs[0], | |
kernel_size=kernel_sizes[0], | |
stride=strides[0], | |
act_name=act_name, | |
norm_name=norm_name, | |
emb_channels=time_emb_dim | |
) | |
# ----------- Encoder ---------------- | |
self.encoders = nn.ModuleList([ | |
DownBlock( | |
spatial_dims = spatial_dims, | |
in_channels = hid_chs[i-1], | |
out_channels = hid_chs[i], | |
kernel_size = kernel_sizes[i], | |
stride = strides[i], | |
downsample_kernel_size = downsample_kernel_sizes[i], | |
norm_name = norm_name, | |
act_name = act_name, | |
dropout = dropout, | |
use_res_block = use_res_block, | |
learnable_interpolation = learnable_interpolation, | |
use_attention = use_attention[i], | |
emb_channels = time_emb_dim | |
) | |
for i in range(1, self.depth) | |
]) | |
# ------------ Decoder ---------- | |
self.decoders = nn.ModuleList([ | |
UpBlock( | |
spatial_dims = spatial_dims, | |
in_channels = hid_chs[i+1], | |
out_channels = hid_chs[i], | |
kernel_size=kernel_sizes[i+1], | |
stride=strides[i+1], | |
upsample_kernel_size=upsample_kernel_sizes[i+1], | |
norm_name=norm_name, | |
act_name=act_name, | |
dropout=dropout, | |
use_res_block=use_res_block, | |
learnable_interpolation=learnable_interpolation, | |
use_attention=use_attention[i], | |
emb_channels=time_emb_dim, | |
skip_channels=hid_chs[i] | |
) | |
for i in range(self.depth-1) | |
]) | |
# --------------- Out-Convolution ---------------- | |
out_ch_hor = out_ch*2 if estimate_variance else out_ch | |
self.outc = UnetOutBlock(spatial_dims, hid_chs[0], out_ch_hor, dropout=None) | |
if isinstance(deep_supervision, bool): | |
deep_supervision = self.depth-1 if deep_supervision else 0 | |
self.outc_ver = nn.ModuleList([ | |
UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None) | |
for i in range(1, deep_supervision+1) | |
]) | |
def forward(self, x_t, t=None, condition=None, self_cond=None): | |
# x_t [B, C, *] | |
# t [B,] | |
# condition [B,] | |
# self_cond [B, C, *] | |
x = [ None for _ in range(len(self.encoders)+1) ] | |
# -------- Time Embedding (Global) ----------- | |
if t is None: | |
time_emb = None | |
else: | |
time_emb = self.time_embedder(t) # [B, C] | |
# -------- Condition Embedding (Global) ----------- | |
if (condition is None) or (self.cond_embedder is None): | |
cond_emb = None | |
else: | |
cond_emb = self.cond_embedder(condition) # [B, C] | |
# ----------- Embedding Summation -------- | |
emb = save_add(time_emb, cond_emb) | |
# ---------- Self-conditioning----------- | |
if self.use_self_conditioning: | |
self_cond = torch.zeros_like(x_t) if self_cond is None else x_t | |
x_t = torch.cat([x_t, self_cond], dim=1) | |
# -------- In-Convolution -------------- | |
x[0] = self.inc(x_t, emb) | |
# --------- Encoder -------------- | |
for i in range(len(self.encoders)): | |
x[i+1] = self.encoders[i](x[i], emb) | |
# -------- Decoder ----------- | |
for i in range(len(self.decoders), 0, -1): | |
x[i-1] = self.decoders[i-1](x[i], x[i-1], emb) | |
# ---------Out-Convolution ------------ | |
y = self.outc(x[0]) | |
y_ver = [outc_ver_i(x[i+1]) for i, outc_ver_i in enumerate(self.outc_ver)] | |
return y, y_ver | |
if __name__=='__main__': | |
model = UNet(in_ch=3, use_res_block=False, learnable_interpolation=False) | |
input = torch.randn((1,3,16,128,128)) | |
time = torch.randn((1,)) | |
out_hor, out_ver = model(input, time) | |
print(out_hor[0].shape) |