import torch import hashlib import os import logging import numpy as np import comfy.clip_vision import comfy.clip_model import comfy.model_management import comfy.utils import comfy.sd import folder_paths import torchvision.transforms.v2 as T from comfy.sd import CLIP from typing import Union from collections import Counter from torch import Tensor from transformers import CLIPImageProcessor from transformers.image_utils import PILImageResampling from .insightface_package import analyze_faces, insightface_loader from .model import PhotoMakerIDEncoder from .model_v2 import PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken from .utils import LoadImageCustom, load_image, prepImage, crop_image_pil, tokenize_with_trigger_word from .style_template import styles class PhotoMakerLoaderPlus: def __init__(self): self.loaded_lora = None self.loaded_clipvision = None @classmethod def INPUT_TYPES(s): return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), ), }, } RETURN_TYPES = ("PHOTOMAKER", ) FUNCTION = "load_photomaker_model" CATEGORY = "PhotoMaker" def load_photomaker_model(self, photomaker_model_name): self.load_data(None, None, photomaker_model_name, 0, 0)[0] if 'qformer_perceiver.token_norm.weight' in self.loaded_clipvision[1].keys(): photomaker_model = PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken() else: photomaker_model = PhotoMakerIDEncoder() photomaker_model.load_state_dict(self.loaded_clipvision[1]) photomaker_model.loader = self photomaker_model.filename = photomaker_model_name return (photomaker_model,) def load_data(self, model, clip, name, strength_model, strength_clip): model_lora, clip_lora = model, clip path = folder_paths.get_full_path("photomaker", name) lora = None if self.loaded_lora is not None: if self.loaded_lora[0] == path: lora = self.loaded_lora[1] else: temp = self.loaded_lora self.loaded_lora = None del temp temp = self.loaded_clipvision self.loaded_clipvision = None del temp if lora is None: data = comfy.utils.load_torch_file(path, safe_load=True) clipvision = data.get("id_encoder", None) lora = data.get("lora_weights", None) self.loaded_lora = (path, lora) self.loaded_clipvision = (path, clipvision) if model is not None and (strength_model > 0 or strength_clip > 0): model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) return (model_lora, clip_lora) class PhotoMakerLoraLoaderPlus: def __init__(self): self.loaded_lora = None self.loaded_clipvision = None @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "photomaker": ("PHOTOMAKER",), "lora_strength": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), }, } RETURN_TYPES = ("MODEL", ) FUNCTION = "load_photomaker_lora" CATEGORY = "PhotoMaker" def load_photomaker_lora(self, model, photomaker, lora_strength): return (photomaker.loader.load_data(model, None, photomaker.filename, lora_strength, 0)[0],) class PhotoMakerInsightFaceLoader: @classmethod def INPUT_TYPES(s): return { "required": { "provider": (["CPU", "CUDA", "ROCM"], ), }, } RETURN_TYPES = ("INSIGHTFACE",) FUNCTION = "load_insightface" CATEGORY = "PhotoMaker" def load_insightface(self, provider): return (insightface_loader(provider),) class PhotoMakerEncodePlus: @classmethod def INPUT_TYPES(s): return {"required": { "clip": ("CLIP",), "photomaker": ("PHOTOMAKER",), "image": ("IMAGE",), "trigger_word": ("STRING", {"default": "img"}), "text": ("STRING", {"multiline": True, "default": "photograph of a man img", "dynamicPrompts": True}), }, "optional": { "insightface_opt": ("INSIGHTFACE",), }, } RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_photomaker" CATEGORY = "PhotoMaker" @torch.no_grad() def apply_photomaker(self, clip: CLIP, photomaker: Union[PhotoMakerIDEncoder, PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken], image: Tensor, trigger_word: str, text: str, insightface_opt=None): if (num_images := len(image)) == 0: raise ValueError("No image provided or found.") trigger_word=trigger_word.strip() tokens = clip.tokenize(text) class_tokens_mask = {} out_tokens = {} num_tokens = getattr(photomaker, 'num_tokens', 1) for key, val in tokens.items(): clip_tokenizer = getattr(clip.tokenizer, f'clip_{key}', clip.tokenizer) img_token = clip_tokenizer.tokenizer(trigger_word, truncation=False, add_special_tokens=False)["input_ids"][0] # only get the first token _tokens = torch.tensor([[tpy[0] for tpy in tpy_] for tpy_ in val ] , dtype=torch.int32) _weights = torch.tensor([[tpy[1] for tpy in tpy_] for tpy_ in val] , dtype=torch.float32) start_token = clip_tokenizer.start_token end_token = clip_tokenizer.end_token pad_token = clip_tokenizer.pad_token tokens_mask = tokenize_with_trigger_word(_tokens, _weights, num_images, num_tokens, img_token,start_token, end_token, pad_token, return_mask=True)[0] tokens_new, weights_new, num_trigger_tokens_processed = tokenize_with_trigger_word(_tokens, _weights, num_images, num_tokens, img_token,start_token, end_token, pad_token) token_weight_pairs = [[(tt,ww) for tt,ww in zip(x.tolist(), y.tolist())] for x,y in zip(tokens_new, weights_new)] mask = (tokens_mask == -1).tolist() class_tokens_mask[key] = mask out_tokens[key] = token_weight_pairs cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) if num_trigger_tokens_processed == 0 or not trigger_word: logging.warning("\033[33mWarning:\033[0m No trigger token found.") return ([[cond, {"pooled_output": pooled}]],) prompt_embeds = cond device_orig = prompt_embeds.device first_key = next(iter(tokens.keys())) class_tokens_mask = class_tokens_mask[first_key] if num_trigger_tokens_processed > 1: image = image.repeat([num_trigger_tokens_processed] + [1] * (len(image.shape) - 1)) photomaker = photomaker.to(device=photomaker.load_device) image.clamp_(0.0, 1.0) input_id_images = image _, h, w, _ = image.shape do_resize = (h, w) != (224, 224) image_bak = image try: if do_resize: clip_preprocess = CLIPImageProcessor(resample=PILImageResampling.LANCZOS, do_normalize=False, do_rescale=False, do_convert_rgb=False) image = clip_preprocess(image, return_tensors="pt").pixel_values.movedim(1,-1) except RuntimeError as e: image = image_bak pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() if photomaker.__class__.__name__ == 'PhotoMakerIDEncoder': cond = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) else: if insightface_opt is None: raise ValueError(f"InsightFace is required for PhotoMaker V2") face_detector = insightface_opt if not hasattr(face_detector, 'get_'): face_detector.get_ = face_detector.get def get(self, img, max_num=0, det_size=(640, 640)): if det_size is not None: self.det_model.input_size = det_size return self.get_(img, max_num) face_detector.get = get.__get__(face_detector, face_detector.__class__) id_embed_list = [] ToPILImage = T.ToPILImage() def tensor_to_pil_np(_img): nonlocal ToPILImage img_pil = ToPILImage(_img.movedim(-1,0)) if img_pil.mode != 'RGB': img_pil = img_pil.convert('RGB') return np.asarray(img_pil) for img in input_id_images: faces = analyze_faces(face_detector, tensor_to_pil_np(img)) if len(faces) > 0: id_embed_list.append(torch.from_numpy((faces[0]['embedding']))) if len(id_embed_list) == 0: raise ValueError(f"No face detected in input image pool") id_embeds = torch.stack(id_embed_list).to(device=photomaker.load_device) class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0) cond = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), class_tokens_mask=class_tokens_mask, id_embeds=id_embeds) cond = cond.to(device=device_orig) return ([[cond, {"pooled_output": pooled}]],) class PhotoMakerStyles: @classmethod def INPUT_TYPES(s): return {"required": { "style_name": (list(styles.keys()), {"default": "Photographic (Default)"}), }, "optional": { "positive": ("STRING", {"multiline": True, "forceInput": True, "dynamicPrompts": True}), "negative": ("STRING", {"multiline": True, "forceInput": True, "dynamicPrompts": True}), }, } RETURN_TYPES = ("STRING","STRING",) RETURN_NAMES = ("POSITIVE","NEGATIVE",) FUNCTION = "apply_photomaker_style" CATEGORY = "PhotoMaker" def apply_photomaker_style(self, style_name, positive: str = '', negative: str = ''): p, n = styles.get(style_name, "Photographic (Default)") return p.replace("{prompt}", positive), n + ' ' + negative class PrepImagesForClipVisionFromPath: def __init__(self) -> None: self.image_loader = LoadImageCustom() self.load_device = comfy.model_management.text_encoder_device() self.offload_device = comfy.model_management.text_encoder_offload_device() @classmethod def INPUT_TYPES(s): return {"required": { "path": ("STRING", {"multiline": False}), "interpolation": (["nearest", "bilinear", "box", "bicubic", "lanczos", "hamming"], {"default": "lanczos"}), "crop_position": (["top", "bottom", "left", "right", "center", "pad"], {"default": "center"}), }, } @classmethod def IS_CHANGED(s, path:str, interpolation, crop_position): image_path_list = s.get_images_paths(path) hashes = [] for image_path in image_path_list: if not (path.startswith("http://") or path.startswith("https://")): m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) hashes.append(m.digest().hex()) return Counter(hashes) @classmethod def VALIDATE_INPUTS(s, path:str, interpolation, crop_position): image_path_list = s.get_images_paths(path) if len(image_path_list) == 0: return "No image provided or found." return True RETURN_TYPES = ("IMAGE",) FUNCTION = "prep_images_for_clip_vision_from_path" CATEGORY = "image" @classmethod def get_images_paths(self, path:str): image_path_list = [] path = path.strip() if path: image_path_list = [path] if not (path.startswith("http://") or path.startswith("https://")) and os.path.isdir(path): image_basename_list = os.listdir(path) image_path_list = [ os.path.join(path, basename) for basename in image_basename_list if not basename.startswith('.') and basename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp', '.gif')) ] return image_path_list def prep_images_for_clip_vision_from_path(self, path:str, interpolation:str, crop_position,): image_path_list = self.get_images_paths(path) if len(image_path_list) == 0: raise ValueError("No image provided or found.") interpolation=interpolation.upper() size = (224, 224) try: input_id_images = [img if (img:=load_image(image_path)).size == size else crop_image_pil(img, crop_position) for image_path in image_path_list] do_resize = not all(img.size == size for img in input_id_images) resample = getattr(PILImageResampling, interpolation) clip_preprocess = CLIPImageProcessor(resample=resample, do_normalize=False, do_resize=do_resize) id_pixel_values = clip_preprocess(input_id_images, return_tensors="pt").pixel_values.movedim(1,-1) except TypeError as err: logging.warning('[PhotoMaker]:', err) logging.warning('[PhotoMaker]: You may need to update transformers.') input_id_images = [self.image_loader.load_image(image_path)[0] for image_path in image_path_list] do_resize = not all(img.shape[-3:-3+2] == size for img in input_id_images) if do_resize: id_pixel_values = torch.cat([prepImage(img, interpolation=interpolation, crop_position=crop_position) for img in input_id_images]) else: id_pixel_values = torch.cat(input_id_images) return (id_pixel_values,) NODE_CLASS_MAPPINGS = { "PhotoMakerLoaderPlus": PhotoMakerLoaderPlus, "PhotoMakerEncodePlus": PhotoMakerEncodePlus, "PhotoMakerStyles": PhotoMakerStyles, "PhotoMakerLoraLoaderPlus": PhotoMakerLoraLoaderPlus, "PrepImagesForClipVisionFromPath": PrepImagesForClipVisionFromPath, "PhotoMakerInsightFaceLoader": PhotoMakerInsightFaceLoader, } NODE_DISPLAY_NAME_MAPPINGS = { "PhotoMakerLoaderPlus": "PhotoMaker Loader Plus", "PhotoMakerEncodePlus": "PhotoMaker Encode Plus", "PhotoMakerStyles": "Apply PhotoMaker Style", "PhotoMakerLoraLoaderPlus": "PhotoMaker LoRA Loader Plus", "PrepImagesForClipVisionFromPath": "Prepare Images For CLIP Vision From Path", "PhotoMakerInsightFaceLoader": "PhotoMaker InsightFace Loader", }