|
import torch |
|
import hashlib |
|
import os |
|
import logging |
|
import numpy as np |
|
import comfy.clip_vision |
|
import comfy.clip_model |
|
import comfy.model_management |
|
import comfy.utils |
|
import comfy.sd |
|
import folder_paths |
|
import torchvision.transforms.v2 as T |
|
from comfy.sd import CLIP |
|
from typing import Union |
|
from collections import Counter |
|
from torch import Tensor |
|
from transformers import CLIPImageProcessor |
|
from transformers.image_utils import PILImageResampling |
|
from .insightface_package import analyze_faces, insightface_loader |
|
from .model import PhotoMakerIDEncoder |
|
from .model_v2 import PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken |
|
from .utils import LoadImageCustom, load_image, prepImage, crop_image_pil, tokenize_with_trigger_word |
|
from .style_template import styles |
|
|
|
class PhotoMakerLoaderPlus: |
|
def __init__(self): |
|
self.loaded_lora = None |
|
self.loaded_clipvision = None |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"photomaker_model_name": (folder_paths.get_filename_list("photomaker"), ), |
|
}, |
|
} |
|
RETURN_TYPES = ("PHOTOMAKER", ) |
|
FUNCTION = "load_photomaker_model" |
|
|
|
CATEGORY = "PhotoMaker" |
|
|
|
def load_photomaker_model(self, photomaker_model_name): |
|
self.load_data(None, None, photomaker_model_name, 0, 0)[0] |
|
if 'qformer_perceiver.token_norm.weight' in self.loaded_clipvision[1].keys(): |
|
photomaker_model = PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken() |
|
else: |
|
photomaker_model = PhotoMakerIDEncoder() |
|
photomaker_model.load_state_dict(self.loaded_clipvision[1]) |
|
photomaker_model.loader = self |
|
photomaker_model.filename = photomaker_model_name |
|
return (photomaker_model,) |
|
|
|
def load_data(self, model, clip, name, strength_model, strength_clip): |
|
model_lora, clip_lora = model, clip |
|
|
|
path = folder_paths.get_full_path("photomaker", name) |
|
lora = None |
|
if self.loaded_lora is not None: |
|
if self.loaded_lora[0] == path: |
|
lora = self.loaded_lora[1] |
|
else: |
|
temp = self.loaded_lora |
|
self.loaded_lora = None |
|
del temp |
|
temp = self.loaded_clipvision |
|
self.loaded_clipvision = None |
|
del temp |
|
|
|
if lora is None: |
|
data = comfy.utils.load_torch_file(path, safe_load=True) |
|
clipvision = data.get("id_encoder", None) |
|
lora = data.get("lora_weights", None) |
|
self.loaded_lora = (path, lora) |
|
self.loaded_clipvision = (path, clipvision) |
|
|
|
if model is not None and (strength_model > 0 or strength_clip > 0): |
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) |
|
return (model_lora, clip_lora) |
|
|
|
class PhotoMakerLoraLoaderPlus: |
|
def __init__(self): |
|
self.loaded_lora = None |
|
self.loaded_clipvision = None |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"model": ("MODEL",), |
|
"photomaker": ("PHOTOMAKER",), |
|
"lora_strength": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), |
|
}, |
|
} |
|
RETURN_TYPES = ("MODEL", ) |
|
FUNCTION = "load_photomaker_lora" |
|
|
|
CATEGORY = "PhotoMaker" |
|
|
|
def load_photomaker_lora(self, model, photomaker, lora_strength): |
|
return (photomaker.loader.load_data(model, None, photomaker.filename, lora_strength, 0)[0],) |
|
|
|
class PhotoMakerInsightFaceLoader: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"provider": (["CPU", "CUDA", "ROCM"], ), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("INSIGHTFACE",) |
|
FUNCTION = "load_insightface" |
|
CATEGORY = "PhotoMaker" |
|
|
|
def load_insightface(self, provider): |
|
return (insightface_loader(provider),) |
|
|
|
class PhotoMakerEncodePlus: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"clip": ("CLIP",), |
|
"photomaker": ("PHOTOMAKER",), |
|
"image": ("IMAGE",), |
|
"trigger_word": ("STRING", {"default": "img"}), |
|
"text": ("STRING", {"multiline": True, "default": "photograph of a man img", "dynamicPrompts": True}), |
|
}, |
|
"optional": { |
|
"insightface_opt": ("INSIGHTFACE",), |
|
}, |
|
} |
|
RETURN_TYPES = ("CONDITIONING",) |
|
FUNCTION = "apply_photomaker" |
|
|
|
CATEGORY = "PhotoMaker" |
|
|
|
@torch.no_grad() |
|
def apply_photomaker(self, clip: CLIP, photomaker: Union[PhotoMakerIDEncoder, PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken], image: Tensor, trigger_word: str, text: str, insightface_opt=None): |
|
if (num_images := len(image)) == 0: |
|
raise ValueError("No image provided or found.") |
|
trigger_word=trigger_word.strip() |
|
tokens = clip.tokenize(text) |
|
class_tokens_mask = {} |
|
out_tokens = {} |
|
num_tokens = getattr(photomaker, 'num_tokens', 1) |
|
for key, val in tokens.items(): |
|
clip_tokenizer = getattr(clip.tokenizer, f'clip_{key}', clip.tokenizer) |
|
img_token = clip_tokenizer.tokenizer(trigger_word, truncation=False, add_special_tokens=False)["input_ids"][0] |
|
_tokens = torch.tensor([[tpy[0] for tpy in tpy_] for tpy_ in val ] , dtype=torch.int32) |
|
_weights = torch.tensor([[tpy[1] for tpy in tpy_] for tpy_ in val] , dtype=torch.float32) |
|
start_token = clip_tokenizer.start_token |
|
end_token = clip_tokenizer.end_token |
|
pad_token = clip_tokenizer.pad_token |
|
|
|
tokens_mask = tokenize_with_trigger_word(_tokens, _weights, num_images, num_tokens, img_token,start_token, end_token, pad_token, return_mask=True)[0] |
|
tokens_new, weights_new, num_trigger_tokens_processed = tokenize_with_trigger_word(_tokens, _weights, num_images, num_tokens, img_token,start_token, end_token, pad_token) |
|
token_weight_pairs = [[(tt,ww) for tt,ww in zip(x.tolist(), y.tolist())] for x,y in zip(tokens_new, weights_new)] |
|
mask = (tokens_mask == -1).tolist() |
|
class_tokens_mask[key] = mask |
|
out_tokens[key] = token_weight_pairs |
|
|
|
cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) |
|
if num_trigger_tokens_processed == 0 or not trigger_word: |
|
logging.warning("\033[33mWarning:\033[0m No trigger token found.") |
|
return ([[cond, {"pooled_output": pooled}]],) |
|
|
|
prompt_embeds = cond |
|
device_orig = prompt_embeds.device |
|
first_key = next(iter(tokens.keys())) |
|
class_tokens_mask = class_tokens_mask[first_key] |
|
if num_trigger_tokens_processed > 1: |
|
image = image.repeat([num_trigger_tokens_processed] + [1] * (len(image.shape) - 1)) |
|
|
|
photomaker = photomaker.to(device=photomaker.load_device) |
|
|
|
image.clamp_(0.0, 1.0) |
|
input_id_images = image |
|
_, h, w, _ = image.shape |
|
do_resize = (h, w) != (224, 224) |
|
image_bak = image |
|
try: |
|
if do_resize: |
|
clip_preprocess = CLIPImageProcessor(resample=PILImageResampling.LANCZOS, do_normalize=False, do_rescale=False, do_convert_rgb=False) |
|
image = clip_preprocess(image, return_tensors="pt").pixel_values.movedim(1,-1) |
|
except RuntimeError as e: |
|
image = image_bak |
|
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() |
|
|
|
if photomaker.__class__.__name__ == 'PhotoMakerIDEncoder': |
|
cond = photomaker(id_pixel_values=pixel_values.unsqueeze(0), |
|
prompt_embeds=cond.to(photomaker.load_device), |
|
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) |
|
else: |
|
if insightface_opt is None: |
|
raise ValueError(f"InsightFace is required for PhotoMaker V2") |
|
face_detector = insightface_opt |
|
if not hasattr(face_detector, 'get_'): |
|
face_detector.get_ = face_detector.get |
|
def get(self, img, max_num=0, det_size=(640, 640)): |
|
if det_size is not None: |
|
self.det_model.input_size = det_size |
|
return self.get_(img, max_num) |
|
face_detector.get = get.__get__(face_detector, face_detector.__class__) |
|
|
|
id_embed_list = [] |
|
|
|
ToPILImage = T.ToPILImage() |
|
def tensor_to_pil_np(_img): |
|
nonlocal ToPILImage |
|
img_pil = ToPILImage(_img.movedim(-1,0)) |
|
if img_pil.mode != 'RGB': img_pil = img_pil.convert('RGB') |
|
return np.asarray(img_pil) |
|
|
|
for img in input_id_images: |
|
faces = analyze_faces(face_detector, tensor_to_pil_np(img)) |
|
if len(faces) > 0: |
|
id_embed_list.append(torch.from_numpy((faces[0]['embedding']))) |
|
|
|
if len(id_embed_list) == 0: |
|
raise ValueError(f"No face detected in input image pool") |
|
|
|
id_embeds = torch.stack(id_embed_list).to(device=photomaker.load_device) |
|
|
|
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0) |
|
cond = photomaker(id_pixel_values=pixel_values.unsqueeze(0), |
|
prompt_embeds=cond.to(photomaker.load_device), |
|
class_tokens_mask=class_tokens_mask, |
|
id_embeds=id_embeds) |
|
cond = cond.to(device=device_orig) |
|
|
|
return ([[cond, {"pooled_output": pooled}]],) |
|
|
|
class PhotoMakerStyles: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"style_name": (list(styles.keys()), {"default": "Photographic (Default)"}), |
|
}, |
|
"optional": { |
|
"positive": ("STRING", {"multiline": True, "forceInput": True, "dynamicPrompts": True}), |
|
"negative": ("STRING", {"multiline": True, "forceInput": True, "dynamicPrompts": True}), |
|
}, |
|
} |
|
RETURN_TYPES = ("STRING","STRING",) |
|
RETURN_NAMES = ("POSITIVE","NEGATIVE",) |
|
FUNCTION = "apply_photomaker_style" |
|
|
|
CATEGORY = "PhotoMaker" |
|
|
|
def apply_photomaker_style(self, style_name, positive: str = '', negative: str = ''): |
|
p, n = styles.get(style_name, "Photographic (Default)") |
|
return p.replace("{prompt}", positive), n + ' ' + negative |
|
|
|
class PrepImagesForClipVisionFromPath: |
|
def __init__(self) -> None: |
|
self.image_loader = LoadImageCustom() |
|
self.load_device = comfy.model_management.text_encoder_device() |
|
self.offload_device = comfy.model_management.text_encoder_offload_device() |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"path": ("STRING", {"multiline": False}), |
|
"interpolation": (["nearest", "bilinear", "box", "bicubic", "lanczos", "hamming"], {"default": "lanczos"}), |
|
"crop_position": (["top", "bottom", "left", "right", "center", "pad"], {"default": "center"}), |
|
}, |
|
} |
|
|
|
@classmethod |
|
def IS_CHANGED(s, path:str, interpolation, crop_position): |
|
image_path_list = s.get_images_paths(path) |
|
hashes = [] |
|
for image_path in image_path_list: |
|
if not (path.startswith("http://") or path.startswith("https://")): |
|
m = hashlib.sha256() |
|
with open(image_path, 'rb') as f: |
|
m.update(f.read()) |
|
hashes.append(m.digest().hex()) |
|
return Counter(hashes) |
|
|
|
@classmethod |
|
def VALIDATE_INPUTS(s, path:str, interpolation, crop_position): |
|
image_path_list = s.get_images_paths(path) |
|
if len(image_path_list) == 0: |
|
return "No image provided or found." |
|
return True |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "prep_images_for_clip_vision_from_path" |
|
|
|
CATEGORY = "image" |
|
|
|
@classmethod |
|
def get_images_paths(self, path:str): |
|
image_path_list = [] |
|
path = path.strip() |
|
if path: |
|
image_path_list = [path] |
|
if not (path.startswith("http://") or path.startswith("https://")) and os.path.isdir(path): |
|
image_basename_list = os.listdir(path) |
|
image_path_list = [ |
|
os.path.join(path, basename) |
|
for basename in image_basename_list |
|
if not basename.startswith('.') and basename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp', '.gif')) |
|
] |
|
return image_path_list |
|
|
|
def prep_images_for_clip_vision_from_path(self, path:str, interpolation:str, crop_position,): |
|
image_path_list = self.get_images_paths(path) |
|
if len(image_path_list) == 0: |
|
raise ValueError("No image provided or found.") |
|
|
|
interpolation=interpolation.upper() |
|
size = (224, 224) |
|
try: |
|
input_id_images = [img if (img:=load_image(image_path)).size == size else crop_image_pil(img, crop_position) for image_path in image_path_list] |
|
do_resize = not all(img.size == size for img in input_id_images) |
|
resample = getattr(PILImageResampling, interpolation) |
|
clip_preprocess = CLIPImageProcessor(resample=resample, do_normalize=False, do_resize=do_resize) |
|
id_pixel_values = clip_preprocess(input_id_images, return_tensors="pt").pixel_values.movedim(1,-1) |
|
except TypeError as err: |
|
logging.warning('[PhotoMaker]:', err) |
|
logging.warning('[PhotoMaker]: You may need to update transformers.') |
|
input_id_images = [self.image_loader.load_image(image_path)[0] for image_path in image_path_list] |
|
do_resize = not all(img.shape[-3:-3+2] == size for img in input_id_images) |
|
if do_resize: |
|
id_pixel_values = torch.cat([prepImage(img, interpolation=interpolation, crop_position=crop_position) for img in input_id_images]) |
|
else: |
|
id_pixel_values = torch.cat(input_id_images) |
|
return (id_pixel_values,) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"PhotoMakerLoaderPlus": PhotoMakerLoaderPlus, |
|
"PhotoMakerEncodePlus": PhotoMakerEncodePlus, |
|
"PhotoMakerStyles": PhotoMakerStyles, |
|
"PhotoMakerLoraLoaderPlus": PhotoMakerLoraLoaderPlus, |
|
"PrepImagesForClipVisionFromPath": PrepImagesForClipVisionFromPath, |
|
"PhotoMakerInsightFaceLoader": PhotoMakerInsightFaceLoader, |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"PhotoMakerLoaderPlus": "PhotoMaker Loader Plus", |
|
"PhotoMakerEncodePlus": "PhotoMaker Encode Plus", |
|
"PhotoMakerStyles": "Apply PhotoMaker Style", |
|
"PhotoMakerLoraLoaderPlus": "PhotoMaker LoRA Loader Plus", |
|
"PrepImagesForClipVisionFromPath": "Prepare Images For CLIP Vision From Path", |
|
"PhotoMakerInsightFaceLoader": "PhotoMaker InsightFace Loader", |
|
} |