Tiger-Model / Tiger Model /Fine-Training.py
FANG DAI
Upload 126 files
2ad255e verified
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Hui Lu, Fang Dai, Siqiong Yao.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path
from pynvml import *
import accelerate
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers123
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
import transformers
from transformers import AutoTokenizer, PretrainedConfig
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
from collections import Counter
import diffusers_Tiger
from diffusers_Tiger import (
AutoencoderKL,
ControlNetModel,
DDPMScheduler,
StableDiffusionControlNetPipeline,
StableDiffusionControlNetInpaintPipeline,
UNet2DConditionModel,
UniPCMultistepScheduler,
DDIMScheduler
)
from diffusers_Tiger.optimization import get_scheduler
from diffusers_Tiger.utils import check_min_version, is_wandb_available
from diffusers_Tiger.utils.import_utils import is_xformers_available
from diffusers_Tiger import fuse
if is_wandb_available():
import wandb
import warnings
warnings.filterwarnings('ignore')
# Will error if the minimal version of diffusers123 is not installed. Remove at your own risks.
check_min_version("0.19.0.dev0")
logger = get_logger(__name__)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].sizeelerator
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def make_inpaint_condition(image, image_mask):
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
image[image_mask > 0.5] = -1.0 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
def log_validation(vae, text_encoder, tokenizer, unet, controlnet_nd, controlnet_bg, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
controlnet_nd = accelerator.unwrap_model(controlnet)
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if len(args.validation_image) == len(args.validation_prompt):
validation_images = args.validation_image
validation_prompts = args.validation_prompt
elif len(args.validation_image) == 1:
validation_images = args.validation_image * len(args.validation_prompt)
validation_prompts = args.validation_prompt
elif len(args.validation_prompt) == 1:
validation_images = args.validation_image
validation_prompts = args.validation_prompt * len(args.validation_image)
else:
raise ValueError(
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)
image_logs = []
images = []
for validation_prompt, validation_image1 in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image1).convert("RGB").resize((512, 512))
mask_image = Image.open(validation_image1).convert("RGB").resize((512, 512))
control_image = make_inpaint_condition(validation_image, mask_image)
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
seed = random.randint(1,1000000)
generator = torch.Generator(device='cuda').manual_seed(seed)
image = pipeline(
validation_prompt,
num_inference_steps=50,
generator=generator,
eta=1.0,
image=validation_image,
mask_image=mask_image,
control_image=control_image,
guidance_scale = 7
).images[0]
images.append(image)
image_logs.append(
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
for log in image_logs:
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
for image in images:
formatted_images.append(np.asarray(image))
formatted_images = np.stack(formatted_images)
tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
elif tracker.name == "wandb":
formatted_images = []
for log in image_logs:
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
tracker.log({"validation": formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
return image_logs
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers123 import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers123.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = ""
if image_logs is not None:
img_str = "You can find some example images below.\n"
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- controlnet
inference: true
---
"""
model_card = f"""
# controlnet-{repo_id}
These are controlnet weights trained on {base_model} with new type of conditioning.
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from unet.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--output_dir",
type=str,
default="controlnet-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default="/export/home/daifang/Diffusion/own_code/dataset",
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
"See https://huggingface.co/docs/diffusers123/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
"instructions."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
help=(
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
" behaviors, so disable this argument if it causes any problems. More info:"
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
),
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
##############################################################################################################
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing the target image."
)
parser.add_argument(
"--conditioning_nd_column",
type=str,
default="condition_nd",
help="The column of the dataset containing the controlnet conditioning image.",
)
parser.add_argument(
"--conditioning_bg_column",
type=str,
default="condition_bg",
help="The column of the dataset containing the controlnet conditioning image.",
)
parser.add_argument(
"--caption_column_nd",
type=str,
default="text_nd",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--caption_column_bg",
type=str,
default="text_nd",
help="The column of the dataset containing a caption or a list of captions.",
)
##############################################################################################################
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--proportion_empty_prompts",
type=float,
default=0,
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
nargs="+",
help=(
"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
),
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
),
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
)
parser.add_argument(
"--validation_steps",
type=int,
default=100,
help=(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="train_controlnet",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
if args.dataset_name is not None and args.train_data_dir is not None:
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
if args.validation_prompt is not None and args.validation_image is None:
raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
if args.validation_prompt is None and args.validation_image is not None:
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
if (
args.validation_image is not None
and args.validation_prompt is not None
and len(args.validation_image) != 1
and len(args.validation_prompt) != 1
and len(args.validation_image) != len(args.validation_prompt)
):
raise ValueError(
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
" or the same number of `--validation_prompt`s and `--validation_image`s"
)
if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
)
return args
def make_train_dataset(args, tokenizer, accelerator):
if args.dataset_name is not None:
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
else:
if args.train_data_dir is not None:
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
)
column_names = dataset["train"].column_names
##########################################################################################################################################################################
# Get the column names for input/target.
# target image
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
# condition nodule image
if args.conditioning_nd_column is None:
conditioning_nd_column = column_names[1]
logger.info(f"conditioning image column defaulting to {conditioning_nd_column}")
else:
conditioning_nd_column = args.conditioning_nd_column
if conditioning_nd_column not in column_names:
raise ValueError(
f"`--conditioning_nd_column` value '{args.conditioning_nd_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
# condition background image
if args.conditioning_bg_column is None:
conditioning_bg_column = column_names[2]
logger.info(f"conditioning bg column defaulting to {conditioning_bg_column}")
else:
conditioning_bg_column = args.conditioning_bg_column
if conditioning_bg_column not in column_names:
raise ValueError(
f"`--conditioning_bg_column` value '{args.conditioning_bg_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
# condition nodule text
if args.caption_column_nd is None:
caption_column_nd = column_names[3]
logger.info(f"caption column defaulting to {caption_column_nd}")
else:
caption_column_nd = args.caption_column_nd
if caption_column_nd not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column_nd}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
# condition backgrorund text
if args.caption_column_bg is None:
caption_column_bg = column_names[4]
logger.info(f"caption column defaulting to {caption_column_bg}")
else:
caption_column_bg = args.caption_column_bg
if caption_column_bg not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column_bg}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
##########################################################################################################################################################################
def tokenize_captions(examples, caption_column, names, is_train=True):
captions = []
for caption in examples[caption_column]:
if random.random() < args.proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column_nd}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
def calculate_word_frequencies(phrases):
total_counts = Counter()
total_words = 0
for phrase in phrases:
words = phrase.replace(',', '').split()
total_counts.update(words)
total_words += len(words)
frequencies = {word: count / total_words for word, count in total_counts.items()}
return frequencies, total_words
def calculate_average_frequencies(phrases, word_frequencies):
average_frequencies = []
for phrase in phrases:
words = phrase.replace(',', '').split()
total_freq = sum(word_frequencies[word] for word in words)
avg_freq = total_freq / len(words) if words else 0
average_frequencies.append((phrase, avg_freq))
return average_frequencies
if names == 'nd':
word_frequencies, total_word_count = calculate_word_frequencies(captions)
weight_matrix = calculate_average_frequencies(captions, word_frequencies)
# Extract the values to replace
values = [desc[1] for desc in weight_matrix]
# Replace the first zero in each row with the corresponding value
for i in range(inputs.input_ids.shape[0]):
weight = int(values[i]*10**5)
inputs.input_ids[i][0] = weight
assert not torch.isnan(inputs.input_ids).any(), "inputs.input_ids contains NaN values"
return inputs.input_ids
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
images = [image_transforms(image) for image in images]
conditioning_nd = [Image.open(image).convert("RGB") for image in examples[conditioning_nd_column]]
conditioning_nd = [conditioning_image_transforms(image) for image in conditioning_nd]
conditioning_bg = [Image.open(image).convert("RGB") for image in examples[conditioning_bg_column]]
conditioning_bg = [conditioning_image_transforms(image) for image in conditioning_bg]
examples["pixel_values"] = images
examples["conditioning_pixel_values_nd"] = conditioning_nd
examples["conditioning_pixel_values_bg"] = conditioning_bg
examples["input_ids_nd"] = tokenize_captions(examples, caption_column = caption_column_nd, names = 'nd')
examples["input_ids_bg"] = tokenize_captions(examples, caption_column = caption_column_bg, names = 'bg')
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
return train_dataset
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
conditioning_pixel_values_nd = torch.stack([example["conditioning_pixel_values_nd"] for example in examples])
conditioning_pixel_values_nd = conditioning_pixel_values_nd.to(memory_format=torch.contiguous_format).float()
conditioning_pixel_values_bg = torch.stack([example["conditioning_pixel_values_bg"] for example in examples])
conditioning_pixel_values_bg = conditioning_pixel_values_bg.to(memory_format=torch.contiguous_format).float()
input_ids_nd = torch.stack([example["input_ids_nd"] for example in examples])
input_ids_bg = torch.stack([example["input_ids_bg"] for example in examples])
return {
"pixel_values": pixel_values,
"conditioning_pixel_values_nd": conditioning_pixel_values_nd,
"conditioning_pixel_values_bg": conditioning_pixel_values_bg,
"input_ids_nd": input_ids_nd,
"input_ids_bg": input_ids_bg,
}
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers_Tiger.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers_Tiger.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
if args.controlnet_model_name_or_path:
logger.info("Loading existing controlnet weights")
controlnet_nd = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
controlnet_bg = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
else:
logger.info("Initializing controlnet weights from unet")
controlnet_nd = ControlNetModel.from_unet(unet)
controlnet_bg = ControlNetModel.from_unet(unet)
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
def save_model_hook(models, weights, output_dir):
weights.pop()
model1 = models[0]
sub_dir = "controlnet_nd"
model1.save_pretrained(os.path.join(output_dir, sub_dir))
def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()
# load diffusers123 style into model
load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
controlnet_nd.requires_grad_(True).train()
controlnet_bg.requires_grad_(True).train()
if args.gradient_checkpointing:
controlnet_nd.enable_gradient_checkpointing()
controlnet_bg.enable_gradient_checkpointing()
# Check that all trainable models are in full precision
low_precision_error_string = (
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training, copy of the weights should still be float32."
)
if accelerator.unwrap_model(controlnet_nd).dtype != torch.float32:
raise ValueError(
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet_nd).dtype}. {low_precision_error_string}"
)
if accelerator.unwrap_model(controlnet_bg).dtype != torch.float32:
raise ValueError(
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet_bg).dtype}. {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize_nd = controlnet_nd.parameters()
params_to_optimize_bg = controlnet_bg.parameters()
optimizer_nd = optimizer_class(
params_to_optimize_nd,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
optimizer_bg = optimizer_class(
params_to_optimize_bg,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
train_dataset = make_train_dataset(args, tokenizer, accelerator)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer_nd,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power)
# Prepare everything with our `accelerator`.
controlnet_nd, controlnet_bg, optimizer_nd, optimizer_bg, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet_nd, controlnet_bg, optimizer_nd, optimizer_bg, train_dataloader, lr_scheduler
)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move vae, unet and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = dict(vars(args))
# tensorboard cannot handle list types for config
tracker_config.pop("validation_prompt")
tracker_config.pop("validation_image")
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
image_logs = None
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
# with accelerator.accumulate(controlnet_nd):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
weight_nd = batch["input_ids_nd"][:, 0]
weight_nd = weight_nd / 10**5
batch["input_ids_nd"][:, 0] = 49406
encoder_hidden_states_nd = text_encoder(batch["input_ids_nd"])[0]
encoder_hidden_states_bg = text_encoder(batch["input_ids_bg"])[0]
controlnet_image_nd = batch["conditioning_pixel_values_nd"].to(dtype=weight_dtype)
controlnet_image_bg = batch["conditioning_pixel_values_bg"].to(dtype=weight_dtype)
# print(weight_nd)
down_block_res_samples_nd, mid_block_res_sample_nd = controlnet_nd(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states_nd, # text
controlnet_cond=controlnet_image_nd,
return_dict=False,
weight=weight_nd)
down_block_res_samples_bg, mid_block_res_sample_bg = controlnet_bg(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states_bg, # text
controlnet_cond=controlnet_image_bg,
return_dict=False)
# Predict the noise residual
samples_nd_list, samples_bg_list = [], []
for number in range(len(down_block_res_samples_nd)):
if number > 1 :
sample = down_block_res_samples_nd[number]
samples_nd = torch.stack((down_block_res_samples_nd[number][0].to('cpu'), \
down_block_res_samples_nd[number][0].to('cpu')))
samples_bg = torch.stack((down_block_res_samples_bg[number][0].to('cpu'), \
down_block_res_samples_bg[number][0].to('cpu')))
channels = sample.shape[1]
model_fuse_down = fuse.AFF(channels=channels).to(device='cpu')
output = model_fuse_down(samples_nd, samples_bg)[0].unsqueeze(0)
samples_nd_list.append(output)
samples_bg_list.append(output)
else:
samples_nd_list.append(down_block_res_samples_nd[number])
samples_bg_list.append(down_block_res_samples_bg[number])
mid_block_res_sample = mid_block_res_sample_bg + mid_block_res_sample_nd
model_pred_nd = unet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states_nd.to('cuda'),
down_block_additional_residuals=[
sample.to(dtype=weight_dtype).to('cuda') for sample in samples_nd_list],
mid_block_additional_residual=mid_block_res_sample.to('cuda').to(dtype=weight_dtype),
).sample
model_pred_bg = unet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states_bg.to('cuda'),
down_block_additional_residuals=[
sample.to(dtype=weight_dtype).to('cuda') for sample in samples_bg_list],
mid_block_additional_residual=mid_block_res_sample.to('cuda').to(dtype=weight_dtype),
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction": # use
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss_nd = F.mse_loss(model_pred_nd.to('cuda').float(), target.float(), reduction="mean")
loss_bg = F.mse_loss(model_pred_bg.to('cuda').float(), target.float(), reduction="mean")
optimizer_nd.zero_grad(set_to_none=args.set_grads_to_none)
optimizer_bg.zero_grad(set_to_none=args.set_grads_to_none)
# h0, h1 = nvmlDeviceGetHandleByIndex(0), nvmlDeviceGetHandleByIndex(1)
# info0, info1 = nvmlDeviceGetMemoryInfo(h0), nvmlDeviceGetMemoryInfo(h1)
# print(f'0free : {info0.free} 1free : {info1.free}')
loss = loss_nd + loss_bg
accelerator.backward(loss)
# loss_nd.backward()
# loss_bg.backward()
# if accelerator.sync_gradients:
# params_to_clip_nd = controlnet_nd.parameters()
# accelerator.clip_grad_norm_(params_to_clip_nd, args.max_grad_norm)
# params_to_clip_bg = controlnet_bg.parameters()
# accelerator.clip_grad_norm_(params_to_clip_bg, args.max_grad_norm)
optimizer_nd.step()
optimizer_bg.step()
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
# if args.validation_prompt is not None :
# image_logs = log_validation(
# vae,
# text_encoder,
# tokenizer,
# unet,
# controlnet_nd,
# controlnet_bg,
# args,
# accelerator,
# weight_dtype,
# global_step,
# )
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
# accelerator.wait_for_everyone()
if accelerator.is_main_process:
controlnet_nd = accelerator.unwrap_model(controlnet_nd)
controlnet_nd.save_pretrained(args.output_dir)
controlnet_bg = accelerator.unwrap_model(controlnet_bg)
controlnet_bg.save_pretrained(args.output_dir)
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)