File size: 486 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

import torch 
from medical_diffusion.external.diffusers.vae import VQModel, VQVAEWrapper, VAEWrapper


# model = AutoencoderKL(in_channels=3, out_channels=3)

input = torch.randn((1, 3, 128, 128)) # [B, C, H, W]

# model = VQModel(in_channels=3, out_channels=3)
# output = model(input, sample_posterior=True)
# print(output)

model = VQVAEWrapper(in_ch=3, out_ch=3)
output = model(input)
print(output)




# model = VAEWrapper(in_ch=3, out_ch=3)
# output = model(input)
# print(output)