medfusion-app / tests /models /test_vae3d.py
mueller-franzes's picture
init
f85e212
raw
history blame contribute delete
582 Bytes
import torch
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN
input = torch.randn((1, 3, 16, 128, 128)) # [B, C, H, W]
model = VQVAE(in_channels=3, out_channels=3, spatial_dims = 3, emb_channels=1, deep_supervision=True)
# output = model(input)
# print(output)
loss = model._step({'source':input}, 1, 'train', 1, 1)
print(loss)
# model = VQGAN(in_channels=3, out_channels=3, spatial_dims = 3, emb_channels=1, deep_supervision=True)
# # output = model(input)
# # print(output)
# loss = model._step({'source':input}, 1, 'train', 1, 1)
# print(loss)