RohitGandikota
testing layout
1f8beea
raw
history blame
16.4 kB
# ref:
# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
from typing import List, Optional
import argparse
import ast
from pathlib import Path
import gc
import torch
from tqdm import tqdm
import os, glob
from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
import train_util
import model_util
import prompt_util
from prompt_util import PromptEmbedsCache, PromptEmbedsPair, PromptSettings
import debug_util
import config_util
from config_util import RootConfig
import random
import numpy as np
import wandb
from PIL import Image
def flush():
torch.cuda.empty_cache()
gc.collect()
def prev_step(model_output, timestep, scheduler, sample):
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
alpha_prod_t =scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
return prev_sample
def train(
config: RootConfig,
prompts: list[PromptSettings],
device: int,
folder_main: str,
folders,
scales,
):
scales = np.array(scales)
folders = np.array(folders)
scales_unique = list(scales)
metadata = {
"prompts": ",".join([prompt.json() for prompt in prompts]),
"config": config.json(),
}
save_path = Path(config.save.path)
modules = DEFAULT_TARGET_REPLACE
if config.network.type == "c3lier":
modules += UNET_TARGET_REPLACE_MODULE_CONV
if config.logging.verbose:
print(metadata)
if config.logging.use_wandb:
wandb.init(project=f"LECO_{config.save.name}", config=metadata)
weight_dtype = config_util.parse_precision(config.train.precision)
save_weight_dtype = config_util.parse_precision(config.train.precision)
tokenizer, text_encoder, unet, noise_scheduler, vae = model_util.load_models(
config.pretrained_model.name_or_path,
scheduler_name=config.train.noise_scheduler,
v2=config.pretrained_model.v2,
v_pred=config.pretrained_model.v_pred,
)
text_encoder.to(device, dtype=weight_dtype)
text_encoder.eval()
unet.to(device, dtype=weight_dtype)
unet.enable_xformers_memory_efficient_attention()
unet.requires_grad_(False)
unet.eval()
vae.to(device)
vae.requires_grad_(False)
vae.eval()
network = LoRANetwork(
unet,
rank=config.network.rank,
multiplier=1.0,
alpha=config.network.alpha,
train_method=config.network.training_method,
).to(device, dtype=weight_dtype)
optimizer_module = train_util.get_optimizer(config.train.optimizer)
#optimizer_args
optimizer_kwargs = {}
if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
for arg in config.train.optimizer_args.split(" "):
key, value = arg.split("=")
value = ast.literal_eval(value)
optimizer_kwargs[key] = value
optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
lr_scheduler = train_util.get_lr_scheduler(
config.train.lr_scheduler,
optimizer,
max_iterations=config.train.iterations,
lr_min=config.train.lr / 100,
)
criteria = torch.nn.MSELoss()
print("Prompts")
for settings in prompts:
print(settings)
# debug
debug_util.check_requires_grad(network)
debug_util.check_training_mode(network)
cache = PromptEmbedsCache()
prompt_pairs: list[PromptEmbedsPair] = []
with torch.no_grad():
for settings in prompts:
print(settings)
for prompt in [
settings.target,
settings.positive,
settings.neutral,
settings.unconditional,
]:
print(prompt)
if isinstance(prompt, list):
if prompt == settings.positive:
key_setting = 'positive'
else:
key_setting = 'attributes'
if len(prompt) == 0:
cache[key_setting] = []
else:
if cache[key_setting] is None:
cache[key_setting] = train_util.encode_prompts(
tokenizer, text_encoder, prompt
)
else:
if cache[prompt] == None:
cache[prompt] = train_util.encode_prompts(
tokenizer, text_encoder, [prompt]
)
prompt_pairs.append(
PromptEmbedsPair(
criteria,
cache[settings.target],
cache[settings.positive],
cache[settings.unconditional],
cache[settings.neutral],
settings,
)
)
del tokenizer
del text_encoder
flush()
pbar = tqdm(range(config.train.iterations))
for i in pbar:
with torch.no_grad():
noise_scheduler.set_timesteps(
config.train.max_denoising_steps, device=device
)
optimizer.zero_grad()
prompt_pair: PromptEmbedsPair = prompt_pairs[
torch.randint(0, len(prompt_pairs), (1,)).item()
]
# 1 ~ 49 からランダム
timesteps_to = torch.randint(
1, config.train.max_denoising_steps-1, (1,)
# 1, 25, (1,)
).item()
height, width = (
prompt_pair.resolution,
prompt_pair.resolution,
)
if prompt_pair.dynamic_resolution:
height, width = train_util.get_random_resolution_in_bucket(
prompt_pair.resolution
)
if config.logging.verbose:
print("guidance_scale:", prompt_pair.guidance_scale)
print("resolution:", prompt_pair.resolution)
print("dynamic_resolution:", prompt_pair.dynamic_resolution)
if prompt_pair.dynamic_resolution:
print("bucketed resolution:", (height, width))
print("batch_size:", prompt_pair.batch_size)
scale_to_look = abs(random.choice(list(scales_unique)))
folder1 = folders[scales==-scale_to_look][0]
folder2 = folders[scales==scale_to_look][0]
ims = os.listdir(f'{folder_main}/{folder1}/')
ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_]
random_sampler = random.randint(0, len(ims)-1)
img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((256,256))
img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((256,256))
seed = random.randint(0,2*15)
generator = torch.manual_seed(seed)
denoised_latents_low, low_noise = train_util.get_noisy_image(
img1,
vae,
generator,
unet,
noise_scheduler,
start_timesteps=0,
total_timesteps=timesteps_to)
denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype)
low_noise = low_noise.to(device, dtype=weight_dtype)
generator = torch.manual_seed(seed)
denoised_latents_high, high_noise = train_util.get_noisy_image(
img2,
vae,
generator,
unet,
noise_scheduler,
start_timesteps=0,
total_timesteps=timesteps_to)
denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype)
high_noise = high_noise.to(device, dtype=weight_dtype)
noise_scheduler.set_timesteps(1000)
current_timestep = noise_scheduler.timesteps[
int(timesteps_to * 1000 / config.train.max_denoising_steps)
]
# with network: の外では空のLoRAのみが有効になる
high_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents_high,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.positive,
prompt_pair.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
# with network: の外では空のLoRAのみが有効になる
low_latents = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents_low,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.unconditional,
prompt_pair.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
if config.logging.verbose:
print("positive_latents:", positive_latents[0, 0, :5, :5])
print("neutral_latents:", neutral_latents[0, 0, :5, :5])
print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
network.set_lora_slider(scale=scale_to_look)
with network:
target_latents_high = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents_high,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.positive,
prompt_pair.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
high_latents.requires_grad = False
low_latents.requires_grad = False
loss_high = criteria(target_latents_high, high_noise.cpu().to(torch.float32))
pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}")
loss_high.backward()
network.set_lora_slider(scale=-scale_to_look)
with network:
target_latents_low = train_util.predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents_low,
train_util.concat_embeddings(
prompt_pair.unconditional,
prompt_pair.neutral,
prompt_pair.batch_size,
),
guidance_scale=1,
).to("cpu", dtype=torch.float32)
high_latents.requires_grad = False
low_latents.requires_grad = False
loss_low = criteria(target_latents_low, low_noise.cpu().to(torch.float32))
pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}")
loss_low.backward()
## NOTICE NO zero_grad between these steps (accumulating gradients)
#following guidelines from Ostris (https://github.com/ostris/ai-toolkit)
optimizer.step()
lr_scheduler.step()
del (
high_latents,
low_latents,
target_latents_low,
target_latents_high,
)
flush()
if (
i % config.save.per_steps == 0
and i != 0
and i != config.train.iterations - 1
):
print("Saving...")
save_path.mkdir(parents=True, exist_ok=True)
network.save_weights(
save_path / f"{config.save.name}_{i}steps.pt",
dtype=save_weight_dtype,
)
print("Saving...")
save_path.mkdir(parents=True, exist_ok=True)
network.save_weights(
save_path / f"{config.save.name}_last.pt",
dtype=save_weight_dtype,
)
del (
unet,
noise_scheduler,
optimizer,
network,
)
flush()
print("Done.")
def main(args):
config_file = args.config_file
config = config_util.load_config_from_yaml(config_file)
if args.name is not None:
config.save.name = args.name
attributes = []
if args.attributes is not None:
attributes = args.attributes.split(',')
attributes = [a.strip() for a in attributes]
config.network.alpha = args.alpha
config.network.rank = args.rank
config.save.name += f'_alpha{args.alpha}'
config.save.name += f'_rank{config.network.rank }'
config.save.name += f'_{config.network.training_method}'
config.save.path += f'/{config.save.name}'
prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
device = torch.device(f"cuda:{args.device}")
folders = args.folders.split(',')
folders = [f.strip() for f in folders]
scales = args.scales.split(',')
scales = [f.strip() for f in scales]
scales = [int(s) for s in scales]
print(folders, scales)
if len(scales) != len(folders):
raise Exception('the number of folders need to match the number of scales')
if args.stylecheck is not None:
check = args.stylecheck.split('-')
for i in range(int(check[0]), int(check[1])):
folder_main = args.folder_main+ f'{i}'
config.save.name = f'{os.path.basename(folder_main)}'
config.save.name += f'_alpha{args.alpha}'
config.save.name += f'_rank{config.network.rank }'
config.save.path = f'models/{config.save.name}'
train(config=config, prompts=prompts, device=device, folder_main = folder_main)
else:
train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
required=False,
default = 'data/config.yaml',
help="Config file for training.",
)
parser.add_argument(
"--alpha",
type=float,
required=True,
help="LoRA weight.",
)
parser.add_argument(
"--rank",
type=int,
required=False,
help="Rank of LoRA.",
default=4,
)
parser.add_argument(
"--device",
type=int,
required=False,
default=0,
help="Device to train on.",
)
parser.add_argument(
"--name",
type=str,
required=False,
default=None,
help="Device to train on.",
)
parser.add_argument(
"--attributes",
type=str,
required=False,
default=None,
help="attritbutes to disentangle",
)
parser.add_argument(
"--folder_main",
type=str,
required=True,
help="The folder to check",
)
parser.add_argument(
"--stylecheck",
type=str,
required=False,
default = None,
help="The folder to check",
)
parser.add_argument(
"--folders",
type=str,
required=False,
default = 'verylow, low, high, veryhigh',
help="folders with different attribute-scaled images",
)
parser.add_argument(
"--scales",
type=str,
required=False,
default = '-2, -1,1, 2',
help="scales for different attribute-scaled images",
)
args = parser.parse_args()
main(args)