|
import torch |
|
from tqdm import trange |
|
|
|
from comfy.samplers import KSAMPLER |
|
|
|
|
|
def generate_eta_values(steps, start_time, end_time, eta, eta_trend): |
|
eta_values = [0] * steps |
|
|
|
if eta_trend == 'constant': |
|
for i in range(start_time, end_time): |
|
eta_values[i] = eta |
|
elif eta_trend == 'linear_increase': |
|
for i in range(start_time, end_time): |
|
progress = (i - start_time) / (end_time - start_time - 1) |
|
eta_values[i] = eta * progress |
|
elif eta_trend == 'linear_decrease': |
|
for i in range(start_time, end_time): |
|
progress = 1 - (i - start_time) / (end_time - start_time - 1) |
|
eta_values[i] = eta * progress |
|
|
|
return eta_values |
|
|
|
|
|
|
|
def get_sample_forward(gamma, start_step, end_step, gamma_trend, seed): |
|
|
|
|
|
@torch.no_grad() |
|
def sample_forward(model, y0, sigmas, extra_args=None, callback=None, disable=None): |
|
generator = torch.Generator() |
|
generator.manual_seed(seed) |
|
extra_args = {} if extra_args is None else extra_args |
|
Y = y0.clone() |
|
y1 = torch.randn(Y.shape, generator=generator).to(y0.device) |
|
N = len(sigmas)-1 |
|
s_in = y0.new_ones([y0.shape[0]]) |
|
gamma_values = generate_eta_values(N, start_step, end_step, gamma, gamma_trend) |
|
for i in trange(N, disable=disable): |
|
|
|
t_i = sigmas[i] / max(sigmas) |
|
|
|
|
|
unconditional_vector_field = model(Y, s_in * sigmas[i], **extra_args) |
|
|
|
|
|
conditional_vector_field = (y1-Y)/(1-t_i) |
|
|
|
|
|
controlled_vector_field = unconditional_vector_field + gamma_values[i] * (conditional_vector_field - unconditional_vector_field) |
|
|
|
|
|
Y = Y + controlled_vector_field * (sigmas[i+1] - sigmas[i]) |
|
|
|
if callback is not None: |
|
callback({'x': Y, 'denoised': Y, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]}) |
|
|
|
return Y |
|
|
|
return sample_forward |
|
|
|
|
|
|
|
def get_sample_reverse(latent_image, eta, start_time, end_time, eta_trend): |
|
|
|
@torch.no_grad() |
|
def sample_reverse(model, y1, sigmas, extra_args=None, callback=None, disable=None): |
|
extra_args = {} if extra_args is None else extra_args |
|
X = y1.clone() |
|
N = len(sigmas)-1 |
|
y0 = latent_image.clone().to(y1.device) |
|
s_in = y0.new_ones([y0.shape[0]]) |
|
eta_values = generate_eta_values(N, start_time, end_time, eta, eta_trend) |
|
for i in trange(N, disable=disable): |
|
|
|
t_i = i/N |
|
sigma = sigmas[i] |
|
|
|
|
|
unconditional_vector_field = -model(X, sigma*s_in, **extra_args) |
|
|
|
|
|
conditional_vector_field = (y0-X)/(1-t_i) |
|
|
|
|
|
controlled_vector_field = unconditional_vector_field + eta_values[i] * (conditional_vector_field - unconditional_vector_field) |
|
|
|
|
|
X = X + controlled_vector_field * (sigmas[i] - sigmas[i+1]) |
|
|
|
if callback is not None: |
|
callback({'x': X, 'denoised': X, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]}) |
|
|
|
return X |
|
|
|
return sample_reverse |
|
|
|
|
|
class HYForwardODESamplerNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"gamma": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step": 0.01}), |
|
"start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), |
|
"end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}), |
|
"gamma_trend": (['constant', 'linear_increase', 'linear_decrease'],) |
|
}, "optional": { |
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }), |
|
}} |
|
RETURN_TYPES = ("SAMPLER",) |
|
FUNCTION = "build" |
|
|
|
CATEGORY = "hunyuanloom" |
|
|
|
def build(self, gamma, start_step, end_step, gamma_trend, seed=0): |
|
sampler = KSAMPLER(get_sample_forward(gamma, start_step, end_step, gamma_trend, seed)) |
|
|
|
return (sampler, ) |
|
|
|
|
|
class HYReverseODESamplerNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"model": ("MODEL",), |
|
"latent_image": ("LATENT",), |
|
"eta": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 100.0, "step": 0.01}), |
|
"start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), |
|
"end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}), |
|
}, "optional": { |
|
"eta_trend": (['constant', 'linear_increase', 'linear_decrease'],) |
|
}} |
|
RETURN_TYPES = ("SAMPLER",) |
|
FUNCTION = "build" |
|
|
|
CATEGORY = "hunyuanloom" |
|
|
|
def build(self, model, latent_image, eta, start_step, end_step, eta_trend='constant'): |
|
process_latent_in = model.get_model_object("process_latent_in") |
|
latent_image = process_latent_in(latent_image['samples']) |
|
sampler = KSAMPLER(get_sample_reverse(latent_image, eta, start_step, end_step, eta_trend)) |
|
|
|
return (sampler, ) |
|
|