ConceptSliders / trainscripts /imagesliders /train_lora-scale-xl.py
RohitGandikota
testing layout
1f8beea
raw
history blame
17.9 kB
# ref:
# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
from typing import List, Optional
import argparse
import ast
from pathlib import Path
import gc, os
import numpy as np
import torch
from tqdm import tqdm
from PIL import Image
import train_util
import random
import model_util
import prompt_util
from prompt_util import (
PromptEmbedsCache,
PromptEmbedsPair,
PromptSettings,
PromptEmbedsXL,
)
import debug_util
import config_util
from config_util import RootConfig
import wandb
NUM_IMAGES_PER_PROMPT = 1
from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
def flush():
torch.cuda.empty_cache()
gc.collect()
def train(
config: RootConfig,
prompts: list[PromptSettings],
device,
folder_main: str,
folders,
scales,
):
scales = np.array(scales)
folders = np.array(folders)
scales_unique = list(scales)
metadata = {
"prompts": ",".join([prompt.json() for prompt in prompts]),
"config": config.json(),
}
save_path = Path(config.save.path)
modules = DEFAULT_TARGET_REPLACE
if config.network.type == "c3lier":
modules += UNET_TARGET_REPLACE_MODULE_CONV
if config.logging.verbose:
print(metadata)
if config.logging.use_wandb:
wandb.init(project=f"LECO_{config.save.name}", config=metadata)
weight_dtype = config_util.parse_precision(config.train.precision)
save_weight_dtype = config_util.parse_precision(config.train.precision)
(
tokenizers,
text_encoders,
unet,
noise_scheduler,
vae
) = model_util.load_models_xl(
config.pretrained_model.name_or_path,
scheduler_name=config.train.noise_scheduler,
)
for text_encoder in text_encoders:
text_encoder.to(device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
unet.to(device, dtype=weight_dtype)
if config.other.use_xformers:
unet.enable_xformers_memory_efficient_attention()
unet.requires_grad_(False)
unet.eval()
vae.to(device)
vae.requires_grad_(False)
vae.eval()
network = LoRANetwork(
unet,
rank=config.network.rank,
multiplier=1.0,
alpha=config.network.alpha,
train_method=config.network.training_method,
).to(device, dtype=weight_dtype)
optimizer_module = train_util.get_optimizer(config.train.optimizer)
#optimizer_args
optimizer_kwargs = {}
if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
for arg in config.train.optimizer_args.split(" "):
key, value = arg.split("=")
value = ast.literal_eval(value)
optimizer_kwargs[key] = value
optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
lr_scheduler = train_util.get_lr_scheduler(
config.train.lr_scheduler,
optimizer,
max_iterations=config.train.iterations,
lr_min=config.train.lr / 100,
)
criteria = torch.nn.MSELoss()
print("Prompts")
for settings in prompts:
print(settings)
# debug
debug_util.check_requires_grad(network)
debug_util.check_training_mode(network)
cache = PromptEmbedsCache()
prompt_pairs: list[PromptEmbedsPair] = []
with torch.no_grad():
for settings in prompts:
print(settings)
for prompt in [
settings.target,
settings.positive,
settings.neutral,
settings.unconditional,
]:
if cache[prompt] == None:
tex_embs, pool_embs = train_util.encode_prompts_xl(
tokenizers,
text_encoders,
[prompt],
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
)
cache[prompt] = PromptEmbedsXL(
tex_embs,
pool_embs
)
prompt_pairs.append(
PromptEmbedsPair(
criteria,
cache[settings.target],
cache[settings.positive],
cache[settings.unconditional],
cache[settings.neutral],
settings,
)
)
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
del tokenizer, text_encoder
flush()
pbar = tqdm(range(config.train.iterations))
loss = None
for i in pbar:
with torch.no_grad():
noise_scheduler.set_timesteps(
config.train.max_denoising_steps, device=device
)
optimizer.zero_grad()
prompt_pair: PromptEmbedsPair = prompt_pairs[
torch.randint(0, len(prompt_pairs), (1,)).item()
]
# 1 ~ 49 からランダム
timesteps_to = torch.randint(
1, config.train.max_denoising_steps, (1,)
).item()
height, width = prompt_pair.resolution, prompt_pair.resolution
if prompt_pair.dynamic_resolution:
height, width = train_util.get_random_resolution_in_bucket(
prompt_pair.resolution
)
if config.logging.verbose:
print("guidance_scale:", prompt_pair.guidance_scale)
print("resolution:", prompt_pair.resolution)
print("dynamic_resolution:", prompt_pair.dynamic_resolution)
if prompt_pair.dynamic_resolution:
print("bucketed resolution:", (height, width))
print("batch_size:", prompt_pair.batch_size)
print("dynamic_crops:", prompt_pair.dynamic_crops)
scale_to_look = abs(random.choice(list(scales_unique)))
folder1 = folders[scales==-scale_to_look][0]
folder2 = folders[scales==scale_to_look][0]
ims = os.listdir(f'{folder_main}/{folder1}/')
ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_]
random_sampler = random.randint(0, len(ims)-1)
img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((512,512))
img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((512,512))
seed = random.randint(0,2*15)
generator = torch.manual_seed(seed)
denoised_latents_low, low_noise = train_util.get_noisy_image(
img1,
vae,
generator,
unet,
noise_scheduler,
start_timesteps=0,
total_timesteps=timesteps_to)
denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype)
low_noise = low_noise.to(device, dtype=weight_dtype)
generator = torch.manual_seed(seed)
denoised_latents_high, high_noise = train_util.get_noisy_image(
img2,
vae,
generator,
unet,
noise_scheduler,
start_timesteps=0,
total_timesteps=timesteps_to)
denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype)
high_noise = high_noise.to(device, dtype=weight_dtype)
noise_scheduler.set_timesteps(1000)
add_time_ids = train_util.get_add_time_ids(
height,
width,
dynamic_crops=prompt_pair.dynamic_crops,
dtype=weight_dtype,
).to(device, dtype=weight_dtype)
current_timestep = noise_scheduler.timesteps[
int(timesteps_to * 1000 / config.train.max_denoising_steps)
]
try:
# with network: の外では空のLoRAのみが有効になる
high_latents = train_util.predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents_high,
text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.text_embeds,
prompt_pair.positive.text_embeds,
prompt_pair.batch_size,
),
add_text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.pooled_embeds,
prompt_pair.positive.pooled_embeds,
prompt_pair.batch_size,
),
add_time_ids=train_util.concat_embeddings(
add_time_ids, add_time_ids, prompt_pair.batch_size
),
guidance_scale=1,
).to(device, dtype=torch.float32)
except:
flush()
print(f'Error Occured!: {np.array(img1).shape} {np.array(img2).shape}')
continue
# with network: の外では空のLoRAのみが有効になる
low_latents = train_util.predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents_low,
text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.text_embeds,
prompt_pair.neutral.text_embeds,
prompt_pair.batch_size,
),
add_text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.pooled_embeds,
prompt_pair.neutral.pooled_embeds,
prompt_pair.batch_size,
),
add_time_ids=train_util.concat_embeddings(
add_time_ids, add_time_ids, prompt_pair.batch_size
),
guidance_scale=1,
).to(device, dtype=torch.float32)
if config.logging.verbose:
print("positive_latents:", positive_latents[0, 0, :5, :5])
print("neutral_latents:", neutral_latents[0, 0, :5, :5])
print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
network.set_lora_slider(scale=scale_to_look)
with network:
target_latents_high = train_util.predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents_high,
text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.text_embeds,
prompt_pair.positive.text_embeds,
prompt_pair.batch_size,
),
add_text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.pooled_embeds,
prompt_pair.positive.pooled_embeds,
prompt_pair.batch_size,
),
add_time_ids=train_util.concat_embeddings(
add_time_ids, add_time_ids, prompt_pair.batch_size
),
guidance_scale=1,
).to(device, dtype=torch.float32)
high_latents.requires_grad = False
low_latents.requires_grad = False
loss_high = criteria(target_latents_high, high_noise.to(torch.float32))
pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}")
loss_high.backward()
# opposite
network.set_lora_slider(scale=-scale_to_look)
with network:
target_latents_low = train_util.predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents_low,
text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.text_embeds,
prompt_pair.neutral.text_embeds,
prompt_pair.batch_size,
),
add_text_embeddings=train_util.concat_embeddings(
prompt_pair.unconditional.pooled_embeds,
prompt_pair.neutral.pooled_embeds,
prompt_pair.batch_size,
),
add_time_ids=train_util.concat_embeddings(
add_time_ids, add_time_ids, prompt_pair.batch_size
),
guidance_scale=1,
).to(device, dtype=torch.float32)
high_latents.requires_grad = False
low_latents.requires_grad = False
loss_low = criteria(target_latents_low, low_noise.to(torch.float32))
pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}")
loss_low.backward()
optimizer.step()
lr_scheduler.step()
del (
high_latents,
low_latents,
target_latents_low,
target_latents_high,
)
flush()
if (
i % config.save.per_steps == 0
and i != 0
and i != config.train.iterations - 1
):
print("Saving...")
save_path.mkdir(parents=True, exist_ok=True)
network.save_weights(
save_path / f"{config.save.name}_{i}steps.pt",
dtype=save_weight_dtype,
)
print("Saving...")
save_path.mkdir(parents=True, exist_ok=True)
network.save_weights(
save_path / f"{config.save.name}_last.pt",
dtype=save_weight_dtype,
)
del (
unet,
noise_scheduler,
loss,
optimizer,
network,
)
flush()
print("Done.")
def main(args):
config_file = args.config_file
config = config_util.load_config_from_yaml(config_file)
if args.name is not None:
config.save.name = args.name
attributes = []
if args.attributes is not None:
attributes = args.attributes.split(',')
attributes = [a.strip() for a in attributes]
config.network.alpha = args.alpha
config.network.rank = args.rank
config.save.name += f'_alpha{args.alpha}'
config.save.name += f'_rank{config.network.rank }'
config.save.name += f'_{config.network.training_method}'
config.save.path += f'/{config.save.name}'
prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
device = torch.device(f"cuda:{args.device}")
folders = args.folders.split(',')
folders = [f.strip() for f in folders]
scales = args.scales.split(',')
scales = [f.strip() for f in scales]
scales = [int(s) for s in scales]
print(folders, scales)
if len(scales) != len(folders):
raise Exception('the number of folders need to match the number of scales')
if args.stylecheck is not None:
check = args.stylecheck.split('-')
for i in range(int(check[0]), int(check[1])):
folder_main = args.folder_main+ f'{i}'
config.save.name = f'{os.path.basename(folder_main)}'
config.save.name += f'_alpha{args.alpha}'
config.save.name += f'_rank{config.network.rank }'
config.save.path = f'models/{config.save.name}'
train(config=config, prompts=prompts, device=device, folder_main = folder_main, folders = folders, scales = scales)
else:
train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
required=True,
help="Config file for training.",
)
# config_file 'data/config.yaml'
parser.add_argument(
"--alpha",
type=float,
required=True,
help="LoRA weight.",
)
# --alpha 1.0
parser.add_argument(
"--rank",
type=int,
required=False,
help="Rank of LoRA.",
default=4,
)
# --rank 4
parser.add_argument(
"--device",
type=int,
required=False,
default=0,
help="Device to train on.",
)
# --device 0
parser.add_argument(
"--name",
type=str,
required=False,
default=None,
help="Device to train on.",
)
# --name 'eyesize_slider'
parser.add_argument(
"--attributes",
type=str,
required=False,
default=None,
help="attritbutes to disentangle (comma seperated string)",
)
parser.add_argument(
"--folder_main",
type=str,
required=True,
help="The folder to check",
)
parser.add_argument(
"--stylecheck",
type=str,
required=False,
default = None,
help="The folder to check",
)
parser.add_argument(
"--folders",
type=str,
required=False,
default = 'verylow, low, high, veryhigh',
help="folders with different attribute-scaled images",
)
parser.add_argument(
"--scales",
type=str,
required=False,
default = '-2, -1, 1, 2',
help="scales for different attribute-scaled images",
)
args = parser.parse_args()
main(args)