File size: 15,148 Bytes
6fecfbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
import contextlib
import os

import comfy
import comfy.model_management
import comfy.utils
import folder_paths
from folder_paths import folder_names_and_paths
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as TT

from ..utils.resampler import Resampler


model_path = folder_paths.models_dir
folder_names_and_paths["ipadapter"] = ([os.path.join(model_path, "ipadapter")], ['.bin'])

# attention_channels
SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2
SD_XL_CHANNELS = [640] * 8 + [1280] * 40 + [1280] * 60 + [640] * 12 + [1280] * 20

def get_filename_list(path):
    return [f for f in os.listdir(path) if f.endswith('.bin')]

class ImageProjModel(nn.Module):
    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()
        
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = nn.LayerNorm(cross_attention_dim)
        
    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

class To_KV(nn.Module):
    def __init__(self, cross_attention_dim):
        super().__init__()

        channels = SD_XL_CHANNELS if cross_attention_dim == 2048 else SD_V12_CHANNELS
        self.to_kvs = nn.ModuleList([nn.Linear(cross_attention_dim, channel, bias=False) for channel in channels])
        
    def load_state_dict(self, state_dict):
        for i, key in enumerate(state_dict.keys()):
            self.to_kvs[i].weight.data = state_dict[key]

def set_model_patch_replace(model, patch_kwargs, key):
    to = model.model_options["transformer_options"]
    if "patches_replace" not in to:
        to["patches_replace"] = {}
    if "attn2" not in to["patches_replace"]:
        to["patches_replace"]["attn2"] = {}
    if key not in to["patches_replace"]["attn2"]:
        patch = CrossAttentionPatch(**patch_kwargs)
        to["patches_replace"]["attn2"][key] = patch
    else:
        to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs)

def attention(q, k, v, extra_options):
    if not hasattr(F, "multi_head_attention_forward"):
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=extra_options["n_heads"]), (q, k, v))
        sim = torch.einsum('b i d, b j d -> b i j', q, k) * (extra_options["dim_head"] ** -0.5)
        sim = F.softmax(sim, dim=-1)
        out = torch.einsum('b i j, b j d -> b i d', sim, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=extra_options["n_heads"])
    else:
        b, _, _ = q.shape
        q, k, v = map(
            lambda t: t.view(b, -1, extra_options["n_heads"], extra_options["dim_head"]).transpose(1, 2),
            (q, k, v),
        )
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
        out = out.transpose(1, 2).reshape(b, -1, extra_options["n_heads"] * extra_options["dim_head"])
    return out

# TODO: still have to find the best way to add noise to the uncond image
def image_add_noise(image, noise):
    image = image.permute([0,3,1,2])
    torch.manual_seed(0) # use a fixed random for reproducible results
    transforms = TT.Compose([
        TT.CenterCrop(min(image.shape[2], image.shape[3])),
        TT.Resize((224, 224), interpolation=TT.InterpolationMode.BICUBIC, antialias=True),
        TT.ElasticTransform(alpha=75.0, sigma=noise*3.5), # shuffle the image
        #TT.GaussianBlur(5, sigma=1.5),              # by adding blur in the negative image we get sharper results
        #TT.RandomSolarize(threshold=.75, p=1),       # add color aberration to prevent sending the same colors in the negative image
        TT.RandomVerticalFlip(p=1.0),                # flip the image to change the geometry even more
        TT.RandomHorizontalFlip(p=1.0),
    ])
    image = transforms(image.cpu())
    image = image.permute([0,2,3,1])
    image = image + ((0.25*(1-noise)+0.05) * torch.randn_like(image) )   # add random noise
    return image

def zeroed_hidden_states(clip_vision):
    image = torch.zeros( [1, 3, 224, 224] )
    inputs = clip_vision.processor(images=image, return_tensors="pt")
    comfy.model_management.load_model_gpu(clip_vision.patcher)
    pixel_values = torch.zeros_like(inputs['pixel_values']).to(clip_vision.load_device)

    if clip_vision.dtype != torch.float32:
        precision_scope = torch.autocast
    else:
        precision_scope = lambda a, b: contextlib.nullcontext(a)

    with precision_scope(comfy.model_management.get_autocast_device(clip_vision.load_device), torch.float32):
        outputs = clip_vision.model(pixel_values, output_hidden_states=True)

    # we only need the penultimate hidden states
    for k in outputs:
        t = outputs[k]
        if t is not None:
            if k == 'hidden_states':
                outputs["penultimate_hidden_states"] = t[-2].cpu()

    return outputs["penultimate_hidden_states"]

