import contextlib import os import comfy import comfy.model_management import comfy.utils import folder_paths from folder_paths import folder_names_and_paths import torch from torch import nn import torch.nn.functional as F import torchvision.transforms as TT from ..utils.resampler import Resampler model_path = folder_paths.models_dir folder_names_and_paths["ipadapter"] = ([os.path.join(model_path, "ipadapter")], ['.bin']) # attention_channels SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2 SD_XL_CHANNELS = [640] * 8 + [1280] * 40 + [1280] * 60 + [640] * 12 + [1280] * 20 def get_filename_list(path): return [f for f in os.listdir(path) if f.endswith('.bin')] class ImageProjModel(nn.Module): def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class To_KV(nn.Module): def __init__(self, cross_attention_dim): super().__init__() channels = SD_XL_CHANNELS if cross_attention_dim == 2048 else SD_V12_CHANNELS self.to_kvs = nn.ModuleList([nn.Linear(cross_attention_dim, channel, bias=False) for channel in channels]) def load_state_dict(self, state_dict): for i, key in enumerate(state_dict.keys()): self.to_kvs[i].weight.data = state_dict[key] def set_model_patch_replace(model, patch_kwargs, key): to = model.model_options["transformer_options"] if "patches_replace" not in to: to["patches_replace"] = {} if "attn2" not in to["patches_replace"]: to["patches_replace"]["attn2"] = {} if key not in to["patches_replace"]["attn2"]: patch = CrossAttentionPatch(**patch_kwargs) to["patches_replace"]["attn2"][key] = patch else: to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs) def attention(q, k, v, extra_options): if not hasattr(F, "multi_head_attention_forward"): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=extra_options["n_heads"]), (q, k, v)) sim = torch.einsum('b i d, b j d -> b i j', q, k) * (extra_options["dim_head"] ** -0.5) sim = F.softmax(sim, dim=-1) out = torch.einsum('b i j, b j d -> b i d', sim, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=extra_options["n_heads"]) else: b, _, _ = q.shape q, k, v = map( lambda t: t.view(b, -1, extra_options["n_heads"], extra_options["dim_head"]).transpose(1, 2), (q, k, v), ) out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(1, 2).reshape(b, -1, extra_options["n_heads"] * extra_options["dim_head"]) return out # TODO: still have to find the best way to add noise to the uncond image def image_add_noise(image, noise): image = image.permute([0,3,1,2]) torch.manual_seed(0) # use a fixed random for reproducible results transforms = TT.Compose([ TT.CenterCrop(min(image.shape[2], image.shape[3])), TT.Resize((224, 224), interpolation=TT.InterpolationMode.BICUBIC, antialias=True), TT.ElasticTransform(alpha=75.0, sigma=noise*3.5), # shuffle the image #TT.GaussianBlur(5, sigma=1.5), # by adding blur in the negative image we get sharper results #TT.RandomSolarize(threshold=.75, p=1), # add color aberration to prevent sending the same colors in the negative image TT.RandomVerticalFlip(p=1.0), # flip the image to change the geometry even more TT.RandomHorizontalFlip(p=1.0), ]) image = transforms(image.cpu()) image = image.permute([0,2,3,1]) image = image + ((0.25*(1-noise)+0.05) * torch.randn_like(image) ) # add random noise return image def zeroed_hidden_states(clip_vision): image = torch.zeros( [1, 3, 224, 224] ) inputs = clip_vision.processor(images=image, return_tensors="pt") comfy.model_management.load_model_gpu(clip_vision.patcher) pixel_values = torch.zeros_like(inputs['pixel_values']).to(clip_vision.load_device) if clip_vision.dtype != torch.float32: precision_scope = torch.autocast else: precision_scope = lambda a, b: contextlib.nullcontext(a) with precision_scope(comfy.model_management.get_autocast_device(clip_vision.load_device), torch.float32): outputs = clip_vision.model(pixel_values, output_hidden_states=True) # we only need the penultimate hidden states for k in outputs: t = outputs[k] if t is not None: if k == 'hidden_states': outputs["penultimate_hidden_states"] = t[-2].cpu() return outputs["penultimate_hidden_states"] class IPAdapter(nn.Module): def __init__(self, ipadapter_model, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.clip_embeddings_dim = clip_embeddings_dim self.cross_attention_dim = ipadapter_model["ip_adapter"]["1.to_k_ip.weight"].shape[1] self.clip_extra_context_tokens = clip_extra_context_tokens self.image_proj_model = self.init_proj() self.image_proj_model.load_state_dict(ipadapter_model["image_proj"]) self.ip_layers = To_KV(cross_attention_dim) self.ip_layers.load_state_dict(ipadapter_model["ip_adapter"]) def init_proj(self): image_proj_model = ImageProjModel( cross_attention_dim=self.cross_attention_dim, clip_embeddings_dim=self.clip_embeddings_dim, clip_extra_context_tokens=self.clip_extra_context_tokens ) return image_proj_model @torch.inference_mode() def get_image_embeds(self, clip_embed, clip_embed_zeroed): image_prompt_embeds = self.image_proj_model(clip_embed) uncond_image_prompt_embeds = self.image_proj_model(clip_embed_zeroed) return image_prompt_embeds, uncond_image_prompt_embeds class IPAdapterPlus(IPAdapter): def init_proj(self): image_proj_model = Resampler( dim=self.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=self.clip_extra_context_tokens, embedding_dim=self.clip_embeddings_dim, output_dim=self.cross_attention_dim, ff_mult=4 ) return image_proj_model class CrossAttentionPatch: # forward for patching def __init__(self, weight, ipadapter, dtype, number, cond, uncond, mask=None): self.weights = [weight] self.ipadapters = [ipadapter] self.conds = [cond] self.unconds = [uncond] self.dtype = dtype self.number = number self.masks = [mask] def set_new_condition(self, weight, ipadapter, cond, uncond, dtype, number, mask=None): self.weights.append(weight) self.ipadapters.append(ipadapter) self.conds.append(cond) self.unconds.append(uncond) self.masks.append(mask) self.dtype = dtype def __call__(self, n, context_attn2, value_attn2, extra_options): org_dtype = n.dtype with torch.autocast("cuda", dtype=self.dtype): q = n k = context_attn2 v = value_attn2 b, _, _ = q.shape out = attention(q, k, v, extra_options) for weight, cond, uncond, ipadapter, mask in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks): uncond_cond = torch.cat([uncond.repeat(b//2, 1, 1), cond.repeat(b//2, 1, 1)], dim=0) # k, v for ip_adapter ip_k = ipadapter.ip_layers.to_kvs[self.number*2](uncond_cond) ip_v = ipadapter.ip_layers.to_kvs[self.number*2+1](uncond_cond) ip_out = attention(q, ip_k, ip_v, extra_options) out = out + ip_out * weight return out.to(dtype=org_dtype) class IPAdapterModelLoader: @classmethod def INPUT_TYPES(s): return {"required": { "ipadapter_file": (folder_paths.get_filename_list("ipadapter"),), } } RETURN_TYPES = ("IPADAPTER",) FUNCTION = "load_ipadapter_model" CATEGORY = "Vyro/IPAdapter" def load_ipadapter_model(self, ipadapter_file): ckpt_path = folder_paths.get_full_path("ipadapter", ipadapter_file) model = comfy.utils.load_torch_file(ckpt_path, safe_load=True) keys = model.keys() if not "ip_adapter" in keys: raise Exception("invalid IPAdapter model {}".format(ckpt_path)) return (model,) class IPAdapterApply: def __init__(self) -> None: self.ipadapter = None @classmethod def INPUT_TYPES(s): return { "required": { "ipadapter": ("IPADAPTER", ), "clip_vision": ("CLIP_VISION",), "image": ("IMAGE",), "model": ("MODEL", ), "weight": ("FLOAT", { "default": 1.0, "min": -1, "max": 3, "step": 0.05 }), "noise": ("FLOAT", { "default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01 }), "weights_per_image": ("STRING", {"default": "1.0"}), }, } RETURN_TYPES = ("MODEL",) FUNCTION = "apply_ipadapter" CATEGORY = "Vyro/IPAdapter" uncond_hidden_states = None def apply_ipadapter(self, ipadapter, clip_vision, image, model, weight, noise, weights_per_image): try: self.dtype = model.model.diffusion_model.dtype self.device = comfy.model_management.get_torch_device() self.weight = weight self.is_plus = "latents" in ipadapter["image_proj"] cross_attention_dim = ipadapter["ip_adapter"]["1.to_k_ip.weight"].shape[1] self.is_sdxl = cross_attention_dim == 2048 work_model = model.clone() #split image into list of tensors images = torch.split(image, 1, dim=0) weights_per_image = [float(x) for x in weights_per_image.split(",")] for i in range(len(images)): image = images[i] weight = weights_per_image[i] clip_embed = clip_vision.encode_image(image) neg_image = image_add_noise(image, noise) if noise > 0 else None if self.is_plus: clip_extra_context_tokens = 16 clip_embed = clip_embed.last_hidden_state if noise > 0: clip_embed_zeroed = clip_vision.encode_image(neg_image).last_hidden_state else: clip_embed_zeroed = zeroed_hidden_states(clip_vision) else: clip_extra_context_tokens = 4 clip_embed = clip_embed.image_embeds if noise > 0: clip_embed_zeroed = clip_vision.encode_image(neg_image).image_embeds else: clip_embed_zeroed = torch.zeros_like(clip_embed) clip_embeddings_dim = clip_embed.shape[-1] if self.ipadapter is None: IPA = IPAdapterPlus if self.is_plus else IPAdapter self.ipadapter = IPA( ipadapter, cross_attention_dim=cross_attention_dim, clip_embeddings_dim=clip_embeddings_dim, clip_extra_context_tokens=clip_extra_context_tokens ) self.ipadapter.to(self.device, dtype=self.dtype) image_prompt_embeds, uncond_image_prompt_embeds = self.ipadapter.get_image_embeds(clip_embed.to(self.device, self.dtype), clip_embed_zeroed.to(self.device, self.dtype)) image_prompt_embeds = image_prompt_embeds.to(self.device, dtype=self.dtype) uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device, dtype=self.dtype) patch_kwargs = { "number": 0, "weight": weight * self.weight, "ipadapter": self.ipadapter, "dtype": self.dtype, "cond": image_prompt_embeds, "uncond": uncond_image_prompt_embeds, } if not self.is_sdxl: for id in [1,2,4,5,7,8]: # id of input_blocks that have cross attention set_model_patch_replace(work_model, patch_kwargs, ("input", id)) patch_kwargs["number"] += 1 for id in [3,4,5,6,7,8,9,10,11]: # id of output_blocks that have cross attention set_model_patch_replace(work_model, patch_kwargs, ("output", id)) patch_kwargs["number"] += 1 set_model_patch_replace(work_model, patch_kwargs, ("middle", 0)) else: for id in [4,5,7,8]: # id of input_blocks that have cross attention block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth for index in block_indices: set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) patch_kwargs["number"] += 1 for id in range(6): # id of output_blocks that have cross attention block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth for index in block_indices: set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) patch_kwargs["number"] += 1 for index in range(10): set_model_patch_replace(work_model, patch_kwargs, ("midlle", 0, index)) patch_kwargs["number"] += 1 return (work_model, ) except Exception as e: #trace stack import traceback print(f'[IPAdapterApply] {e}') # traceback.print_exception(e) return (model, ) NODE_CLASS_MAPPINGS = { "IPAdapterModelLoader": IPAdapterModelLoader, "IPAdapterApply": IPAdapterApply, } NODE_DISPLAY_NAME_MAPPINGS = { "IPAdapterModelLoader": "Load IPAdapter Model", "IPAdapterApply": "Apply IPAdapter", }