DiffusionModel / sdxl_gen_img.py
thorfinn0330's picture
Upload folder using huggingface_hub
11c2c17 verified
raw
history blame
110 kB
import itertools
import json
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
import inspect
import time
import zipfile
from diffusers.utils import deprecate
from diffusers.configuration_utils import FrozenDict
import argparse
import math
import os
import random
import re
import diffusers
import numpy as np
import torch
import torchvision
from diffusers import (
AutoencoderKL,
DDPMScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
DDIMScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
# UNet2DConditionModel,
StableDiffusionPipeline,
)
from einops import rearrange
from tqdm import tqdm
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig
import PIL
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import library.model_util as model_util
import library.train_util as train_util
import library.sdxl_model_util as sdxl_model_util
import library.sdxl_train_util as sdxl_train_util
from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from library.sdxl_original_unet import SdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
# その他の設定
LATENT_CHANNELS = 4
DOWNSAMPLING_FACTOR = 8
# region モジュール入れ替え部
"""
高速化のためのモジュール入れ替え
"""
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
unet.set_use_memory_efficient_attention(False, True)
elif xformers:
print("Enable xformers for U-Net")
try:
import xformers.ops
except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True)
# TODO common train_util.py
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
replace_vae_attn_to_memory_efficient()
elif xformers:
replace_vae_attn_to_xformers()
elif sdpa:
replace_vae_attn_to_sdpa()
def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states, **kwargs):
q_bucket_size = 512
k_bucket_size = 1024
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.to_q(hidden_states)
key_proj = self.to_k(hidden_states)
value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj)
)
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_flash_attn_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_flash_attn(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_flash_attn
def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.to_q(hidden_states)
key_proj = self.to_k(hidden_states)
value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj)
)
query_proj = query_proj.contiguous()
key_proj = key_proj.contiguous()
value_proj = value_proj.contiguous()
out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
out = rearrange(out, "b h n d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_xformers_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_xformers(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_xformers
def replace_vae_attn_to_sdpa():
print("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.to_q(hidden_states)
key_proj = self.to_k(hidden_states)
value_proj = self.to_v(hidden_states)
query_proj, key_proj, value_proj = map(
lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj)
)
out = torch.nn.functional.scaled_dot_product_attention(
query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False
)
out = rearrange(out, "b n h d -> b n (h d)")
# compute next hidden_states
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward_sdpa_0_14(self, hidden_states, **kwargs):
if not hasattr(self, "to_q"):
self.to_q = self.query
self.to_k = self.key
self.to_v = self.value
self.to_out = [self.proj_attn, torch.nn.Identity()]
self.heads = self.num_heads
return forward_sdpa(self, hidden_states, **kwargs)
if diffusers.__version__ < "0.15.0":
diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14
else:
diffusers.models.attention_processor.Attention.forward = forward_sdpa
# endregion
# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正
# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正
class PipelineLike:
def __init__(
self,
device,
vae: AutoencoderKL,
text_encoders: List[CLIPTextModel],
tokenizers: List[CLIPTokenizer],
unet: SdxlUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int,
):
super().__init__()
self.device = device
self.clip_skip = clip_skip
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
self.vae = vae
self.text_encoders = text_encoders
self.tokenizers = tokenizers
self.unet: SdxlUNet2DConditionModel = unet
self.scheduler = scheduler
self.safety_checker = None
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
self.token_replacements_list.append({})
# ControlNet # not supported yet
self.control_nets: List[ControlNetInfo] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
# Textual Inversion
def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids):
self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids
def set_enable_control_net(self, en: bool):
self.control_net_enabled = en
def get_token_replacer(self, tokenizer):
tokenizer_index = self.tokenizers.index(tokenizer)
token_replacements = self.token_replacements_list[tokenizer_index]
def replace_tokens(tokens):
# print("replace_tokens", tokens, "=>", token_replacements)
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
new_tokens = []
for token in tokens:
if token in token_replacements:
replacement = token_replacements[token]
new_tokens.extend(replacement)
else:
new_tokens.append(token)
return new_tokens
return replace_tokens
def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
height: int = 1024,
width: int = 1024,
original_height: int = None,
original_width: int = None,
crop_top: int = 0,
crop_left: int = 0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_scale: float = None,
strength: float = 0.8,
# num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
vae_batch_size: float = None,
return_latents: bool = False,
# return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1,
img2img_noise=None,
**kwargs,
):
# TODO support secondary prompt
num_images_per_prompt = 1 # fixed because already prompt is repeated
if isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
reginonal_network = " AND " in prompt[0]
vae_batch_size = (
batch_size
if vae_batch_size is None
else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
)
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
)
# get prompt text embeddings
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None:
print(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None
# get unconditional embeddings for classifier free guidance
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
tes_text_embs = []
tes_uncond_embs = []
tes_real_uncond_embs = []
# use last pool
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
token_replacer = self.get_token_replacer(tokenizer)
text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings(
tokenizer,
text_encoder,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
**kwargs,
)
tes_text_embs.append(text_embeddings)
tes_uncond_embs.append(uncond_embeddings)
if negative_scale is not None:
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
token_replacer,
prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須
uncond_prompt=[""] * batch_size,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
**kwargs,
)
tes_real_uncond_embs.append(real_uncond_embeddings)
# concat text encoder outputs
text_embeddings = tes_text_embs[0]
uncond_embeddings = tes_uncond_embs[0]
for i in range(1, len(tes_text_embs)):
text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048
uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
if self.control_nets:
if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images]
# ControlNetのhintにguide imageを流用する
# 前処理はControlNet側で行う
# create size embs
if original_height is None:
original_height = height
if original_width is None:
original_width = width
if crop_top is None:
crop_top = 0
if crop_left is None:
crop_left = 0
emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype)
uc_vector = c_vector.clone().to(self.device, dtype=text_embeddings.dtype)
c_vector = torch.cat([text_pool, c_vector], dim=1)
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
vector_embeddings = torch.cat([uc_vector, c_vector])
# set timesteps
self.scheduler.set_timesteps(num_inference_steps, self.device)
latents_dtype = text_embeddings.dtype
init_latents_orig = None
mask = None
if init_image is None:
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (
batch_size * num_images_per_prompt,
self.unet.in_channels,
height // 8,
width // 8,
)
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
latents = torch.randn(
latents_shape,
generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
timesteps = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
else:
# image to tensor
if isinstance(init_image, PIL.Image.Image):
init_image = [init_image]
if isinstance(init_image[0], PIL.Image.Image):
init_image = [preprocess_image(im) for im in init_image]
init_image = torch.cat(init_image)
if isinstance(init_image, list):
init_image = torch.stack(init_image)
# mask image to tensor
if mask_image is not None:
if isinstance(mask_image, PIL.Image.Image):
mask_image = [mask_image]
if isinstance(mask_image[0], PIL.Image.Image):
mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint
# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
if init_image.size()[-2:] == (height // 8, width // 8):
init_latents = init_image
else:
if vae_batch_size >= batch_size:
init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
(init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to(
self.vae.dtype
)
).latent_dist
init_latents.append(init_latent_dist.sample(generator=generator))
init_latents = torch.cat(init_latents)
init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
if len(init_latents) == 1:
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
init_latents_orig = init_latents
# preprocess mask
if mask_image is not None:
mask = mask_image.to(device=self.device, dtype=latents_dtype)
if len(mask) == 1:
mask = mask.repeat((batch_size, 1, 1, 1))
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
if self.control_nets and self.control_net_enabled:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
# not working yet
noise_pred = original_control_net.call_unet_and_control_net(
i,
num_latent_input,
self.unet,
self.control_nets,
guided_hints,
i / len(timesteps),
latent_model_input,
t,
text_emb_last,
).sample
else:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
# perform guidance
if do_classifier_free_guidance:
if negative_scale is None:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(
num_latent_input
) # uncond is real uncond
noise_pred = (
noise_pred_uncond
+ guidance_scale * (noise_pred_text - noise_pred_uncond)
- negative_scale * (noise_pred_negative - noise_pred_uncond)
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if i % callback_steps == 0:
if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
if return_latents:
return (latents, False)
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
if vae_batch_size >= batch_size:
image = self.vae.decode(latents.to(self.vae.dtype)).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
self.vae.decode(
(latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype)
).sample
)
image = torch.cat(images)
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
# image = self.numpy_to_pil(image)
image = (image * 255).round().astype("uint8")
image = [Image.fromarray(im) for im in image]
return image
# return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
# keep break as separate token
text = text.replace("BREAK", "\\BREAK\\")
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK":
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
if word.strip() == "BREAK":
# pad until next multiple of tokenizer's max token length
pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length)
print(f"BREAK pad_len: {pad_len}")
for i in range(pad_len):
# v2のときEOSをつけるべきかどうかわからないぜ
# if i == 0:
# text_token.append(tokenizer.eos_token_id)
# else:
text_token.append(tokenizer.pad_token_id)
text_weight.append(1.0)
continue
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
token = token_replacer(token) # for Textual Inversion
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
text_encoder: CLIPTextModel,
text_input: torch.Tensor,
chunk_length: int,
clip_skip: int,
eos: int,
pad: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
pool = None
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1]
else: # v2
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
# -2 is same for Text Encoder 1 and 2
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-2]
if pool is None:
pool = enc_out["text_embeds"] # use 1st chunk
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-2]
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
return text_embeddings, pool
def get_weighted_text_embeddings(
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 1,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
clip_skip=None,
token_replacer=None,
device=None,
**kwargs,
):
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
# split the prompts with "AND". each prompt must have the same number of splits
new_prompts = []
for p in prompt:
new_prompts.extend(p.split(" AND "))
prompt = new_prompts
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2)
else:
prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = tokenizer.bos_token_id
eos = tokenizer.eos_token_id
pad = tokenizer.pad_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=tokenizer.model_max_length,
)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=tokenizer.model_max_length,
)
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
# get the embeddings
text_embeddings, text_pool = get_unweighted_text_embeddings(
text_encoder,
prompt_tokens,
tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
if uncond_prompt is not None:
uncond_embeddings, uncond_pool = get_unweighted_text_embeddings(
text_encoder,
uncond_tokens,
tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle,
)
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
# →全体でいいんじゃないかな
if (not skip_parsing) and (not skip_weighting):
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens
return text_embeddings, text_pool, None, None, prompt_tokens
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = torch.from_numpy(mask)
return mask
# regular expression for dynamic prompt:
# starts and ends with "{" and "}"
# contains at least one variant divided by "|"
# optional framgments divided by "$$" at start
# if the first fragment is "E" or "e", enumerate all variants
# if the second fragment is a number or two numbers, repeat the variants in the range
# if the third fragment is a string, use it as a separator
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
def handle_dynamic_prompt_variants(prompt, repeat_count):
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
if not founds:
return [prompt]
# make each replacement for each variant
enumerating = False
replacers = []
for found in founds:
# if "e$$" is found, enumerate all variants
found_enumerating = found.group(2) is not None
enumerating = enumerating or found_enumerating
separator = ", " if found.group(6) is None else found.group(6)
variants = found.group(7).split("|")
# parse count range
count_range = found.group(4)
if count_range is None:
count_range = [1, 1]
else:
count_range = count_range.split("-")
if len(count_range) == 1:
count_range = [int(count_range[0]), int(count_range[0])]
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
print(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
if count_range[0] < 0:
count_range[0] = 0
if count_range[1] > len(variants):
count_range[1] = len(variants)
if found_enumerating:
# make function to enumerate all combinations
def make_replacer_enum(vari, cr, sep):
def replacer():
values = []
for count in range(cr[0], cr[1] + 1):
for comb in itertools.combinations(vari, count):
values.append(sep.join(comb))
return values
return replacer
replacers.append(make_replacer_enum(variants, count_range, separator))
else:
# make function to choose random combinations
def make_replacer_single(vari, cr, sep):
def replacer():
count = random.randint(cr[0], cr[1])
comb = random.sample(vari, count)
return [sep.join(comb)]
return replacer
replacers.append(make_replacer_single(variants, count_range, separator))
# make each prompt
if not enumerating:
# if not enumerating, repeat the prompt, replace each variant randomly
prompts = []
for _ in range(repeat_count):
current = prompt
for found, replacer in zip(founds, replacers):
current = current.replace(found.group(0), replacer()[0], 1)
prompts.append(current)
else:
# if enumerating, iterate all combinations for previous prompts
prompts = [prompt]
for found, replacer in zip(founds, replacers):
if found.group(2) is not None:
# make all combinations for existing prompts
new_prompts = []
for current in prompts:
replecements = replacer()
for replecement in replecements:
new_prompts.append(current.replace(found.group(0), replecement, 1))
prompts = new_prompts
for found, replacer in zip(founds, replacers):
# make random selection for existing prompts
if found.group(2) is None:
for i in range(len(prompts)):
prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1)
return prompts
# endregion
# def load_clip_l14_336(dtype):
# print(f"loading CLIP: {CLIP_ID_L14_336}")
# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
# return text_encoder
class BatchDataBase(NamedTuple):
# バッチ分割が必要ないデータ
step: int
prompt: str
negative_prompt: str
seed: int
init_image: Any
mask_image: Any
clip_prompt: str
guide_image: Any
class BatchDataExt(NamedTuple):
# バッチ分割が必要なデータ
width: int
height: int
original_width: int
original_height: int
crop_left: int
crop_top: int
steps: int
scale: float
negative_scale: float
strength: float
network_muls: Tuple[float]
num_sub_prompts: int
class BatchData(NamedTuple):
return_latents: bool
base: BatchDataBase
ext: BatchDataExt
def main(args):
if args.fp16:
dtype = torch.float16
elif args.bf16:
dtype = torch.bfloat16
else:
dtype = torch.float32
highres_fix = args.highres_fix_scale is not None
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
# モデルを読み込む
if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
files = glob.glob(args.ckpt)
if len(files) == 1:
args.ckpt = files[0]
use_stable_diffusion_format = os.path.isfile(args.ckpt)
assert use_stable_diffusion_format, "Diffusers pretrained models are not supported yet"
print("load StableDiffusion checkpoint")
text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt, "cpu"
)
# else:
# print("load Diffusers pretrained models")
# TODO use Diffusers 0.18.1 and support SDXL pipeline
# raise NotImplementedError("Diffusers pretrained models are not supported yet")
# loading_pipe = StableDiffusionXLPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
# text_encoder = loading_pipe.text_encoder
# vae = loading_pipe.vae
# unet = loading_pipe.unet
# tokenizer = loading_pipe.tokenizer
# del loading_pipe
# # Diffusers U-Net to original U-Net
# original_unet = SdxlUNet2DConditionModel(
# unet.config.sample_size,
# unet.config.attention_head_dim,
# unet.config.cross_attention_dim,
# unet.config.use_linear_projection,
# unet.config.upcast_attention,
# )
# original_unet.load_state_dict(unet.state_dict())
# unet = original_unet
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype)
print("additional VAE loaded")
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
mem_eff = not (args.xformers or args.sdpa)
replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa)
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
# tokenizerを読み込む
print("loading tokenizer")
if use_stable_diffusion_format:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
# schedulerを用意する
sched_init_args = {}
scheduler_num_noises_per_step = 1
if args.sampler == "ddim":
scheduler_cls = DDIMScheduler
scheduler_module = diffusers.schedulers.scheduling_ddim
elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
scheduler_module = diffusers.schedulers.scheduling_ddpm
elif args.sampler == "pndm":
scheduler_cls = PNDMScheduler
scheduler_module = diffusers.schedulers.scheduling_pndm
elif args.sampler == "lms" or args.sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_lms_discrete
elif args.sampler == "euler" or args.sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sampler
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
elif args.sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
elif args.sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_heun_discrete
elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
scheduler_num_noises_per_step = 2
# samplerの乱数をあらかじめ指定するための処理
# replace randn
class NoiseManager:
def __init__(self):
self.sampler_noises = None
self.sampler_noise_index = 0
def reset_sampler_noises(self, noises):
self.sampler_noise_index = 0
self.sampler_noises = noises
def randn(self, shape, device=None, dtype=None, layout=None, generator=None):
# print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index)
if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises):
noise = self.sampler_noises[self.sampler_noise_index]
if shape != noise.shape:
noise = None
else:
noise = None
if noise == None:
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
self.sampler_noise_index += 1
return noise
class TorchRandReplacer:
def __init__(self, noise_manager):
self.noise_manager = noise_manager
def __getattr__(self, item):
if item == "randn":
return self.noise_manager.randn
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
noise_manager = NoiseManager()
if scheduler_module is not None:
scheduler_module.torch = TorchRandReplacer(noise_manager)
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
)
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
print("set clip_sample to True")
scheduler.config.clip_sample = True
# deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
# custom pipelineをコピったやつを生成する
if args.vae_slices:
from library.slicing_vae import SlicingAutoencoderKL
sli_vae = SlicingAutoencoderKL(
act_fn="silu",
block_out_channels=(128, 256, 512, 512),
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
in_channels=3,
latent_channels=4,
layers_per_block=2,
norm_num_groups=32,
out_channels=3,
sample_size=512,
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
num_slices=args.vae_slices,
)
sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする
vae = sli_vae
del sli_vae
vae_dtype = dtype
if args.no_half_vae:
print("set vae_dtype to float32")
vae_dtype = torch.float32
vae.to(vae_dtype).to(device)
text_encoder1.to(dtype).to(device)
text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device)
# networkを組み込む
if args.network_module:
networks = []
network_default_muls = []
network_pre_calc = args.network_pre_calc
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {}
if args.network_args and i < len(args.network_args):
network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return
mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergeable:
network.apply_to([text_encoder1, text_encoder2], unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
if network_pre_calc:
print("backup original weights")
network.backup_weights()
networks.append(network)
else:
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)
else:
networks = []
# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler)
imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {}
if args.highres_fix_upscaler_args:
for net_arg in args.highres_fix_upscaler_args.split(";"):
key, value = net_arg.split("=")
us_kwargs[key] = value
print("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)
# ControlNetの処理
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
if args.opt_channels_last:
print(f"set optimizing: channels last")
text_encoder1.to(memory_format=torch.channels_last)
text_encoder2.to(memory_format=torch.channels_last)
vae.to(memory_format=torch.channels_last)
unet.to(memory_format=torch.channels_last)
if networks:
for network in networks:
network.to(memory_format=torch.channels_last)
for cn in control_nets:
cn.unet.to(memory_format=torch.channels_last)
cn.net.to(memory_format=torch.channels_last)
pipe = PipelineLike(
device,
vae,
[text_encoder1, text_encoder2],
[tokenizer1, tokenizer2],
unet,
scheduler,
args.clip_skip,
)
pipe.set_control_nets(control_nets)
print("pipeline is ready.")
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
# Textual Inversionを処理する
if args.textual_inversion_embeddings:
token_ids_embeds1 = []
token_ids_embeds2 = []
for embeds_file in args.textual_inversion_embeddings:
if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file
data = load_file(embeds_file)
else:
data = torch.load(embeds_file, map_location="cpu")
if "string_to_param" in data:
data = data["string_to_param"]
embeds1 = data["clip_l"] # text encoder 1
embeds2 = data["clip_g"] # text encoder 2
num_vectors_per_token = embeds1.size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
# remove non-alphabet characters to avoid splitting by tokenizer
# TODO make random alphabet string
token_string = "".join([c for c in token_string if c.isalpha()])
token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
)
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
assert (
min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1
), f"token ids1 is not ordered"
assert (
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
), f"token ids2 is not ordered"
assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"
if num_vectors_per_token > 1:
pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
pipe.add_token_replacement(1, token_ids2[0], token_ids2)
token_ids_embeds1.append((token_ids1, embeds1))
token_ids_embeds2.append((token_ids2, embeds2))
text_encoder1.resize_token_embeddings(len(tokenizer1))
text_encoder2.resize_token_embeddings(len(tokenizer2))
token_embeds1 = text_encoder1.get_input_embeddings().weight.data
token_embeds2 = text_encoder2.get_input_embeddings().weight.data
for token_ids, embeds in token_ids_embeds1:
for token_id, embed in zip(token_ids, embeds):
token_embeds1[token_id] = embed
for token_ids, embeds in token_ids_embeds2:
for token_id, embed in zip(token_ids, embeds):
token_embeds2[token_id] = embed
# promptを取得する
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
elif args.prompt is not None:
prompt_list = [args.prompt]
else:
prompt_list = []
if args.interactive:
args.n_iter = 1
# img2imgの前処理、画像の読み込みなど
def load_images(path):
if os.path.isfile(path):
paths = [path]
else:
paths = (
glob.glob(os.path.join(path, "*.png"))
+ glob.glob(os.path.join(path, "*.jpg"))
+ glob.glob(os.path.join(path, "*.jpeg"))
+ glob.glob(os.path.join(path, "*.webp"))
)
paths.sort()
images = []
for p in paths:
image = Image.open(p)
if image.mode != "RGB":
print(f"convert image to RGB from {image.mode}: {p}")
image = image.convert("RGB")
images.append(image)
return images
def resize_images(imgs, size):
resized = []
for img in imgs:
r_img = img.resize(size, Image.Resampling.LANCZOS)
if hasattr(img, "filename"): # filename属性がない場合があるらしい
r_img.filename = img.filename
resized.append(r_img)
return resized
if args.image_path is not None:
print(f"load image for img2img: {args.image_path}")
init_images = load_images(args.image_path)
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
print(f"loaded {len(init_images)} images for img2img")
else:
init_images = None
if args.mask_path is not None:
print(f"load mask for inpainting: {args.mask_path}")
mask_images = load_images(args.mask_path)
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
print(f"loaded {len(mask_images)} mask images for inpainting")
else:
mask_images = None
# promptがないとき、画像のPngInfoから取得する
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
print("get prompts from images' meta data")
for img in init_images:
if "prompt" in img.text:
prompt = img.text["prompt"]
if "negative-prompt" in img.text:
prompt += " --n " + img.text["negative-prompt"]
prompt_list.append(prompt)
# プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する)
l = []
for im in init_images:
l.extend([im] * args.images_per_prompt)
init_images = l
if mask_images is not None:
l = []
for im in mask_images:
l.extend([im] * args.images_per_prompt)
mask_images = l
# 画像サイズにオプション指定があるときはリサイズする
if args.W is not None and args.H is not None:
# highres fix を考慮に入れる
w, h = args.W, args.H
if highres_fix:
w = int(w * args.highres_fix_scale + 0.5)
h = int(h * args.highres_fix_scale + 0.5)
if init_images is not None:
print(f"resize img2img source images to {w}*{h}")
init_images = resize_images(init_images, (w, h))
if mask_images is not None:
print(f"resize img2img mask images to {w}*{h}")
mask_images = resize_images(mask_images, (w, h))
regional_network = False
if networks and mask_images:
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True
print("use mask as region")
size = None
for i, network in enumerate(networks):
if i < 3:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None
prev_image = None # for VGG16 guided
if args.guide_image_path is not None:
print(f"load image for ControlNet guidance: {args.guide_image_path}")
guide_images = []
for p in args.guide_image_path:
guide_images.extend(load_images(p))
print(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0:
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
guide_images = None
else:
guide_images = None
# seed指定時はseedを決めておく
if args.seed is not None:
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
random.seed(args.seed)
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
if len(predefined_seeds) == 1:
predefined_seeds[0] = args.seed
else:
predefined_seeds = None
# デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み)
if args.W is None:
args.W = 1024
if args.H is None:
args.H = 1024
# 画像生成のループ
os.makedirs(args.outdir, exist_ok=True)
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7FFFFFFF)
# バッチ処理の関数
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
batch_size = len(batch)
# highres_fixの処理
if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
print("process 1st stage")
batch_1st = []
for _, base, ext in batch:
def scale_and_round(x):
if x is None:
return None
return int(x * args.highres_fix_scale + 0.5)
width_1st = scale_and_round(ext.width)
height_1st = scale_and_round(ext.height)
width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32
original_width_1st = scale_and_round(ext.original_width)
original_height_1st = scale_and_round(ext.original_height)
crop_left_1st = scale_and_round(ext.crop_left)
crop_top_1st = scale_and_round(ext.crop_top)
strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength
ext_1st = BatchDataExt(
width_1st,
height_1st,
original_width_1st,
original_height_1st,
crop_left_1st,
crop_top_1st,
args.highres_fix_steps,
ext.scale,
ext.negative_scale,
strength_1st,
ext.network_muls,
ext.num_sub_prompts,
)
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage")
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler:
# upscalerを使って画像を拡大する
lowreso_imgs = None if is_1st_latent else images_1st
lowreso_latents = None if not is_1st_latent else images_1st
# 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
batch_size = len(images_1st)
vae_batch_size = (
batch_size
if args.vae_batch_size is None
else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
)
vae_batch_size = int(vae_batch_size)
images_1st = upscaler.upscale(
vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
)
elif args.highres_fix_latents_upscaling:
# latentを拡大する
org_dtype = images_1st.dtype
if images_1st.dtype == torch.bfloat16:
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
images_1st = torch.nn.functional.interpolate(
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear"
) # , antialias=True)
images_1st = images_1st.to(org_dtype)
else:
# 画像をLANCZOSで拡大する
images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st]
batch_2nd = []
for i, (bd, image) in enumerate(zip(batch, images_1st)):
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(bd_2nd)
batch = batch_2nd
if args.highres_fix_disable_control_net:
pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする
# このバッチの情報を取り出す
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image),
(
width,
height,
original_width,
original_height,
crop_left,
crop_top,
steps,
scale,
negative_scale,
strength,
network_muls,
num_sub_prompts,
),
) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
prompts = []
negative_prompts = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
for _ in range(steps * scheduler_num_noises_per_step)
]
seeds = []
clip_prompts = []
if init_image is not None: # img2img?
i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
init_images = []
if mask_image is not None:
mask_images = []
else:
mask_images = None
else:
i2i_noises = None
init_images = None
mask_images = None
if guide_image is not None: # CLIP image guided?
guide_images = []
else:
guide_images = None
# バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする
all_images_are_same = True
all_masks_are_same = True
all_guide_images_are_same = True
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
prompts.append(prompt)
negative_prompts.append(negative_prompt)
seeds.append(seed)
clip_prompts.append(clip_prompt)
if init_image is not None:
init_images.append(init_image)
if i > 0 and all_images_are_same:
all_images_are_same = init_images[-2] is init_image
if mask_image is not None:
mask_images.append(mask_image)
if i > 0 and all_masks_are_same:
all_masks_are_same = mask_images[-2] is mask_image
if guide_image is not None:
if type(guide_image) is list:
guide_images.extend(guide_image)
all_guide_images_are_same = False
else:
guide_images.append(guide_image)
if i > 0 and all_guide_images_are_same:
all_guide_images_are_same = guide_images[-2] is guide_image
# make start code
torch.manual_seed(seed)
start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype)
# make each noises
for j in range(steps * scheduler_num_noises_per_step):
noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype)
if i2i_noises is not None: # img2img noise
i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype)
noise_manager.reset_sampler_noises(noises)
# すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する
if init_images is not None and all_images_are_same:
init_images = init_images[0]
if mask_images is not None and all_masks_are_same:
mask_images = mask_images[0]
if guide_images is not None and all_guide_images_are_same:
guide_images = guide_images[0]
# ControlNet使用時はguide imageをリサイズする
if control_nets:
# TODO resampleのメソッド
guide_images = guide_images if type(guide_images) == list else [guide_images]
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
if len(guide_images) == 1:
guide_images = guide_images[0]
# generate
if networks:
# 追加ネットワークの処理
shared = {}
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
if not regional_network and network_pre_calc:
for n in networks:
n.restore_weights()
for n in networks:
n.pre_calculation()
print("pre-calculation... done")
images = pipe(
prompts,
negative_prompts,
init_images,
mask_images,
height,
width,
original_height,
original_width,
crop_top,
crop_left,
steps,
scale,
negative_scale,
strength,
latents=start_code,
output_type="pil",
max_embeddings_multiples=max_embeddings_multiples,
img2img_noise=i2i_noises,
vae_batch_size=args.vae_batch_size,
return_latents=return_latents,
clip_prompts=clip_prompts,
clip_guide_images=guide_images,
)
if highres_1st and not args.highres_fix_save_1st: # return images or latents
return images
# save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts)
):
metadata = PngInfo()
metadata.add_text("prompt", prompt)
metadata.add_text("seed", str(seed))
metadata.add_text("sampler", args.sampler)
metadata.add_text("steps", str(steps))
metadata.add_text("scale", str(scale))
if negative_prompt is not None:
metadata.add_text("negative-prompt", negative_prompt)
if negative_scale is not None:
metadata.add_text("negative-scale", str(negative_scale))
if clip_prompt is not None:
metadata.add_text("clip-prompt", clip_prompt)
metadata.add_text("original-height", str(original_height))
metadata.add_text("original-width", str(original_width))
metadata.add_text("crop-top", str(crop_top))
metadata.add_text("crop-left", str(crop_left))
if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
if not args.no_preview and not highres_1st and args.interactive:
try:
import cv2
for prompt, image in zip(prompts, images):
cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ
cv2.waitKey()
cv2.destroyAllWindows()
except ImportError:
print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
return images
# 画像生成のプロンプトが一周するまでのループ
prompt_index = 0
global_step = 0
batch_data = []
while args.interactive or prompt_index < len(prompt_list):
if len(prompt_list) == 0:
# interactive
valid = False
while not valid:
print("\nType prompt:")
try:
raw_prompt = input()
except EOFError:
break
valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0
if not valid: # EOF, end app
break
else:
raw_prompt = prompt_list[prompt_index]
# sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
if pi == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing
width = args.W
height = args.H
original_width = args.original_width
original_height = args.original_height
crop_top = args.crop_top
crop_left = args.crop_left
scale = args.scale
negative_scale = args.negative_scale
steps = args.steps
seed = None
seeds = None
strength = 0.8 if args.strength is None else args.strength
negative_prompt = ""
clip_prompt = None
network_muls = None
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
print(f"width: {width}")
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
print(f"height: {height}")
continue
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
if m:
original_width = int(m.group(1))
print(f"original width: {width}")
continue
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
if m:
original_height = int(m.group(1))
print(f"original height: {height}")
continue
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
if m:
crop_top = int(m.group(1))
print(f"crop top: {crop_top}")
continue
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
if m:
crop_left = int(m.group(1))
print(f"crop left: {crop_left}")
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
steps = max(1, min(1000, int(m.group(1))))
print(f"steps: {steps}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
print(f"seeds: {seeds}")
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
print(f"scale: {scale}")
continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
if m: # negative scale
if m.group(1).lower() == "none":
negative_scale = None
else:
negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}")
continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength
strength = float(m.group(1))
print(f"strength: {strength}")
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
print(f"negative prompt: {negative_prompt}")
continue
m = re.match(r"c (.+)", parg, re.IGNORECASE)
if m: # clip prompt
clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}")
continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # network multiplies
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
print(f"network mul: {network_muls}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
if len(seeds) > 0:
seed = seeds.pop(0)
else:
if predefined_seeds is not None:
if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0)
else:
print("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
else:
seed = None # 前のを消す
if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
if args.interactive:
print(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None
# 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
if init_images is not None:
init_image = init_images[global_step % len(init_images)]
# img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する
# 32単位に丸めたやつにresizeされるので踏襲する
if not highres_fix:
width, height = init_image.size
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
print(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
if mask_images is not None:
mask_image = mask_images[global_step % len(mask_images)]
if guide_images is not None:
if control_nets: # 複数件の場合あり
c = len(control_nets)
p = global_step % (len(guide_images) // c)
guide_image = guide_images[p * c : p * c + c]
else:
guide_image = guide_images[global_step % len(guide_images)]
if regional_network:
num_sub_prompts = len(prompt.split(" AND "))
assert (
len(networks) <= num_sub_prompts
), "Number of networks must be less than or equal to number of sub prompts."
else:
num_sub_prompts = None
b1 = BatchData(
False,
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
BatchDataExt(
width,
height,
original_width,
original_height,
crop_left,
crop_top,
steps,
scale,
negative_scale,
strength,
tuple(network_muls) if network_muls else None,
num_sub_prompts,
),
)
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
process_batch(batch_data, highres_fix)
batch_data.clear()
batch_data.append(b1)
if len(batch_data) == args.batch_size:
prev_image = process_batch(batch_data, highres_fix)[0]
batch_data.clear()
global_step += 1
prompt_index += 1
if len(batch_data) > 0:
process_batch(batch_data, highres_fix)
batch_data.clear()
print("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
parser.add_argument(
"--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む"
)
parser.add_argument(
"--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)"
)
parser.add_argument(
"--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない"
)
parser.add_argument(
"--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像"
)
parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク")
parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength")
parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数")
parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先")
parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする")
parser.add_argument(
"--use_original_file_name",
action="store_true",
help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける",
)
# parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", )
parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数")
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
parser.add_argument(
"--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値"
)
parser.add_argument(
"--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値"
)
parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値")
parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
parser.add_argument(
"--vae_batch_size",
type=float,
default=None,
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率",
)
parser.add_argument(
"--vae_slices",
type=int,
default=None,
help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨",
)
parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない")
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
parser.add_argument(
"--sampler",
type=str,
default="ddim",
choices=[
"ddim",
"pndm",
"lms",
"euler",
"euler_a",
"heun",
"dpm_2",
"dpm_2_a",
"dpmsolver",
"dpmsolver++",
"dpmsingle",
"k_lms",
"k_euler",
"k_euler_a",
"k_dpm_2",
"k_dpm_2_a",
],
help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類",
)
parser.add_argument(
"--scale",
type=float,
default=7.5,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale",
)
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
parser.add_argument(
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
)
parser.add_argument(
"--tokenizer_cache_dir",
type=str,
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
)
# parser.add_argument("--replace_clip_l14_336", action='store_true',
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
parser.add_argument(
"--seed",
type=int,
default=None,
help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed",
)
parser.add_argument(
"--iter_same_seed",
action="store_true",
help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)",
)
parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する")
parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する")
parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する")
parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa")
parser.add_argument(
"--diffusers_xformers",
action="store_true",
help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)",
)
parser.add_argument(
"--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する"
)
parser.add_argument(
"--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名"
)
parser.add_argument(
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
)
parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率")
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,
default=None,
nargs="*",
help="Embeddings files of Textual Inversion / Textual Inversionのembeddings",
)
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
parser.add_argument(
"--max_embeddings_multiples",
type=int,
default=None,
help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
)
parser.add_argument(
"--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像"
)
parser.add_argument(
"--highres_fix_scale",
type=float,
default=None,
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする",
)
parser.add_argument(
"--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数"
)
parser.add_argument(
"--highres_fix_strength",
type=float,
default=None,
help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ",
)
parser.add_argument(
"--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する"
)
parser.add_argument(
"--highres_fix_latents_upscaling",
action="store_true",
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
)
parser.add_argument(
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
)
parser.add_argument(
"--highres_fix_upscaler_args",
type=str,
default=None,
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
)
parser.add_argument(
"--highres_fix_disable_control_net",
action="store_true",
help="disable ControlNet for highres fix / highres fixでControlNetを使わない",
)
parser.add_argument(
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
)
parser.add_argument(
"--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
)
parser.add_argument(
"--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
)
parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み")
parser.add_argument(
"--control_net_ratios",
type=float,
default=None,
nargs="*",
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
)
# parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)