# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from tqdm.auto import tqdm import os, argparse, datetime, math import logging from omegaconf import OmegaConf import shutil from latentsync.data.syncnet_dataset import SyncNetDataset from latentsync.models.syncnet import SyncNet from latentsync.models.syncnet_wav2lip import SyncNetWav2Lip from latentsync.utils.util import gather_loss, plot_loss_chart from accelerate.utils import set_seed import torch from diffusers import AutoencoderKL from diffusers.utils.logging import get_logger from einops import rearrange import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from latentsync.utils.util import init_dist, cosine_loss logger = get_logger(__name__) def main(config): # Initialize distributed training local_rank = init_dist() global_rank = dist.get_rank() num_processes = dist.get_world_size() is_main_process = global_rank == 0 seed = config.run.seed + global_rank set_seed(seed) # Logging folder folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S") output_dir = os.path.join(config.data.train_output_dir, folder_name) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Handle the output folder creation if is_main_process: os.makedirs(output_dir, exist_ok=True) os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) os.makedirs(f"{output_dir}/loss_charts", exist_ok=True) shutil.copy(config.config_path, output_dir) device = torch.device(local_rank) if config.data.latent_space: vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) vae.requires_grad_(False) vae.to(device) else: vae = None # Dataset and Dataloader setup train_dataset = SyncNetDataset(config.data.train_data_dir, config.data.train_fileslist, config) val_dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config) train_distributed_sampler = DistributedSampler( train_dataset, num_replicas=num_processes, rank=global_rank, shuffle=True, seed=config.run.seed, ) # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=config.data.batch_size, shuffle=False, sampler=train_distributed_sampler, num_workers=config.data.num_workers, pin_memory=False, drop_last=True, worker_init_fn=train_dataset.worker_init_fn, ) num_samples_limit = 640 val_batch_size = min( num_samples_limit // config.data.num_frames, config.data.batch_size ) # limit batch size to avoid CUDA OOM val_dataloader = torch.utils.data.DataLoader( val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=config.data.num_workers, pin_memory=False, drop_last=False, worker_init_fn=val_dataset.worker_init_fn, ) # Model syncnet = SyncNet(OmegaConf.to_container(config.model)).to(device) # syncnet = SyncNetWav2Lip().to(device) optimizer = torch.optim.AdamW( list(filter(lambda p: p.requires_grad, syncnet.parameters())), lr=config.optimizer.lr ) if config.ckpt.resume_ckpt_path != "": if is_main_process: logger.info(f"Load checkpoint from: {config.ckpt.resume_ckpt_path}") ckpt = torch.load(config.ckpt.resume_ckpt_path, map_location=device) syncnet.load_state_dict(ckpt["state_dict"]) global_step = ckpt["global_step"] train_step_list = ckpt["train_step_list"] train_loss_list = ckpt["train_loss_list"] val_step_list = ckpt["val_step_list"] val_loss_list = ckpt["val_loss_list"] else: global_step = 0 train_step_list = [] train_loss_list = [] val_step_list = [] val_loss_list = [] # DDP wrapper syncnet = DDP(syncnet, device_ids=[local_rank], output_device=local_rank) num_update_steps_per_epoch = math.ceil(len(train_dataloader)) num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch) # validation_steps = int(config.ckpt.save_ckpt_steps // 5) # validation_steps = 100 if is_main_process: logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Instantaneous batch size per device = {config.data.batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {config.data.batch_size * num_processes}") logger.info(f" Total optimization steps = {config.run.max_train_steps}") first_epoch = global_step // num_update_steps_per_epoch num_val_batches = config.data.num_val_samples // (num_processes * config.data.batch_size) # Only show the progress bar once on each machine. progress_bar = tqdm( range(0, config.run.max_train_steps), initial=global_step, desc="Steps", disable=not is_main_process ) # Support mixed-precision training scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None for epoch in range(first_epoch, num_train_epochs): train_dataloader.sampler.set_epoch(epoch) syncnet.train() for step, batch in enumerate(train_dataloader): ### >>>> Training >>>> ### frames = batch["frames"].to(device, dtype=torch.float16) audio_samples = batch["audio_samples"].to(device, dtype=torch.float16) y = batch["y"].to(device, dtype=torch.float32) if config.data.latent_space: max_batch_size = ( num_samples_limit // config.data.num_frames ) # due to the limited cuda memory, we split the input frames into parts if frames.shape[0] > max_batch_size: assert ( frames.shape[0] % max_batch_size == 0 ), f"max_batch_size {max_batch_size} should be divisible by batch_size {frames.shape[0]}" frames_part_results = [] for i in range(0, frames.shape[0], max_batch_size): frames_part = frames[i : i + max_batch_size] frames_part = rearrange(frames_part, "b f c h w -> (b f) c h w") with torch.no_grad(): frames_part = vae.encode(frames_part).latent_dist.sample() * 0.18215 frames_part_results.append(frames_part) frames = torch.cat(frames_part_results, dim=0) else: frames = rearrange(frames, "b f c h w -> (b f) c h w") with torch.no_grad(): frames = vae.encode(frames).latent_dist.sample() * 0.18215 frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames) else: frames = rearrange(frames, "b f c h w -> b (f c) h w") if config.data.lower_half: height = frames.shape[2] frames = frames[:, :, height // 2 :, :] # audio_embeds = wav2vec_encoder(audio_samples).last_hidden_state # Mixed-precision training with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training): vision_embeds, audio_embeds = syncnet(frames, audio_samples) loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean() optimizer.zero_grad() # Backpropagate if config.run.mixed_precision_training: scaler.scale(loss).backward() """ >>> gradient clipping >>> """ scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm) """ <<< gradient clipping <<< """ scaler.step(optimizer) scaler.update() else: loss.backward() """ >>> gradient clipping >>> """ torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm) """ <<< gradient clipping <<< """ optimizer.step() progress_bar.update(1) global_step += 1 global_average_loss = gather_loss(loss, device) train_step_list.append(global_step) train_loss_list.append(global_average_loss) if is_main_process and global_step % config.run.validation_steps == 0: logger.info(f"Validation at step {global_step}") val_loss = validation( val_dataloader, device, syncnet, cosine_loss, config.data.latent_space, config.data.lower_half, vae, num_val_batches, ) val_step_list.append(global_step) val_loss_list.append(val_loss) logger.info(f"Validation loss at step {global_step} is {val_loss:0.3f}") if is_main_process and global_step % config.ckpt.save_ckpt_steps == 0: checkpoint_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt") torch.save( { "state_dict": syncnet.module.state_dict(), # to unwrap DDP "global_step": global_step, "train_step_list": train_step_list, "train_loss_list": train_loss_list, "val_step_list": val_step_list, "val_loss_list": val_loss_list, }, checkpoint_save_path, ) logger.info(f"Saved checkpoint to {checkpoint_save_path}") plot_loss_chart( os.path.join(output_dir, f"loss_charts/loss_chart-{global_step}.png"), ("Train loss", train_step_list, train_loss_list), ("Val loss", val_step_list, val_loss_list), ) progress_bar.set_postfix({"step_loss": global_average_loss}) if global_step >= config.run.max_train_steps: break progress_bar.close() dist.destroy_process_group() @torch.no_grad() def validation(val_dataloader, device, syncnet, cosine_loss, latent_space, lower_half, vae, num_val_batches): syncnet.eval() losses = [] val_step = 0 while True: for step, batch in enumerate(val_dataloader): ### >>>> Validation >>>> ### frames = batch["frames"].to(device, dtype=torch.float16) audio_samples = batch["audio_samples"].to(device, dtype=torch.float16) y = batch["y"].to(device, dtype=torch.float32) if latent_space: num_frames = frames.shape[1] frames = rearrange(frames, "b f c h w -> (b f) c h w") frames = vae.encode(frames).latent_dist.sample() * 0.18215 frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=num_frames) else: frames = rearrange(frames, "b f c h w -> b (f c) h w") if lower_half: height = frames.shape[2] frames = frames[:, :, height // 2 :, :] with torch.autocast(device_type="cuda", dtype=torch.float16): vision_embeds, audio_embeds = syncnet(frames, audio_samples) loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean() losses.append(loss.item()) val_step += 1 if val_step > num_val_batches: syncnet.train() if len(losses) == 0: raise RuntimeError("No validation data") return sum(losses) / len(losses) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Code to train the expert lip-sync discriminator") parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_vae.yaml") args = parser.parse_args() # Load a configuration file config = OmegaConf.load(args.config_path) config.config_path = args.config_path main(config)