import torch import hashlib from pathlib import Path from typing import Iterable from PIL import Image, ImageOps import numpy as np import folder_paths class LoadImageFromPath: @classmethod def INPUT_TYPES(s): return {"required": {"image": ("STRING", {"default": r"ComfyUI_00001_-assets\ComfyUI_00001_.png [output]"})}, } CATEGORY = "image" RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): image_path = LoadImageFromPath._resolve_path(image) i = Image.open(image_path) i = ImageOps.exif_transpose(i) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) else: mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") return (image, mask) def _resolve_path(image) -> Path: image_path = Path(folder_paths.get_annotated_filepath(image)) return image_path @classmethod def IS_CHANGED(s, image): image_path = LoadImageFromPath._resolve_path(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() @classmethod def VALIDATE_INPUTS(s, image): # If image is an output of another node, it will be None during validation if image is None: return True image_path = LoadImageFromPath._resolve_path(image) if not image_path.exists(): return "Invalid image path: {}".format(image_path) return True class PILToImage: @classmethod def INPUT_TYPES(s): return {'required': {'images': ('PIL_IMAGE', )}, } RETURN_TYPES = ('IMAGE',) FUNCTION = 'pil_images_to_images' CATEGORY = 'image/PIL' def pil_images_to_images(images: Iterable[Image.Image]) -> torch.Tensor: pil_images = images images = [] for pil_image in pil_images: i = pil_image i = ImageOps.exif_transpose(i) if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] images.append(image) if len(images) > 1: images = torch.cat(images, dim=0) else: images = images[0] return (images,) class PILToMask: @classmethod def INPUT_TYPES(s): return {'required': {'images': ('PIL_IMAGE', )}, } RETURN_TYPES = ('IMAGE',) FUNCTION = 'pil_images_to_masks' CATEGORY = 'image/PIL' def pil_images_to_masks(images: Iterable[Image.Image]) -> torch.Tensor: pil_images = images masks = [] for pil_image in pil_images: i = pil_image i = ImageOps.exif_transpose(i) if i.mode == 'I': i = i.point(lambda i: i * (1 / 255)) if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) else: mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") masks.append(mask) if len(masks) > 1: masks = torch.cat(masks, dim=0) else: masks = masks[0] return (masks,) class ImageToPIL: @classmethod def INPUT_TYPES(s): return {'required': {'images': ('IMAGE', )}, } RETURN_TYPES = ('PIL_IMAGE',) FUNCTION = 'images_to_pil_images' CATEGORY = 'image/PIL' def images_to_pil_images(self, images: torch.Tensor) -> list[Image.Image]: pil_images = [] for image in images: i = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) pil_images.append(img) return (pil_images,)