|
import os, torch |
|
from pathlib import Path |
|
from PIL import Image, ImageDraw, ImageFont |
|
from .utils import easySave, get_sd_version |
|
from .adv_encode import advanced_encode |
|
from .controlnet import easyControlnet |
|
from .log import log_node_warn |
|
from ..layer_diffuse import LayerDiffuse |
|
from ..config import RESOURCES_DIR |
|
from nodes import CLIPTextEncode |
|
try: |
|
from comfy_extras.nodes_flux import FluxGuidance |
|
except: |
|
FluxGuidance = None |
|
|
|
class easyXYPlot(): |
|
|
|
def __init__(self, xyPlotData, save_prefix, image_output, prompt, extra_pnginfo, my_unique_id, sampler, easyCache): |
|
self.x_node_type, self.x_type = sampler.safe_split(xyPlotData.get("x_axis"), ': ') |
|
self.y_node_type, self.y_type = sampler.safe_split(xyPlotData.get("y_axis"), ': ') |
|
self.x_values = xyPlotData.get("x_vals") if self.x_type != "None" else [] |
|
self.y_values = xyPlotData.get("y_vals") if self.y_type != "None" else [] |
|
self.custom_font = xyPlotData.get("custom_font") |
|
|
|
self.grid_spacing = xyPlotData.get("grid_spacing") |
|
self.latent_id = 0 |
|
self.output_individuals = xyPlotData.get("output_individuals") |
|
|
|
self.x_label, self.y_label = [], [] |
|
self.max_width, self.max_height = 0, 0 |
|
self.latents_plot = [] |
|
self.image_list = [] |
|
|
|
self.num_cols = len(self.x_values) if len(self.x_values) > 0 else 1 |
|
self.num_rows = len(self.y_values) if len(self.y_values) > 0 else 1 |
|
|
|
self.total = self.num_cols * self.num_rows |
|
self.num = 0 |
|
|
|
self.save_prefix = save_prefix |
|
self.image_output = image_output |
|
self.prompt = prompt |
|
self.extra_pnginfo = extra_pnginfo |
|
self.my_unique_id = my_unique_id |
|
|
|
self.sampler = sampler |
|
self.easyCache = easyCache |
|
|
|
|
|
@staticmethod |
|
def define_variable(plot_image_vars, value_type, value, index): |
|
|
|
plot_image_vars[value_type] = value |
|
if value_type in ["seed", "Seeds++ Batch"]: |
|
value_label = f"{value}" |
|
else: |
|
value_label = f"{value_type}: {value}" |
|
|
|
if "ControlNet" in value_type: |
|
value_label = f"ControlNet {index + 1}" |
|
|
|
if value_type in ['Lora', 'Checkpoint']: |
|
arr = value.split(',') |
|
model_name = os.path.basename(os.path.splitext(arr[0])[0]) |
|
trigger_words = ' ' + arr[3] if len(arr[3]) > 2 else '' |
|
value_label = f"{model_name}{trigger_words}" |
|
|
|
if value_type in ["ModelMergeBlocks"]: |
|
if ":" in value: |
|
line = value.split(':') |
|
value_label = f"{line[0]}" |
|
elif len(value) > 16: |
|
value_label = f"ModelMergeBlocks {index + 1}" |
|
else: |
|
value_label = f"MMB: {value}" |
|
|
|
if value_type in ["Pos Condition"]: |
|
value_label = f"pos cond {index + 1}" if index>0 else f"pos cond" |
|
if value_type in ["Neg Condition"]: |
|
value_label = f"neg cond {index + 1}" if index>0 else f"neg cond" |
|
|
|
if value_type in ["Positive Prompt S/R"]: |
|
value_label = f"pos prompt {index + 1}" if index>0 else f"pos prompt" |
|
if value_type in ["Negative Prompt S/R"]: |
|
value_label = f"neg prompt {index + 1}" if index>0 else f"neg prompt" |
|
|
|
if value_type in ["steps", "cfg", "denoise", "clip_skip", |
|
"lora_model_strength", "lora_clip_strength"]: |
|
value_label = f"{value_type}: {value}" |
|
|
|
if value_type == "positive": |
|
value_label = f"pos prompt {index + 1}" |
|
elif value_type == "negative": |
|
value_label = f"neg prompt {index + 1}" |
|
|
|
return plot_image_vars, value_label |
|
|
|
@staticmethod |
|
def get_font(font_size, font_path=None): |
|
if font_path is None: |
|
font_path = str(Path(os.path.join(RESOURCES_DIR, 'OpenSans-Medium.ttf'))) |
|
return ImageFont.truetype(font_path, font_size) |
|
|
|
@staticmethod |
|
def update_label(label, value, num_items): |
|
if len(label) < num_items: |
|
return [*label, value] |
|
return label |
|
|
|
@staticmethod |
|
def rearrange_tensors(latent, num_cols, num_rows): |
|
new_latent = [] |
|
for i in range(num_rows): |
|
for j in range(num_cols): |
|
index = j * num_rows + i |
|
new_latent.append(latent[index]) |
|
return new_latent |
|
|
|
def calculate_background_dimensions(self): |
|
border_size = int((self.max_width // 8) * 1.5) if self.y_type != "None" or self.x_type != "None" else 0 |
|
bg_width = self.num_cols * (self.max_width + self.grid_spacing) - self.grid_spacing + border_size * ( |
|
self.y_type != "None") |
|
bg_height = self.num_rows * (self.max_height + self.grid_spacing) - self.grid_spacing + border_size * ( |
|
self.x_type != "None") |
|
|
|
x_offset_initial = border_size if self.y_type != "None" else 0 |
|
y_offset = border_size if self.x_type != "None" else 0 |
|
|
|
return bg_width, bg_height, x_offset_initial, y_offset |
|
|
|
def adjust_font_size(self, text, initial_font_size, label_width): |
|
font = self.get_font(initial_font_size, self.custom_font) |
|
text_width = font.getbbox(text) |
|
if text_width and text_width[2]: |
|
text_width = text_width[2] |
|
|
|
scaling_factor = 0.9 |
|
if text_width > (label_width * scaling_factor): |
|
return int(initial_font_size * (label_width / text_width) * scaling_factor) |
|
else: |
|
return initial_font_size |
|
|
|
def textsize(self, d, text, font): |
|
_, _, width, height = d.textbbox((0, 0), text=text, font=font) |
|
return width, height |
|
|
|
def create_label(self, img, text, initial_font_size, is_x_label=True, max_font_size=70, min_font_size=10): |
|
label_width = img.width if is_x_label else img.height |
|
|
|
|
|
font_size = self.adjust_font_size(text, initial_font_size, label_width) |
|
font_size = min(max_font_size, font_size) |
|
font_size = max(min_font_size, font_size) |
|
|
|
label_height = int(font_size * 1.5) if is_x_label else font_size |
|
|
|
label_bg = Image.new('RGBA', (label_width, label_height), color=(255, 255, 255, 0)) |
|
d = ImageDraw.Draw(label_bg) |
|
|
|
font = self.get_font(font_size, self.custom_font) |
|
|
|
|
|
if self.textsize(d, text, font=font)[0] > label_width: |
|
while self.textsize(d, text + '...', font=font)[0] > label_width and len(text) > 0: |
|
text = text[:-1] |
|
text = text + '...' |
|
|
|
|
|
text_lines = text.split('\n') |
|
text_widths, text_heights = zip(*[self.textsize(d, line, font=font) for line in text_lines]) |
|
max_text_width = max(text_widths) |
|
total_text_height = sum(text_heights) |
|
|
|
|
|
lines_positions = [] |
|
current_y = 0 |
|
for line, line_width, line_height in zip(text_lines, text_widths, text_heights): |
|
text_x = (label_width - line_width) // 2 |
|
text_y = current_y + (label_height - total_text_height) // 2 |
|
current_y += line_height |
|
lines_positions.append((line, (text_x, text_y))) |
|
|
|
|
|
for line, (text_x, text_y) in lines_positions: |
|
d.text((text_x, text_y), line, fill='black', font=font) |
|
|
|
return label_bg |
|
|
|
def sample_plot_image(self, plot_image_vars, samples, preview_latent, latents_plot, image_list, disable_noise, |
|
start_step, last_step, force_full_denoise, x_value=None, y_value=None): |
|
model, clip, vae, positive, negative, seed, steps, cfg = None, None, None, None, None, None, None, None |
|
sampler_name, scheduler, denoise = None, None, None |
|
|
|
a1111_prompt_style = plot_image_vars['a1111_prompt_style'] if "a1111_prompt_style" in plot_image_vars else False |
|
clip = clip if clip is not None else plot_image_vars["clip"] |
|
steps = plot_image_vars['steps'] if "steps" in plot_image_vars else 1 |
|
|
|
sd_version = get_sd_version(plot_image_vars['model']) |
|
|
|
|
|
if plot_image_vars["x_node_type"] == "advanced" or plot_image_vars["y_node_type"] == "advanced": |
|
if self.x_type == "Seeds++ Batch" or self.y_type == "Seeds++ Batch": |
|
seed = int(x_value) if self.x_type == "Seeds++ Batch" else int(y_value) |
|
if self.x_type == "Steps" or self.y_type == "Steps": |
|
steps = int(x_value) if self.x_type == "Steps" else int(y_value) |
|
if self.x_type == "StartStep" or self.y_type == "StartStep": |
|
start_step = int(x_value) if self.x_type == "StartStep" else int(y_value) |
|
if self.x_type == "EndStep" or self.y_type == "EndStep": |
|
last_step = int(x_value) if self.x_type == "EndStep" else int(y_value) |
|
if self.x_type == "CFG Scale" or self.y_type == "CFG Scale": |
|
cfg = float(x_value) if self.x_type == "CFG Scale" else float(y_value) |
|
if self.x_type == "Sampler" or self.y_type == "Sampler": |
|
sampler_name = x_value if self.x_type == "Sampler" else y_value |
|
if self.x_type == "Scheduler" or self.y_type == "Scheduler": |
|
scheduler = x_value if self.x_type == "Scheduler" else y_value |
|
if self.x_type == "Sampler&Scheduler" or self.y_type == "Sampler&Scheduler": |
|
arr = x_value.split(',') if self.x_type == "Sampler&Scheduler" else y_value.split(',') |
|
if arr[0] and arr[0]!= 'None': |
|
sampler_name = arr[0] |
|
if arr[1] and arr[1]!= 'None': |
|
scheduler = arr[1] |
|
if self.x_type == "Denoise" or self.y_type == "Denoise": |
|
denoise = float(x_value) if self.x_type == "Denoise" else float(y_value) |
|
if self.x_type == "Pos Condition" or self.y_type == "Pos Condition": |
|
positive = plot_image_vars['positive_cond_stack'][int(x_value)] if self.x_type == "Pos Condition" else plot_image_vars['positive_cond_stack'][int(y_value)] |
|
if self.x_type == "Neg Condition" or self.y_type == "Neg Condition": |
|
negative = plot_image_vars['negative_cond_stack'][int(x_value)] if self.x_type == "Neg Condition" else plot_image_vars['negative_cond_stack'][int(y_value)] |
|
|
|
if self.x_type == "ModelMergeBlocks" or self.y_type == "ModelMergeBlocks": |
|
ckpt_name_1, ckpt_name_2 = plot_image_vars['models'] |
|
model1, clip1, vae1, clip_vision = self.easyCache.load_checkpoint(ckpt_name_1) |
|
model2, clip2, vae2, clip_vision = self.easyCache.load_checkpoint(ckpt_name_2) |
|
xy_values = x_value if self.x_type == "ModelMergeBlocks" else y_value |
|
if ":" in xy_values: |
|
xy_line = xy_values.split(':') |
|
xy_values = xy_line[1] |
|
|
|
xy_arrs = xy_values.split(',') |
|
|
|
if len(xy_arrs) == 3: |
|
input, middle, out = xy_arrs |
|
kwargs = { |
|
"input": input, |
|
"middle": middle, |
|
"out": out |
|
} |
|
elif len(xy_arrs) == 30: |
|
kwargs = {} |
|
kwargs["time_embed."] = xy_arrs[0] |
|
kwargs["label_emb."] = xy_arrs[1] |
|
|
|
for i in range(12): |
|
kwargs["input_blocks.{}.".format(i)] = xy_arrs[2+i] |
|
|
|
for i in range(3): |
|
kwargs["middle_block.{}.".format(i)] = xy_arrs[14+i] |
|
|
|
for i in range(12): |
|
kwargs["output_blocks.{}.".format(i)] = xy_arrs[17+i] |
|
|
|
kwargs["out."] = xy_arrs[29] |
|
else: |
|
raise Exception("ModelMergeBlocks weight length error") |
|
default_ratio = next(iter(kwargs.values())) |
|
|
|
m = model1.clone() |
|
kp = model2.get_key_patches("diffusion_model.") |
|
|
|
for k in kp: |
|
ratio = float(default_ratio) |
|
k_unet = k[len("diffusion_model."):] |
|
|
|
last_arg_size = 0 |
|
for arg in kwargs: |
|
if k_unet.startswith(arg) and last_arg_size < len(arg): |
|
ratio = float(kwargs[arg]) |
|
last_arg_size = len(arg) |
|
|
|
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) |
|
|
|
vae_use = plot_image_vars['vae_use'] |
|
|
|
clip = clip2 if vae_use == 'Use Model 2' else clip1 |
|
if vae_use == 'Use Model 2': |
|
vae = vae2 |
|
elif vae_use == 'Use Model 1': |
|
vae = vae1 |
|
else: |
|
vae = self.easyCache.load_vae(vae_use) |
|
model = m |
|
|
|
|
|
optional_lora_stack = plot_image_vars['lora_stack'] |
|
if optional_lora_stack is not None and optional_lora_stack != []: |
|
for lora in optional_lora_stack: |
|
model, clip = self.easyCache.load_lora(lora) |
|
|
|
|
|
clip = clip.clone() |
|
if plot_image_vars['clip_skip'] != 0: |
|
clip.clip_layer(plot_image_vars['clip_skip']) |
|
|
|
|
|
if self.x_type == "Checkpoint" or self.y_type == "Checkpoint": |
|
xy_values = x_value if self.x_type == "Checkpoint" else y_value |
|
ckpt_name, clip_skip, vae_name = xy_values.split(",") |
|
ckpt_name = ckpt_name.replace('*', ',') |
|
vae_name = vae_name.replace('*', ',') |
|
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(ckpt_name) |
|
if vae_name != 'None': |
|
vae = self.easyCache.load_vae(vae_name) |
|
|
|
|
|
optional_lora_stack = plot_image_vars['lora_stack'] |
|
if optional_lora_stack is not None and optional_lora_stack != []: |
|
for lora in optional_lora_stack: |
|
lora['model'] = model |
|
lora['clip'] = clip |
|
model, clip = self.easyCache.load_lora(lora) |
|
|
|
|
|
clip = clip.clone() |
|
if clip_skip != 'None': |
|
clip.clip_layer(int(clip_skip)) |
|
positive = plot_image_vars['positive'] |
|
negative = plot_image_vars['negative'] |
|
a1111_prompt_style = plot_image_vars['a1111_prompt_style'] |
|
steps = plot_image_vars['steps'] |
|
clip = clip if clip is not None else plot_image_vars["clip"] |
|
positive = advanced_encode(clip, positive, |
|
plot_image_vars['positive_token_normalization'], |
|
plot_image_vars['positive_weight_interpretation'], |
|
w_max=1.0, |
|
apply_to_pooled="enable", |
|
a1111_prompt_style=a1111_prompt_style, steps=steps) |
|
|
|
negative = advanced_encode(clip, negative, |
|
plot_image_vars['negative_token_normalization'], |
|
plot_image_vars['negative_weight_interpretation'], |
|
w_max=1.0, |
|
apply_to_pooled="enable", |
|
a1111_prompt_style=a1111_prompt_style, steps=steps) |
|
if "positive_cond" in plot_image_vars: |
|
positive = positive + plot_image_vars["positive_cond"] |
|
if "negative_cond" in plot_image_vars: |
|
negative = negative + plot_image_vars["negative_cond"] |
|
|
|
|
|
if self.x_type == "Lora" or self.y_type == "Lora": |
|
model = model if model is not None else plot_image_vars["model"] |
|
clip = clip if clip is not None else plot_image_vars["clip"] |
|
|
|
xy_values = x_value if self.x_type == "Lora" else y_value |
|
lora_name, lora_model_strength, lora_clip_strength, _ = xy_values.split(",") |
|
lora_stack = [{"lora_name": lora_name, "model": model, "clip" :clip, "model_strength": float(lora_model_strength), "clip_strength": float(lora_clip_strength)}] |
|
if 'lora_stack' in plot_image_vars: |
|
lora_stack = lora_stack + plot_image_vars['lora_stack'] |
|
|
|
if lora_stack is not None and lora_stack != []: |
|
for lora in lora_stack: |
|
model, clip = self.easyCache.load_lora(lora) |
|
|
|
|
|
if "Positive" in self.x_type or "Positive" in self.y_type: |
|
if self.x_type == 'Positive Prompt S/R' or self.y_type == 'Positive Prompt S/R': |
|
positive = x_value if self.x_type == "Positive Prompt S/R" else y_value |
|
|
|
if sd_version == 'flux': |
|
positive, = CLIPTextEncode().encode(clip, positive) |
|
else: |
|
positive = advanced_encode(clip, positive, |
|
plot_image_vars['positive_token_normalization'], |
|
plot_image_vars['positive_weight_interpretation'], |
|
w_max=1.0, |
|
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps) |
|
|
|
|
|
|
|
|
|
if "Negative" in self.x_type or "Negative" in self.y_type: |
|
if self.x_type == 'Negative Prompt S/R' or self.y_type == 'Negative Prompt S/R': |
|
negative = x_value if self.x_type == "Negative Prompt S/R" else y_value |
|
|
|
if sd_version == 'flux': |
|
negative, = CLIPTextEncode().encode(clip, negative) |
|
else: |
|
negative = advanced_encode(clip, negative, |
|
plot_image_vars['negative_token_normalization'], |
|
plot_image_vars['negative_weight_interpretation'], |
|
w_max=1.0, |
|
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps) |
|
|
|
|
|
|
|
|
|
if "ControlNet" in self.x_type or "ControlNet" in self.y_type: |
|
cnet = plot_image_vars["cnet"] if "cnet" in plot_image_vars else None |
|
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None |
|
negative = plot_image_vars["negative_cond"] if "negative" in plot_image_vars else None |
|
if cnet: |
|
index = x_value if "ControlNet" in self.x_type else y_value |
|
controlnet = cnet[index] |
|
for index, item in enumerate(controlnet): |
|
control_net_name = item[0] |
|
image = item[1] |
|
strength = item[2] |
|
start_percent = item[3] |
|
end_percent = item[4] |
|
positive, negative = easyControlnet().apply(control_net_name, image, positive, negative, strength, start_percent, end_percent, None, 1) |
|
|
|
if self.x_type == "Flux Guidance" or self.y_type == "Flux Guidance": |
|
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None |
|
flux_guidance = float(x_value) if self.x_type == "Flux Guidance" else float(y_value) |
|
positive, = FluxGuidance().append(positive, flux_guidance) |
|
|
|
|
|
if plot_image_vars["x_node_type"] == "loader" or plot_image_vars["y_node_type"] == "loader": |
|
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name']) |
|
|
|
if plot_image_vars['lora_name'] != "None": |
|
lora = {"lora_name": plot_image_vars['lora_name'], "model": model, "clip": clip, "model_strength": plot_image_vars['lora_model_strength'], "clip_strength": plot_image_vars['lora_clip_strength']} |
|
model, clip = self.easyCache.load_lora(lora) |
|
|
|
|
|
if plot_image_vars['vae_name'] not in ["Baked-VAE", "Baked VAE"]: |
|
vae = self.easyCache.load_vae(plot_image_vars['vae_name']) |
|
|
|
|
|
if not clip: |
|
raise Exception("No CLIP found") |
|
clip = clip.clone() |
|
clip.clip_layer(plot_image_vars['clip_skip']) |
|
|
|
if sd_version == 'flux': |
|
positive, = CLIPTextEncode().encode(clip, positive) |
|
else: |
|
positive = advanced_encode(clip, plot_image_vars['positive'], |
|
plot_image_vars['positive_token_normalization'], |
|
plot_image_vars['positive_weight_interpretation'], w_max=1.0, |
|
apply_to_pooled="enable",a1111_prompt_style=a1111_prompt_style, steps=steps) |
|
|
|
if sd_version == 'flux': |
|
negative, = CLIPTextEncode().encode(clip, negative) |
|
else: |
|
negative = advanced_encode(clip, plot_image_vars['negative'], |
|
plot_image_vars['negative_token_normalization'], |
|
plot_image_vars['negative_weight_interpretation'], w_max=1.0, |
|
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps) |
|
|
|
model = model if model is not None else plot_image_vars["model"] |
|
vae = vae if vae is not None else plot_image_vars["vae"] |
|
positive = positive if positive is not None else plot_image_vars["positive_cond"] |
|
negative = negative if negative is not None else plot_image_vars["negative_cond"] |
|
|
|
seed = seed if seed is not None else plot_image_vars["seed"] |
|
steps = steps if steps is not None else plot_image_vars["steps"] |
|
cfg = cfg if cfg is not None else plot_image_vars["cfg"] |
|
sampler_name = sampler_name if sampler_name is not None else plot_image_vars["sampler_name"] |
|
scheduler = scheduler if scheduler is not None else plot_image_vars["scheduler"] |
|
denoise = denoise if denoise is not None else plot_image_vars["denoise"] |
|
|
|
noise_device = plot_image_vars["noise_device"] if "noise_device" in plot_image_vars else 'cpu' |
|
|
|
|
|
layer_diffusion_method = plot_image_vars["layer_diffusion_method"] if "layer_diffusion_method" in plot_image_vars else None |
|
empty_samples = plot_image_vars["empty_samples"] if "empty_samples" in plot_image_vars else None |
|
|
|
if layer_diffusion_method: |
|
samp_blend_samples = plot_image_vars["blend_samples"] if "blend_samples" in plot_image_vars else None |
|
additional_cond = plot_image_vars["layer_diffusion_cond"] if "layer_diffusion_cond" in plot_image_vars else None |
|
|
|
images = plot_image_vars["images"].movedim(-1, 1) if "images" in plot_image_vars else None |
|
weight = plot_image_vars['layer_diffusion_weight'] if 'layer_diffusion_weight' in plot_image_vars else 1.0 |
|
model, positive, negative = LayerDiffuse().apply_layer_diffusion(model, layer_diffusion_method, weight, samples, |
|
samp_blend_samples, positive, |
|
negative, images, additional_cond) |
|
|
|
samples = empty_samples if layer_diffusion_method is not None and empty_samples is not None else samples |
|
|
|
samples = self.sampler.common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, samples, |
|
denoise=denoise, disable_noise=disable_noise, preview_latent=preview_latent, |
|
start_step=start_step, last_step=last_step, |
|
force_full_denoise=force_full_denoise, noise_device=noise_device) |
|
|
|
|
|
latent = samples["samples"] |
|
|
|
|
|
latents_plot.append(latent) |
|
|
|
|
|
image = vae.decode(latent).cpu() |
|
|
|
if self.output_individuals in [True, "True"]: |
|
easySave(image, self.save_prefix, self.image_output) |
|
|
|
|
|
pil_image = self.sampler.tensor2pil(image) |
|
image_list.append(pil_image) |
|
|
|
|
|
self.max_width = max(self.max_width, pil_image.width) |
|
self.max_height = max(self.max_height, pil_image.height) |
|
|
|
|
|
return image_list, self.max_width, self.max_height, latents_plot |
|
|
|
|
|
def validate_xy_plot(self): |
|
if self.x_type == 'None' and self.y_type == 'None': |
|
log_node_warn(f'#{self.my_unique_id}','No Valid Plot Types - Reverting to default sampling...') |
|
return False |
|
else: |
|
return True |
|
|
|
def get_latent(self, samples): |
|
|
|
latent_image_tensor = samples["samples"] |
|
|
|
|
|
image_tensors = torch.split(latent_image_tensor, 1, dim=0) |
|
|
|
|
|
latent_list = [{'samples': image} for image in image_tensors] |
|
|
|
|
|
if self.latent_id >= len(latent_list): |
|
log_node_warn(f'#{self.my_unique_id}',f'The selected latent_id ({self.latent_id}) is out of range.') |
|
log_node_warn(f'#{self.my_unique_id}', f'Automatically setting the latent_id to the last image in the list (index: {len(latent_list) - 1}).') |
|
|
|
self.latent_id = len(latent_list) - 1 |
|
|
|
return latent_list[self.latent_id] |
|
|
|
def get_labels_and_sample(self, plot_image_vars, latent_image, preview_latent, start_step, last_step, |
|
force_full_denoise, disable_noise): |
|
for x_index, x_value in enumerate(self.x_values): |
|
plot_image_vars, x_value_label = self.define_variable(plot_image_vars, self.x_type, x_value, |
|
x_index) |
|
self.x_label = self.update_label(self.x_label, x_value_label, len(self.x_values)) |
|
if self.y_type != 'None': |
|
for y_index, y_value in enumerate(self.y_values): |
|
plot_image_vars, y_value_label = self.define_variable(plot_image_vars, self.y_type, y_value, |
|
y_index) |
|
self.y_label = self.update_label(self.y_label, y_value_label, len(self.y_values)) |
|
|
|
|
|
|
|
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image( |
|
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list, |
|
disable_noise, start_step, last_step, force_full_denoise, x_value, y_value) |
|
self.num += 1 |
|
else: |
|
|
|
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image( |
|
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list, disable_noise, |
|
start_step, last_step, force_full_denoise, x_value) |
|
self.num += 1 |
|
|
|
|
|
self.latents_plot = self.rearrange_tensors(self.latents_plot, self.num_cols, self.num_rows) |
|
|
|
|
|
self.latents_plot = torch.cat(self.latents_plot, dim=0) |
|
|
|
return self.latents_plot |
|
|
|
def plot_images_and_labels(self): |
|
|
|
bg_width, bg_height, x_offset_initial, y_offset = self.calculate_background_dimensions() |
|
|
|
|
|
background = Image.new('RGBA', (int(bg_width), int(bg_height)), color=(255, 255, 255, 255)) |
|
|
|
output_image = [] |
|
for row_index in range(self.num_rows): |
|
x_offset = x_offset_initial |
|
|
|
for col_index in range(self.num_cols): |
|
index = col_index * self.num_rows + row_index |
|
img = self.image_list[index] |
|
output_image.append(self.sampler.pil2tensor(img)) |
|
background.paste(img, (x_offset, y_offset)) |
|
|
|
|
|
if row_index == 0 and self.x_type != "None": |
|
label_bg = self.create_label(img, self.x_label[col_index], int(48 * img.width / 512)) |
|
label_y = (y_offset - label_bg.height) // 2 |
|
background.alpha_composite(label_bg, (x_offset, label_y)) |
|
|
|
|
|
if col_index == 0 and self.y_type != "None": |
|
label_bg = self.create_label(img, self.y_label[row_index], int(48 * img.height / 512), False) |
|
label_bg = label_bg.rotate(90, expand=True) |
|
|
|
label_x = (x_offset - label_bg.width) // 2 |
|
label_y = y_offset + (img.height - label_bg.height) // 2 |
|
background.alpha_composite(label_bg, (label_x, label_y)) |
|
|
|
x_offset += img.width + self.grid_spacing |
|
|
|
y_offset += img.height + self.grid_spacing |
|
|
|
return (self.sampler.pil2tensor(background), output_image) |