|
import torch |
|
from tqdm import trange |
|
|
|
from comfy.samplers import KSAMPLER |
|
|
|
|
|
def generate_trend_values(steps, start_time, end_time, eta, eta_trend): |
|
eta_values = [0] * steps |
|
|
|
if eta_trend == 'constant': |
|
for i in range(start_time, end_time): |
|
eta_values[i] = eta |
|
elif eta_trend == 'linear_increase': |
|
for i in range(start_time, end_time): |
|
progress = (i - start_time) / (end_time - start_time - 1) |
|
eta_values[i] = eta * progress |
|
elif eta_trend == 'linear_decrease': |
|
for i in range(start_time, end_time): |
|
progress = 1 - (i - start_time) / (end_time - start_time - 1) |
|
eta_values[i] = eta * progress |
|
|
|
return eta_values |
|
|
|
|
|
|
|
def get_sample_forward(gamma, start_step, end_step, gamma_trend, seed, attn_bank=None, order="first"): |
|
|
|
generator = torch.Generator() |
|
generator.manual_seed(seed) |
|
|
|
@torch.no_grad() |
|
def sample_forward(model, y0, sigmas, extra_args=None, callback=None, disable=None): |
|
if attn_bank is not None: |
|
for block_idx in attn_bank['block_map']: |
|
attn_bank['block_map'][block_idx].clear() |
|
|
|
extra_args = {} if extra_args is None else extra_args |
|
model_options = extra_args.get('model_options', {}) |
|
model_options = {**model_options} |
|
transformer_options = model_options.get('transformer_options', {}) |
|
transformer_options = {**transformer_options, 'total_steps': len(sigmas)-1, 'sample_mode': 'forward', 'attn_bank': attn_bank} |
|
model_options['transformer_options'] = transformer_options |
|
extra_args['model_options'] = model_options |
|
|
|
Y = y0.clone() |
|
y1 = torch.randn(Y.shape, generator=generator).to(y0.device) |
|
N = len(sigmas)-1 |
|
s_in = y0.new_ones([y0.shape[0]]) |
|
gamma_values = generate_trend_values(N, start_step, end_step, gamma, gamma_trend) |
|
for i in trange(N, disable=disable): |
|
transformer_options['step'] = i |
|
sigma = sigmas[i] |
|
sigma_next = sigmas[i+1] |
|
t_i = model.inner_model.inner_model.model_sampling.timestep(sigmas[i]) |
|
|
|
conditional_vector_field = (y1-Y)/(1-t_i) |
|
|
|
transformer_options['pred_order'] = 'first' |
|
pred = model(Y, s_in * sigmas[i], **extra_args) |
|
|
|
if order == 'second': |
|
transformer_options['pred_order'] = 'second' |
|
img_mid = Y + (sigma_next- sigma) / 2 * pred |
|
sigma_mid = (sigma + (sigma_next - sigma) / 2) |
|
pred_mid = model(img_mid, s_in * sigma_mid, **extra_args) |
|
|
|
first_order = (pred_mid - pred) / ((sigma_next - sigma) / 2) |
|
pred = pred + gamma_values[i] * (conditional_vector_field - pred) |
|
|
|
Y = Y + (sigma_next - sigma) * pred + 0.5 * (sigma_next - sigma) ** 2 * first_order |
|
else: |
|
pred = pred + gamma_values[i] * (conditional_vector_field - pred) |
|
Y = Y + pred * (sigma_next - sigma) |
|
|
|
if callback is not None: |
|
callback({'x': Y, 'denoised': Y, 'i': i, 'sigma': sigma, 'sigma_hat': sigma}) |
|
|
|
return Y |
|
|
|
return sample_forward |
|
|
|
|
|
def get_sample_reverse(latent_image, eta, start_time, end_time, eta_trend, attn_bank=None, order='first'): |
|
|
|
@torch.no_grad() |
|
def sample_reverse(model, y1, sigmas, extra_args=None, callback=None, disable=None): |
|
extra_args = {} if extra_args is None else extra_args |
|
model_options = extra_args.get('model_options', {}) |
|
model_options = {**model_options} |
|
transformer_options = model_options.get('transformer_options', {}) |
|
transformer_options = {**transformer_options, 'total_steps': len(sigmas)-1, 'sample_mode': 'reverse', 'attn_bank': attn_bank} |
|
model_options['transformer_options'] = transformer_options |
|
extra_args['model_options'] = model_options |
|
|
|
X = y1.clone() |
|
N = len(sigmas)-1 |
|
y0 = latent_image.clone().to(y1.device) |
|
s_in = y0.new_ones([y0.shape[0]]) |
|
eta_values = generate_trend_values(N, start_time, end_time, eta, eta_trend) |
|
for i in trange(N, disable=disable): |
|
transformer_options['step'] = i |
|
t_i = 1-model.inner_model.inner_model.model_sampling.timestep(sigmas[i]) |
|
sigma = sigmas[i] |
|
sigma_prev = sigmas[i+1] |
|
|
|
conditional_vector_field = (y0-X)/(1-t_i) |
|
|
|
transformer_options['pred_order'] = 'first' |
|
pred = model(X, sigma*s_in, **extra_args) |
|
|
|
if order == 'second': |
|
transformer_options['pred_order'] = 'second' |
|
img_mid = X + (sigma_prev- sigma) / 2 * pred |
|
sigma_mid = (sigma + (sigma_prev - sigma) / 2) |
|
pred_mid = model(img_mid, s_in * sigma_mid, **extra_args) |
|
|
|
first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2) |
|
pred = -pred + eta_values[i] * (conditional_vector_field + pred) |
|
|
|
first_order = -first_order + eta_values[i] * (conditional_vector_field + first_order) |
|
X = X + (sigma - sigma_prev) * pred + 0.5 * (sigma - sigma_prev) ** 2 * first_order |
|
else: |
|
controlled_vector_field = -pred + eta_values[i] * (conditional_vector_field + pred) |
|
X = X + controlled_vector_field * (sigma - sigma_prev) |
|
|
|
if callback is not None: |
|
callback({'x': X, 'denoised': X, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]}) |
|
|
|
return X |
|
|
|
return sample_reverse |
|
|
|
|
|
class LTXRFForwardODESamplerNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"gamma": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step": 0.01}), |
|
"start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), |
|
"end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}), |
|
"gamma_trend": (['linear_decrease', 'linear_increase', 'constant'],) |
|
}, "optional": { |
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }), |
|
"attn_bank": ("ATTN_BANK",), |
|
"order": (["first", "second"],), |
|
}} |
|
RETURN_TYPES = ("SAMPLER",) |
|
FUNCTION = "build" |
|
|
|
CATEGORY = "ltxtricks" |
|
|
|
def build(self, gamma, start_step, end_step, gamma_trend, seed=0, attn_bank=None, order="first"): |
|
sampler = KSAMPLER(get_sample_forward(gamma, start_step, end_step, gamma_trend, seed, attn_bank=attn_bank, order=order)) |
|
|
|
return (sampler, ) |
|
|
|
|
|
class LTXRFReverseODESamplerNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"model": ("MODEL",), |
|
"latent_image": ("LATENT",), |
|
"eta": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 100.0, "step": 0.01}), |
|
"start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), |
|
"end_step": ("INT", {"default": 15, "min": 0, "max": 1000, "step": 1}), |
|
}, "optional": { |
|
"eta_trend": (['linear_decrease', 'linear_increase', 'constant'],), |
|
"attn_inj": ("ATTN_INJ",), |
|
"order": (["first", "second"],), |
|
}} |
|
RETURN_TYPES = ("SAMPLER",) |
|
FUNCTION = "build" |
|
|
|
CATEGORY = "ltxtricks" |
|
|
|
def build(self, model, latent_image, eta, start_step, end_step, eta_trend='constant', attn_inj=None, order='first'): |
|
process_latent_in = model.get_model_object("process_latent_in") |
|
latent_image = process_latent_in(latent_image['samples']) |
|
sampler = KSAMPLER(get_sample_reverse(latent_image, eta, start_step, end_step, eta_trend, attn_bank=attn_inj, order=order)) |
|
|
|
return (sampler, ) |
|
|