class IPAdapter(nn.Module):
    def __init__(self, ipadapter_model, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.clip_embeddings_dim = clip_embeddings_dim
        self.cross_attention_dim = ipadapter_model["ip_adapter"]["1.to_k_ip.weight"].shape[1]
        self.clip_extra_context_tokens = clip_extra_context_tokens

        self.image_proj_model = self.init_proj()

        self.image_proj_model.load_state_dict(ipadapter_model["image_proj"])
        self.ip_layers = To_KV(cross_attention_dim)
        self.ip_layers.load_state_dict(ipadapter_model["ip_adapter"])

    def init_proj(self):
        image_proj_model = ImageProjModel(
            cross_attention_dim=self.cross_attention_dim,
            clip_embeddings_dim=self.clip_embeddings_dim,
            clip_extra_context_tokens=self.clip_extra_context_tokens
        )
        return image_proj_model

    @torch.inference_mode()
    def get_image_embeds(self, clip_embed, clip_embed_zeroed):
        image_prompt_embeds = self.image_proj_model(clip_embed)
        uncond_image_prompt_embeds = self.image_proj_model(clip_embed_zeroed)
        return image_prompt_embeds, uncond_image_prompt_embeds

class IPAdapterPlus(IPAdapter):
    def init_proj(self):
        image_proj_model = Resampler(
            dim=self.cross_attention_dim,
            depth=4,
            dim_head=64,
            heads=12,
            num_queries=self.clip_extra_context_tokens,
            embedding_dim=self.clip_embeddings_dim,
            output_dim=self.cross_attention_dim,
            ff_mult=4
        )
        return image_proj_model

class CrossAttentionPatch:
    # forward for patching
    def __init__(self, weight, ipadapter, dtype, number, cond, uncond, mask=None):
        self.weights = [weight]
        self.ipadapters = [ipadapter]
        self.conds = [cond]
        self.unconds = [uncond]
        self.dtype = dtype
        self.number = number
        self.masks = [mask]
    
    def set_new_condition(self, weight, ipadapter, cond, uncond, dtype, number, mask=None):
        self.weights.append(weight)
        self.ipadapters.append(ipadapter)
        self.conds.append(cond)
        self.unconds.append(uncond)
        self.masks.append(mask)
        self.dtype = dtype

    def __call__(self, n, context_attn2, value_attn2, extra_options):
        org_dtype = n.dtype
        with torch.autocast("cuda", dtype=self.dtype):
            q = n
            k = context_attn2
            v = value_attn2
            b, _, _ = q.shape

            out = attention(q, k, v, extra_options)

            for weight, cond, uncond, ipadapter, mask in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks):
                uncond_cond = torch.cat([uncond.repeat(b//2, 1, 1), cond.repeat(b//2, 1, 1)], dim=0)

                # k, v for ip_adapter
                ip_k = ipadapter.ip_layers.to_kvs[self.number*2](uncond_cond)
                ip_v = ipadapter.ip_layers.to_kvs[self.number*2+1](uncond_cond)

                ip_out = attention(q, ip_k, ip_v, extra_options)

                out = out + ip_out * weight

        return out.to(dtype=org_dtype)

class IPAdapterModelLoader:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
            "ipadapter_file": (folder_paths.get_filename_list("ipadapter"),),
            }
        }

    RETURN_TYPES = ("IPADAPTER",)
    FUNCTION = "load_ipadapter_model"

    CATEGORY = "Vyro/IPAdapter"

    def load_ipadapter_model(self, ipadapter_file):
        ckpt_path = folder_paths.get_full_path("ipadapter", ipadapter_file)

        model = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
        keys = model.keys()

        if not "ip_adapter" in keys:
            raise Exception("invalid IPAdapter model {}".format(ckpt_path))

        return (model,)

class IPAdapterApply:
    def __init__(self) -> None:
        self.ipadapter = None
        
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "ipadapter": ("IPADAPTER", ),
                "clip_vision": ("CLIP_VISION",),
                "image": ("IMAGE",),
                "model": ("MODEL", ),
                "weight": ("FLOAT", { "default": 1.0, "min": -1, "max": 3, "step": 0.05 }),
                "noise": ("FLOAT", { "default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01 }),
                "weights_per_image": ("STRING", {"default": "1.0"}),
            },
        }


    RETURN_TYPES = ("MODEL",)
    FUNCTION = "apply_ipadapter"
    CATEGORY = "Vyro/IPAdapter"

    uncond_hidden_states = None

    def apply_ipadapter(self, ipadapter, clip_vision, image, model, weight, noise, weights_per_image):
        try:
            self.dtype = model.model.diffusion_model.dtype
            self.device = comfy.model_management.get_torch_device()
            self.weight = weight
            self.is_plus = "latents" in ipadapter["image_proj"]

            cross_attention_dim = ipadapter["ip_adapter"]["1.to_k_ip.weight"].shape[1]
            self.is_sdxl = cross_attention_dim == 2048
            work_model = model.clone()

            
            #split image into list of tensors
            images = torch.split(image, 1, dim=0)
            
            weights_per_image = [float(x) for x in weights_per_image.split(",")]
                
            for i in range(len(images)):
                image = images[i]
                weight = weights_per_image[i]
                clip_embed = clip_vision.encode_image(image)
                
                neg_image = image_add_noise(image, noise) if noise > 0 else None
                
                if self.is_plus:
                    clip_extra_context_tokens = 16
                    clip_embed = clip_embed.last_hidden_state
                    if noise > 0:
                        clip_embed_zeroed = clip_vision.encode_image(neg_image).last_hidden_state
                    else:
                        clip_embed_zeroed = zeroed_hidden_states(clip_vision)
                else:
                    clip_extra_context_tokens = 4
                    clip_embed = clip_embed.image_embeds
                    if noise > 0:
                        clip_embed_zeroed = clip_vision.encode_image(neg_image).image_embeds
                    else:
                        clip_embed_zeroed = torch.zeros_like(clip_embed)

                clip_embeddings_dim = clip_embed.shape[-1]

                if self.ipadapter is None:
                    IPA = IPAdapterPlus if self.is_plus else IPAdapter
                    self.ipadapter = IPA(
                        ipadapter,
                        cross_attention_dim=cross_attention_dim,
                        clip_embeddings_dim=clip_embeddings_dim,
                        clip_extra_context_tokens=clip_extra_context_tokens
                    )
                
                self.ipadapter.to(self.device, dtype=self.dtype)

                image_prompt_embeds, uncond_image_prompt_embeds = self.ipadapter.get_image_embeds(clip_embed.to(self.device, self.dtype), clip_embed_zeroed.to(self.device, self.dtype))
                image_prompt_embeds = image_prompt_embeds.to(self.device, dtype=self.dtype)
                uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device, dtype=self.dtype)


                patch_kwargs = {
                    "number": 0,
                    "weight": weight * self.weight,
                    "ipadapter": self.ipadapter,
                    "dtype": self.dtype,
                    "cond": image_prompt_embeds,
                    "uncond": uncond_image_prompt_embeds,
                }

                if not self.is_sdxl:
                    for id in [1,2,4,5,7,8]: # id of input_blocks that have cross attention
                        set_model_patch_replace(work_model, patch_kwargs, ("input", id))
                        patch_kwargs["number"] += 1
                    for id in [3,4,5,6,7,8,9,10,11]: # id of output_blocks that have cross attention
                        set_model_patch_replace(work_model, patch_kwargs, ("output", id))
                        patch_kwargs["number"] += 1
                    set_model_patch_replace(work_model, patch_kwargs, ("middle", 0))
                else:
                    for id in [4,5,7,8]: # id of input_blocks that have cross attention
                        block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth
                        for index in block_indices:
                            set_model_patch_replace(work_model, patch_kwargs, ("input", id, index))
                            patch_kwargs["number"] += 1
                    for id in range(6): # id of output_blocks that have cross attention
                        block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth
                        for index in block_indices:
                            set_model_patch_replace(work_model, patch_kwargs, ("output", id, index))
                            patch_kwargs["number"] += 1
                    for index in range(10):
                        set_model_patch_replace(work_model, patch_kwargs, ("midlle", 0, index))
                        patch_kwargs["number"] += 1

            return (work_model, )
        except Exception as e:
            #trace stack
            import traceback
            print(f'[IPAdapterApply] {e}')
            # traceback.print_exception(e)
            return (model, )

NODE_CLASS_MAPPINGS = {
    "IPAdapterModelLoader": IPAdapterModelLoader,
    "IPAdapterApply": IPAdapterApply,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "IPAdapterModelLoader": "Load IPAdapter Model",
    "IPAdapterApply": "Apply IPAdapter",
}