|
import comfy.samplers |
|
import comfy.utils |
|
import torch |
|
from comfy.model_patcher import ModelPatcher |
|
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy |
|
|
|
from .nodes_registry import comfy_node |
|
|
|
|
|
def stg( |
|
noise_pred_pos, |
|
noise_pred_neg, |
|
noise_pred_pertubed, |
|
cfg_scale, |
|
stg_scale, |
|
rescale_scale, |
|
): |
|
noise_pred = ( |
|
noise_pred_neg |
|
+ cfg_scale * (noise_pred_pos - noise_pred_neg) |
|
+ stg_scale * (noise_pred_pos - noise_pred_pertubed) |
|
) |
|
if rescale_scale != 0: |
|
factor = noise_pred_pos.std() / noise_pred.std() |
|
factor = rescale_scale * factor + (1 - rescale_scale) |
|
noise_pred = noise_pred * factor |
|
return noise_pred |
|
|
|
|
|
@comfy_node(name="LTXVApplySTG") |
|
class LTXVApplySTG: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"model": ( |
|
"MODEL", |
|
{"tooltip": "The model to apply the STG to."}, |
|
), |
|
"stg_mode": (["attention", "residual"],), |
|
"block_indices": ( |
|
"STRING", |
|
{ |
|
"default": "14, 19", |
|
"tooltip": "Comma-separated indices of the blocks to apply the STG to.", |
|
}, |
|
), |
|
} |
|
} |
|
|
|
FUNCTION = "apply_stg" |
|
RETURN_TYPES = ("MODEL",) |
|
RETURN_NAMES = ("model",) |
|
CATEGORY = "lightricks/LTXV" |
|
|
|
def apply_stg(self, model: ModelPatcher, stg_mode: str, block_indices: str): |
|
skip_block_list = [int(i.strip()) for i in block_indices.split(",")] |
|
stg_mode = ( |
|
SkipLayerStrategy.Attention |
|
if stg_mode == "attention" |
|
else SkipLayerStrategy.Residual |
|
) |
|
new_model = model.clone() |
|
|
|
new_model.model_options["transformer_options"]["skip_layer_strategy"] = stg_mode |
|
if "skip_block_list" in new_model.model_options["transformer_options"]: |
|
skip_block_list.extend( |
|
new_model.model_options["transformer_options"]["skip_block_list"] |
|
) |
|
new_model.model_options["transformer_options"][ |
|
"skip_block_list" |
|
] = skip_block_list |
|
|
|
return (new_model,) |
|
|
|
|
|
class STGGuider(comfy.samplers.CFGGuider): |
|
def set_conds(self, positive, negative): |
|
self.inner_set_conds( |
|
{"positive": positive, "negative": negative, "perturbed": positive} |
|
) |
|
|
|
def set_cfg(self, cfg, stg_scale, rescale_scale: float = None): |
|
self.cfg = cfg |
|
self.stg_scale = stg_scale |
|
self.rescale_scale = rescale_scale |
|
|
|
def predict_noise( |
|
self, |
|
x: torch.Tensor, |
|
timestep: torch.Tensor, |
|
model_options: dict = {}, |
|
seed=None, |
|
): |
|
|
|
|
|
|
|
perturbed_cond = self.conds.get("perturbed", None) |
|
positive_cond = self.conds.get("positive", None) |
|
negative_cond = self.conds.get("negative", None) |
|
|
|
noise_pred_neg = 0 |
|
|
|
if self.cfg > 1: |
|
model_options["transformer_options"]["ptb_index"] = 2 |
|
(noise_pred_perturbed, noise_pred_pos, noise_pred_neg) = ( |
|
comfy.samplers.calc_cond_batch( |
|
self.inner_model, |
|
[perturbed_cond, positive_cond, negative_cond], |
|
x, |
|
timestep, |
|
model_options, |
|
) |
|
) |
|
else: |
|
model_options["transformer_options"]["ptb_index"] = 1 |
|
(noise_pred_perturbed, noise_pred_pos) = comfy.samplers.calc_cond_batch( |
|
self.inner_model, |
|
[perturbed_cond, positive_cond], |
|
x, |
|
timestep, |
|
model_options, |
|
) |
|
stg_result = stg( |
|
noise_pred_pos, |
|
noise_pred_neg, |
|
noise_pred_perturbed, |
|
self.cfg, |
|
self.stg_scale, |
|
self.rescale_scale, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
for fn in model_options.get("sampler_post_cfg_function", []): |
|
args = { |
|
"denoised": stg_result, |
|
"cond": positive_cond, |
|
"uncond": negative_cond, |
|
"model": self.inner_model, |
|
"uncond_denoised": noise_pred_neg, |
|
"cond_denoised": noise_pred_pos, |
|
"sigma": timestep, |
|
"model_options": model_options, |
|
"input": x, |
|
|
|
"perturbed_cond": positive_cond, |
|
"perturbed_cond_denoised": noise_pred_perturbed, |
|
} |
|
stg_result = fn(args) |
|
|
|
return stg_result |
|
|
|
|
|
@comfy_node(name="STGGuider") |
|
class STGGuiderNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"model": ("MODEL",), |
|
"positive": ("CONDITIONING",), |
|
"negative": ("CONDITIONING",), |
|
"cfg": ( |
|
"FLOAT", |
|
{ |
|
"default": 1.0, |
|
"min": 0.0, |
|
"max": 100.0, |
|
"step": 0.1, |
|
"round": 0.01, |
|
}, |
|
), |
|
"stg": ( |
|
"FLOAT", |
|
{"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}, |
|
), |
|
"rescale": ( |
|
"FLOAT", |
|
{"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}, |
|
), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("GUIDER",) |
|
|
|
FUNCTION = "get_guider" |
|
CATEGORY = "lightricks/LTXV" |
|
|
|
def get_guider(self, model, positive, negative, cfg, stg, rescale): |
|
guider = STGGuider(model) |
|
guider.set_conds(positive, negative) |
|
guider.set_cfg(cfg, stg, rescale) |
|
return (guider,) |
|
|