|
|
|
|
|
"""
|
|
This is the main training script for stage 1 of the project.
|
|
It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration.
|
|
|
|
The script includes the following classes and functions:
|
|
|
|
1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings,
|
|
and face masks as input and returns the denoised latents.
|
|
3. log_validation: A function that logs the validation information using the given VAE, image encoder,
|
|
network, scheduler, accelerator, width, height, and configuration.
|
|
4. train_stage1_process: A function that processes the training stage 1 using the given configuration.
|
|
|
|
The script also includes the necessary imports and a brief description of the purpose of the file.
|
|
"""
|
|
|
|
import argparse
|
|
import copy
|
|
import logging
|
|
import math
|
|
import os
|
|
import random
|
|
import warnings
|
|
from datetime import datetime
|
|
|
|
import cv2
|
|
import diffusers
|
|
import mlflow
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
import transformers
|
|
from accelerate import Accelerator
|
|
from accelerate.logging import get_logger
|
|
from accelerate.utils import DistributedDataParallelKwargs
|
|
from diffusers import AutoencoderKL, DDIMScheduler
|
|
from diffusers.optimization import get_scheduler
|
|
from diffusers.utils import check_min_version
|
|
from diffusers.utils.import_utils import is_xformers_available
|
|
from insightface.app import FaceAnalysis
|
|
from omegaconf import OmegaConf
|
|
from PIL import Image
|
|
from torch import nn
|
|
from tqdm.auto import tqdm
|
|
|
|
from hallo.animate.face_animate_static import StaticPipeline
|
|
from hallo.datasets.mask_image import FaceMaskDataset
|
|
from hallo.models.face_locator import FaceLocator
|
|
from hallo.models.image_proj import ImageProjModel
|
|
from hallo.models.mutual_self_attention import ReferenceAttentionControl
|
|
from hallo.models.unet_2d_condition import UNet2DConditionModel
|
|
from hallo.models.unet_3d import UNet3DConditionModel
|
|
from hallo.utils.util import (compute_snr, delete_additional_ckpt,
|
|
import_filename, init_output_dir,
|
|
load_checkpoint, move_final_checkpoint,
|
|
save_checkpoint, seed_everything)
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
check_min_version("0.10.0.dev0")
|
|
|
|
logger = get_logger(__name__, log_level="INFO")
|
|
|
|
|
|
class Net(nn.Module):
|
|
"""
|
|
The Net class defines a neural network model that combines a reference UNet2DConditionModel,
|
|
a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
|
|
|
|
Args:
|
|
reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
|
|
denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
|
|
face_locator (FaceLocator): The face locator model used for face animation.
|
|
reference_control_writer: The reference control writer component.
|
|
reference_control_reader: The reference control reader component.
|
|
imageproj: The image projection model.
|
|
|
|
Forward method:
|
|
noisy_latents (torch.Tensor): The noisy latents tensor.
|
|
timesteps (torch.Tensor): The timesteps tensor.
|
|
ref_image_latents (torch.Tensor): The reference image latents tensor.
|
|
face_emb (torch.Tensor): The face embeddings tensor.
|
|
face_mask (torch.Tensor): The face mask tensor.
|
|
uncond_fwd (bool): A flag indicating whether to perform unconditional forward pass.
|
|
|
|
Returns:
|
|
torch.Tensor: The output tensor of the neural network model.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
reference_unet: UNet2DConditionModel,
|
|
denoising_unet: UNet3DConditionModel,
|
|
face_locator: FaceLocator,
|
|
reference_control_writer: ReferenceAttentionControl,
|
|
reference_control_reader: ReferenceAttentionControl,
|
|
imageproj: ImageProjModel,
|
|
):
|
|
super().__init__()
|
|
self.reference_unet = reference_unet
|
|
self.denoising_unet = denoising_unet
|
|
self.face_locator = face_locator
|
|
self.reference_control_writer = reference_control_writer
|
|
self.reference_control_reader = reference_control_reader
|
|
self.imageproj = imageproj
|
|
|
|
def forward(
|
|
self,
|
|
noisy_latents,
|
|
timesteps,
|
|
ref_image_latents,
|
|
face_emb,
|
|
face_mask,
|
|
uncond_fwd: bool = False,
|
|
):
|
|
"""
|
|
Forward pass of the model.
|
|
Args:
|
|
self (Net): The model instance.
|
|
noisy_latents (torch.Tensor): Noisy latents.
|
|
timesteps (torch.Tensor): Timesteps.
|
|
ref_image_latents (torch.Tensor): Reference image latents.
|
|
face_emb (torch.Tensor): Face embedding.
|
|
face_mask (torch.Tensor): Face mask.
|
|
uncond_fwd (bool, optional): Unconditional forward pass. Defaults to False.
|
|
|
|
Returns:
|
|
torch.Tensor: Model prediction.
|
|
"""
|
|
|
|
face_emb = self.imageproj(face_emb)
|
|
face_mask = face_mask.to(device="cuda")
|
|
face_mask_feature = self.face_locator(face_mask)
|
|
|
|
if not uncond_fwd:
|
|
ref_timesteps = torch.zeros_like(timesteps)
|
|
self.reference_unet(
|
|
ref_image_latents,
|
|
ref_timesteps,
|
|
encoder_hidden_states=face_emb,
|
|
return_dict=False,
|
|
)
|
|
self.reference_control_reader.update(self.reference_control_writer)
|
|
model_pred = self.denoising_unet(
|
|
noisy_latents,
|
|
timesteps,
|
|
mask_cond_fea=face_mask_feature,
|
|
encoder_hidden_states=face_emb,
|
|
).sample
|
|
|
|
return model_pred
|
|
|
|
|
|
def get_noise_scheduler(cfg: argparse.Namespace):
|
|
"""
|
|
Create noise scheduler for training
|
|
|
|
Args:
|
|
cfg (omegaconf.dictconfig.DictConfig): Configuration object.
|
|
|
|
Returns:
|
|
train noise scheduler and val noise scheduler
|
|
"""
|
|
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
|
|
if cfg.enable_zero_snr:
|
|
sched_kwargs.update(
|
|
rescale_betas_zero_snr=True,
|
|
timestep_spacing="trailing",
|
|
prediction_type="v_prediction",
|
|
)
|
|
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
|
sched_kwargs.update({"beta_schedule": "scaled_linear"})
|
|
train_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
|
|
|
return train_noise_scheduler, val_noise_scheduler
|
|
|
|
|
|
def log_validation(
|
|
vae,
|
|
net,
|
|
scheduler,
|
|
accelerator,
|
|
width,
|
|
height,
|
|
imageproj,
|
|
cfg,
|
|
save_dir,
|
|
global_step,
|
|
face_analysis_model_path,
|
|
):
|
|
"""
|
|
Log validation generation image.
|
|
|
|
Args:
|
|
vae (nn.Module): Variational Autoencoder model.
|
|
net (Net): Main model.
|
|
scheduler (diffusers.SchedulerMixin): Noise scheduler.
|
|
accelerator (accelerate.Accelerator): Accelerator for training.
|
|
width (int): Width of the input images.
|
|
height (int): Height of the input images.
|
|
imageproj (nn.Module): Image projection model.
|
|
cfg (omegaconf.dictconfig.DictConfig): Configuration object.
|
|
save_dir (str): directory path to save log result.
|
|
global_step (int): Global step number.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
logger.info("Running validation... ")
|
|
|
|
ori_net = accelerator.unwrap_model(net)
|
|
ori_net = copy.deepcopy(ori_net)
|
|
reference_unet = ori_net.reference_unet
|
|
denoising_unet = ori_net.denoising_unet
|
|
face_locator = ori_net.face_locator
|
|
|
|
generator = torch.manual_seed(42)
|
|
image_enc = FaceAnalysis(
|
|
name="",
|
|
root=face_analysis_model_path,
|
|
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
)
|
|
image_enc.prepare(ctx_id=0, det_size=(640, 640))
|
|
|
|
pipe = StaticPipeline(
|
|
vae=vae,
|
|
reference_unet=reference_unet,
|
|
denoising_unet=denoising_unet,
|
|
face_locator=face_locator,
|
|
scheduler=scheduler,
|
|
imageproj=imageproj,
|
|
)
|
|
|
|
pil_images = []
|
|
for ref_image_path, mask_image_path in zip(cfg.ref_image_paths, cfg.mask_image_paths):
|
|
|
|
mask_name = os.path.splitext(
|
|
os.path.basename(mask_image_path))[0]
|
|
ref_name = os.path.splitext(
|
|
os.path.basename(ref_image_path))[0]
|
|
ref_image_pil = Image.open(ref_image_path).convert("RGB")
|
|
mask_image_pil = Image.open(mask_image_path).convert("RGB")
|
|
|
|
|
|
face_info = image_enc.get(
|
|
cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR))
|
|
face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (
|
|
x['bbox'][3] - x['bbox'][1]))[-1]
|
|
face_emb = torch.tensor(face_info['embedding'])
|
|
face_emb = face_emb.to(
|
|
imageproj.device, imageproj.dtype)
|
|
|
|
image = pipe(
|
|
ref_image_pil,
|
|
mask_image_pil,
|
|
width,
|
|
height,
|
|
20,
|
|
3.5,
|
|
face_emb,
|
|
generator=generator,
|
|
).images
|
|
image = image[0, :, 0].permute(1, 2, 0).cpu().numpy()
|
|
res_image_pil = Image.fromarray((image * 255).astype(np.uint8))
|
|
|
|
w, h = res_image_pil.size
|
|
canvas = Image.new("RGB", (w * 3, h), "white")
|
|
ref_image_pil = ref_image_pil.resize((w, h))
|
|
mask_image_pil = mask_image_pil.resize((w, h))
|
|
canvas.paste(ref_image_pil, (0, 0))
|
|
canvas.paste(mask_image_pil, (w, 0))
|
|
canvas.paste(res_image_pil, (w * 2, 0))
|
|
|
|
out_file = os.path.join(
|
|
save_dir, f"{global_step:06d}-{ref_name}_{mask_name}.jpg"
|
|
)
|
|
canvas.save(out_file)
|
|
|
|
del pipe
|
|
del ori_net
|
|
torch.cuda.empty_cache()
|
|
|
|
return pil_images
|
|
|
|
|
|
def train_stage1_process(cfg: argparse.Namespace) -> None:
|
|
"""
|
|
Trains the model using the given configuration (cfg).
|
|
|
|
Args:
|
|
cfg (dict): The configuration dictionary containing the parameters for training.
|
|
|
|
Notes:
|
|
- This function trains the model using the given configuration.
|
|
- It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
|
|
- The training progress is logged and tracked using the accelerator.
|
|
- The trained model is saved after the training is completed.
|
|
"""
|
|
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
accelerator = Accelerator(
|
|
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
|
|
mixed_precision=cfg.solver.mixed_precision,
|
|
log_with="mlflow",
|
|
project_dir="./mlruns",
|
|
kwargs_handlers=[kwargs],
|
|
)
|
|
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
|
level=logging.INFO,
|
|
)
|
|
|
|
logger.info(accelerator.state, main_process_only=False)
|
|
if accelerator.is_local_main_process:
|
|
transformers.utils.logging.set_verbosity_warning()
|
|
diffusers.utils.logging.set_verbosity_info()
|
|
else:
|
|
transformers.utils.logging.set_verbosity_error()
|
|
diffusers.utils.logging.set_verbosity_error()
|
|
|
|
|
|
if cfg.seed is not None:
|
|
seed_everything(cfg.seed)
|
|
|
|
|
|
exp_name = cfg.exp_name
|
|
save_dir = f"{cfg.output_dir}/{exp_name}"
|
|
checkpoint_dir = os.path.join(save_dir, "checkpoints")
|
|
module_dir = os.path.join(save_dir, "modules")
|
|
validation_dir = os.path.join(save_dir, "validation")
|
|
|
|
if accelerator.is_main_process:
|
|
init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir])
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
|
if cfg.weight_dtype == "fp16":
|
|
weight_dtype = torch.float16
|
|
elif cfg.weight_dtype == "bf16":
|
|
weight_dtype = torch.bfloat16
|
|
elif cfg.weight_dtype == "fp32":
|
|
weight_dtype = torch.float32
|
|
else:
|
|
raise ValueError(
|
|
f"Do not support weight dtype: {cfg.weight_dtype} during training"
|
|
)
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
|
|
"cuda", dtype=weight_dtype
|
|
)
|
|
reference_unet = UNet2DConditionModel.from_pretrained(
|
|
cfg.base_model_path,
|
|
subfolder="unet",
|
|
).to(device="cuda", dtype=weight_dtype)
|
|
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
|
cfg.base_model_path,
|
|
"",
|
|
subfolder="unet",
|
|
unet_additional_kwargs={
|
|
"use_motion_module": False,
|
|
"unet_use_temporal_attention": False,
|
|
},
|
|
use_landmark=False
|
|
).to(device="cuda", dtype=weight_dtype)
|
|
imageproj = ImageProjModel(
|
|
cross_attention_dim=denoising_unet.config.cross_attention_dim,
|
|
clip_embeddings_dim=512,
|
|
clip_extra_context_tokens=4,
|
|
).to(device="cuda", dtype=weight_dtype)
|
|
|
|
if cfg.face_locator_pretrained:
|
|
face_locator = FaceLocator(
|
|
conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
|
|
).to(device="cuda", dtype=weight_dtype)
|
|
miss, _ = face_locator.load_state_dict(
|
|
cfg.face_state_dict_path, strict=False)
|
|
logger.info(f"Missing key for face locator: {len(miss)}")
|
|
else:
|
|
face_locator = FaceLocator(
|
|
conditioning_embedding_channels=320,
|
|
).to(device="cuda", dtype=weight_dtype)
|
|
|
|
vae.requires_grad_(False)
|
|
denoising_unet.requires_grad_(True)
|
|
reference_unet.requires_grad_(True)
|
|
imageproj.requires_grad_(True)
|
|
face_locator.requires_grad_(True)
|
|
|
|
reference_control_writer = ReferenceAttentionControl(
|
|
reference_unet,
|
|
do_classifier_free_guidance=False,
|
|
mode="write",
|
|
fusion_blocks="full",
|
|
)
|
|
reference_control_reader = ReferenceAttentionControl(
|
|
denoising_unet,
|
|
do_classifier_free_guidance=False,
|
|
mode="read",
|
|
fusion_blocks="full",
|
|
)
|
|
|
|
net = Net(
|
|
reference_unet,
|
|
denoising_unet,
|
|
face_locator,
|
|
reference_control_writer,
|
|
reference_control_reader,
|
|
imageproj,
|
|
).to(dtype=weight_dtype)
|
|
|
|
|
|
train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg)
|
|
|
|
|
|
if cfg.solver.enable_xformers_memory_efficient_attention:
|
|
if is_xformers_available():
|
|
reference_unet.enable_xformers_memory_efficient_attention()
|
|
denoising_unet.enable_xformers_memory_efficient_attention()
|
|
else:
|
|
raise ValueError(
|
|
"xformers is not available. Make sure it is installed correctly"
|
|
)
|
|
|
|
if cfg.solver.gradient_checkpointing:
|
|
reference_unet.enable_gradient_checkpointing()
|
|
denoising_unet.enable_gradient_checkpointing()
|
|
|
|
if cfg.solver.scale_lr:
|
|
learning_rate = (
|
|
cfg.solver.learning_rate
|
|
* cfg.solver.gradient_accumulation_steps
|
|
* cfg.data.train_bs
|
|
* accelerator.num_processes
|
|
)
|
|
else:
|
|
learning_rate = cfg.solver.learning_rate
|
|
|
|
|
|
if cfg.solver.use_8bit_adam:
|
|
try:
|
|
import bitsandbytes as bnb
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
|
) from exc
|
|
|
|
optimizer_cls = bnb.optim.AdamW8bit
|
|
else:
|
|
optimizer_cls = torch.optim.AdamW
|
|
|
|
trainable_params = list(
|
|
filter(lambda p: p.requires_grad, net.parameters()))
|
|
optimizer = optimizer_cls(
|
|
trainable_params,
|
|
lr=learning_rate,
|
|
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
|
weight_decay=cfg.solver.adam_weight_decay,
|
|
eps=cfg.solver.adam_epsilon,
|
|
)
|
|
|
|
|
|
lr_scheduler = get_scheduler(
|
|
cfg.solver.lr_scheduler,
|
|
optimizer=optimizer,
|
|
num_warmup_steps=cfg.solver.lr_warmup_steps
|
|
* cfg.solver.gradient_accumulation_steps,
|
|
num_training_steps=cfg.solver.max_train_steps
|
|
* cfg.solver.gradient_accumulation_steps,
|
|
)
|
|
|
|
|
|
train_dataset = FaceMaskDataset(
|
|
img_size=(cfg.data.train_width, cfg.data.train_height),
|
|
data_meta_paths=cfg.data.meta_paths,
|
|
sample_margin=cfg.data.sample_margin,
|
|
)
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
|
|
)
|
|
|
|
|
|
(
|
|
net,
|
|
optimizer,
|
|
train_dataloader,
|
|
lr_scheduler,
|
|
) = accelerator.prepare(
|
|
net,
|
|
optimizer,
|
|
train_dataloader,
|
|
lr_scheduler,
|
|
)
|
|
|
|
num_update_steps_per_epoch = math.ceil(
|
|
len(train_dataloader) / cfg.solver.gradient_accumulation_steps
|
|
)
|
|
|
|
num_train_epochs = math.ceil(
|
|
cfg.solver.max_train_steps / num_update_steps_per_epoch
|
|
)
|
|
|
|
|
|
|
|
if accelerator.is_main_process:
|
|
run_time = datetime.now().strftime("%Y%m%d-%H%M")
|
|
accelerator.init_trackers(
|
|
cfg.exp_name,
|
|
init_kwargs={"mlflow": {"run_name": run_time}},
|
|
)
|
|
|
|
mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
|
|
|
|
logger.info(f"save config to {save_dir}")
|
|
OmegaConf.save(
|
|
cfg, os.path.join(save_dir, "config.yaml")
|
|
)
|
|
|
|
total_batch_size = (
|
|
cfg.data.train_bs
|
|
* accelerator.num_processes
|
|
* cfg.solver.gradient_accumulation_steps
|
|
)
|
|
|
|
logger.info("***** Running training *****")
|
|
logger.info(f" Num examples = {len(train_dataset)}")
|
|
logger.info(f" Num Epochs = {num_train_epochs}")
|
|
logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
|
|
logger.info(
|
|
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
|
)
|
|
logger.info(
|
|
f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
|
|
)
|
|
logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
|
|
global_step = 0
|
|
first_epoch = 0
|
|
|
|
|
|
|
|
if cfg.resume_from_checkpoint:
|
|
logger.info(f"Loading checkpoint from {checkpoint_dir}")
|
|
global_step = load_checkpoint(cfg, checkpoint_dir, accelerator)
|
|
first_epoch = global_step // num_update_steps_per_epoch
|
|
|
|
|
|
progress_bar = tqdm(
|
|
range(global_step, cfg.solver.max_train_steps),
|
|
disable=not accelerator.is_main_process,
|
|
)
|
|
progress_bar.set_description("Steps")
|
|
net.train()
|
|
for _ in range(first_epoch, num_train_epochs):
|
|
train_loss = 0.0
|
|
for _, batch in enumerate(train_dataloader):
|
|
with accelerator.accumulate(net):
|
|
|
|
pixel_values = batch["img"].to(weight_dtype)
|
|
with torch.no_grad():
|
|
latents = vae.encode(pixel_values).latent_dist.sample()
|
|
latents = latents.unsqueeze(2)
|
|
latents = latents * 0.18215
|
|
|
|
noise = torch.randn_like(latents)
|
|
if cfg.noise_offset > 0.0:
|
|
noise += cfg.noise_offset * torch.randn(
|
|
(noise.shape[0], noise.shape[1], 1, 1, 1),
|
|
device=noise.device,
|
|
)
|
|
|
|
bsz = latents.shape[0]
|
|
|
|
timesteps = torch.randint(
|
|
0,
|
|
train_noise_scheduler.num_train_timesteps,
|
|
(bsz,),
|
|
device=latents.device,
|
|
)
|
|
timesteps = timesteps.long()
|
|
|
|
face_mask_img = batch["tgt_mask"]
|
|
face_mask_img = face_mask_img.unsqueeze(
|
|
2)
|
|
face_mask_img = face_mask_img.to(weight_dtype)
|
|
|
|
uncond_fwd = random.random() < cfg.uncond_ratio
|
|
face_emb_list = []
|
|
ref_image_list = []
|
|
for _, (ref_img, face_emb) in enumerate(
|
|
zip(batch["ref_img"], batch["face_emb"])
|
|
):
|
|
if uncond_fwd:
|
|
face_emb_list.append(torch.zeros_like(face_emb))
|
|
else:
|
|
face_emb_list.append(face_emb)
|
|
ref_image_list.append(ref_img)
|
|
|
|
with torch.no_grad():
|
|
ref_img = torch.stack(ref_image_list, dim=0).to(
|
|
dtype=vae.dtype, device=vae.device
|
|
)
|
|
ref_image_latents = vae.encode(
|
|
ref_img
|
|
).latent_dist.sample()
|
|
ref_image_latents = ref_image_latents * 0.18215
|
|
|
|
face_emb = torch.stack(face_emb_list, dim=0).to(
|
|
dtype=imageproj.dtype, device=imageproj.device
|
|
)
|
|
|
|
|
|
noisy_latents = train_noise_scheduler.add_noise(
|
|
latents, noise, timesteps
|
|
)
|
|
|
|
|
|
if train_noise_scheduler.prediction_type == "epsilon":
|
|
target = noise
|
|
elif train_noise_scheduler.prediction_type == "v_prediction":
|
|
target = train_noise_scheduler.get_velocity(
|
|
latents, noise, timesteps
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown prediction type {train_noise_scheduler.prediction_type}"
|
|
)
|
|
model_pred = net(
|
|
noisy_latents,
|
|
timesteps,
|
|
ref_image_latents,
|
|
face_emb,
|
|
face_mask_img,
|
|
uncond_fwd,
|
|
)
|
|
|
|
if cfg.snr_gamma == 0:
|
|
loss = F.mse_loss(
|
|
model_pred.float(), target.float(), reduction="mean"
|
|
)
|
|
else:
|
|
snr = compute_snr(train_noise_scheduler, timesteps)
|
|
if train_noise_scheduler.config.prediction_type == "v_prediction":
|
|
|
|
snr = snr + 1
|
|
mse_loss_weights = (
|
|
torch.stack(
|
|
[snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
|
|
).min(dim=1)[0]
|
|
/ snr
|
|
)
|
|
loss = F.mse_loss(
|
|
model_pred.float(), target.float(), reduction="none"
|
|
)
|
|
loss = (
|
|
loss.mean(dim=list(range(1, len(loss.shape))))
|
|
* mse_loss_weights
|
|
)
|
|
loss = loss.mean()
|
|
|
|
|
|
avg_loss = accelerator.gather(
|
|
loss.repeat(cfg.data.train_bs)).mean()
|
|
train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
|
|
|
|
|
|
accelerator.backward(loss)
|
|
if accelerator.sync_gradients:
|
|
accelerator.clip_grad_norm_(
|
|
trainable_params,
|
|
cfg.solver.max_grad_norm,
|
|
)
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
if accelerator.sync_gradients:
|
|
reference_control_reader.clear()
|
|
reference_control_writer.clear()
|
|
progress_bar.update(1)
|
|
global_step += 1
|
|
accelerator.log({"train_loss": train_loss}, step=global_step)
|
|
train_loss = 0.0
|
|
if global_step % cfg.checkpointing_steps == 0 or global_step == cfg.solver.max_train_steps:
|
|
accelerator.wait_for_everyone()
|
|
save_path = os.path.join(
|
|
checkpoint_dir, f"checkpoint-{global_step}")
|
|
if accelerator.is_main_process:
|
|
delete_additional_ckpt(checkpoint_dir, 3)
|
|
accelerator.save_state(save_path)
|
|
accelerator.wait_for_everyone()
|
|
unwrap_net = accelerator.unwrap_model(net)
|
|
if accelerator.is_main_process:
|
|
save_checkpoint(
|
|
unwrap_net.reference_unet,
|
|
module_dir,
|
|
"reference_unet",
|
|
global_step,
|
|
total_limit=3,
|
|
)
|
|
save_checkpoint(
|
|
unwrap_net.imageproj,
|
|
module_dir,
|
|
"imageproj",
|
|
global_step,
|
|
total_limit=3,
|
|
)
|
|
save_checkpoint(
|
|
unwrap_net.denoising_unet,
|
|
module_dir,
|
|
"denoising_unet",
|
|
global_step,
|
|
total_limit=3,
|
|
)
|
|
save_checkpoint(
|
|
unwrap_net.face_locator,
|
|
module_dir,
|
|
"face_locator",
|
|
global_step,
|
|
total_limit=3,
|
|
)
|
|
|
|
if global_step % cfg.val.validation_steps == 0 or global_step == 1:
|
|
if accelerator.is_main_process:
|
|
generator = torch.Generator(device=accelerator.device)
|
|
generator.manual_seed(cfg.seed)
|
|
log_validation(
|
|
vae=vae,
|
|
net=net,
|
|
scheduler=val_noise_scheduler,
|
|
accelerator=accelerator,
|
|
width=cfg.data.train_width,
|
|
height=cfg.data.train_height,
|
|
imageproj=imageproj,
|
|
cfg=cfg,
|
|
save_dir=validation_dir,
|
|
global_step=global_step,
|
|
face_analysis_model_path=cfg.face_analysis_model_path
|
|
)
|
|
|
|
logs = {
|
|
"step_loss": loss.detach().item(),
|
|
"lr": lr_scheduler.get_last_lr()[0],
|
|
}
|
|
progress_bar.set_postfix(**logs)
|
|
|
|
if global_step >= cfg.solver.max_train_steps:
|
|
|
|
if accelerator.is_main_process:
|
|
move_final_checkpoint(save_dir, module_dir, "reference_unet")
|
|
move_final_checkpoint(save_dir, module_dir, "imageproj")
|
|
move_final_checkpoint(save_dir, module_dir, "denoising_unet")
|
|
move_final_checkpoint(save_dir, module_dir, "face_locator")
|
|
break
|
|
|
|
accelerator.wait_for_everyone()
|
|
accelerator.end_training()
|
|
|
|
|
|
def load_config(config_path: str) -> dict:
|
|
"""
|
|
Loads the configuration file.
|
|
|
|
Args:
|
|
config_path (str): Path to the configuration file.
|
|
|
|
Returns:
|
|
dict: The configuration dictionary.
|
|
"""
|
|
|
|
if config_path.endswith(".yaml"):
|
|
return OmegaConf.load(config_path)
|
|
if config_path.endswith(".py"):
|
|
return import_filename(config_path).cfg
|
|
raise ValueError("Unsupported format for config file")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", type=str,
|
|
default="./configs/train/stage1.yaml")
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
config = load_config(args.config)
|
|
train_stage1_process(config)
|
|
except Exception as e:
|
|
logging.error("Failed to execute the training process: %s", e)
|
|
|