all_models / custom_nodes /ComfyUI-LTXTricks /nodes /rectified_sampler_nodes.py
jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
from tqdm import trange
from comfy.samplers import KSAMPLER
def generate_trend_values(steps, start_time, end_time, eta, eta_trend):
eta_values = [0] * steps
if eta_trend == 'constant':
for i in range(start_time, end_time):
eta_values[i] = eta
elif eta_trend == 'linear_increase':
for i in range(start_time, end_time):
progress = (i - start_time) / (end_time - start_time - 1)
eta_values[i] = eta * progress
elif eta_trend == 'linear_decrease':
for i in range(start_time, end_time):
progress = 1 - (i - start_time) / (end_time - start_time - 1)
eta_values[i] = eta * progress
return eta_values
def get_sample_forward(gamma, start_step, end_step, gamma_trend, seed, attn_bank=None, order="first"):
# Controlled Forward ODE (Algorithm 1)
generator = torch.Generator()
generator.manual_seed(seed)
@torch.no_grad()
def sample_forward(model, y0, sigmas, extra_args=None, callback=None, disable=None):
if attn_bank is not None:
for block_idx in attn_bank['block_map']:
attn_bank['block_map'][block_idx].clear()
extra_args = {} if extra_args is None else extra_args
model_options = extra_args.get('model_options', {})
model_options = {**model_options}
transformer_options = model_options.get('transformer_options', {})
transformer_options = {**transformer_options, 'total_steps': len(sigmas)-1, 'sample_mode': 'forward', 'attn_bank': attn_bank}
model_options['transformer_options'] = transformer_options
extra_args['model_options'] = model_options
Y = y0.clone()
y1 = torch.randn(Y.shape, generator=generator).to(y0.device)
N = len(sigmas)-1
s_in = y0.new_ones([y0.shape[0]])
gamma_values = generate_trend_values(N, start_step, end_step, gamma, gamma_trend)
for i in trange(N, disable=disable):
transformer_options['step'] = i
sigma = sigmas[i]
sigma_next = sigmas[i+1]
t_i = model.inner_model.inner_model.model_sampling.timestep(sigmas[i])
conditional_vector_field = (y1-Y)/(1-t_i)
transformer_options['pred_order'] = 'first'
pred = model(Y, s_in * sigmas[i], **extra_args) # this implementation takes sigma instead of timestep
if order == 'second':
transformer_options['pred_order'] = 'second'
img_mid = Y + (sigma_next- sigma) / 2 * pred
sigma_mid = (sigma + (sigma_next - sigma) / 2)
pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
first_order = (pred_mid - pred) / ((sigma_next - sigma) / 2)
pred = pred + gamma_values[i] * (conditional_vector_field - pred)
# first_order = first_order + gamma_values[i] * (conditional_vector_field - first_order)
Y = Y + (sigma_next - sigma) * pred + 0.5 * (sigma_next - sigma) ** 2 * first_order
else:
pred = pred + gamma_values[i] * (conditional_vector_field - pred)
Y = Y + pred * (sigma_next - sigma)
if callback is not None:
callback({'x': Y, 'denoised': Y, 'i': i, 'sigma': sigma, 'sigma_hat': sigma})
return Y
return sample_forward
def get_sample_reverse(latent_image, eta, start_time, end_time, eta_trend, attn_bank=None, order='first'):
# Controlled Reverse ODE (Algorithm 2)
@torch.no_grad()
def sample_reverse(model, y1, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
model_options = extra_args.get('model_options', {})
model_options = {**model_options}
transformer_options = model_options.get('transformer_options', {})
transformer_options = {**transformer_options, 'total_steps': len(sigmas)-1, 'sample_mode': 'reverse', 'attn_bank': attn_bank}
model_options['transformer_options'] = transformer_options
extra_args['model_options'] = model_options
X = y1.clone()
N = len(sigmas)-1
y0 = latent_image.clone().to(y1.device)
s_in = y0.new_ones([y0.shape[0]])
eta_values = generate_trend_values(N, start_time, end_time, eta, eta_trend)
for i in trange(N, disable=disable):
transformer_options['step'] = i
t_i = 1-model.inner_model.inner_model.model_sampling.timestep(sigmas[i])
sigma = sigmas[i]
sigma_prev = sigmas[i+1]
conditional_vector_field = (y0-X)/(1-t_i)
transformer_options['pred_order'] = 'first'
pred = model(X, sigma*s_in, **extra_args) # this implementation takes sigma instead of timestep
if order == 'second':
transformer_options['pred_order'] = 'second'
img_mid = X + (sigma_prev- sigma) / 2 * pred
sigma_mid = (sigma + (sigma_prev - sigma) / 2)
pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2)
pred = -pred + eta_values[i] * (conditional_vector_field + pred)
first_order = -first_order + eta_values[i] * (conditional_vector_field + first_order)
X = X + (sigma - sigma_prev) * pred + 0.5 * (sigma - sigma_prev) ** 2 * first_order
else:
controlled_vector_field = -pred + eta_values[i] * (conditional_vector_field + pred)
X = X + controlled_vector_field * (sigma - sigma_prev)
if callback is not None:
callback({'x': X, 'denoised': X, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]})
return X
return sample_reverse
class LTXRFForwardODESamplerNode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"gamma": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
"end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}),
"gamma_trend": (['linear_decrease', 'linear_increase', 'constant'],)
}, "optional": {
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }),
"attn_bank": ("ATTN_BANK",),
"order": (["first", "second"],),
}}
RETURN_TYPES = ("SAMPLER",)
FUNCTION = "build"
CATEGORY = "ltxtricks"
def build(self, gamma, start_step, end_step, gamma_trend, seed=0, attn_bank=None, order="first"):
sampler = KSAMPLER(get_sample_forward(gamma, start_step, end_step, gamma_trend, seed, attn_bank=attn_bank, order=order))
return (sampler, )
class LTXRFReverseODESamplerNode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"latent_image": ("LATENT",),
"eta": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 100.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
"end_step": ("INT", {"default": 15, "min": 0, "max": 1000, "step": 1}),
}, "optional": {
"eta_trend": (['linear_decrease', 'linear_increase', 'constant'],),
"attn_inj": ("ATTN_INJ",),
"order": (["first", "second"],),
}}
RETURN_TYPES = ("SAMPLER",)
FUNCTION = "build"
CATEGORY = "ltxtricks"
def build(self, model, latent_image, eta, start_step, end_step, eta_trend='constant', attn_inj=None, order='first'):
process_latent_in = model.get_model_object("process_latent_in")
latent_image = process_latent_in(latent_image['samples'])
sampler = KSAMPLER(get_sample_reverse(latent_image, eta, start_step, end_step, eta_trend, attn_bank=attn_inj, order=order))
return (sampler, )