File size: 429 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

from medical_diffusion.external.stable_diffusion.unet_openai import UNetModel
from medical_diffusion.models.embedders import LabelEmbedder

import torch 


noise_estimator = UNetModel
noise_estimator_kwargs = {}


model  = noise_estimator(**noise_estimator_kwargs)
print(model)

input = torch.randn((1,4,32,32))
time = torch.randn([1,])
cond = None #torch.tensor([0,]) 
out_hor, out_ver = model(input, time, cond)
print(out_hor)