# ref: # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566 # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py from typing import List, Optional import argparse import ast from pathlib import Path import gc import torch from tqdm import tqdm from trainscripts.textsliders.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV import trainscripts.textsliders.train_util as train_util import trainscripts.textsliders.model_util as model_util import trainscripts.textsliders.prompt_util as prompt_util from trainscripts.textsliders.prompt_util import ( PromptEmbedsCache, PromptEmbedsPair, PromptSettings, PromptEmbedsXL, ) import trainscripts.textsliders.debug_util as debug_util import trainscripts.textsliders.config_util as config_util from trainscripts.textsliders.config_util import RootConfig import wandb NUM_IMAGES_PER_PROMPT = 1 def flush(): torch.cuda.empty_cache() gc.collect() def train( config: RootConfig, prompts: list[PromptSettings], device, ): metadata = { "prompts": ",".join([prompt.json() for prompt in prompts]), "config": config.json(), } save_path = Path(config.save.path) modules = DEFAULT_TARGET_REPLACE if config.network.type == "c3lier": modules += UNET_TARGET_REPLACE_MODULE_CONV if config.logging.verbose: print(metadata) if config.logging.use_wandb: wandb.init(project=f"LECO_{config.save.name}", config=metadata) weight_dtype = config_util.parse_precision(config.train.precision) save_weight_dtype = config_util.parse_precision(config.train.precision) ( tokenizers, text_encoders, unet, noise_scheduler, ) = model_util.load_models_xl( config.pretrained_model.name_or_path, scheduler_name=config.train.noise_scheduler, ) for text_encoder in text_encoders: text_encoder.to(device, dtype=weight_dtype) text_encoder.requires_grad_(False) text_encoder.eval() unet.to(device, dtype=weight_dtype) if config.other.use_xformers: unet.enable_xformers_memory_efficient_attention() unet.requires_grad_(False) unet.eval() network = LoRANetwork( unet, rank=config.network.rank, multiplier=1.0, alpha=config.network.alpha, train_method=config.network.training_method, ).to(device, dtype=weight_dtype) optimizer_module = train_util.get_optimizer(config.train.optimizer) #optimizer_args optimizer_kwargs = {} if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0: for arg in config.train.optimizer_args.split(" "): key, value = arg.split("=") value = ast.literal_eval(value) optimizer_kwargs[key] = value optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs) lr_scheduler = train_util.get_lr_scheduler( config.train.lr_scheduler, optimizer, max_iterations=config.train.iterations, lr_min=config.train.lr / 100, ) criteria = torch.nn.MSELoss() print("Prompts") for settings in prompts: print(settings) # debug debug_util.check_requires_grad(network) debug_util.check_training_mode(network) cache = PromptEmbedsCache() prompt_pairs: list[PromptEmbedsPair] = [] with torch.no_grad(): for settings in prompts: print(settings) for prompt in [ settings.target, settings.positive, settings.neutral, settings.unconditional, ]: if cache[prompt] == None: tex_embs, pool_embs = train_util.encode_prompts_xl( tokenizers, text_encoders, [prompt], num_images_per_prompt=NUM_IMAGES_PER_PROMPT, ) cache[prompt] = PromptEmbedsXL( tex_embs, pool_embs ) prompt_pairs.append( PromptEmbedsPair( criteria, cache[settings.target], cache[settings.positive], cache[settings.unconditional], cache[settings.neutral], settings, ) ) for tokenizer, text_encoder in zip(tokenizers, text_encoders): del tokenizer, text_encoder flush() pbar = tqdm(range(config.train.iterations)) loss = None for i in pbar: with torch.no_grad(): noise_scheduler.set_timesteps( config.train.max_denoising_steps, device=device ) optimizer.zero_grad() prompt_pair: PromptEmbedsPair = prompt_pairs[ torch.randint(0, len(prompt_pairs), (1,)).item() ] # 1 ~ 49 からランダム timesteps_to = torch.randint( 1, config.train.max_denoising_steps, (1,) ).item() height, width = prompt_pair.resolution, prompt_pair.resolution if prompt_pair.dynamic_resolution: height, width = train_util.get_random_resolution_in_bucket( prompt_pair.resolution ) if config.logging.verbose: print("gudance_scale:", prompt_pair.guidance_scale) print("resolution:", prompt_pair.resolution) print("dynamic_resolution:", prompt_pair.dynamic_resolution) if prompt_pair.dynamic_resolution: print("bucketed resolution:", (height, width)) print("batch_size:", prompt_pair.batch_size) print("dynamic_crops:", prompt_pair.dynamic_crops) latents = train_util.get_initial_latents( noise_scheduler, prompt_pair.batch_size, height, width, 1 ).to(device, dtype=weight_dtype) add_time_ids = train_util.get_add_time_ids( height, width, dynamic_crops=prompt_pair.dynamic_crops, dtype=weight_dtype, ).to(device, dtype=weight_dtype) with network: # ちょっとデノイズされれたものが返る denoised_latents = train_util.diffusion_xl( unet, noise_scheduler, latents, # 単純なノイズのlatentsを渡す text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.text_embeds, prompt_pair.target.text_embeds, prompt_pair.batch_size, ), add_text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.pooled_embeds, prompt_pair.target.pooled_embeds, prompt_pair.batch_size, ), add_time_ids=train_util.concat_embeddings( add_time_ids, add_time_ids, prompt_pair.batch_size ), start_timesteps=0, total_timesteps=timesteps_to, guidance_scale=3, ) noise_scheduler.set_timesteps(1000) current_timestep = noise_scheduler.timesteps[ int(timesteps_to * 1000 / config.train.max_denoising_steps) ] # with network: の外では空のLoRAのみが有効になる positive_latents = train_util.predict_noise_xl( unet, noise_scheduler, current_timestep, denoised_latents, text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.text_embeds, prompt_pair.positive.text_embeds, prompt_pair.batch_size, ), add_text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.pooled_embeds, prompt_pair.positive.pooled_embeds, prompt_pair.batch_size, ), add_time_ids=train_util.concat_embeddings( add_time_ids, add_time_ids, prompt_pair.batch_size ), guidance_scale=1, ).to(device, dtype=weight_dtype) neutral_latents = train_util.predict_noise_xl( unet, noise_scheduler, current_timestep, denoised_latents, text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.text_embeds, prompt_pair.neutral.text_embeds, prompt_pair.batch_size, ), add_text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.pooled_embeds, prompt_pair.neutral.pooled_embeds, prompt_pair.batch_size, ), add_time_ids=train_util.concat_embeddings( add_time_ids, add_time_ids, prompt_pair.batch_size ), guidance_scale=1, ).to(device, dtype=weight_dtype) unconditional_latents = train_util.predict_noise_xl( unet, noise_scheduler, current_timestep, denoised_latents, text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.text_embeds, prompt_pair.unconditional.text_embeds, prompt_pair.batch_size, ), add_text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.pooled_embeds, prompt_pair.unconditional.pooled_embeds, prompt_pair.batch_size, ), add_time_ids=train_util.concat_embeddings( add_time_ids, add_time_ids, prompt_pair.batch_size ), guidance_scale=1, ).to(device, dtype=weight_dtype) if config.logging.verbose: print("positive_latents:", positive_latents[0, 0, :5, :5]) print("neutral_latents:", neutral_latents[0, 0, :5, :5]) print("unconditional_latents:", unconditional_latents[0, 0, :5, :5]) with network: target_latents = train_util.predict_noise_xl( unet, noise_scheduler, current_timestep, denoised_latents, text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.text_embeds, prompt_pair.target.text_embeds, prompt_pair.batch_size, ), add_text_embeddings=train_util.concat_embeddings( prompt_pair.unconditional.pooled_embeds, prompt_pair.target.pooled_embeds, prompt_pair.batch_size, ), add_time_ids=train_util.concat_embeddings( add_time_ids, add_time_ids, prompt_pair.batch_size ), guidance_scale=1, ).to(device, dtype=weight_dtype) if config.logging.verbose: print("target_latents:", target_latents[0, 0, :5, :5]) positive_latents.requires_grad = False neutral_latents.requires_grad = False unconditional_latents.requires_grad = False loss = prompt_pair.loss( target_latents=target_latents, positive_latents=positive_latents, neutral_latents=neutral_latents, unconditional_latents=unconditional_latents, ) # 1000倍しないとずっと0.000...になってしまって見た目的に面白くない pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}") if config.logging.use_wandb: wandb.log( {"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]} ) loss.backward() optimizer.step() lr_scheduler.step() del ( positive_latents, neutral_latents, unconditional_latents, target_latents, latents, ) flush() # if ( # i % config.save.per_steps == 0 # and i != 0 # and i != config.train.iterations - 1 # ): # print("Saving...") # save_path.mkdir(parents=True, exist_ok=True) # network.save_weights( # save_path / f"{config.save.name}_{i}steps.pt", # dtype=save_weight_dtype, # ) print("Saving...") save_path.mkdir(parents=True, exist_ok=True) network.save_weights( save_path / f"{config.save.name}", dtype=save_weight_dtype, ) del ( unet, noise_scheduler, loss, optimizer, network, ) flush() print("Done.") # def main(args): # config_file = args.config_file # config = config_util.load_config_from_yaml(config_file) # if args.name is not None: # config.save.name = args.name # attributes = [] # if args.attributes is not None: # attributes = args.attributes.split(',') # attributes = [a.strip() for a in attributes] # config.network.alpha = args.alpha # config.network.rank = args.rank # config.save.name += f'_alpha{args.alpha}' # config.save.name += f'_rank{config.network.rank }' # config.save.name += f'_{config.network.training_method}' # config.save.path += f'/{config.save.name}' # prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes) # device = torch.device(f"cuda:{args.device}") # train(config, prompts, device) def train_xl(target, positive, negative, lr, iterations, config_file, rank, device, attributes,save_name): config = config_util.load_config_from_yaml(config_file) randn = torch.randint(1, 10000000, (1,)).item() config.save.name = save_name config.train.lr = float(lr) config.train.iterations=int(iterations) if attributes is not None: attributes = attributes.split(',') attributes = [a.strip() for a in attributes] else: attributes = [] config.network.alpha = 1.0 config.network.rank = int(rank) config.save.path += f'/{config.save.name}' prompts = prompt_util.load_prompts_from_yaml(path=config.prompts_file, target=target, positive=positive, negative=negative, attributes=attributes) device = torch.device(device) train(config, prompts, device)