|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import collections.abc as clabc |
|
|
|
import torch |
|
|
|
|
|
class ImageSelector: |
|
""" |
|
Select some of the images and pipe through |
|
""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
""" |
|
Input: list of index of selected image, seperated by comma (",") |
|
support colon (":") sperated range (left included, right excluded) |
|
Indexes start with 1 for simplicity |
|
""" |
|
return { |
|
"required": { |
|
"images": ("IMAGE", ), |
|
"selected_indexes": ("STRING", { |
|
"multiline": False, |
|
"default": "1,2,3" |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", ) |
|
|
|
|
|
FUNCTION = "run" |
|
|
|
OUTPUT_NODE = False |
|
|
|
CATEGORY = "image" |
|
|
|
def run(self, images: torch.Tensor, selected_indexes: str): |
|
""" |
|
根据 selected_indexes 选择 images 中的图片,支持连续索引和范围索引 |
|
|
|
Args: |
|
images (torch.Tensor): 输入的图像张量,维度为 [N, C, H, W], 其中 N 为图片数量, C 为通道数, H、W 为图片的高和宽。 |
|
selected_indexes (str): 选择的图片索引,支持连续索引和范围索引,例如:"0,2,4:6,8" 表示选择第1、3、5张和第2、4、6、8张图片。 |
|
|
|
Returns: |
|
tuple: 选择的图片张量,维度为 [N', C, H, W],其中 N' 为选择的图片数量。 |
|
|
|
""" |
|
shape = images.shape |
|
len_first_dim = shape[0] |
|
|
|
selected_index: list[int] = [] |
|
total_indexes: list[int] = list(range(len_first_dim)) |
|
for s in selected_indexes.strip().split(','): |
|
try: |
|
if ":" in s: |
|
_li = s.strip().split(':', maxsplit=1) |
|
_start = _li[0] |
|
_end = _li[1] |
|
if _start and _end: |
|
selected_index.extend( |
|
total_indexes[int(_start) - 1:int(_end) - 1] |
|
) |
|
elif _start: |
|
selected_index.extend( |
|
total_indexes[int(_start) - 1:] |
|
) |
|
elif _end: |
|
selected_index.extend( |
|
total_indexes[:int(_end) - 1] |
|
) |
|
else: |
|
x: int = int(s.strip()) - 1 |
|
if x < len_first_dim: |
|
selected_index.append(x) |
|
except: |
|
pass |
|
|
|
if selected_index: |
|
print(f"ImageSelector: selected: {len(selected_index)} images") |
|
return (images[selected_index], ) |
|
|
|
print(f"ImageSelector: selected no images, passthrough") |
|
return (images, ) |
|
|
|
|
|
class ImageDuplicator: |
|
""" |
|
Duplicate each images and pipe through |
|
""" |
|
|
|
def __init__(self): |
|
self._name = "ImageDuplicator" |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
""" |
|
Input: copies you want to get |
|
""" |
|
return { |
|
"required": { |
|
"images": ("IMAGE", ), |
|
"dup_times": ("INT", { |
|
"default": 2, |
|
"min": 1, |
|
"max": 16, |
|
"step": 1, |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", ) |
|
|
|
|
|
FUNCTION = "run" |
|
|
|
OUTPUT_NODE = False |
|
|
|
CATEGORY = "image" |
|
|
|
def run(self, images: torch.Tensor, dup_times: int): |
|
""" |
|
对输入的图像张量进行复制多次,并将复制后的张量拼接起来返回。 |
|
|
|
Args: |
|
images (torch.Tensor): 输入的图像张量,维度为 (batch_size, channels, height, width)。 |
|
dup_times (int): 复制的次数。 |
|
|
|
Returns: |
|
torch.Tensor: 拼接后的图像张量,维度为 (batch_size * dup_times, channels, height, width)。 |
|
|
|
""" |
|
|
|
tensor_list = [images |
|
] + [torch.clone(images) for _ in range(dup_times - 1)] |
|
|
|
print( |
|
f"ImageDuplicator: dup {dup_times} times,", |
|
f"return {len(tensor_list)} images", |
|
) |
|
return (torch.cat(tensor_list), ) |
|
|
|
|
|
class LatentSelector: |
|
""" |
|
Select some of the latent images and pipe through |
|
""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
""" |
|
Input: list of index of selected image, seperated by comma (",") |
|
support colon (":") sperated range (left included, right excluded) |
|
Indexes start with 1 for simplicity |
|
""" |
|
return { |
|
"required": { |
|
"latent_image": ("LATENT", ), |
|
"selected_indexes": ("STRING", { |
|
"multiline": False, |
|
"default": "1,2,3" |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("LATENT", ) |
|
|
|
|
|
FUNCTION = "run" |
|
|
|
OUTPUT_NODE = False |
|
|
|
CATEGORY = "latent" |
|
|
|
def run(self, latent_image: clabc.Mapping[str, torch.Tensor], |
|
selected_indexes: str): |
|
""" |
|
对latent_image进行筛选,根据selected_indexes指定的索引进行筛选 |
|
Args: |
|
latent_image: 待筛选的latent_image,Mapping[str, torch.Tensor],包含'samples'字段 |
|
selected_indexes: 待筛选的索引,以逗号分隔,支持连续索引范围以冒号分隔,例如'1,3,5:7,9' |
|
|
|
Returns: |
|
筛选后的latent_image,Mapping[str, torch.Tensor] |
|
""" |
|
samples = latent_image['samples'] |
|
shape = samples.shape |
|
len_first_dim = shape[0] |
|
|
|
selected_index: list[int] = [] |
|
total_indexes: list[int] = list(range(len_first_dim)) |
|
for s in selected_indexes.strip().split(','): |
|
try: |
|
if ":" in s: |
|
_li = s.strip().split(':', maxsplit=1) |
|
_start = _li[0] |
|
_end = _li[1] |
|
if _start and _end: |
|
selected_index.extend( |
|
total_indexes[int(_start) - 1:int(_end) - 1] |
|
) |
|
elif _start: |
|
selected_index.extend( |
|
total_indexes[int(_start) - 1:] |
|
) |
|
elif _end: |
|
selected_index.extend( |
|
total_indexes[:int(_end) - 1] |
|
) |
|
else: |
|
x: int = int(s.strip()) - 1 |
|
if x < len_first_dim: |
|
selected_index.append(x) |
|
except: |
|
pass |
|
|
|
if selected_index: |
|
print(f"LatentSelector: selected: {len(selected_index)} latents") |
|
return ({'samples': samples[selected_index, :, :, :]}, ) |
|
|
|
print(f"LatentSelector: selected no latents, passthrough") |
|
return (latent_image, ) |
|
|
|
|
|
class LatentDuplicator: |
|
""" |
|
Duplicate each latent images and pipe through |
|
""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
""" |
|
Input: copies you want to get |
|
""" |
|
return { |
|
"required": { |
|
"latent_image": ("LATENT", ), |
|
"dup_times": ("INT", { |
|
"default": 2, |
|
"min": 1, |
|
"max": 16, |
|
"step": 1, |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("LATENT", ) |
|
|
|
|
|
FUNCTION = "run" |
|
|
|
OUTPUT_NODE = False |
|
|
|
CATEGORY = "latent" |
|
|
|
def run(self, latent_image: clabc.Mapping[str, torch.Tensor], |
|
dup_times: int): |
|
""" |
|
对latent_image进行复制, 复制次数为dup_times。 |
|
|
|
Args: |
|
latent_image (clabc.Mapping[str, torch.Tensor]): 输入的latent_image, 包含'samples'键。 |
|
dup_times (int): 复制次数。 |
|
|
|
Returns: |
|
Tuple[Dict[str, torch.Tensor]]: 返回包含samples的字典, samples是一个长度为(dup_times+1)的样本张量。 |
|
|
|
""" |
|
samples = latent_image['samples'] |
|
|
|
sample_list = [samples] + [ |
|
torch.clone(samples) for _ in range(dup_times - 1) |
|
] |
|
|
|
print( |
|
f"LatentDuplicator: dup {dup_times} times,", |
|
f"return {len(sample_list)} images", |
|
) |
|
return ({ |
|
'samples': torch.cat(sample_list), |
|
}, ) |
|
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"ImageSelector": ImageSelector, |
|
"ImageDuplicator": ImageDuplicator, |
|
"LatentSelector": LatentSelector, |
|
"LatentDuplicator": LatentDuplicator |
|
} |
|
|