Spaces:
Runtime error
Runtime error
File size: 1,026 Bytes
f85e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from medical_diffusion.models.estimators import UNet
from medical_diffusion.models.embedders import LabelEmbedder
import torch
cond_embedder = LabelEmbedder
cond_embedder_kwargs = {
'emb_dim': 64,
'num_classes':2
}
noise_estimator = UNet
noise_estimator_kwargs = {
'in_ch':3,
'out_ch':3,
'spatial_dims':2,
'hid_chs': [32, 64, 128, 256],
'kernel_sizes': [ 1, 3, 3, 3],
'strides': [ 1, 2, 2, 2],
# 'kernel_sizes':[(1,3,3), (1,3,3), (1,3,3), 3, 3],
# 'strides':[ 1, (1,2,2), (1,2,2), 2, 2],
# 'kernel_sizes':[3, 3, 3, 3, 3],
# 'strides': [1, 2, 2, 2, 2],
'cond_embedder':cond_embedder,
'cond_embedder_kwargs': cond_embedder_kwargs,
'use_attention': 'linear', #['none', 'spatial', 'spatial', 'spatial', 'linear'],
}
model = UNet(**noise_estimator_kwargs)
# print(model)
input = torch.randn((1,3,256,256))
time = torch.randn([1,])
cond = torch.tensor([0,])
out_hor, out_ver = model(input, time, cond)
# print(out_hor) |