medfusion-app / tests /models /test_unet.py
mueller-franzes's picture
init
f85e212
from medical_diffusion.models.estimators import UNet
from medical_diffusion.models.embedders import LabelEmbedder
import torch
cond_embedder = LabelEmbedder
cond_embedder_kwargs = {
'emb_dim': 64,
'num_classes':2
}
noise_estimator = UNet
noise_estimator_kwargs = {
'in_ch':3,
'out_ch':3,
'spatial_dims':2,
'hid_chs': [32, 64, 128, 256],
'kernel_sizes': [ 1, 3, 3, 3],
'strides': [ 1, 2, 2, 2],
# 'kernel_sizes':[(1,3,3), (1,3,3), (1,3,3), 3, 3],
# 'strides':[ 1, (1,2,2), (1,2,2), 2, 2],
# 'kernel_sizes':[3, 3, 3, 3, 3],
# 'strides': [1, 2, 2, 2, 2],
'cond_embedder':cond_embedder,
'cond_embedder_kwargs': cond_embedder_kwargs,
'use_attention': 'linear', #['none', 'spatial', 'spatial', 'spatial', 'linear'],
}
model = UNet(**noise_estimator_kwargs)
# print(model)
input = torch.randn((1,3,256,256))
time = torch.randn([1,])
cond = torch.tensor([0,])
out_hor, out_ver = model(input, time, cond)
# print(out_hor)