jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import math
from einops import rearrange
import torch
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_patcher
import comfy.samplers
DEFAULT_PAG_LTX = { 'layers': set([14]) }
def gaussian_blur_2d(img, kernel_size, sigma):
height = img.shape[-1]
kernel_size = min(kernel_size, height - (height % 2 - 1))
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class LTXPerturbedAttentionNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
"rescale": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
"cfg": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
},
"optional": {
"attn_override": ("ATTN_OVERRIDE",),
# "attn_type": (["PAG", "SEG"],),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "ltxtricks/attn"
def patch(self, model, scale, rescale, cfg, attn_override=DEFAULT_PAG_LTX, attn_type="PAG"):
m = model.clone()
def pag_fn(q, k,v, heads, attn_precision=None, transformer_options=None):
return v
def seg_fn(q, k, v, heads, attn_precision=None, transformer_options=None):
_, sequence_length, _ = q.shape
b, c, f, h, w = transformer_options['original_shape']
q = rearrange(q, 'b (f h w) d -> b (f d) w h', h=h, w=w)
kernel_size = math.ceil(6 * scale) + 1 - math.ceil(6 * scale) % 2
q = gaussian_blur_2d(q, kernel_size, scale)
q = rearrange(q, 'b (f d) w h -> b (f h w) d', f=f)
return optimized_attention(q, k, v, heads, attn_precision=attn_precision)
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
uncond_pred = args["uncond_denoised"]
len_conds = 1 if args.get('uncond', None) is None else 2
cond = args["cond"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
if scale == 0:
if len_conds == 1:
return cond_pred
return uncond_pred + (cond_pred - uncond_pred)
attn_fn = pag_fn if attn_type == 'PAG' else seg_fn
for block_idx in attn_override['layers']:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, attn_fn, f"layer", "self_attn", int(block_idx))
(perturbed,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
# if len_conds == 1:
# output = cond_pred + scale * (cond_pred - pag)
# else:
# output = cond_pred + (scale-1.0) * (cond_pred - uncond_pred) + scale * (cond_pred - pag)
output = uncond_pred + cfg * (cond_pred - uncond_pred) \
+ scale * (cond_pred - perturbed)
if rescale > 0:
factor = cond_pred.std() / output.std()
factor = rescale * factor + (1 - rescale)
output = output * factor
return output
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m,)