|
import torch |
|
from tqdm import trange |
|
|
|
from comfy.samplers import KSAMPLER, CFGGuider, sampling_function |
|
|
|
|
|
class FlowEditGuider(CFGGuider): |
|
def __init__(self, model_patcher): |
|
super().__init__(model_patcher) |
|
self.cfgs = {} |
|
|
|
def set_conds(self, **kwargs): |
|
self.inner_set_conds(kwargs) |
|
|
|
def set_cfgs(self, **kwargs): |
|
self.cfgs = {**kwargs} |
|
|
|
def predict_noise(self, x, timestep, model_options={}, seed=None): |
|
latent_type = model_options['transformer_options']['latent_type'] |
|
positive = self.conds.get(f'{latent_type}_positive', None) |
|
negative = self.conds.get(f'{latent_type}_negative', None) |
|
cfg = self.cfgs.get(latent_type, self.cfg) |
|
return sampling_function(self.inner_model, x, timestep, negative, positive, cfg, model_options=model_options, seed=seed) |
|
|
|
|
|
class LTXFlowEditCFGGuiderNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": |
|
{ |
|
"model": ("MODEL",), |
|
"source_pos": ("CONDITIONING", ), |
|
"source_neg": ("CONDITIONING", ), |
|
"target_pos": ("CONDITIONING", ), |
|
"target_neg": ("CONDITIONING", ), |
|
"source_cfg": ("FLOAT", {"default": 2, "min": 0, "max": 0xffffffffffffffff, "step": 0.01 }), |
|
"target_cfg": ("FLOAT", {"default": 4.5, "min": 0, "max": 0xffffffffffffffff, "step": 0.01 }), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("GUIDER",) |
|
|
|
FUNCTION = "get_guider" |
|
CATEGORY = "ltxtricks" |
|
|
|
def get_guider(self, model, source_pos, source_neg, target_pos, target_neg, source_cfg, target_cfg): |
|
guider = FlowEditGuider(model) |
|
guider.set_conds(source_positive=source_pos, source_negative=source_neg, target_positive=target_pos, target_negative=target_neg) |
|
guider.set_cfgs(source=source_cfg, target=target_cfg) |
|
return (guider,) |
|
|
|
|
|
def get_flowedit_sample(skip_steps, refine_steps, seed): |
|
generator = torch.manual_seed(seed) |
|
@torch.no_grad() |
|
def flowedit_sample(model, x_init, sigmas, extra_args=None, callback=None, disable=None): |
|
extra_args = {} if extra_args is None else extra_args |
|
|
|
model_options = extra_args.get('model_options', {}) |
|
transformer_options = model_options.get('transformer_options', {}) |
|
transformer_options = {**transformer_options} |
|
model_options['transformer_options'] = transformer_options |
|
extra_args['model_options'] = model_options |
|
|
|
source_extra_args = {**extra_args, 'model_options': { 'transformer_options': { **transformer_options,'latent_type': 'source '} }} |
|
|
|
sigmas = sigmas[skip_steps:] |
|
|
|
x_tgt = x_init.clone() |
|
N = len(sigmas)-1 |
|
s_in = x_init.new_ones([x_init.shape[0]]) |
|
|
|
for i in trange(N, disable=disable): |
|
sigma = sigmas[i] |
|
noise = torch.randn(x_init.shape, generator=generator).to(x_init.device) |
|
|
|
zt_src = (1-sigma)*x_init + sigma*noise |
|
|
|
if i < N-refine_steps: |
|
zt_tgt = x_tgt + zt_src - x_init |
|
transformer_options['latent_type'] = 'source' |
|
source_extra_args['model_options']['transformer_options']['latent_type'] = 'source' |
|
vt_src = model(zt_src, sigma*s_in, **source_extra_args) |
|
else: |
|
if i == N-refine_steps: |
|
x_tgt = x_tgt + zt_src - x_init |
|
zt_tgt = x_tgt |
|
vt_src = 0 |
|
|
|
transformer_options['latent_type'] = 'target' |
|
vt_tgt = model(zt_tgt, sigma*s_in, **extra_args) |
|
|
|
v_delta = vt_tgt - vt_src |
|
x_tgt += (sigmas[i+1] - sigmas[i]) * v_delta |
|
|
|
if callback is not None: |
|
callback({'x': x_tgt, 'denoised': x_tgt, 'i': i+skip_steps, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]}) |
|
|
|
return x_tgt |
|
|
|
return flowedit_sample |
|
|
|
|
|
class LTXFlowEditSamplerNode: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"skip_steps": ("INT", {"default": 4, "min": 0, "max": 0xffffffffffffffff }), |
|
"refine_steps": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }), |
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }), |
|
}, "optional": { |
|
}} |
|
RETURN_TYPES = ("SAMPLER",) |
|
FUNCTION = "build" |
|
|
|
CATEGORY = "ltxtricks" |
|
|
|
def build(self, skip_steps, refine_steps, seed): |
|
sampler = KSAMPLER(get_flowedit_sample(skip_steps, refine_steps, seed)) |
|
return (sampler, ) |
|
|