import torch | |
def get_alphacumprod(sigma): | |
return 1 / ((sigma * sigma) + 1) | |
def add_noise(src_latent, noise, sigma): | |
alphas_cumprod = get_alphacumprod(sigma) | |
sqrt_alpha_prod = alphas_cumprod ** 0.5 | |
sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
while len(sqrt_alpha_prod.shape) < len(src_latent.shape): | |
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod) ** 0.5 | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
while len(sqrt_one_minus_alpha_prod.shape) < len(src_latent.shape): | |
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
noisy_samples = sqrt_alpha_prod * src_latent + sqrt_one_minus_alpha_prod * noise | |
return noisy_samples | |
def add_noise_flux(src_latent, noise, sigma): | |
return sigma * noise + (1.0 - sigma) * src_latent |