File size: 3,983 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

import math

from einops import rearrange
import torch
import torch.nn.functional as F

import comfy.model_patcher
import comfy.samplers


DEFAULT_SEG_FLUX = { 'double': set([]), 'single': set(['0'])}


def gaussian_blur_2d(img, kernel_size, sigma):
    height = img.shape[-1]
    kernel_size = min(kernel_size, height - (height % 2 - 1))
    ksize_half = (kernel_size - 1) * 0.5

    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)

    pdf = torch.exp(-0.5 * (x / sigma).pow(2))

    x_kernel = pdf / pdf.sum()
    x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)

    kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
    kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])

    padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]

    img = F.pad(img, padding, mode="reflect")
    img = F.conv2d(img, kernel2d, groups=img.shape[-3])

    return img


class SEGAttentionNode:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
                "blur": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 999.0, "step": 0.01, "round": 0.01}),
                "inf_blur": ("BOOLEAN", {"default": False} ),
            },
            "optional": {
                "attn_override": ("ATTN_OVERRIDE",)
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "fluxtapoz/attn"

    def patch(self, model, scale, blur, inf_blur, attn_override=DEFAULT_SEG_FLUX):
        m = model.clone()

        def seg_attention(q, extra_options, txt_shape=256):
            _, _, sequence_length, _ = q.shape
            shape = extra_options['original_shape']
            patch_size = extra_options['patch_size']
            oh, ow = shape[-2:]
            h = oh // patch_size
            q_img = q[:, :, txt_shape:]
            num_heads = q_img.shape[1]
            q_img = rearrange(q_img, 'b a (h w) d -> b (a d) w h', h=h)
            if not inf_blur:
                kernel_size = math.ceil(6 * blur) + 1 - math.ceil(6 * blur) % 2
                q_img = gaussian_blur_2d(q_img, kernel_size, blur)
            else:
                q_img = q_img.mean(dim=(-2, -1), keepdim=True)
            q_img = rearrange(q_img, 'b (a d) w h -> b a (h w) d', a=num_heads)
            q[:,:, txt_shape:] = q_img
            return q

        def post_cfg_function(args):
            model = args["model"]

            cond_pred = args["cond_denoised"]
            uncond_pred = args["uncond_denoised"]

            if scale == 0 or blur == 0:
                return uncond_pred + (cond_pred - uncond_pred)
            
            cond = args["cond"]
            sigma = args["sigma"]
            model_options = args["model_options"].copy()
            x = args["input"]
            # Hack since comfy doesn't pass in conditionals and unconditionals to cfg_function
            # and doesn't pass in cond_scale to post_cfg_function
            len_conds = 1 if args.get('uncond', None) is None else 2 
            
            for block_idx in attn_override['double']:
                model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, seg_attention, f"double", "post_q", int(block_idx))

            for block_idx in attn_override['single']:
                model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, seg_attention, f"single", "post_q", int(block_idx))

            (seg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)

            if len_conds == 1:
                return cond_pred + scale * (cond_pred - seg)

            return cond_pred + (scale-1.0) * (cond_pred - uncond_pred) + scale * (cond_pred - seg)

        m.set_model_sampler_post_cfg_function(post_cfg_function)

        return (m,)