Spaces:
Running
on
A10G
Running
on
A10G
# ref: | |
# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566 | |
# - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py | |
from typing import List, Optional | |
import argparse | |
import ast | |
from pathlib import Path | |
import gc, os | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from PIL import Image | |
import train_util | |
import random | |
import model_util | |
import prompt_util | |
from prompt_util import ( | |
PromptEmbedsCache, | |
PromptEmbedsPair, | |
PromptSettings, | |
PromptEmbedsXL, | |
) | |
import debug_util | |
import config_util | |
from config_util import RootConfig | |
import wandb | |
NUM_IMAGES_PER_PROMPT = 1 | |
from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV | |
def flush(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def train( | |
config: RootConfig, | |
prompts: list[PromptSettings], | |
device, | |
folder_main: str, | |
folders, | |
scales, | |
): | |
scales = np.array(scales) | |
folders = np.array(folders) | |
scales_unique = list(scales) | |
metadata = { | |
"prompts": ",".join([prompt.json() for prompt in prompts]), | |
"config": config.json(), | |
} | |
save_path = Path(config.save.path) | |
modules = DEFAULT_TARGET_REPLACE | |
if config.network.type == "c3lier": | |
modules += UNET_TARGET_REPLACE_MODULE_CONV | |
if config.logging.verbose: | |
print(metadata) | |
if config.logging.use_wandb: | |
wandb.init(project=f"LECO_{config.save.name}", config=metadata) | |
weight_dtype = config_util.parse_precision(config.train.precision) | |
save_weight_dtype = config_util.parse_precision(config.train.precision) | |
( | |
tokenizers, | |
text_encoders, | |
unet, | |
noise_scheduler, | |
vae | |
) = model_util.load_models_xl( | |
config.pretrained_model.name_or_path, | |
scheduler_name=config.train.noise_scheduler, | |
) | |
for text_encoder in text_encoders: | |
text_encoder.to(device, dtype=weight_dtype) | |
text_encoder.requires_grad_(False) | |
text_encoder.eval() | |
unet.to(device, dtype=weight_dtype) | |
if config.other.use_xformers: | |
unet.enable_xformers_memory_efficient_attention() | |
unet.requires_grad_(False) | |
unet.eval() | |
vae.to(device) | |
vae.requires_grad_(False) | |
vae.eval() | |
network = LoRANetwork( | |
unet, | |
rank=config.network.rank, | |
multiplier=1.0, | |
alpha=config.network.alpha, | |
train_method=config.network.training_method, | |
).to(device, dtype=weight_dtype) | |
optimizer_module = train_util.get_optimizer(config.train.optimizer) | |
#optimizer_args | |
optimizer_kwargs = {} | |
if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0: | |
for arg in config.train.optimizer_args.split(" "): | |
key, value = arg.split("=") | |
value = ast.literal_eval(value) | |
optimizer_kwargs[key] = value | |
optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs) | |
lr_scheduler = train_util.get_lr_scheduler( | |
config.train.lr_scheduler, | |
optimizer, | |
max_iterations=config.train.iterations, | |
lr_min=config.train.lr / 100, | |
) | |
criteria = torch.nn.MSELoss() | |
print("Prompts") | |
for settings in prompts: | |
print(settings) | |
# debug | |
debug_util.check_requires_grad(network) | |
debug_util.check_training_mode(network) | |
cache = PromptEmbedsCache() | |
prompt_pairs: list[PromptEmbedsPair] = [] | |
with torch.no_grad(): | |
for settings in prompts: | |
print(settings) | |
for prompt in [ | |
settings.target, | |
settings.positive, | |
settings.neutral, | |
settings.unconditional, | |
]: | |
if cache[prompt] == None: | |
tex_embs, pool_embs = train_util.encode_prompts_xl( | |
tokenizers, | |
text_encoders, | |
[prompt], | |
num_images_per_prompt=NUM_IMAGES_PER_PROMPT, | |
) | |
cache[prompt] = PromptEmbedsXL( | |
tex_embs, | |
pool_embs | |
) | |
prompt_pairs.append( | |
PromptEmbedsPair( | |
criteria, | |
cache[settings.target], | |
cache[settings.positive], | |
cache[settings.unconditional], | |
cache[settings.neutral], | |
settings, | |
) | |
) | |
for tokenizer, text_encoder in zip(tokenizers, text_encoders): | |
del tokenizer, text_encoder | |
flush() | |
pbar = tqdm(range(config.train.iterations)) | |
loss = None | |
for i in pbar: | |
with torch.no_grad(): | |
noise_scheduler.set_timesteps( | |
config.train.max_denoising_steps, device=device | |
) | |
optimizer.zero_grad() | |
prompt_pair: PromptEmbedsPair = prompt_pairs[ | |
torch.randint(0, len(prompt_pairs), (1,)).item() | |
] | |
# 1 ~ 49 からランダム | |
timesteps_to = torch.randint( | |
1, config.train.max_denoising_steps, (1,) | |
).item() | |
height, width = prompt_pair.resolution, prompt_pair.resolution | |
if prompt_pair.dynamic_resolution: | |
height, width = train_util.get_random_resolution_in_bucket( | |
prompt_pair.resolution | |
) | |
if config.logging.verbose: | |
print("guidance_scale:", prompt_pair.guidance_scale) | |
print("resolution:", prompt_pair.resolution) | |
print("dynamic_resolution:", prompt_pair.dynamic_resolution) | |
if prompt_pair.dynamic_resolution: | |
print("bucketed resolution:", (height, width)) | |
print("batch_size:", prompt_pair.batch_size) | |
print("dynamic_crops:", prompt_pair.dynamic_crops) | |
scale_to_look = abs(random.choice(list(scales_unique))) | |
folder1 = folders[scales==-scale_to_look][0] | |
folder2 = folders[scales==scale_to_look][0] | |
ims = os.listdir(f'{folder_main}/{folder1}/') | |
ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_] | |
random_sampler = random.randint(0, len(ims)-1) | |
img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((512,512)) | |
img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((512,512)) | |
seed = random.randint(0,2*15) | |
generator = torch.manual_seed(seed) | |
denoised_latents_low, low_noise = train_util.get_noisy_image( | |
img1, | |
vae, | |
generator, | |
unet, | |
noise_scheduler, | |
start_timesteps=0, | |
total_timesteps=timesteps_to) | |
denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype) | |
low_noise = low_noise.to(device, dtype=weight_dtype) | |
generator = torch.manual_seed(seed) | |
denoised_latents_high, high_noise = train_util.get_noisy_image( | |
img2, | |
vae, | |
generator, | |
unet, | |
noise_scheduler, | |
start_timesteps=0, | |
total_timesteps=timesteps_to) | |
denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype) | |
high_noise = high_noise.to(device, dtype=weight_dtype) | |
noise_scheduler.set_timesteps(1000) | |
add_time_ids = train_util.get_add_time_ids( | |
height, | |
width, | |
dynamic_crops=prompt_pair.dynamic_crops, | |
dtype=weight_dtype, | |
).to(device, dtype=weight_dtype) | |
current_timestep = noise_scheduler.timesteps[ | |
int(timesteps_to * 1000 / config.train.max_denoising_steps) | |
] | |
try: | |
# with network: の外では空のLoRAのみが有効になる | |
high_latents = train_util.predict_noise_xl( | |
unet, | |
noise_scheduler, | |
current_timestep, | |
denoised_latents_high, | |
text_embeddings=train_util.concat_embeddings( | |
prompt_pair.unconditional.text_embeds, | |
prompt_pair.positive.text_embeds, | |
prompt_pair.batch_size, | |
), | |
add_text_embeddings=train_util.concat_embeddings( | |
prompt_pair.unconditional.pooled_embeds, | |
prompt_pair.positive.pooled_embeds, | |
prompt_pair.batch_size, | |
), | |
add_time_ids=train_util.concat_embeddings( | |
add_time_ids, add_time_ids, prompt_pair.batch_size | |
), | |
guidance_scale=1, | |
).to(device, dtype=torch.float32) | |
except: | |
flush() | |
print(f'Error Occured!: {np.array(img1).shape} {np.array(img2).shape}') | |
continue | |
# with network: の外では空のLoRAのみが有効になる | |
low_latents = train_util.predict_noise_xl( | |
unet, | |
noise_scheduler, | |
current_timestep, | |
denoised_latents_low, | |
text_embeddings=train_util.concat_embeddings( | |
prompt_pair.unconditional.text_embeds, | |
prompt_pair.neutral.text_embeds, | |
prompt_pair.batch_size, | |
), | |
add_text_embeddings=train_util.concat_embeddings( | |
prompt_pair.unconditional.pooled_embeds, | |
prompt_pair.neutral.pooled_embeds, | |
prompt_pair.batch_size, | |
), | |
add_time_ids=train_util.concat_embeddings( | |
add_time_ids, add_time_ids, prompt_pair.batch_size | |
), | |
guidance_scale=1, | |
).to(device, dtype=torch.float32) | |
if config.logging.verbose: | |
print("positive_latents:", positive_latents[0, 0, :5, :5]) | |
print("neutral_latents:", neutral_latents[0, 0, :5, :5]) | |
print("unconditional_latents:", unconditional_latents[0, 0, :5, :5]) | |
network.set_lora_slider(scale=scale_to_look) | |
with network: | |
target_latents_high = train_util.predict_noise_xl( | |
unet, | |
noise_scheduler, | |
current_timestep, | |
denoised_latents_high, | |
text_embeddings=train_util.concat_embeddings( | |
prompt_pair.unconditional.text_embeds, | |
prompt_pair.positive.text_embeds, | |
prompt_pair.batch_size, | |
), | |
add_text_embeddings=train_util.concat_embeddings( | |
prompt_pair.unconditional.pooled_embeds, | |
prompt_pair.positive.pooled_embeds, | |
prompt_pair.batch_size, | |
), | |
add_time_ids=train_util.concat_embeddings( | |
add_time_ids, add_time_ids, prompt_pair.batch_size | |
), | |
guidance_scale=1, | |
).to(device, dtype=torch.float32) | |
high_latents.requires_grad = False | |
low_latents.requires_grad = False | |
loss_high = criteria(target_latents_high, high_noise.to(torch.float32)) | |
pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}") | |
loss_high.backward() | |
# opposite | |
network.set_lora_slider(scale=-scale_to_look) | |
with network: | |
target_latents_low = train_util.predict_noise_xl( | |
unet, | |
noise_scheduler, | |
current_timestep, | |
denoised_latents_low, | |
text_embeddings=train_util.concat_embeddings( | |
prompt_pair.unconditional.text_embeds, | |
prompt_pair.neutral.text_embeds, | |
prompt_pair.batch_size, | |
), | |
add_text_embeddings=train_util.concat_embeddings( | |
prompt_pair.unconditional.pooled_embeds, | |
prompt_pair.neutral.pooled_embeds, | |
prompt_pair.batch_size, | |
), | |
add_time_ids=train_util.concat_embeddings( | |
add_time_ids, add_time_ids, prompt_pair.batch_size | |
), | |
guidance_scale=1, | |
).to(device, dtype=torch.float32) | |
high_latents.requires_grad = False | |
low_latents.requires_grad = False | |
loss_low = criteria(target_latents_low, low_noise.to(torch.float32)) | |
pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}") | |
loss_low.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
del ( | |
high_latents, | |
low_latents, | |
target_latents_low, | |
target_latents_high, | |
) | |
flush() | |
if ( | |
i % config.save.per_steps == 0 | |
and i != 0 | |
and i != config.train.iterations - 1 | |
): | |
print("Saving...") | |
save_path.mkdir(parents=True, exist_ok=True) | |
network.save_weights( | |
save_path / f"{config.save.name}_{i}steps.pt", | |
dtype=save_weight_dtype, | |
) | |
print("Saving...") | |
save_path.mkdir(parents=True, exist_ok=True) | |
network.save_weights( | |
save_path / f"{config.save.name}_last.pt", | |
dtype=save_weight_dtype, | |
) | |
del ( | |
unet, | |
noise_scheduler, | |
loss, | |
optimizer, | |
network, | |
) | |
flush() | |
print("Done.") | |
def main(args): | |
config_file = args.config_file | |
config = config_util.load_config_from_yaml(config_file) | |
if args.name is not None: | |
config.save.name = args.name | |
attributes = [] | |
if args.attributes is not None: | |
attributes = args.attributes.split(',') | |
attributes = [a.strip() for a in attributes] | |
config.network.alpha = args.alpha | |
config.network.rank = args.rank | |
config.save.name += f'_alpha{args.alpha}' | |
config.save.name += f'_rank{config.network.rank }' | |
config.save.name += f'_{config.network.training_method}' | |
config.save.path += f'/{config.save.name}' | |
prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes) | |
device = torch.device(f"cuda:{args.device}") | |
folders = args.folders.split(',') | |
folders = [f.strip() for f in folders] | |
scales = args.scales.split(',') | |
scales = [f.strip() for f in scales] | |
scales = [int(s) for s in scales] | |
print(folders, scales) | |
if len(scales) != len(folders): | |
raise Exception('the number of folders need to match the number of scales') | |
if args.stylecheck is not None: | |
check = args.stylecheck.split('-') | |
for i in range(int(check[0]), int(check[1])): | |
folder_main = args.folder_main+ f'{i}' | |
config.save.name = f'{os.path.basename(folder_main)}' | |
config.save.name += f'_alpha{args.alpha}' | |
config.save.name += f'_rank{config.network.rank }' | |
config.save.path = f'models/{config.save.name}' | |
train(config=config, prompts=prompts, device=device, folder_main = folder_main, folders = folders, scales = scales) | |
else: | |
train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config_file", | |
required=True, | |
help="Config file for training.", | |
) | |
# config_file 'data/config.yaml' | |
parser.add_argument( | |
"--alpha", | |
type=float, | |
required=True, | |
help="LoRA weight.", | |
) | |
# --alpha 1.0 | |
parser.add_argument( | |
"--rank", | |
type=int, | |
required=False, | |
help="Rank of LoRA.", | |
default=4, | |
) | |
# --rank 4 | |
parser.add_argument( | |
"--device", | |
type=int, | |
required=False, | |
default=0, | |
help="Device to train on.", | |
) | |
# --device 0 | |
parser.add_argument( | |
"--name", | |
type=str, | |
required=False, | |
default=None, | |
help="Device to train on.", | |
) | |
# --name 'eyesize_slider' | |
parser.add_argument( | |
"--attributes", | |
type=str, | |
required=False, | |
default=None, | |
help="attritbutes to disentangle (comma seperated string)", | |
) | |
parser.add_argument( | |
"--folder_main", | |
type=str, | |
required=True, | |
help="The folder to check", | |
) | |
parser.add_argument( | |
"--stylecheck", | |
type=str, | |
required=False, | |
default = None, | |
help="The folder to check", | |
) | |
parser.add_argument( | |
"--folders", | |
type=str, | |
required=False, | |
default = 'verylow, low, high, veryhigh', | |
help="folders with different attribute-scaled images", | |
) | |
parser.add_argument( | |
"--scales", | |
type=str, | |
required=False, | |
default = '-2, -1, 1, 2', | |
help="scales for different attribute-scaled images", | |
) | |
args = parser.parse_args() | |
main(args) | |