File size: 3,540 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

import math

from einops import rearrange
import torch
import torch.nn.functional as F

from comfy.ldm.modules.attention import optimized_attention
import comfy.model_patcher
import comfy.samplers


DEFAULT_PAG_FLUX = { 'double': set([]), 'single': set(['0'])}


class PAGAttentionNode:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
            },
            "optional": {
                "attn_override": ("ATTN_OVERRIDE",),
                "rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "fluxtapoz/attn"

    def patch(self, model, scale, attn_override=DEFAULT_PAG_FLUX, rescale=0):
        m = model.clone()

        def pag_mask(q, extra_options, txt_size=256):
            # From diffusers implementation
            identity_block_size = q.shape[1] - txt_size
            # create a full mask with all entries set to 0
            seq_len = q.size(2)
            full_mask = torch.zeros((seq_len, seq_len), device=q.device, dtype=q.dtype)

            # set the attention value between image patches to -inf
            full_mask[:identity_block_size, :identity_block_size] = float("-inf")

            # set the diagonal of the attention value between image patches to 0
            full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)

            # expand the mask to match the attention weights shape
            full_mask = full_mask.unsqueeze(0).unsqueeze(0)  # Add batch and num_heads dimensions

            return full_mask

        def post_cfg_function(args):
            model = args["model"]

            cond_pred = args["cond_denoised"]
            uncond_pred = args["uncond_denoised"]

            len_conds = 1 if args.get('uncond', None) is None else 2 
            if scale == 0:
                if len_conds == 1:
                    return cond_pred
                return uncond_pred + (cond_pred - uncond_pred)
            
            cond = args["cond"]
            sigma = args["sigma"]
            model_options = args["model_options"].copy()
            x = args["input"]
            # Hack since comfy doesn't pass in conditionals and unconditionals to cfg_function
            # and doesn't pass in cond_scale to post_cfg_function
            
            for block_idx in attn_override['double']:
                model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, pag_mask, f"double", "mask_fn", int(block_idx))

            for block_idx in attn_override['single']:
                model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, pag_mask, f"single", "mask_fn", int(block_idx))

            (pag,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)

            if len_conds == 1:
                output = cond_pred + scale * (cond_pred - pag)
            else:
                output = cond_pred + (scale-1.0) * (cond_pred - uncond_pred) + scale * (cond_pred - pag)

            if rescale > 0:
                factor = cond_pred.std() / output.std()
                factor = rescale * factor + (1 - rescale)
                output = output * factor

            return output

        m.set_model_sampler_post_cfg_function(post_cfg_function)

        return (m,)