File size: 655 Bytes
eaefa93 ee90412 eaefa93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
from torch.nn import functional as F
T = 300 ## according to the paper
### SOO MMANNYY PRECOMPUTEDD VALUESS TO TRACKKKKSS
betas = torch.linspace(1e-4, 0.02, T)
alphas = 1. - betas
alphas_cumulative_products = torch.cumprod(alphas, axis=0)
alphas_cumulative_products_prev = F.pad(alphas_cumulative_products[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumulative_products = torch.sqrt(alphas_cumulative_products)
sqrt_one_minus_alphas_cumulative_products = torch.sqrt(1. - alphas_cumulative_products)
posterior_variance = betas * (1. - alphas_cumulative_products_prev) / (1. - alphas_cumulative_products) |