jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
from tqdm import trange
from comfy.samplers import KSAMPLER
def get_sample_forward(attn_bank, save_steps, single_layers, double_layers):
@torch.no_grad()
def sample_forward(model, x, sigmas, extra_args=None, callback=None, disable=None):
attn_bank.clear()
attn_bank['save_steps'] = save_steps
extra_args = {} if extra_args is None else extra_args
model_options = extra_args.get('model_options', {})
model_options = {**model_options}
transformer_options = model_options.get('transformer_options', {})
transformer_options = {**transformer_options}
model_options['transformer_options'] = transformer_options
extra_args['model_options'] = model_options
N = len(sigmas)-1
s_in = x.new_ones([x.shape[0]])
for i in trange(N, disable=disable):
sigma = sigmas[i]
sigma_next = sigmas[i+1]
if N-i-1 < save_steps:
attn_bank[N-i-1] = {
'first': {},
'mid': {}
}
transformer_options['rfedit'] = {
'step': N-i-1,
'process': 'forward' if N-i-1 < save_steps else None,
'pred': 'first',
'bank': attn_bank,
'single_layers': single_layers,
'double_layers': double_layers,
}
pred = model(x, s_in * sigma, **extra_args)
transformer_options['rfedit'] = {
'step': N-i-1,
'process': 'forward' if N-i-1 < save_steps else None,
'pred': 'mid',
'bank': attn_bank,
'single_layers': single_layers,
'double_layers': double_layers,
}
img_mid = x + (sigma_next- sigma) / 2 * pred
sigma_mid = (sigma + (sigma_next - sigma) / 2)
pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
first_order = (pred_mid - pred) / ((sigma_next - sigma) / 2)
x = x + (sigma_next - sigma) * pred + 0.5 * (sigma_next - sigma) ** 2 * first_order
if callback is not None:
callback({'x': x, 'denoised': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]})
return x
return sample_forward
def get_sample_reverse(attn_bank, inject_steps, single_layers, double_layers):
@torch.no_grad()
def sample_reverse(model, x, sigmas, extra_args=None, callback=None, disable=None):
if inject_steps > attn_bank['save_steps']:
raise ValueError(f'You must save at least as many steps as you want to inject. save_steps: {attn_bank["save_steps"]}, inject_steps: {inject_steps}')
extra_args = {} if extra_args is None else extra_args
model_options = extra_args.get('model_options', {})
model_options = {**model_options}
transformer_options = model_options.get('transformer_options', {})
transformer_options = {**transformer_options}
model_options['transformer_options'] = transformer_options
extra_args['model_options'] = model_options
N = len(sigmas)-1
s_in = x.new_ones([x.shape[0]])
for i in trange(N, disable=disable):
sigma = sigmas[i]
sigma_prev = sigmas[i+1]
transformer_options['rfedit'] = {
'step': i,
'process': 'reverse' if i < inject_steps else None,
'pred': 'first',
'bank': attn_bank,
'single_layers': single_layers,
'double_layers': double_layers,
}
pred = model(x, s_in * sigma, **extra_args)
transformer_options['rfedit'] = {
'step': i,
'process': 'reverse' if i < inject_steps else None,
'pred': 'mid',
'bank': attn_bank,
'single_layers': single_layers,
'double_layers': double_layers,
}
img_mid = x + (sigma_prev- sigma) / 2 * pred
sigma_mid = (sigma + (sigma_prev - sigma) / 2)
pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2)
x = x + (sigma_prev - sigma) * pred + 0.5 * (sigma_prev - sigma) ** 2 * first_order
if callback is not None:
callback({'x': x, 'denoised': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]})
return x
return sample_reverse
DEFAULT_SINGLE_LAYERS = {}
for i in range(38):
DEFAULT_SINGLE_LAYERS[f'{i}'] = i > 19
DEFAULT_DOUBLE_LAYERS = {}
for i in range(19):
DEFAULT_DOUBLE_LAYERS[f'{i}'] = False
class FlowEditForwardSamplerNode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"save_steps": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }),
},
"optional": {
"single_layers": ("SINGLE_LAYERS",),
"double_layers": ("DOUBLE_LAYERS",)
}}
RETURN_TYPES = ("SAMPLER","ATTN_INJ")
FUNCTION = "build"
CATEGORY = "fluxtapoz"
def build(self, save_steps, single_layers=DEFAULT_SINGLE_LAYERS, double_layers=DEFAULT_DOUBLE_LAYERS):
attn_bank = {}
sampler = KSAMPLER(get_sample_forward(attn_bank, save_steps, single_layers, double_layers))
return (sampler, attn_bank)
class FlowEditReverseSamplerNode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"attn_inj": ("ATTN_INJ",),
"latent_image": ("LATENT",),
"eta": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 100.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
"end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}),
},
"optional": {
}}
RETURN_TYPES = ("SAMPLER",)
FUNCTION = "build"
CATEGORY = "fluxtapoz"
def build(self, latent_image, eta, start_step, end_step):
sampler = KSAMPLER(get_sample_reverse(attn_inj, inject_steps, single_layers, double_layers))
return (sampler, )
def get_sample_reverse2(attn_bank, inject_steps, single_layers, double_layers):
@torch.no_grad()
def sample_reverse(model, x, sigmas, extra_args=None, callback=None, disable=None):
if inject_steps > attn_bank['save_steps']:
raise ValueError(f'You must save at least as many steps as you want to inject. save_steps: {attn_bank["save_steps"]}, inject_steps: {inject_steps}')
extra_args = {} if extra_args is None else extra_args
model_options = extra_args.get('model_options', {})
model_options = {**model_options}
transformer_options = model_options.get('transformer_options', {})
transformer_options = {**transformer_options}
model_options['transformer_options'] = transformer_options
extra_args['model_options'] = model_options
N = len(sigmas)-1
s_in = x.new_ones([x.shape[0]])
for i in trange(N, disable=disable):
sigma = sigmas[i]
sigma_prev = sigmas[i+1]
transformer_options['rfedit'] = {
'step': i,
'process': 'reverse' if i < inject_steps else None,
'pred': 'first',
'bank': attn_bank,
'single_layers': single_layers,
'double_layers': double_layers,
}
pred = model(x, s_in * sigma, **extra_args)
transformer_options['rfedit'] = {
'step': i,
'process': 'reverse' if i < inject_steps else None,
'pred': 'mid',
'bank': attn_bank,
'single_layers': single_layers,
'double_layers': double_layers,
}
img_mid = x + (sigma_prev- sigma) / 2 * pred
sigma_mid = (sigma + (sigma_prev - sigma) / 2)
pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2)
x = x + (sigma_prev - sigma) * pred + 0.5 * (sigma_prev - sigma) ** 2 * first_order
if callback is not None:
callback({'x': x, 'denoised': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]})
return x
return sample_reverse
class FlowEdit2ReverseSamplerNode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"attn_inj": ("ATTN_INJ",),
"inject_steps": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
},
"optional": {
"single_layers": ("SINGLE_LAYERS",),
"double_layers": ("DOUBLE_LAYERS",)
}}
RETURN_TYPES = ("SAMPLER",)
FUNCTION = "build"
CATEGORY = "ltxtricks"
def build(self, attn_inj, inject_steps, single_layers=DEFAULT_SINGLE_LAYERS, double_layers=DEFAULT_DOUBLE_LAYERS):
sampler = KSAMPLER(get_sample_reverse(attn_inj, inject_steps, single_layers, double_layers))
return (sampler, )
class PrepareAttnBankNode:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"latent": ("LATENT",),
"attn_inj": ("ATTN_INJ",),
}
}
RETURN_TYPES = ("LATENT", "ATTN_INJ")
FUNCTION = "prepare"
CATEGORY = "ltxtricks"
def prepare(self, latent, attn_inj):
# Hack to force order of operations in ComfyUI graph
return (latent, attn_inj)
class RFSingleBlocksOverrideNode:
@classmethod
def INPUT_TYPES(s):
layers = {}
for i in range(38):
layers[f'{i}'] = ("BOOLEAN", { "default": i > 19 })
return {"required": layers}
RETURN_TYPES = ("SINGLE_LAYERS",)
FUNCTION = "build"
CATEGORY = "ltxtricks"
def build(self, *args, **kwargs):
return (kwargs,)
class RFDoubleBlocksOverrideNode:
@classmethod
def INPUT_TYPES(s):
layers = {}
for i in range(19):
layers[f'{i}'] = ("BOOLEAN", { "default": False })
return {"required": layers}
RETURN_TYPES = ("DOUBLE_LAYERS",)
FUNCTION = "build"
CATEGORY = "ltxtricks"
def build(self, *args, **kwargs):
return (kwargs,)