File size: 4,102 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

import math

from einops import rearrange
import torch
import torch.nn.functional as F

from comfy.ldm.modules.attention import optimized_attention
import comfy.model_patcher
import comfy.samplers


DEFAULT_PAG_LTX = { 'layers': set([14]) }


def gaussian_blur_2d(img, kernel_size, sigma):
    height = img.shape[-1]
    kernel_size = min(kernel_size, height - (height % 2 - 1))
    ksize_half = (kernel_size - 1) * 0.5

    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)

    pdf = torch.exp(-0.5 * (x / sigma).pow(2))

    x_kernel = pdf / pdf.sum()
    x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)

    kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
    kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])

    padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]

    img = F.pad(img, padding, mode="reflect")
    img = F.conv2d(img, kernel2d, groups=img.shape[-3])

    return img



class LTXPerturbedAttentionNode:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "scale": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
                "rescale": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
                "cfg": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
            },
            "optional": {
                "attn_override": ("ATTN_OVERRIDE",),
                # "attn_type": (["PAG", "SEG"],),
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "ltxtricks/attn"

    def patch(self, model, scale, rescale, cfg, attn_override=DEFAULT_PAG_LTX, attn_type="PAG"):
        m = model.clone()

        def pag_fn(q, k,v, heads, attn_precision=None, transformer_options=None):
            return v
        
        def seg_fn(q, k, v, heads, attn_precision=None, transformer_options=None):
            _, sequence_length, _ = q.shape
            b, c, f, h, w = transformer_options['original_shape']

            q = rearrange(q, 'b (f h w) d -> b (f d) w h', h=h, w=w)
            kernel_size = math.ceil(6 * scale) + 1 - math.ceil(6 * scale) % 2
            q = gaussian_blur_2d(q, kernel_size, scale)
            q = rearrange(q, 'b (f d) w h -> b (f h w) d', f=f)
            return optimized_attention(q, k, v, heads, attn_precision=attn_precision)

        def post_cfg_function(args):
            model = args["model"]

            cond_pred = args["cond_denoised"]
            uncond_pred = args["uncond_denoised"]

            len_conds = 1 if args.get('uncond', None) is None else 2 
            
            cond = args["cond"]
            sigma = args["sigma"]
            model_options = args["model_options"].copy()
            x = args["input"]

            if scale == 0:
                if len_conds == 1:
                    return cond_pred
                return uncond_pred + (cond_pred - uncond_pred)
            
            attn_fn = pag_fn if attn_type == 'PAG' else seg_fn
            for block_idx in attn_override['layers']:
                model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, attn_fn, f"layer", "self_attn", int(block_idx))

            (perturbed,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)

            # if len_conds == 1:
            #     output = cond_pred + scale * (cond_pred - pag)
            # else:
            #     output = cond_pred + (scale-1.0) * (cond_pred - uncond_pred) + scale * (cond_pred - pag)


            output = uncond_pred + cfg * (cond_pred - uncond_pred) \
                + scale * (cond_pred - perturbed)
            if rescale > 0:
                factor = cond_pred.std() / output.std()
                factor = rescale * factor + (1 - rescale)
                output = output * factor

            return output


        m.set_model_sampler_post_cfg_function(post_cfg_function)

        return (m,)