#!/usr/bin/env python # coding=utf-8 # Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import logging import math import os import random import shutil from pathlib import Path from pynvml import * import accelerate import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers123 from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm import transformers from transformers import AutoTokenizer, PretrainedConfig import tensorflow as tf tf.get_logger().setLevel('ERROR') from collections import Counter import diffusers_Tiger from diffusers_Tiger import ( AutoencoderKL, ControlNetModel, DDPMScheduler, StableDiffusionControlNetPipeline, StableDiffusionControlNetInpaintPipeline, UNet2DConditionModel, UniPCMultistepScheduler, DDIMScheduler ) from diffusers_Tiger.optimization import get_scheduler from diffusers_Tiger.utils import check_min_version, is_wandb_available from diffusers_Tiger.utils.import_utils import is_xformers_available from diffusers_Tiger import fuse if is_wandb_available(): import wandb import warnings warnings.filterwarnings('ignore') # Will error if the minimal version of diffusers123 is not installed. Remove at your own risks. check_min_version("0.19.0.dev0") logger = get_logger(__name__) def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].sizeelerator grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def make_inpaint_condition(image, image_mask): image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" image[image_mask > 0.5] = -1.0 # set as masked pixel image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) image = torch.from_numpy(image) return image def log_validation(vae, text_encoder, tokenizer, unet, controlnet_nd, controlnet_bg, args, accelerator, weight_dtype, step): logger.info("Running validation... ") controlnet_nd = accelerator.unwrap_model(controlnet) pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, safety_checker=None, revision=args.revision, torch_dtype=weight_dtype, ) pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() if args.seed is None: generator = None else: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if len(args.validation_image) == len(args.validation_prompt): validation_images = args.validation_image validation_prompts = args.validation_prompt elif len(args.validation_image) == 1: validation_images = args.validation_image * len(args.validation_prompt) validation_prompts = args.validation_prompt elif len(args.validation_prompt) == 1: validation_images = args.validation_image validation_prompts = args.validation_prompt * len(args.validation_image) else: raise ValueError( "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" ) image_logs = [] images = [] for validation_prompt, validation_image1 in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image1).convert("RGB").resize((512, 512)) mask_image = Image.open(validation_image1).convert("RGB").resize((512, 512)) control_image = make_inpaint_condition(validation_image, mask_image) for _ in range(args.num_validation_images): with torch.autocast("cuda"): seed = random.randint(1,1000000) generator = torch.Generator(device='cuda').manual_seed(seed) image = pipeline( validation_prompt, num_inference_steps=50, generator=generator, eta=1.0, image=validation_image, mask_image=mask_image, control_image=control_image, guidance_scale = 7 ).images[0] images.append(image) image_logs.append( {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} ) for tracker in accelerator.trackers: if tracker.name == "tensorboard": for log in image_logs: images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] formatted_images = [] formatted_images.append(np.asarray(validation_image)) for image in images: formatted_images.append(np.asarray(image)) formatted_images = np.stack(formatted_images) tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") elif tracker.name == "wandb": formatted_images = [] for log in image_logs: images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) for image in images: image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) tracker.log({"validation": formatted_images}) else: logger.warn(f"image logging not implemented for {tracker.name}") return image_logs def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers123 import CLIPTextModel return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers123.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation else: raise ValueError(f"{model_class} is not supported.") def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): img_str = "" if image_logs is not None: img_str = "You can find some example images below.\n" for i, log in enumerate(image_logs): images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] validation_image.save(os.path.join(repo_folder, "image_control.png")) img_str += f"prompt: {validation_prompt}\n" images = [validation_image] + images image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) img_str += f"![images_{i})](./images_{i}.png)\n" yaml = f""" --- license: creativeml-openrail-m base_model: {base_model} tags: - stable-diffusion - stable-diffusion-diffusers - text-to-image - diffusers - controlnet inference: true --- """ model_card = f""" # controlnet-{repo_id} These are controlnet weights trained on {base_model} with new type of conditioning. {img_str} """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--controlnet_model_name_or_path", type=str, default=None, help="Path to pretrained controlnet model or model identifier from huggingface.co/models." " If not specified controlnet weights are initialized from unet.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help=( "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" " float32 precision." ), ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--output_dir", type=str, default="controlnet-model", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--cache_dir", type=str, default="/export/home/daifang/Diffusion/own_code/dataset", help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--checkpointing_steps", type=int, default=500, help=( "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." "See https://huggingface.co/docs/diffusers123/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" "instructions." ), ) parser.add_argument( "--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument( "--set_grads_to_none", action="store_true", help=( "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" " behaviors, so disable this argument if it causes any problems. More info:" " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" ), ) parser.add_argument( "--dataset_name", type=str, default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," " or to a folder containing files that 🤗 Datasets can understand." ), ) parser.add_argument( "--dataset_config_name", type=str, default=None, help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument( "--train_data_dir", type=str, default=None, help=( "A folder containing the training data. Folder contents must follow the structure described in" " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." ), ) ############################################################################################################## parser.add_argument( "--image_column", type=str, default="image", help="The column of the dataset containing the target image." ) parser.add_argument( "--conditioning_nd_column", type=str, default="condition_nd", help="The column of the dataset containing the controlnet conditioning image.", ) parser.add_argument( "--conditioning_bg_column", type=str, default="condition_bg", help="The column of the dataset containing the controlnet conditioning image.", ) parser.add_argument( "--caption_column_nd", type=str, default="text_nd", help="The column of the dataset containing a caption or a list of captions.", ) parser.add_argument( "--caption_column_bg", type=str, default="text_nd", help="The column of the dataset containing a caption or a list of captions.", ) ############################################################################################################## parser.add_argument( "--max_train_samples", type=int, default=None, help=( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ), ) parser.add_argument( "--proportion_empty_prompts", type=float, default=0, help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", ) parser.add_argument( "--validation_prompt", type=str, default=None, nargs="+", help=( "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." " Provide either a matching number of `--validation_image`s, a single `--validation_image`" " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." ), ) parser.add_argument( "--validation_image", type=str, default=None, nargs="+", help=( "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" " `--validation_image` that will be used with all `--validation_prompt`s." ), ) parser.add_argument( "--num_validation_images", type=int, default=4, help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", ) parser.add_argument( "--validation_steps", type=int, default=100, help=( "Run validation every X steps. Validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`" " and logging the images." ), ) parser.add_argument( "--tracker_project_name", type=str, default="train_controlnet", help=( "The `project_name` argument passed to Accelerator.init_trackers for" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") if args.dataset_name is not None and args.train_data_dir is not None: raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") if args.validation_prompt is not None and args.validation_image is None: raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") if args.validation_prompt is None and args.validation_image is not None: raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") if ( args.validation_image is not None and args.validation_prompt is not None and len(args.validation_image) != 1 and len(args.validation_prompt) != 1 and len(args.validation_image) != len(args.validation_prompt) ): raise ValueError( "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," " or the same number of `--validation_prompt`s and `--validation_image`s" ) if args.resolution % 8 != 0: raise ValueError( "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." ) return args def make_train_dataset(args, tokenizer, accelerator): if args.dataset_name is not None: dataset = load_dataset( args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, ) else: if args.train_data_dir is not None: dataset = load_dataset( args.train_data_dir, cache_dir=args.cache_dir, ) column_names = dataset["train"].column_names ########################################################################################################################################################################## # Get the column names for input/target. # target image if args.image_column is None: image_column = column_names[0] logger.info(f"image column defaulting to {image_column}") else: image_column = args.image_column if image_column not in column_names: raise ValueError( f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) # condition nodule image if args.conditioning_nd_column is None: conditioning_nd_column = column_names[1] logger.info(f"conditioning image column defaulting to {conditioning_nd_column}") else: conditioning_nd_column = args.conditioning_nd_column if conditioning_nd_column not in column_names: raise ValueError( f"`--conditioning_nd_column` value '{args.conditioning_nd_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) # condition background image if args.conditioning_bg_column is None: conditioning_bg_column = column_names[2] logger.info(f"conditioning bg column defaulting to {conditioning_bg_column}") else: conditioning_bg_column = args.conditioning_bg_column if conditioning_bg_column not in column_names: raise ValueError( f"`--conditioning_bg_column` value '{args.conditioning_bg_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) # condition nodule text if args.caption_column_nd is None: caption_column_nd = column_names[3] logger.info(f"caption column defaulting to {caption_column_nd}") else: caption_column_nd = args.caption_column_nd if caption_column_nd not in column_names: raise ValueError( f"`--caption_column` value '{args.caption_column_nd}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) # condition backgrorund text if args.caption_column_bg is None: caption_column_bg = column_names[4] logger.info(f"caption column defaulting to {caption_column_bg}") else: caption_column_bg = args.caption_column_bg if caption_column_bg not in column_names: raise ValueError( f"`--caption_column` value '{args.caption_column_bg}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) ########################################################################################################################################################################## def tokenize_captions(examples, caption_column, names, is_train=True): captions = [] for caption in examples[caption_column]: if random.random() < args.proportion_empty_prompts: captions.append("") elif isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column_nd}` should contain either strings or lists of strings." ) inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) def calculate_word_frequencies(phrases): total_counts = Counter() total_words = 0 for phrase in phrases: words = phrase.replace(',', '').split() total_counts.update(words) total_words += len(words) frequencies = {word: count / total_words for word, count in total_counts.items()} return frequencies, total_words def calculate_average_frequencies(phrases, word_frequencies): average_frequencies = [] for phrase in phrases: words = phrase.replace(',', '').split() total_freq = sum(word_frequencies[word] for word in words) avg_freq = total_freq / len(words) if words else 0 average_frequencies.append((phrase, avg_freq)) return average_frequencies if names == 'nd': word_frequencies, total_word_count = calculate_word_frequencies(captions) weight_matrix = calculate_average_frequencies(captions, word_frequencies) # Extract the values to replace values = [desc[1] for desc in weight_matrix] # Replace the first zero in each row with the corresponding value for i in range(inputs.input_ids.shape[0]): weight = int(values[i]*10**5) inputs.input_ids[i][0] = weight assert not torch.isnan(inputs.input_ids).any(), "inputs.input_ids contains NaN values" return inputs.input_ids image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) conditioning_image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution), transforms.ToTensor(), ] ) def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] images = [image_transforms(image) for image in images] conditioning_nd = [Image.open(image).convert("RGB") for image in examples[conditioning_nd_column]] conditioning_nd = [conditioning_image_transforms(image) for image in conditioning_nd] conditioning_bg = [Image.open(image).convert("RGB") for image in examples[conditioning_bg_column]] conditioning_bg = [conditioning_image_transforms(image) for image in conditioning_bg] examples["pixel_values"] = images examples["conditioning_pixel_values_nd"] = conditioning_nd examples["conditioning_pixel_values_bg"] = conditioning_bg examples["input_ids_nd"] = tokenize_captions(examples, caption_column = caption_column_nd, names = 'nd') examples["input_ids_bg"] = tokenize_captions(examples, caption_column = caption_column_bg, names = 'bg') return examples with accelerator.main_process_first(): if args.max_train_samples is not None: dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms train_dataset = dataset["train"].with_transform(preprocess_train) return train_dataset def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() conditioning_pixel_values_nd = torch.stack([example["conditioning_pixel_values_nd"] for example in examples]) conditioning_pixel_values_nd = conditioning_pixel_values_nd.to(memory_format=torch.contiguous_format).float() conditioning_pixel_values_bg = torch.stack([example["conditioning_pixel_values_bg"] for example in examples]) conditioning_pixel_values_bg = conditioning_pixel_values_bg.to(memory_format=torch.contiguous_format).float() input_ids_nd = torch.stack([example["input_ids_nd"] for example in examples]) input_ids_bg = torch.stack([example["input_ids_bg"] for example in examples]) return { "pixel_values": pixel_values, "conditioning_pixel_values_nd": conditioning_pixel_values_nd, "conditioning_pixel_values_bg": conditioning_pixel_values_bg, "input_ids_nd": input_ids_nd, "input_ids_bg": input_ids_bg, } def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers_Tiger.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers_Tiger.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False, ) # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) if args.controlnet_model_name_or_path: logger.info("Loading existing controlnet weights") controlnet_nd = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) controlnet_bg = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) else: logger.info("Initializing controlnet weights from unet") controlnet_nd = ControlNetModel.from_unet(unet) controlnet_bg = ControlNetModel.from_unet(unet) if version.parse(accelerate.__version__) >= version.parse("0.16.0"): def save_model_hook(models, weights, output_dir): weights.pop() model1 = models[0] sub_dir = "controlnet_nd" model1.save_pretrained(os.path.join(output_dir, sub_dir)) def load_model_hook(models, input_dir): while len(models) > 0: # pop models so that they are not loaded again model = models.pop() # load diffusers123 style into model load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder.requires_grad_(False) controlnet_nd.requires_grad_(True).train() controlnet_bg.requires_grad_(True).train() if args.gradient_checkpointing: controlnet_nd.enable_gradient_checkpointing() controlnet_bg.enable_gradient_checkpointing() # Check that all trainable models are in full precision low_precision_error_string = ( " Please make sure to always have all model weights in full float32 precision when starting training - even if" " doing mixed precision training, copy of the weights should still be float32." ) if accelerator.unwrap_model(controlnet_nd).dtype != torch.float32: raise ValueError( f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet_nd).dtype}. {low_precision_error_string}" ) if accelerator.unwrap_model(controlnet_bg).dtype != torch.float32: raise ValueError( f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet_bg).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW # Optimizer creation params_to_optimize_nd = controlnet_nd.parameters() params_to_optimize_bg = controlnet_bg.parameters() optimizer_nd = optimizer_class( params_to_optimize_nd, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) optimizer_bg = optimizer_class( params_to_optimize_bg, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) train_dataset = make_train_dataset(args, tokenizer, accelerator) train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer_nd, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power) # Prepare everything with our `accelerator`. controlnet_nd, controlnet_bg, optimizer_nd, optimizer_bg, train_dataloader, lr_scheduler = accelerator.prepare( controlnet_nd, controlnet_bg, optimizer_nd, optimizer_bg, train_dataloader, lr_scheduler ) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move vae, unet and text_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = dict(vars(args)) # tensorboard cannot handle list types for config tracker_config.pop("validation_prompt") tracker_config.pop("validation_image") accelerator.init_trackers(args.tracker_project_name, config=tracker_config) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) image_logs = None for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): # with accelerator.accumulate(controlnet_nd): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning weight_nd = batch["input_ids_nd"][:, 0] weight_nd = weight_nd / 10**5 batch["input_ids_nd"][:, 0] = 49406 encoder_hidden_states_nd = text_encoder(batch["input_ids_nd"])[0] encoder_hidden_states_bg = text_encoder(batch["input_ids_bg"])[0] controlnet_image_nd = batch["conditioning_pixel_values_nd"].to(dtype=weight_dtype) controlnet_image_bg = batch["conditioning_pixel_values_bg"].to(dtype=weight_dtype) # print(weight_nd) down_block_res_samples_nd, mid_block_res_sample_nd = controlnet_nd( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states_nd, # text controlnet_cond=controlnet_image_nd, return_dict=False, weight=weight_nd) down_block_res_samples_bg, mid_block_res_sample_bg = controlnet_bg( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states_bg, # text controlnet_cond=controlnet_image_bg, return_dict=False) # Predict the noise residual samples_nd_list, samples_bg_list = [], [] for number in range(len(down_block_res_samples_nd)): if number > 1 : sample = down_block_res_samples_nd[number] samples_nd = torch.stack((down_block_res_samples_nd[number][0].to('cpu'), \ down_block_res_samples_nd[number][0].to('cpu'))) samples_bg = torch.stack((down_block_res_samples_bg[number][0].to('cpu'), \ down_block_res_samples_bg[number][0].to('cpu'))) channels = sample.shape[1] model_fuse_down = fuse.AFF(channels=channels).to(device='cpu') output = model_fuse_down(samples_nd, samples_bg)[0].unsqueeze(0) samples_nd_list.append(output) samples_bg_list.append(output) else: samples_nd_list.append(down_block_res_samples_nd[number]) samples_bg_list.append(down_block_res_samples_bg[number]) mid_block_res_sample = mid_block_res_sample_bg + mid_block_res_sample_nd model_pred_nd = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states_nd.to('cuda'), down_block_additional_residuals=[ sample.to(dtype=weight_dtype).to('cuda') for sample in samples_nd_list], mid_block_additional_residual=mid_block_res_sample.to('cuda').to(dtype=weight_dtype), ).sample model_pred_bg = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states_bg.to('cuda'), down_block_additional_residuals=[ sample.to(dtype=weight_dtype).to('cuda') for sample in samples_bg_list], mid_block_additional_residual=mid_block_res_sample.to('cuda').to(dtype=weight_dtype), ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": # use target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss_nd = F.mse_loss(model_pred_nd.to('cuda').float(), target.float(), reduction="mean") loss_bg = F.mse_loss(model_pred_bg.to('cuda').float(), target.float(), reduction="mean") optimizer_nd.zero_grad(set_to_none=args.set_grads_to_none) optimizer_bg.zero_grad(set_to_none=args.set_grads_to_none) # h0, h1 = nvmlDeviceGetHandleByIndex(0), nvmlDeviceGetHandleByIndex(1) # info0, info1 = nvmlDeviceGetMemoryInfo(h0), nvmlDeviceGetMemoryInfo(h1) # print(f'0free : {info0.free} 1free : {info1.free}') loss = loss_nd + loss_bg accelerator.backward(loss) # loss_nd.backward() # loss_bg.backward() # if accelerator.sync_gradients: # params_to_clip_nd = controlnet_nd.parameters() # accelerator.clip_grad_norm_(params_to_clip_nd, args.max_grad_norm) # params_to_clip_bg = controlnet_bg.parameters() # accelerator.clip_grad_norm_(params_to_clip_bg, args.max_grad_norm) optimizer_nd.step() optimizer_bg.step() lr_scheduler.step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") # if args.validation_prompt is not None : # image_logs = log_validation( # vae, # text_encoder, # tokenizer, # unet, # controlnet_nd, # controlnet_bg, # args, # accelerator, # weight_dtype, # global_step, # ) logs = {"loss": loss.detach().item()} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break # Create the pipeline using using the trained modules and save it. # accelerator.wait_for_everyone() if accelerator.is_main_process: controlnet_nd = accelerator.unwrap_model(controlnet_nd) controlnet_nd.save_pretrained(args.output_dir) controlnet_bg = accelerator.unwrap_model(controlnet_bg) controlnet_bg.save_pretrained(args.output_dir) accelerator.end_training() if __name__ == "__main__": args = parse_args() main(args)