Spaces:
Runtime error
Runtime error
from medical_diffusion.models.noise_schedulers import GaussianNoiseScheduler | |
import torch | |
from pathlib import Path | |
from torchvision.utils import save_image | |
device = torch.device('cuda') | |
scheduler = GaussianNoiseScheduler() | |
# scheduler.to(device) | |
path_out = Path.cwd()/'results/test' | |
# print(scheduler.posterior_mean_coef1) | |
torch.manual_seed(0) | |
# x_0 = torch.ones((2, 3, 64, 64)) | |
x_0 = torch.rand((2, 3, 64, 64)) | |
noise = torch.randn_like(x_0) | |
t = torch.tensor([0, 999]) | |
x_t = scheduler.estimate_x_t(x_0=x_0, t=t, x_T=noise) | |
# x_0_pred = scheduler.estimate_x_t(x_0=x_0, t=torch.full_like(t, 0) , noise=noise) | |
# assert (x_0_pred == x_0).all(), "For t=0, function should return x_0" | |
# x_t, noise, t = scheduler.sample(x_0) | |
# print(x_t) | |
# x_0 = scheduler.estimate_x_0(x_t, noise, t) | |
# print(x_0) | |
# print(x_0.shape) | |
pred = torch.randn_like(x_t) | |
x_t_prior, _ = scheduler.estimate_x_t_prior_from_x_T(x_t, t, pred, clip_x0=False) | |
print(x_t_prior) | |
# save_image(x_t_prior, path_out/'test2.png') | |