|
import torch |
|
from tqdm import tqdm, trange |
|
|
|
from ..utils.sampling_utils import generate_eta_values |
|
|
|
|
|
@torch.no_grad() |
|
def mochi_sample(model, z, sigmas, callback=None): |
|
total_steps = len(sigmas)-1 |
|
latent_shape = z.shape |
|
for i in tqdm(range(0, total_steps), desc="Processing Samples", total=total_steps): |
|
pred = model(z=z, sigma=torch.full([latent_shape[0]], sigmas[i], device=z.device)) |
|
z = z + pred * (sigmas[i] - sigmas[i + 1]) |
|
|
|
if callback is not None: |
|
callback(i, z) |
|
|
|
return z |
|
|
|
|
|
def get_rf_forward_sample_fn(gamma, seed, correction=True): |
|
|
|
generator = torch.Generator() |
|
generator.manual_seed(seed) |
|
|
|
@torch.no_grad() |
|
def sample_forward(model, y0, sigmas, extra_args={}, callback=None, disable=None): |
|
Y = y0.clone() |
|
y1 = torch.randn(Y.shape, generator=generator).to(y0.device) |
|
N = len(sigmas)-1 |
|
s_in = y0.new_ones([y0.shape[0]]) |
|
for i in trange(N, disable=disable): |
|
|
|
t_i = sigmas[i] |
|
|
|
|
|
unconditional_vector_field = -model(Y, sigmas[i]*s_in, **extra_args) |
|
|
|
if correction: |
|
|
|
conditional_vector_field = (y1-Y)/(1-t_i) |
|
|
|
|
|
controlled_vector_field = unconditional_vector_field + gamma * (conditional_vector_field - unconditional_vector_field) |
|
else: |
|
controlled_vector_field = unconditional_vector_field |
|
|
|
|
|
Y = Y + controlled_vector_field * (sigmas[i+1] - sigmas[i]) |
|
|
|
if callback is not None: |
|
callback({'x': Y, 'denoised': Y, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]}) |
|
|
|
return Y |
|
|
|
return sample_forward |
|
|
|
|
|
def get_rf_reverse_sample_fn(latent_image, eta, start_time, end_time, eta_trend): |
|
|
|
@torch.no_grad() |
|
def sample_reverse(model, y1, sigmas, extra_args={}, callback=None, disable=None): |
|
latent_shape = y1.shape |
|
X = y1.clone() |
|
N = len(sigmas)-1 |
|
y0 = latent_image.clone().to(y1.device) |
|
eta_values = generate_eta_values(N, start_time, end_time, eta, eta_trend) |
|
s_in = y0.new_ones([y0.shape[0]]) |
|
for i in trange(N, disable=disable): |
|
|
|
t_i = 1 - sigmas[i] |
|
|
|
|
|
|
|
unconditional_vector_field = model(X, sigmas[i]*s_in, **extra_args) |
|
|
|
|
|
conditional_vector_field = (y0-X)/(1-t_i) |
|
|
|
|
|
controlled_vector_field = unconditional_vector_field + eta_values[i] * (conditional_vector_field - unconditional_vector_field) |
|
|
|
|
|
X = X + controlled_vector_field * (sigmas[i] - sigmas[i+1]) |
|
|
|
|
|
if callback is not None: |
|
callback({'x': X, 'denoised': X, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]}) |
|
|
|
return X |
|
|
|
return sample_reverse |
|
|