aiavatar2 / scripts /train_stage1.py
Spanicin's picture
Upload 7 files
e863356 verified
# pylint: disable=E1101,C0415,W0718,R0801
# scripts/train_stage1.py
"""
This is the main training script for stage 1 of the project.
It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration.
The script includes the following classes and functions:
1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings,
and face masks as input and returns the denoised latents.
3. log_validation: A function that logs the validation information using the given VAE, image encoder,
network, scheduler, accelerator, width, height, and configuration.
4. train_stage1_process: A function that processes the training stage 1 using the given configuration.
The script also includes the necessary imports and a brief description of the purpose of the file.
"""
import argparse
import copy
import logging
import math
import os
import random
import warnings
from datetime import datetime
import cv2
import diffusers
import mlflow
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from insightface.app import FaceAnalysis
from omegaconf import OmegaConf
from PIL import Image
from torch import nn
from tqdm.auto import tqdm
from hallo.animate.face_animate_static import StaticPipeline
from hallo.datasets.mask_image import FaceMaskDataset
from hallo.models.face_locator import FaceLocator
from hallo.models.image_proj import ImageProjModel
from hallo.models.mutual_self_attention import ReferenceAttentionControl
from hallo.models.unet_2d_condition import UNet2DConditionModel
from hallo.models.unet_3d import UNet3DConditionModel
from hallo.utils.util import (compute_snr, delete_additional_ckpt,
import_filename, init_output_dir,
load_checkpoint, move_final_checkpoint,
save_checkpoint, seed_everything)
warnings.filterwarnings("ignore")
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__, log_level="INFO")
class Net(nn.Module):
"""
The Net class defines a neural network model that combines a reference UNet2DConditionModel,
a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image.
Args:
reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation.
denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation.
face_locator (FaceLocator): The face locator model used for face animation.
reference_control_writer: The reference control writer component.
reference_control_reader: The reference control reader component.
imageproj: The image projection model.
Forward method:
noisy_latents (torch.Tensor): The noisy latents tensor.
timesteps (torch.Tensor): The timesteps tensor.
ref_image_latents (torch.Tensor): The reference image latents tensor.
face_emb (torch.Tensor): The face embeddings tensor.
face_mask (torch.Tensor): The face mask tensor.
uncond_fwd (bool): A flag indicating whether to perform unconditional forward pass.
Returns:
torch.Tensor: The output tensor of the neural network model.
"""
def __init__(
self,
reference_unet: UNet2DConditionModel,
denoising_unet: UNet3DConditionModel,
face_locator: FaceLocator,
reference_control_writer: ReferenceAttentionControl,
reference_control_reader: ReferenceAttentionControl,
imageproj: ImageProjModel,
):
super().__init__()
self.reference_unet = reference_unet
self.denoising_unet = denoising_unet
self.face_locator = face_locator
self.reference_control_writer = reference_control_writer
self.reference_control_reader = reference_control_reader
self.imageproj = imageproj
def forward(
self,
noisy_latents,
timesteps,
ref_image_latents,
face_emb,
face_mask,
uncond_fwd: bool = False,
):
"""
Forward pass of the model.
Args:
self (Net): The model instance.
noisy_latents (torch.Tensor): Noisy latents.
timesteps (torch.Tensor): Timesteps.
ref_image_latents (torch.Tensor): Reference image latents.
face_emb (torch.Tensor): Face embedding.
face_mask (torch.Tensor): Face mask.
uncond_fwd (bool, optional): Unconditional forward pass. Defaults to False.
Returns:
torch.Tensor: Model prediction.
"""
face_emb = self.imageproj(face_emb)
face_mask = face_mask.to(device="cuda")
face_mask_feature = self.face_locator(face_mask)
if not uncond_fwd:
ref_timesteps = torch.zeros_like(timesteps)
self.reference_unet(
ref_image_latents,
ref_timesteps,
encoder_hidden_states=face_emb,
return_dict=False,
)
self.reference_control_reader.update(self.reference_control_writer)
model_pred = self.denoising_unet(
noisy_latents,
timesteps,
mask_cond_fea=face_mask_feature,
encoder_hidden_states=face_emb,
).sample
return model_pred
def get_noise_scheduler(cfg: argparse.Namespace):
"""
Create noise scheduler for training
Args:
cfg (omegaconf.dictconfig.DictConfig): Configuration object.
Returns:
train noise scheduler and val noise scheduler
"""
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
if cfg.enable_zero_snr:
sched_kwargs.update(
rescale_betas_zero_snr=True,
timestep_spacing="trailing",
prediction_type="v_prediction",
)
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
sched_kwargs.update({"beta_schedule": "scaled_linear"})
train_noise_scheduler = DDIMScheduler(**sched_kwargs)
return train_noise_scheduler, val_noise_scheduler
def log_validation(
vae,
net,
scheduler,
accelerator,
width,
height,
imageproj,
cfg,
save_dir,
global_step,
face_analysis_model_path,
):
"""
Log validation generation image.
Args:
vae (nn.Module): Variational Autoencoder model.
net (Net): Main model.
scheduler (diffusers.SchedulerMixin): Noise scheduler.
accelerator (accelerate.Accelerator): Accelerator for training.
width (int): Width of the input images.
height (int): Height of the input images.
imageproj (nn.Module): Image projection model.
cfg (omegaconf.dictconfig.DictConfig): Configuration object.
save_dir (str): directory path to save log result.
global_step (int): Global step number.
Returns:
None
"""
logger.info("Running validation... ")
ori_net = accelerator.unwrap_model(net)
ori_net = copy.deepcopy(ori_net)
reference_unet = ori_net.reference_unet
denoising_unet = ori_net.denoising_unet
face_locator = ori_net.face_locator
generator = torch.manual_seed(42)
image_enc = FaceAnalysis(
name="",
root=face_analysis_model_path,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
image_enc.prepare(ctx_id=0, det_size=(640, 640))
pipe = StaticPipeline(
vae=vae,
reference_unet=reference_unet,
denoising_unet=denoising_unet,
face_locator=face_locator,
scheduler=scheduler,
imageproj=imageproj,
)
pil_images = []
for ref_image_path, mask_image_path in zip(cfg.ref_image_paths, cfg.mask_image_paths):
# for mask_image_path in mask_image_paths:
mask_name = os.path.splitext(
os.path.basename(mask_image_path))[0]
ref_name = os.path.splitext(
os.path.basename(ref_image_path))[0]
ref_image_pil = Image.open(ref_image_path).convert("RGB")
mask_image_pil = Image.open(mask_image_path).convert("RGB")
# Prepare face embeds
face_info = image_enc.get(
cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR))
face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (
x['bbox'][3] - x['bbox'][1]))[-1] # only use the maximum face
face_emb = torch.tensor(face_info['embedding'])
face_emb = face_emb.to(
imageproj.device, imageproj.dtype)
image = pipe(
ref_image_pil,
mask_image_pil,
width,
height,
20,
3.5,
face_emb,
generator=generator,
).images
image = image[0, :, 0].permute(1, 2, 0).cpu().numpy() # (3, 512, 512)
res_image_pil = Image.fromarray((image * 255).astype(np.uint8))
# Save ref_image, src_image and the generated_image
w, h = res_image_pil.size
canvas = Image.new("RGB", (w * 3, h), "white")
ref_image_pil = ref_image_pil.resize((w, h))
mask_image_pil = mask_image_pil.resize((w, h))
canvas.paste(ref_image_pil, (0, 0))
canvas.paste(mask_image_pil, (w, 0))
canvas.paste(res_image_pil, (w * 2, 0))
out_file = os.path.join(
save_dir, f"{global_step:06d}-{ref_name}_{mask_name}.jpg"
)
canvas.save(out_file)
del pipe
del ori_net
torch.cuda.empty_cache()
return pil_images
def train_stage1_process(cfg: argparse.Namespace) -> None:
"""
Trains the model using the given configuration (cfg).
Args:
cfg (dict): The configuration dictionary containing the parameters for training.
Notes:
- This function trains the model using the given configuration.
- It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler.
- The training progress is logged and tracked using the accelerator.
- The trained model is saved after the training is completed.
"""
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
mixed_precision=cfg.solver.mixed_precision,
log_with="mlflow",
project_dir="./mlruns",
kwargs_handlers=[kwargs],
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if cfg.seed is not None:
seed_everything(cfg.seed)
# create output dir for training
exp_name = cfg.exp_name
save_dir = f"{cfg.output_dir}/{exp_name}"
checkpoint_dir = os.path.join(save_dir, "checkpoints")
module_dir = os.path.join(save_dir, "modules")
validation_dir = os.path.join(save_dir, "validation")
if accelerator.is_main_process:
init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir])
accelerator.wait_for_everyone()
# create model
if cfg.weight_dtype == "fp16":
weight_dtype = torch.float16
elif cfg.weight_dtype == "bf16":
weight_dtype = torch.bfloat16
elif cfg.weight_dtype == "fp32":
weight_dtype = torch.float32
else:
raise ValueError(
f"Do not support weight dtype: {cfg.weight_dtype} during training"
)
# create model
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
"cuda", dtype=weight_dtype
)
reference_unet = UNet2DConditionModel.from_pretrained(
cfg.base_model_path,
subfolder="unet",
).to(device="cuda", dtype=weight_dtype)
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
cfg.base_model_path,
"",
subfolder="unet",
unet_additional_kwargs={
"use_motion_module": False,
"unet_use_temporal_attention": False,
},
use_landmark=False
).to(device="cuda", dtype=weight_dtype)
imageproj = ImageProjModel(
cross_attention_dim=denoising_unet.config.cross_attention_dim,
clip_embeddings_dim=512,
clip_extra_context_tokens=4,
).to(device="cuda", dtype=weight_dtype)
if cfg.face_locator_pretrained:
face_locator = FaceLocator(
conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
).to(device="cuda", dtype=weight_dtype)
miss, _ = face_locator.load_state_dict(
cfg.face_state_dict_path, strict=False)
logger.info(f"Missing key for face locator: {len(miss)}")
else:
face_locator = FaceLocator(
conditioning_embedding_channels=320,
).to(device="cuda", dtype=weight_dtype)
# Freeze
vae.requires_grad_(False)
denoising_unet.requires_grad_(True)
reference_unet.requires_grad_(True)
imageproj.requires_grad_(True)
face_locator.requires_grad_(True)
reference_control_writer = ReferenceAttentionControl(
reference_unet,
do_classifier_free_guidance=False,
mode="write",
fusion_blocks="full",
)
reference_control_reader = ReferenceAttentionControl(
denoising_unet,
do_classifier_free_guidance=False,
mode="read",
fusion_blocks="full",
)
net = Net(
reference_unet,
denoising_unet,
face_locator,
reference_control_writer,
reference_control_reader,
imageproj,
).to(dtype=weight_dtype)
# get noise scheduler
train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg)
# init optimizer
if cfg.solver.enable_xformers_memory_efficient_attention:
if is_xformers_available():
reference_unet.enable_xformers_memory_efficient_attention()
denoising_unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError(
"xformers is not available. Make sure it is installed correctly"
)
if cfg.solver.gradient_checkpointing:
reference_unet.enable_gradient_checkpointing()
denoising_unet.enable_gradient_checkpointing()
if cfg.solver.scale_lr:
learning_rate = (
cfg.solver.learning_rate
* cfg.solver.gradient_accumulation_steps
* cfg.data.train_bs
* accelerator.num_processes
)
else:
learning_rate = cfg.solver.learning_rate
# Initialize the optimizer
if cfg.solver.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError as exc:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
) from exc
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
trainable_params = list(
filter(lambda p: p.requires_grad, net.parameters()))
optimizer = optimizer_cls(
trainable_params,
lr=learning_rate,
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
weight_decay=cfg.solver.adam_weight_decay,
eps=cfg.solver.adam_epsilon,
)
# init scheduler
lr_scheduler = get_scheduler(
cfg.solver.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.solver.lr_warmup_steps
* cfg.solver.gradient_accumulation_steps,
num_training_steps=cfg.solver.max_train_steps
* cfg.solver.gradient_accumulation_steps,
)
# get data loader
train_dataset = FaceMaskDataset(
img_size=(cfg.data.train_width, cfg.data.train_height),
data_meta_paths=cfg.data.meta_paths,
sample_margin=cfg.data.sample_margin,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
)
# Prepare everything with our `accelerator`.
(
net,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(
net,
optimizer,
train_dataloader,
lr_scheduler,
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / cfg.solver.gradient_accumulation_steps
)
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(
cfg.solver.max_train_steps / num_update_steps_per_epoch
)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
run_time = datetime.now().strftime("%Y%m%d-%H%M")
accelerator.init_trackers(
cfg.exp_name,
init_kwargs={"mlflow": {"run_name": run_time}},
)
# dump config file
mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
logger.info(f"save config to {save_dir}")
OmegaConf.save(
cfg, os.path.join(save_dir, "config.yaml")
)
# Train!
total_batch_size = (
cfg.data.train_bs
* accelerator.num_processes
* cfg.solver.gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
global_step = 0
first_epoch = 0
# load checkpoint
# Potentially load in the weights and states from a previous save
if cfg.resume_from_checkpoint:
logger.info(f"Loading checkpoint from {checkpoint_dir}")
global_step = load_checkpoint(cfg, checkpoint_dir, accelerator)
first_epoch = global_step // num_update_steps_per_epoch
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(global_step, cfg.solver.max_train_steps),
disable=not accelerator.is_main_process,
)
progress_bar.set_description("Steps")
net.train()
for _ in range(first_epoch, num_train_epochs):
train_loss = 0.0
for _, batch in enumerate(train_dataloader):
with accelerator.accumulate(net):
# Convert videos to latent space
pixel_values = batch["img"].to(weight_dtype)
with torch.no_grad():
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents.unsqueeze(2) # (b, c, 1, h, w)
latents = latents * 0.18215
noise = torch.randn_like(latents)
if cfg.noise_offset > 0.0:
noise += cfg.noise_offset * torch.randn(
(noise.shape[0], noise.shape[1], 1, 1, 1),
device=noise.device,
)
bsz = latents.shape[0]
# Sample a random timestep for each video
timesteps = torch.randint(
0,
train_noise_scheduler.num_train_timesteps,
(bsz,),
device=latents.device,
)
timesteps = timesteps.long()
face_mask_img = batch["tgt_mask"]
face_mask_img = face_mask_img.unsqueeze(
2)
face_mask_img = face_mask_img.to(weight_dtype)
uncond_fwd = random.random() < cfg.uncond_ratio
face_emb_list = []
ref_image_list = []
for _, (ref_img, face_emb) in enumerate(
zip(batch["ref_img"], batch["face_emb"])
):
if uncond_fwd:
face_emb_list.append(torch.zeros_like(face_emb))
else:
face_emb_list.append(face_emb)
ref_image_list.append(ref_img)
with torch.no_grad():
ref_img = torch.stack(ref_image_list, dim=0).to(
dtype=vae.dtype, device=vae.device
)
ref_image_latents = vae.encode(
ref_img
).latent_dist.sample()
ref_image_latents = ref_image_latents * 0.18215
face_emb = torch.stack(face_emb_list, dim=0).to(
dtype=imageproj.dtype, device=imageproj.device
)
# add noise
noisy_latents = train_noise_scheduler.add_noise(
latents, noise, timesteps
)
# Get the target for loss depending on the prediction type
if train_noise_scheduler.prediction_type == "epsilon":
target = noise
elif train_noise_scheduler.prediction_type == "v_prediction":
target = train_noise_scheduler.get_velocity(
latents, noise, timesteps
)
else:
raise ValueError(
f"Unknown prediction type {train_noise_scheduler.prediction_type}"
)
model_pred = net(
noisy_latents,
timesteps,
ref_image_latents,
face_emb,
face_mask_img,
uncond_fwd,
)
if cfg.snr_gamma == 0:
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="mean"
)
else:
snr = compute_snr(train_noise_scheduler, timesteps)
if train_noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack(
[snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
loss = (
loss.mean(dim=list(range(1, len(loss.shape))))
* mse_loss_weights
)
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(
loss.repeat(cfg.data.train_bs)).mean()
train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(
trainable_params,
cfg.solver.max_grad_norm,
)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
reference_control_reader.clear()
reference_control_writer.clear()
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % cfg.checkpointing_steps == 0 or global_step == cfg.solver.max_train_steps:
accelerator.wait_for_everyone()
save_path = os.path.join(
checkpoint_dir, f"checkpoint-{global_step}")
if accelerator.is_main_process:
delete_additional_ckpt(checkpoint_dir, 3)
accelerator.save_state(save_path)
accelerator.wait_for_everyone()
unwrap_net = accelerator.unwrap_model(net)
if accelerator.is_main_process:
save_checkpoint(
unwrap_net.reference_unet,
module_dir,
"reference_unet",
global_step,
total_limit=3,
)
save_checkpoint(
unwrap_net.imageproj,
module_dir,
"imageproj",
global_step,
total_limit=3,
)
save_checkpoint(
unwrap_net.denoising_unet,
module_dir,
"denoising_unet",
global_step,
total_limit=3,
)
save_checkpoint(
unwrap_net.face_locator,
module_dir,
"face_locator",
global_step,
total_limit=3,
)
if global_step % cfg.val.validation_steps == 0 or global_step == 1:
if accelerator.is_main_process:
generator = torch.Generator(device=accelerator.device)
generator.manual_seed(cfg.seed)
log_validation(
vae=vae,
net=net,
scheduler=val_noise_scheduler,
accelerator=accelerator,
width=cfg.data.train_width,
height=cfg.data.train_height,
imageproj=imageproj,
cfg=cfg,
save_dir=validation_dir,
global_step=global_step,
face_analysis_model_path=cfg.face_analysis_model_path
)
logs = {
"step_loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
if global_step >= cfg.solver.max_train_steps:
# process final module weight for stage2
if accelerator.is_main_process:
move_final_checkpoint(save_dir, module_dir, "reference_unet")
move_final_checkpoint(save_dir, module_dir, "imageproj")
move_final_checkpoint(save_dir, module_dir, "denoising_unet")
move_final_checkpoint(save_dir, module_dir, "face_locator")
break
accelerator.wait_for_everyone()
accelerator.end_training()
def load_config(config_path: str) -> dict:
"""
Loads the configuration file.
Args:
config_path (str): Path to the configuration file.
Returns:
dict: The configuration dictionary.
"""
if config_path.endswith(".yaml"):
return OmegaConf.load(config_path)
if config_path.endswith(".py"):
return import_filename(config_path).cfg
raise ValueError("Unsupported format for config file")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str,
default="./configs/train/stage1.yaml")
args = parser.parse_args()
try:
config = load_config(args.config)
train_stage1_process(config)
except Exception as e:
logging.error("Failed to execute the training process: %s", e)