|
import torch
|
|
import torch.nn.functional as F
|
|
from torchvision.transforms import functional as TF
|
|
from PIL import Image, ImageDraw, ImageFilter, ImageFont
|
|
import scipy.ndimage
|
|
import numpy as np
|
|
from contextlib import nullcontext
|
|
import os
|
|
|
|
import model_management
|
|
from comfy.utils import ProgressBar
|
|
from nodes import MAX_RESOLUTION
|
|
|
|
import folder_paths
|
|
|
|
from ..utility.utility import tensor2pil, pil2tensor
|
|
|
|
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
class BatchCLIPSeg:
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
|
|
return {"required":
|
|
{
|
|
"images": ("IMAGE",),
|
|
"text": ("STRING", {"multiline": False}),
|
|
"threshold": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 10.0, "step": 0.001}),
|
|
"binary_mask": ("BOOLEAN", {"default": True}),
|
|
"combine_mask": ("BOOLEAN", {"default": False}),
|
|
"use_cuda": ("BOOLEAN", {"default": True}),
|
|
},
|
|
"optional":
|
|
{
|
|
"blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
|
"opt_model": ("CLIPSEGMODEL", ),
|
|
"prev_mask": ("MASK", {"default": None}),
|
|
"image_bg_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"invert": ("BOOLEAN", {"default": False}),
|
|
}
|
|
}
|
|
|
|
CATEGORY = "KJNodes/masking"
|
|
RETURN_TYPES = ("MASK", "IMAGE", )
|
|
RETURN_NAMES = ("Mask", "Image", )
|
|
FUNCTION = "segment_image"
|
|
DESCRIPTION = """
|
|
Segments an image or batch of images using CLIPSeg.
|
|
"""
|
|
|
|
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None, invert= False, image_bg_level=0.5):
|
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
|
import torchvision.transforms as transforms
|
|
offload_device = model_management.unet_offload_device()
|
|
device = model_management.get_torch_device()
|
|
if not use_cuda:
|
|
device = torch.device("cpu")
|
|
dtype = model_management.unet_dtype()
|
|
|
|
if opt_model is None:
|
|
checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', 'clipseg-rd64-refined-fp16')
|
|
if not hasattr(self, "model"):
|
|
try:
|
|
if not os.path.exists(checkpoint_path):
|
|
from huggingface_hub import snapshot_download
|
|
snapshot_download(repo_id="Kijai/clipseg-rd64-refined-fp16", local_dir=checkpoint_path, local_dir_use_symlinks=False)
|
|
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
|
|
except:
|
|
checkpoint_path = "CIDAS/clipseg-rd64-refined"
|
|
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
|
|
processor = CLIPSegProcessor.from_pretrained(checkpoint_path)
|
|
|
|
else:
|
|
self.model = opt_model['model']
|
|
processor = opt_model['processor']
|
|
|
|
self.model.to(dtype).to(device)
|
|
|
|
B, H, W, C = images.shape
|
|
images = images.to(device)
|
|
|
|
autocast_condition = (dtype != torch.float32) and not model_management.is_device_mps(device)
|
|
with torch.autocast(model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
|
|
|
|
PIL_images = [Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) for image in images ]
|
|
prompt = [text] * len(images)
|
|
input_prc = processor(text=prompt, images=PIL_images, return_tensors="pt")
|
|
|
|
for key in input_prc:
|
|
input_prc[key] = input_prc[key].to(device)
|
|
outputs = self.model(**input_prc)
|
|
|
|
mask_tensor = torch.sigmoid(outputs.logits)
|
|
mask_tensor = (mask_tensor - mask_tensor.min()) / (mask_tensor.max() - mask_tensor.min())
|
|
mask_tensor = torch.where(mask_tensor > (threshold), mask_tensor, torch.tensor(0, dtype=torch.float))
|
|
print(mask_tensor.shape)
|
|
if len(mask_tensor.shape) == 2:
|
|
mask_tensor = mask_tensor.unsqueeze(0)
|
|
mask_tensor = F.interpolate(mask_tensor.unsqueeze(1), size=(H, W), mode='nearest')
|
|
mask_tensor = mask_tensor.squeeze(1)
|
|
|
|
self.model.to(offload_device)
|
|
|
|
if binary_mask:
|
|
mask_tensor = (mask_tensor > 0).float()
|
|
if blur_sigma > 0:
|
|
kernel_size = int(6 * int(blur_sigma) + 1)
|
|
blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma))
|
|
mask_tensor = blur(mask_tensor)
|
|
|
|
if combine_mask:
|
|
mask_tensor = torch.max(mask_tensor, dim=0)[0]
|
|
mask_tensor = mask_tensor.unsqueeze(0).repeat(len(images),1,1)
|
|
|
|
del outputs
|
|
model_management.soft_empty_cache()
|
|
|
|
if prev_mask is not None:
|
|
if prev_mask.shape != mask_tensor.shape:
|
|
prev_mask = F.interpolate(prev_mask.unsqueeze(1), size=(H, W), mode='nearest')
|
|
mask_tensor = mask_tensor + prev_mask.to(device)
|
|
torch.clamp(mask_tensor, min=0.0, max=1.0)
|
|
|
|
if invert:
|
|
mask_tensor = 1 - mask_tensor
|
|
|
|
image_tensor = images * mask_tensor.unsqueeze(-1) + (1 - mask_tensor.unsqueeze(-1)) * image_bg_level
|
|
image_tensor = torch.clamp(image_tensor, min=0.0, max=1.0).cpu().float()
|
|
|
|
mask_tensor = mask_tensor.cpu().float()
|
|
|
|
return mask_tensor, image_tensor,
|
|
|
|
class DownloadAndLoadCLIPSeg:
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
|
|
return {"required":
|
|
{
|
|
"model": (
|
|
[ 'Kijai/clipseg-rd64-refined-fp16',
|
|
'CIDAS/clipseg-rd64-refined',
|
|
],
|
|
),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "KJNodes/masking"
|
|
RETURN_TYPES = ("CLIPSEGMODEL",)
|
|
RETURN_NAMES = ("clipseg_model",)
|
|
FUNCTION = "segment_image"
|
|
DESCRIPTION = """
|
|
Downloads and loads CLIPSeg model with huggingface_hub,
|
|
to ComfyUI/models/clip_seg
|
|
"""
|
|
|
|
def segment_image(self, model):
|
|
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
|
|
checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', os.path.basename(model))
|
|
if not hasattr(self, "model"):
|
|
if not os.path.exists(checkpoint_path):
|
|
from huggingface_hub import snapshot_download
|
|
snapshot_download(repo_id=model, local_dir=checkpoint_path, local_dir_use_symlinks=False)
|
|
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
|
|
|
|
processor = CLIPSegProcessor.from_pretrained(checkpoint_path)
|
|
|
|
clipseg_model = {}
|
|
clipseg_model['model'] = self.model
|
|
clipseg_model['processor'] = processor
|
|
|
|
return clipseg_model,
|
|
|
|
class CreateTextMask:
|
|
|
|
RETURN_TYPES = ("IMAGE", "MASK",)
|
|
FUNCTION = "createtextmask"
|
|
CATEGORY = "KJNodes/text"
|
|
DESCRIPTION = """
|
|
Creates a text image and mask.
|
|
Looks for fonts from this folder:
|
|
ComfyUI/custom_nodes/ComfyUI-KJNodes/fonts
|
|
|
|
If start_rotation and/or end_rotation are different values,
|
|
creates animation between them.
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"invert": ("BOOLEAN", {"default": False}),
|
|
"frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
|
|
"text_x": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
|
|
"text_y": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
|
|
"font_size": ("INT", {"default": 32,"min": 8, "max": 4096, "step": 1}),
|
|
"font_color": ("STRING", {"default": "white"}),
|
|
"text": ("STRING", {"default": "HELLO!", "multiline": True}),
|
|
"font": (folder_paths.get_filename_list("kjnodes_fonts"), ),
|
|
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
|
"height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
|
"start_rotation": ("INT", {"default": 0,"min": 0, "max": 359, "step": 1}),
|
|
"end_rotation": ("INT", {"default": 0,"min": -359, "max": 359, "step": 1}),
|
|
},
|
|
}
|
|
|
|
def createtextmask(self, frames, width, height, invert, text_x, text_y, text, font_size, font_color, font, start_rotation, end_rotation):
|
|
|
|
batch_size = frames
|
|
out = []
|
|
masks = []
|
|
rotation = start_rotation
|
|
if start_rotation != end_rotation:
|
|
rotation_increment = (end_rotation - start_rotation) / (batch_size - 1)
|
|
|
|
font_path = folder_paths.get_full_path("kjnodes_fonts", font)
|
|
|
|
for i in range(batch_size):
|
|
image = Image.new("RGB", (width, height), "black")
|
|
draw = ImageDraw.Draw(image)
|
|
font = ImageFont.truetype(font_path, font_size)
|
|
|
|
|
|
words = text.split()
|
|
|
|
|
|
lines = []
|
|
current_line = []
|
|
current_line_width = 0
|
|
try:
|
|
|
|
for word in words:
|
|
word_width = font.getbbox(word)[2]
|
|
if current_line_width + word_width <= width - 2 * text_x:
|
|
current_line.append(word)
|
|
current_line_width += word_width + font.getbbox(" ")[2]
|
|
else:
|
|
lines.append(" ".join(current_line))
|
|
current_line = [word]
|
|
current_line_width = word_width
|
|
except:
|
|
for word in words:
|
|
word_width = font.getsize(word)[0]
|
|
if current_line_width + word_width <= width - 2 * text_x:
|
|
current_line.append(word)
|
|
current_line_width += word_width + font.getsize(" ")[0]
|
|
else:
|
|
lines.append(" ".join(current_line))
|
|
current_line = [word]
|
|
current_line_width = word_width
|
|
|
|
|
|
if current_line:
|
|
lines.append(" ".join(current_line))
|
|
|
|
|
|
y_offset = text_y
|
|
for line in lines:
|
|
text_width = font.getlength(line)
|
|
text_height = font_size
|
|
text_center_x = text_x + text_width / 2
|
|
text_center_y = y_offset + text_height / 2
|
|
try:
|
|
draw.text((text_x, y_offset), line, font=font, fill=font_color, features=['-liga'])
|
|
except:
|
|
draw.text((text_x, y_offset), line, font=font, fill=font_color)
|
|
y_offset += text_height
|
|
|
|
if start_rotation != end_rotation:
|
|
image = image.rotate(rotation, center=(text_center_x, text_center_y))
|
|
rotation += rotation_increment
|
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
image = torch.from_numpy(image)[None,]
|
|
mask = image[:, :, :, 0]
|
|
masks.append(mask)
|
|
out.append(image)
|
|
|
|
if invert:
|
|
return (1.0 - torch.cat(out, dim=0), 1.0 - torch.cat(masks, dim=0),)
|
|
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
|
|
|
|
class ColorToMask:
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "clip"
|
|
CATEGORY = "KJNodes/masking"
|
|
DESCRIPTION = """
|
|
Converts chosen RGB value to a mask.
|
|
With batch inputs, the **per_batch**
|
|
controls the number of images processed at once.
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
"invert": ("BOOLEAN", {"default": False}),
|
|
"red": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
|
|
"green": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
|
|
"blue": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
|
|
"threshold": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}),
|
|
"per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
|
|
},
|
|
}
|
|
|
|
def clip(self, images, red, green, blue, threshold, invert, per_batch):
|
|
|
|
color = torch.tensor([red, green, blue], dtype=torch.uint8)
|
|
black = torch.tensor([0, 0, 0], dtype=torch.uint8)
|
|
white = torch.tensor([255, 255, 255], dtype=torch.uint8)
|
|
|
|
if invert:
|
|
black, white = white, black
|
|
|
|
steps = images.shape[0]
|
|
pbar = ProgressBar(steps)
|
|
tensors_out = []
|
|
|
|
for start_idx in range(0, images.shape[0], per_batch):
|
|
|
|
|
|
color_distances = torch.norm(images[start_idx:start_idx+per_batch] * 255 - color, dim=-1)
|
|
|
|
|
|
mask = color_distances <= threshold
|
|
|
|
|
|
mask_out = torch.where(mask.unsqueeze(-1), white, black).float()
|
|
mask_out = mask_out.mean(dim=-1)
|
|
|
|
tensors_out.append(mask_out.cpu())
|
|
batch_count = mask_out.shape[0]
|
|
pbar.update(batch_count)
|
|
|
|
tensors_out = torch.cat(tensors_out, dim=0)
|
|
tensors_out = torch.clamp(tensors_out, min=0.0, max=1.0)
|
|
return tensors_out,
|
|
|
|
class CreateFluidMask:
|
|
|
|
RETURN_TYPES = ("IMAGE", "MASK")
|
|
FUNCTION = "createfluidmask"
|
|
CATEGORY = "KJNodes/masking/generate"
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"invert": ("BOOLEAN", {"default": False}),
|
|
"frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
|
|
"width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
|
|
"height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
|
|
"inflow_count": ("INT", {"default": 3,"min": 0, "max": 255, "step": 1}),
|
|
"inflow_velocity": ("INT", {"default": 1,"min": 0, "max": 255, "step": 1}),
|
|
"inflow_radius": ("INT", {"default": 8,"min": 0, "max": 255, "step": 1}),
|
|
"inflow_padding": ("INT", {"default": 50,"min": 0, "max": 255, "step": 1}),
|
|
"inflow_duration": ("INT", {"default": 60,"min": 0, "max": 255, "step": 1}),
|
|
},
|
|
}
|
|
|
|
def createfluidmask(self, frames, width, height, invert, inflow_count, inflow_velocity, inflow_radius, inflow_padding, inflow_duration):
|
|
from ..utility.fluid import Fluid
|
|
try:
|
|
from scipy.special import erf
|
|
except:
|
|
from scipy.spatial import erf
|
|
out = []
|
|
masks = []
|
|
RESOLUTION = width, height
|
|
DURATION = frames
|
|
|
|
INFLOW_PADDING = inflow_padding
|
|
INFLOW_DURATION = inflow_duration
|
|
INFLOW_RADIUS = inflow_radius
|
|
INFLOW_VELOCITY = inflow_velocity
|
|
INFLOW_COUNT = inflow_count
|
|
|
|
print('Generating fluid solver, this may take some time.')
|
|
fluid = Fluid(RESOLUTION, 'dye')
|
|
|
|
center = np.floor_divide(RESOLUTION, 2)
|
|
r = np.min(center) - INFLOW_PADDING
|
|
|
|
points = np.linspace(-np.pi, np.pi, INFLOW_COUNT, endpoint=False)
|
|
points = tuple(np.array((np.cos(p), np.sin(p))) for p in points)
|
|
normals = tuple(-p for p in points)
|
|
points = tuple(r * p + center for p in points)
|
|
|
|
inflow_velocity = np.zeros_like(fluid.velocity)
|
|
inflow_dye = np.zeros(fluid.shape)
|
|
for p, n in zip(points, normals):
|
|
mask = np.linalg.norm(fluid.indices - p[:, None, None], axis=0) <= INFLOW_RADIUS
|
|
inflow_velocity[:, mask] += n[:, None] * INFLOW_VELOCITY
|
|
inflow_dye[mask] = 1
|
|
|
|
|
|
for f in range(DURATION):
|
|
print(f'Computing frame {f + 1} of {DURATION}.')
|
|
if f <= INFLOW_DURATION:
|
|
fluid.velocity += inflow_velocity
|
|
fluid.dye += inflow_dye
|
|
|
|
curl = fluid.step()[1]
|
|
|
|
|
|
curl = (erf(curl * 2) + 1) / 4
|
|
|
|
color = np.dstack((curl, np.ones(fluid.shape), fluid.dye))
|
|
color = (np.clip(color, 0, 1) * 255).astype('uint8')
|
|
image = np.array(color).astype(np.float32) / 255.0
|
|
image = torch.from_numpy(image)[None,]
|
|
mask = image[:, :, :, 0]
|
|
masks.append(mask)
|
|
out.append(image)
|
|
|
|
if invert:
|
|
return (1.0 - torch.cat(out, dim=0),1.0 - torch.cat(masks, dim=0),)
|
|
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
|
|
|
|
class CreateAudioMask:
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "createaudiomask"
|
|
CATEGORY = "KJNodes/deprecated"
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"invert": ("BOOLEAN", {"default": False}),
|
|
"frames": ("INT", {"default": 16,"min": 1, "max": 255, "step": 1}),
|
|
"scale": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 2.0, "step": 0.01}),
|
|
"audio_path": ("STRING", {"default": "audio.wav"}),
|
|
"width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
|
|
"height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
|
|
},
|
|
}
|
|
|
|
def createaudiomask(self, frames, width, height, invert, audio_path, scale):
|
|
try:
|
|
import librosa
|
|
except ImportError:
|
|
raise Exception("Can not import librosa. Install it with 'pip install librosa'")
|
|
batch_size = frames
|
|
out = []
|
|
masks = []
|
|
if audio_path == "audio.wav":
|
|
audio_path = os.path.join(script_directory, audio_path)
|
|
audio, sr = librosa.load(audio_path)
|
|
spectrogram = np.abs(librosa.stft(audio))
|
|
|
|
for i in range(batch_size):
|
|
image = Image.new("RGB", (width, height), "black")
|
|
draw = ImageDraw.Draw(image)
|
|
frame = spectrogram[:, i]
|
|
circle_radius = int(height * np.mean(frame))
|
|
circle_radius *= scale
|
|
circle_center = (width // 2, height // 2)
|
|
|
|
draw.ellipse([(circle_center[0] - circle_radius, circle_center[1] - circle_radius),
|
|
(circle_center[0] + circle_radius, circle_center[1] + circle_radius)],
|
|
fill='white')
|
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
image = torch.from_numpy(image)[None,]
|
|
mask = image[:, :, :, 0]
|
|
masks.append(mask)
|
|
out.append(image)
|
|
|
|
if invert:
|
|
return (1.0 - torch.cat(out, dim=0),)
|
|
return (torch.cat(out, dim=0),torch.cat(masks, dim=0),)
|
|
|
|
class CreateGradientMask:
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "createmask"
|
|
CATEGORY = "KJNodes/masking/generate"
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"invert": ("BOOLEAN", {"default": False}),
|
|
"frames": ("INT", {"default": 0,"min": 0, "max": 255, "step": 1}),
|
|
"width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
|
|
"height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
|
|
},
|
|
}
|
|
def createmask(self, frames, width, height, invert):
|
|
|
|
batch_size = frames
|
|
out = []
|
|
|
|
image_batch = np.zeros((batch_size, height, width), dtype=np.float32)
|
|
|
|
for i in range(batch_size):
|
|
gradient = np.linspace(1.0, 0.0, width, dtype=np.float32)
|
|
time = i / frames
|
|
offset_gradient = gradient - time
|
|
image_batch[i] = offset_gradient.reshape(1, -1)
|
|
output = torch.from_numpy(image_batch)
|
|
mask = output
|
|
out.append(mask)
|
|
if invert:
|
|
return (1.0 - torch.cat(out, dim=0),)
|
|
return (torch.cat(out, dim=0),)
|
|
|
|
class CreateFadeMask:
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "createfademask"
|
|
CATEGORY = "KJNodes/deprecated"
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"invert": ("BOOLEAN", {"default": False}),
|
|
"frames": ("INT", {"default": 2,"min": 2, "max": 255, "step": 1}),
|
|
"width": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
|
|
"height": ("INT", {"default": 256,"min": 16, "max": 4096, "step": 1}),
|
|
"interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
|
|
"start_level": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"midpoint_level": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"end_level": ("FLOAT", {"default": 0.0,"min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"midpoint_frame": ("INT", {"default": 0,"min": 0, "max": 4096, "step": 1}),
|
|
},
|
|
}
|
|
|
|
def createfademask(self, frames, width, height, invert, interpolation, start_level, midpoint_level, end_level, midpoint_frame):
|
|
def ease_in(t):
|
|
return t * t
|
|
|
|
def ease_out(t):
|
|
return 1 - (1 - t) * (1 - t)
|
|
|
|
def ease_in_out(t):
|
|
return 3 * t * t - 2 * t * t * t
|
|
|
|
batch_size = frames
|
|
out = []
|
|
image_batch = np.zeros((batch_size, height, width), dtype=np.float32)
|
|
|
|
if midpoint_frame == 0:
|
|
midpoint_frame = batch_size // 2
|
|
|
|
for i in range(batch_size):
|
|
if i <= midpoint_frame:
|
|
t = i / midpoint_frame
|
|
if interpolation == "ease_in":
|
|
t = ease_in(t)
|
|
elif interpolation == "ease_out":
|
|
t = ease_out(t)
|
|
elif interpolation == "ease_in_out":
|
|
t = ease_in_out(t)
|
|
color = start_level - t * (start_level - midpoint_level)
|
|
else:
|
|
t = (i - midpoint_frame) / (batch_size - midpoint_frame)
|
|
if interpolation == "ease_in":
|
|
t = ease_in(t)
|
|
elif interpolation == "ease_out":
|
|
t = ease_out(t)
|
|
elif interpolation == "ease_in_out":
|
|
t = ease_in_out(t)
|
|
color = midpoint_level - t * (midpoint_level - end_level)
|
|
|
|
color = np.clip(color, 0, 255)
|
|
image = np.full((height, width), color, dtype=np.float32)
|
|
image_batch[i] = image
|
|
|
|
output = torch.from_numpy(image_batch)
|
|
mask = output
|
|
out.append(mask)
|
|
|
|
if invert:
|
|
return (1.0 - torch.cat(out, dim=0),)
|
|
return (torch.cat(out, dim=0),)
|
|
|
|
class CreateFadeMaskAdvanced:
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "createfademask"
|
|
CATEGORY = "KJNodes/masking/generate"
|
|
DESCRIPTION = """
|
|
Create a batch of masks interpolated between given frames and values.
|
|
Uses same syntax as Fizz' BatchValueSchedule.
|
|
First value is the frame index (not that this starts from 0, not 1)
|
|
and the second value inside the brackets is the float value of the mask in range 0.0 - 1.0
|
|
|
|
For example the default values:
|
|
0:(0.0)
|
|
7:(1.0)
|
|
15:(0.0)
|
|
|
|
Would create a mask batch fo 16 frames, starting from black,
|
|
interpolating with the chosen curve to fully white at the 8th frame,
|
|
and interpolating from that to fully black at the 16th frame.
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}),
|
|
"invert": ("BOOLEAN", {"default": False}),
|
|
"frames": ("INT", {"default": 16,"min": 2, "max": 255, "step": 1}),
|
|
"width": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}),
|
|
"height": ("INT", {"default": 512,"min": 1, "max": 4096, "step": 1}),
|
|
"interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],),
|
|
},
|
|
}
|
|
|
|
def createfademask(self, frames, width, height, invert, points_string, interpolation):
|
|
def ease_in(t):
|
|
return t * t
|
|
|
|
def ease_out(t):
|
|
return 1 - (1 - t) * (1 - t)
|
|
|
|
def ease_in_out(t):
|
|
return 3 * t * t - 2 * t * t * t
|
|
|
|
|
|
points = []
|
|
points_string = points_string.rstrip(',\n')
|
|
for point_str in points_string.split(','):
|
|
frame_str, color_str = point_str.split(':')
|
|
frame = int(frame_str.strip())
|
|
color = float(color_str.strip()[1:-1])
|
|
points.append((frame, color))
|
|
|
|
|
|
if len(points) == 0 or points[-1][0] != frames - 1:
|
|
|
|
points.append((frames - 1, points[-1][1] if points else 0))
|
|
|
|
|
|
points.sort(key=lambda x: x[0])
|
|
|
|
batch_size = frames
|
|
out = []
|
|
image_batch = np.zeros((batch_size, height, width), dtype=np.float32)
|
|
|
|
|
|
next_point = 1
|
|
|
|
for i in range(batch_size):
|
|
while next_point < len(points) and i > points[next_point][0]:
|
|
next_point += 1
|
|
|
|
|
|
prev_point = next_point - 1
|
|
t = (i - points[prev_point][0]) / (points[next_point][0] - points[prev_point][0])
|
|
if interpolation == "ease_in":
|
|
t = ease_in(t)
|
|
elif interpolation == "ease_out":
|
|
t = ease_out(t)
|
|
elif interpolation == "ease_in_out":
|
|
t = ease_in_out(t)
|
|
elif interpolation == "linear":
|
|
pass
|
|
|
|
color = points[prev_point][1] - t * (points[prev_point][1] - points[next_point][1])
|
|
color = np.clip(color, 0, 255)
|
|
image = np.full((height, width), color, dtype=np.float32)
|
|
image_batch[i] = image
|
|
|
|
output = torch.from_numpy(image_batch)
|
|
mask = output
|
|
out.append(mask)
|
|
|
|
if invert:
|
|
return (1.0 - torch.cat(out, dim=0),)
|
|
return (torch.cat(out, dim=0),)
|
|
|
|
class CreateMagicMask:
|
|
|
|
RETURN_TYPES = ("MASK", "MASK",)
|
|
RETURN_NAMES = ("mask", "mask_inverted",)
|
|
FUNCTION = "createmagicmask"
|
|
CATEGORY = "KJNodes/masking/generate"
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"frames": ("INT", {"default": 16,"min": 2, "max": 4096, "step": 1}),
|
|
"depth": ("INT", {"default": 12,"min": 1, "max": 500, "step": 1}),
|
|
"distortion": ("FLOAT", {"default": 1.5,"min": 0.0, "max": 100.0, "step": 0.01}),
|
|
"seed": ("INT", {"default": 123,"min": 0, "max": 99999999, "step": 1}),
|
|
"transitions": ("INT", {"default": 1,"min": 1, "max": 20, "step": 1}),
|
|
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
|
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
|
},
|
|
}
|
|
|
|
def createmagicmask(self, frames, transitions, depth, distortion, seed, frame_width, frame_height):
|
|
from ..utility.magictex import coordinate_grid, random_transform, magic
|
|
import matplotlib.pyplot as plt
|
|
rng = np.random.default_rng(seed)
|
|
out = []
|
|
coords = coordinate_grid((frame_width, frame_height))
|
|
|
|
|
|
frames_per_transition = frames // transitions
|
|
|
|
|
|
base_params = {
|
|
"coords": random_transform(coords, rng),
|
|
"depth": depth,
|
|
"distortion": distortion,
|
|
}
|
|
for t in range(transitions):
|
|
|
|
params1 = base_params.copy()
|
|
params2 = base_params.copy()
|
|
|
|
params1['coords'] = random_transform(coords, rng)
|
|
params2['coords'] = random_transform(coords, rng)
|
|
|
|
for i in range(frames_per_transition):
|
|
|
|
alpha = i / frames_per_transition
|
|
|
|
|
|
params = params1.copy()
|
|
params['coords'] = (1 - alpha) * params1['coords'] + alpha * params2['coords']
|
|
|
|
tex = magic(**params)
|
|
|
|
dpi = frame_width / 10
|
|
fig = plt.figure(figsize=(10, 10), dpi=dpi)
|
|
|
|
ax = fig.add_subplot(111)
|
|
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
|
|
|
ax.get_yaxis().set_ticks([])
|
|
ax.get_xaxis().set_ticks([])
|
|
ax.imshow(tex, aspect='auto')
|
|
|
|
fig.canvas.draw()
|
|
img = np.array(fig.canvas.renderer._renderer)
|
|
|
|
plt.close(fig)
|
|
|
|
pil_img = Image.fromarray(img).convert("L")
|
|
mask = torch.tensor(np.array(pil_img)) / 255.0
|
|
|
|
out.append(mask)
|
|
|
|
return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),)
|
|
|
|
class CreateShapeMask:
|
|
|
|
RETURN_TYPES = ("MASK", "MASK",)
|
|
RETURN_NAMES = ("mask", "mask_inverted",)
|
|
FUNCTION = "createshapemask"
|
|
CATEGORY = "KJNodes/masking/generate"
|
|
DESCRIPTION = """
|
|
Creates a mask or batch of masks with the specified shape.
|
|
Locations are center locations.
|
|
Grow value is the amount to grow the shape on each frame, creating animated masks.
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"shape": (
|
|
[ 'circle',
|
|
'square',
|
|
'triangle',
|
|
],
|
|
{
|
|
"default": 'circle'
|
|
}),
|
|
"frames": ("INT", {"default": 1,"min": 1, "max": 4096, "step": 1}),
|
|
"location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
|
|
"location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
|
|
"grow": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}),
|
|
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
|
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
|
"shape_width": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
|
|
"shape_height": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
|
|
},
|
|
}
|
|
|
|
def createshapemask(self, frames, frame_width, frame_height, location_x, location_y, shape_width, shape_height, grow, shape):
|
|
|
|
batch_size = frames
|
|
out = []
|
|
color = "white"
|
|
for i in range(batch_size):
|
|
image = Image.new("RGB", (frame_width, frame_height), "black")
|
|
draw = ImageDraw.Draw(image)
|
|
|
|
|
|
current_width = max(0, shape_width + i*grow)
|
|
current_height = max(0, shape_height + i*grow)
|
|
|
|
if shape == 'circle' or shape == 'square':
|
|
|
|
left_up_point = (location_x - current_width // 2, location_y - current_height // 2)
|
|
right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
|
|
two_points = [left_up_point, right_down_point]
|
|
|
|
if shape == 'circle':
|
|
draw.ellipse(two_points, fill=color)
|
|
elif shape == 'square':
|
|
draw.rectangle(two_points, fill=color)
|
|
|
|
elif shape == 'triangle':
|
|
|
|
left_up_point = (location_x - current_width // 2, location_y + current_height // 2)
|
|
right_down_point = (location_x + current_width // 2, location_y + current_height // 2)
|
|
top_point = (location_x, location_y - current_height // 2)
|
|
draw.polygon([top_point, left_up_point, right_down_point], fill=color)
|
|
|
|
image = pil2tensor(image)
|
|
mask = image[:, :, :, 0]
|
|
out.append(mask)
|
|
outstack = torch.cat(out, dim=0)
|
|
return (outstack, 1.0 - outstack,)
|
|
|
|
class CreateVoronoiMask:
|
|
|
|
RETURN_TYPES = ("MASK", "MASK",)
|
|
RETURN_NAMES = ("mask", "mask_inverted",)
|
|
FUNCTION = "createvoronoi"
|
|
CATEGORY = "KJNodes/masking/generate"
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"frames": ("INT", {"default": 16,"min": 2, "max": 4096, "step": 1}),
|
|
"num_points": ("INT", {"default": 15,"min": 1, "max": 4096, "step": 1}),
|
|
"line_width": ("INT", {"default": 4,"min": 1, "max": 4096, "step": 1}),
|
|
"speed": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"frame_width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
|
"frame_height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
|
},
|
|
}
|
|
|
|
def createvoronoi(self, frames, num_points, line_width, speed, frame_width, frame_height):
|
|
from scipy.spatial import Voronoi
|
|
|
|
batch_size = frames
|
|
out = []
|
|
|
|
|
|
aspect_ratio = frame_width / frame_height
|
|
|
|
|
|
start_points = np.random.rand(num_points, 2)
|
|
start_points[:, 0] *= aspect_ratio
|
|
|
|
end_points = np.random.rand(num_points, 2)
|
|
end_points[:, 0] *= aspect_ratio
|
|
|
|
for i in range(batch_size):
|
|
|
|
t = (i * speed) / (batch_size - 1)
|
|
t = np.clip(t, 0, 1)
|
|
points = (1 - t) * start_points + t * end_points
|
|
|
|
|
|
points[:, 0] *= aspect_ratio
|
|
|
|
vor = Voronoi(points)
|
|
|
|
|
|
fig, ax = plt.subplots()
|
|
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
|
ax.set_xlim([0, aspect_ratio]); ax.set_ylim([0, 1])
|
|
ax.axis('off')
|
|
ax.margins(0, 0)
|
|
fig.set_size_inches(aspect_ratio * frame_height/100, frame_height/100)
|
|
ax.fill_between([0, 1], [0, 1], color='white')
|
|
|
|
|
|
for simplex in vor.ridge_vertices:
|
|
simplex = np.asarray(simplex)
|
|
if np.all(simplex >= 0):
|
|
plt.plot(vor.vertices[simplex, 0], vor.vertices[simplex, 1], 'k-', linewidth=line_width)
|
|
|
|
fig.canvas.draw()
|
|
img = np.array(fig.canvas.renderer._renderer)
|
|
|
|
plt.close(fig)
|
|
|
|
pil_img = Image.fromarray(img).convert("L")
|
|
mask = torch.tensor(np.array(pil_img)) / 255.0
|
|
|
|
out.append(mask)
|
|
|
|
return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),)
|
|
|
|
class GetMaskSizeAndCount:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"mask": ("MASK",),
|
|
}}
|
|
|
|
RETURN_TYPES = ("MASK","INT", "INT", "INT",)
|
|
RETURN_NAMES = ("mask", "width", "height", "count",)
|
|
FUNCTION = "getsize"
|
|
CATEGORY = "KJNodes/masking"
|
|
DESCRIPTION = """
|
|
Returns the width, height and batch size of the mask,
|
|
and passes it through unchanged.
|
|
|
|
"""
|
|
|
|
def getsize(self, mask):
|
|
width = mask.shape[2]
|
|
height = mask.shape[1]
|
|
count = mask.shape[0]
|
|
return {"ui": {
|
|
"text": [f"{count}x{width}x{height}"]},
|
|
"result": (mask, width, height, count)
|
|
}
|
|
|
|
class GrowMaskWithBlur:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"mask": ("MASK",),
|
|
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
|
|
"incremental_expandrate": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
|
|
"tapered_corners": ("BOOLEAN", {"default": True}),
|
|
"flip_input": ("BOOLEAN", {"default": False}),
|
|
"blur_radius": ("FLOAT", {
|
|
"default": 0.0,
|
|
"min": 0.0,
|
|
"max": 100,
|
|
"step": 0.1
|
|
}),
|
|
"lerp_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
"decay_factor": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
},
|
|
"optional": {
|
|
"fill_holes": ("BOOLEAN", {"default": False}),
|
|
},
|
|
}
|
|
|
|
CATEGORY = "KJNodes/masking"
|
|
RETURN_TYPES = ("MASK", "MASK",)
|
|
RETURN_NAMES = ("mask", "mask_inverted",)
|
|
FUNCTION = "expand_mask"
|
|
DESCRIPTION = """
|
|
# GrowMaskWithBlur
|
|
- mask: Input mask or mask batch
|
|
- expand: Expand or contract mask or mask batch by a given amount
|
|
- incremental_expandrate: increase expand rate by a given amount per frame
|
|
- tapered_corners: use tapered corners
|
|
- flip_input: flip input mask
|
|
- blur_radius: value higher than 0 will blur the mask
|
|
- lerp_alpha: alpha value for interpolation between frames
|
|
- decay_factor: decay value for interpolation between frames
|
|
- fill_holes: fill holes in the mask (slow)"""
|
|
|
|
def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, incremental_expandrate, lerp_alpha, decay_factor, fill_holes=False):
|
|
alpha = lerp_alpha
|
|
decay = decay_factor
|
|
if flip_input:
|
|
mask = 1.0 - mask
|
|
c = 0 if tapered_corners else 1
|
|
kernel = np.array([[c, 1, c],
|
|
[1, 1, 1],
|
|
[c, 1, c]])
|
|
growmask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu()
|
|
out = []
|
|
previous_output = None
|
|
current_expand = expand
|
|
for m in growmask:
|
|
output = m.numpy().astype(np.float32)
|
|
for _ in range(abs(round(current_expand))):
|
|
if current_expand < 0:
|
|
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
|
|
else:
|
|
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
|
if current_expand < 0:
|
|
current_expand -= abs(incremental_expandrate)
|
|
else:
|
|
current_expand += abs(incremental_expandrate)
|
|
if fill_holes:
|
|
binary_mask = output > 0
|
|
output = scipy.ndimage.binary_fill_holes(binary_mask)
|
|
output = output.astype(np.float32) * 255
|
|
output = torch.from_numpy(output)
|
|
if alpha < 1.0 and previous_output is not None:
|
|
|
|
output = alpha * output + (1 - alpha) * previous_output
|
|
if decay < 1.0 and previous_output is not None:
|
|
|
|
output += decay * previous_output
|
|
output = output / output.max()
|
|
previous_output = output
|
|
out.append(output)
|
|
|
|
if blur_radius != 0:
|
|
|
|
for idx, tensor in enumerate(out):
|
|
|
|
pil_image = tensor2pil(tensor.cpu().detach())[0]
|
|
|
|
pil_image = pil_image.filter(ImageFilter.GaussianBlur(blur_radius))
|
|
|
|
out[idx] = pil2tensor(pil_image)
|
|
blurred = torch.cat(out, dim=0)
|
|
return (blurred, 1.0 - blurred)
|
|
else:
|
|
return (torch.stack(out, dim=0), 1.0 - torch.stack(out, dim=0),)
|
|
|
|
class MaskBatchMulti:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"inputcount": ("INT", {"default": 2, "min": 2, "max": 1000, "step": 1}),
|
|
"mask_1": ("MASK", ),
|
|
"mask_2": ("MASK", ),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
RETURN_NAMES = ("masks",)
|
|
FUNCTION = "combine"
|
|
CATEGORY = "KJNodes/masking"
|
|
DESCRIPTION = """
|
|
Creates an image batch from multiple masks.
|
|
You can set how many inputs the node has,
|
|
with the **inputcount** and clicking update.
|
|
"""
|
|
|
|
def combine(self, inputcount, **kwargs):
|
|
mask = kwargs["mask_1"]
|
|
for c in range(1, inputcount):
|
|
new_mask = kwargs[f"mask_{c + 1}"]
|
|
if mask.shape[1:] != new_mask.shape[1:]:
|
|
new_mask = F.interpolate(new_mask.unsqueeze(1), size=(mask.shape[1], mask.shape[2]), mode="bicubic").squeeze(1)
|
|
mask = torch.cat((mask, new_mask), dim=0)
|
|
return (mask,)
|
|
|
|
class OffsetMask:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"mask": ("MASK",),
|
|
"x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
|
"y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
|
"angle": ("INT", { "default": 0, "min": -360, "max": 360, "step": 1, "display": "number" }),
|
|
"duplication_factor": ("INT", { "default": 1, "min": 1, "max": 1000, "step": 1, "display": "number" }),
|
|
"roll": ("BOOLEAN", { "default": False }),
|
|
"incremental": ("BOOLEAN", { "default": False }),
|
|
"padding_mode": (
|
|
[
|
|
'empty',
|
|
'border',
|
|
'reflection',
|
|
|
|
], {
|
|
"default": 'empty'
|
|
}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
RETURN_NAMES = ("mask",)
|
|
FUNCTION = "offset"
|
|
CATEGORY = "KJNodes/masking"
|
|
DESCRIPTION = """
|
|
Offsets the mask by the specified amount.
|
|
- mask: Input mask or mask batch
|
|
- x: Horizontal offset
|
|
- y: Vertical offset
|
|
- angle: Angle in degrees
|
|
- roll: roll edge wrapping
|
|
- duplication_factor: Number of times to duplicate the mask to form a batch
|
|
- border padding_mode: Padding mode for the mask
|
|
"""
|
|
|
|
def offset(self, mask, x, y, angle, roll=False, incremental=False, duplication_factor=1, padding_mode="empty"):
|
|
|
|
mask = mask.repeat(duplication_factor, 1, 1).clone()
|
|
|
|
batch_size, height, width = mask.shape
|
|
|
|
if angle != 0 and incremental:
|
|
for i in range(batch_size):
|
|
rotation_angle = angle * (i+1)
|
|
mask[i] = TF.rotate(mask[i].unsqueeze(0), rotation_angle).squeeze(0)
|
|
elif angle > 0:
|
|
for i in range(batch_size):
|
|
mask[i] = TF.rotate(mask[i].unsqueeze(0), angle).squeeze(0)
|
|
|
|
if roll:
|
|
if incremental:
|
|
for i in range(batch_size):
|
|
shift_x = min(x*(i+1), width-1)
|
|
shift_y = min(y*(i+1), height-1)
|
|
if shift_x != 0:
|
|
mask[i] = torch.roll(mask[i], shifts=shift_x, dims=1)
|
|
if shift_y != 0:
|
|
mask[i] = torch.roll(mask[i], shifts=shift_y, dims=0)
|
|
else:
|
|
shift_x = min(x, width-1)
|
|
shift_y = min(y, height-1)
|
|
if shift_x != 0:
|
|
mask = torch.roll(mask, shifts=shift_x, dims=2)
|
|
if shift_y != 0:
|
|
mask = torch.roll(mask, shifts=shift_y, dims=1)
|
|
else:
|
|
|
|
for i in range(batch_size):
|
|
if incremental:
|
|
temp_x = min(x * (i+1), width-1)
|
|
temp_y = min(y * (i+1), height-1)
|
|
else:
|
|
temp_x = min(x, width-1)
|
|
temp_y = min(y, height-1)
|
|
if temp_x > 0:
|
|
if padding_mode == 'empty':
|
|
mask[i] = torch.cat([torch.zeros((height, temp_x)), mask[i, :, :-temp_x]], dim=1)
|
|
elif padding_mode in ['replicate', 'reflect']:
|
|
mask[i] = F.pad(mask[i, :, :-temp_x], (0, temp_x), mode=padding_mode)
|
|
elif temp_x < 0:
|
|
if padding_mode == 'empty':
|
|
mask[i] = torch.cat([mask[i, :, :temp_x], torch.zeros((height, -temp_x))], dim=1)
|
|
elif padding_mode in ['replicate', 'reflect']:
|
|
mask[i] = F.pad(mask[i, :, -temp_x:], (temp_x, 0), mode=padding_mode)
|
|
|
|
if temp_y > 0:
|
|
if padding_mode == 'empty':
|
|
mask[i] = torch.cat([torch.zeros((temp_y, width)), mask[i, :-temp_y, :]], dim=0)
|
|
elif padding_mode in ['replicate', 'reflect']:
|
|
mask[i] = F.pad(mask[i, :-temp_y, :], (0, temp_y), mode=padding_mode)
|
|
elif temp_y < 0:
|
|
if padding_mode == 'empty':
|
|
mask[i] = torch.cat([mask[i, :temp_y, :], torch.zeros((-temp_y, width))], dim=0)
|
|
elif padding_mode in ['replicate', 'reflect']:
|
|
mask[i] = F.pad(mask[i, -temp_y:, :], (temp_y, 0), mode=padding_mode)
|
|
|
|
return mask,
|
|
|
|
class RoundMask:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"mask": ("MASK",),
|
|
}}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "round"
|
|
CATEGORY = "KJNodes/masking"
|
|
DESCRIPTION = """
|
|
Rounds the mask or batch of masks to a binary mask.
|
|
<img src="https://github.com/kijai/ComfyUI-KJNodes/assets/40791699/52c85202-f74e-4b96-9dac-c8bda5ddcc40" width="300" height="250" alt="RoundMask example">
|
|
|
|
"""
|
|
|
|
def round(self, mask):
|
|
mask = mask.round()
|
|
return (mask,)
|
|
|
|
class ResizeMask:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"mask": ("MASK",),
|
|
"width": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }),
|
|
"height": ("INT", { "default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8, "display": "number" }),
|
|
"keep_proportions": ("BOOLEAN", { "default": False }),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK", "INT", "INT",)
|
|
RETURN_NAMES = ("mask", "width", "height",)
|
|
FUNCTION = "resize"
|
|
CATEGORY = "KJNodes/masking"
|
|
DESCRIPTION = """
|
|
Resizes the mask or batch of masks to the specified width and height.
|
|
"""
|
|
|
|
def resize(self, mask, width, height, keep_proportions):
|
|
if keep_proportions:
|
|
_, oh, ow = mask.shape
|
|
width = ow if width == 0 else width
|
|
height = oh if height == 0 else height
|
|
ratio = min(width / ow, height / oh)
|
|
width = round(ow*ratio)
|
|
height = round(oh*ratio)
|
|
outputs = mask.unsqueeze(1)
|
|
outputs = F.interpolate(outputs, size=(height, width), mode="nearest")
|
|
outputs = outputs.squeeze(1)
|
|
|
|
return(outputs, outputs.shape[2], outputs.shape[1],)
|
|
|
|
class RemapMaskRange:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"mask": ("MASK",),
|
|
"min": ("FLOAT", {"default": 0.0,"min": -10.0, "max": 1.0, "step": 0.01}),
|
|
"max": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 10.0, "step": 0.01}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
RETURN_NAMES = ("mask",)
|
|
FUNCTION = "remap"
|
|
CATEGORY = "KJNodes/masking"
|
|
DESCRIPTION = """
|
|
Sets new min and max values for the mask.
|
|
"""
|
|
|
|
def remap(self, mask, min, max):
|
|
|
|
|
|
mask_max = torch.max(mask)
|
|
|
|
|
|
mask_max = mask_max if mask_max > 0 else 1
|
|
|
|
|
|
|
|
scaled_mask = (mask / mask_max) * (max - min) + min
|
|
|
|
|
|
scaled_mask = torch.clamp(scaled_mask, min=0.0, max=1.0)
|
|
|
|
return (scaled_mask, )
|
|
|