Spaces:
Runtime error
Runtime error
from medical_diffusion.models.estimators import UNet | |
from medical_diffusion.models.embedders import LabelEmbedder | |
import torch | |
cond_embedder = LabelEmbedder | |
cond_embedder_kwargs = { | |
'emb_dim': 64, | |
'num_classes':2 | |
} | |
noise_estimator = UNet | |
noise_estimator_kwargs = { | |
'in_ch':3, | |
'out_ch':3, | |
'spatial_dims':2, | |
'hid_chs': [32, 64, 128, 256], | |
'kernel_sizes': [ 1, 3, 3, 3], | |
'strides': [ 1, 2, 2, 2], | |
# 'kernel_sizes':[(1,3,3), (1,3,3), (1,3,3), 3, 3], | |
# 'strides':[ 1, (1,2,2), (1,2,2), 2, 2], | |
# 'kernel_sizes':[3, 3, 3, 3, 3], | |
# 'strides': [1, 2, 2, 2, 2], | |
'cond_embedder':cond_embedder, | |
'cond_embedder_kwargs': cond_embedder_kwargs, | |
'use_attention': 'linear', #['none', 'spatial', 'spatial', 'spatial', 'linear'], | |
} | |
model = UNet(**noise_estimator_kwargs) | |
# print(model) | |
input = torch.randn((1,3,256,256)) | |
time = torch.randn([1,]) | |
cond = torch.tensor([0,]) | |
out_hor, out_ver = model(input, time, cond) | |
# print(out_hor) |