import copy import os import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import * from torch.cuda import amp from torch.optim.optimizer import Optimizer from torch.utils.data.dataset import TensorDataset from model.seq2seq import DiffusionPredictor from config import * from dist_utils import * from renderer import * # This part is modified from: https://github.com/phizaz/diffae/blob/master/experiment.py class LitModel(pl.LightningModule): def __init__(self, conf: TrainConfig): super().__init__() assert conf.train_mode != TrainMode.manipulate if conf.seed is not None: pl.seed_everything(conf.seed) self.save_hyperparameters(conf.as_dict_jsonable()) self.conf = conf self.model = DiffusionPredictor(conf) self.ema_model = copy.deepcopy(self.model) self.ema_model.requires_grad_(False) self.ema_model.eval() self.sampler = conf.make_diffusion_conf().make_sampler() self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() # this is shared for both model and latent self.T_sampler = conf.make_T_sampler() if conf.train_mode.use_latent_net(): self.latent_sampler = conf.make_latent_diffusion_conf( ).make_sampler() self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( ).make_sampler() else: self.latent_sampler = None self.eval_latent_sampler = None # initial variables for consistent sampling self.register_buffer( 'x_T', torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size)) def render(self, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, step_T, control_flag): if step_T is None: sampler = self.eval_sampler else: sampler = self.conf._make_diffusion_conf(step_T).make_sampler() pred_img = render_condition(self.conf, self.ema_model, sampler, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, control_flag) return pred_img def forward(self, noise=None, x_start=None, ema_model: bool = False): with amp.autocast(False): if not self.disable_ema: model = self.ema_model else: model = self.model gen = self.eval_sampler.sample(model=model, noise=noise, x_start=x_start) return gen def setup(self, stage=None) -> None: """ make datasets & seeding each worker separately """ ############################################## # NEED TO SET THE SEED SEPARATELY HERE if self.conf.seed is not None: seed = self.conf.seed * get_world_size() + self.global_rank np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) print('local seed:', seed) ############################################## self.train_data = self.conf.make_dataset() print('train data:', len(self.train_data)) self.val_data = self.train_data print('val data:', len(self.val_data)) def _train_dataloader(self, drop_last=True): """ really make the dataloader """ # make sure to use the fraction of batch size # the batch size is global! conf = self.conf.clone() conf.batch_size = self.batch_size dataloader = conf.make_loader(self.train_data, shuffle=True, drop_last=drop_last) return dataloader def train_dataloader(self): """ return the dataloader, if diffusion mode => return image dataset if latent mode => return the inferred latent dataset """ print('on train dataloader start ...') if self.conf.train_mode.require_dataset_infer(): if self.conds is None: # usually we load self.conds from a file # so we do not need to do this again! self.conds = self.infer_whole_dataset() # need to use float32! unless the mean & std will be off! # (1, c) self.conds_mean.data = self.conds.float().mean(dim=0, keepdim=True) self.conds_std.data = self.conds.float().std(dim=0, keepdim=True) print('mean:', self.conds_mean.mean(), 'std:', self.conds_std.mean()) # return the dataset with pre-calculated conds conf = self.conf.clone() conf.batch_size = self.batch_size data = TensorDataset(self.conds) return conf.make_loader(data, shuffle=True) else: return self._train_dataloader() @property def batch_size(self): """ local batch size for each worker """ ws = get_world_size() assert self.conf.batch_size % ws == 0 return self.conf.batch_size // ws @property def num_samples(self): """ (global) batch size * iterations """ # batch size here is global! # global_step already takes into account the accum batches return self.global_step * self.conf.batch_size_effective def is_last_accum(self, batch_idx): """ is it the last gradient accumulation loop? used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not """ return (batch_idx + 1) % self.conf.accum_batches == 0 def training_step(self, batch, batch_idx): """ given an input, calculate the loss function no optimization at this stage. """ with amp.autocast(False): motion_start = batch['motion_start'] # torch.Size([B, 512]) motion_direction = batch['motion_direction'] # torch.Size([B, 125, 20]) audio_feats = batch['audio_feats'].float() # torch.Size([B, 25, 250, 1024]) face_location = batch['face_location'].float() # torch.Size([B, 125]) face_scale = batch['face_scale'].float() # torch.Size([B, 125, 1]) yaw_pitch_roll = batch['yaw_pitch_roll'].float() # torch.Size([B, 125, 3]) motion_direction_start = batch['motion_direction_start'].float() # torch.Size([B, 20]) # import pdb; pdb.set_trace() if self.conf.train_mode == TrainMode.diffusion: """ main training mode!!! """ # with numpy seed we have the problem that the sample t's are related! t, weight = self.T_sampler.sample(len(motion_start), motion_start.device) losses = self.sampler.training_losses(model=self.model, motion_direction_start=motion_direction_start, motion_target=motion_direction, motion_start=motion_start, audio_feats=audio_feats, face_location=face_location, face_scale=face_scale, yaw_pitch_roll=yaw_pitch_roll, t=t) else: raise NotImplementedError() loss = losses['loss'].mean() # divide by accum batches to make the accumulated gradient exact! for key in losses.keys(): losses[key] = self.all_gather(losses[key]).mean() if self.global_rank == 0: self.logger.experiment.add_scalar('loss', losses['loss'], self.num_samples) for key in losses: self.logger.experiment.add_scalar( f'loss/{key}', losses[key], self.num_samples) return {'loss': loss} def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: """ after each training step ... """ if self.is_last_accum(batch_idx): if self.conf.train_mode == TrainMode.latent_diffusion: # it trains only the latent hence change only the latent ema(self.model.latent_net, self.ema_model.latent_net, self.conf.ema_decay) else: ema(self.model, self.ema_model, self.conf.ema_decay) def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: # fix the fp16 + clip grad norm problem with pytorch lightinng # this is the currently correct way to do it if self.conf.grad_clip > 0: # from trainer.params_grads import grads_norm, iter_opt_params params = [ p for group in optimizer.param_groups for p in group['params'] ] torch.nn.utils.clip_grad_norm_(params, max_norm=self.conf.grad_clip) def configure_optimizers(self): out = {} if self.conf.optimizer == OptimizerType.adam: optim = torch.optim.Adam(self.model.parameters(), lr=self.conf.lr, weight_decay=self.conf.weight_decay) elif self.conf.optimizer == OptimizerType.adamw: optim = torch.optim.AdamW(self.model.parameters(), lr=self.conf.lr, weight_decay=self.conf.weight_decay) else: raise NotImplementedError() out['optimizer'] = optim if self.conf.warmup > 0: sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=WarmupLR( self.conf.warmup)) out['lr_scheduler'] = { 'scheduler': sched, 'interval': 'step', } return out def split_tensor(self, x): """ extract the tensor for a corresponding "worker" in the batch dimension Args: x: (n, c) Returns: x: (n_local, c) """ n = len(x) rank = self.global_rank world_size = get_world_size() # print(f'rank: {rank}/{world_size}') per_rank = n // world_size return x[rank * per_rank:(rank + 1) * per_rank] def ema(source, target, decay): source_dict = source.state_dict() target_dict = target.state_dict() for key in source_dict.keys(): target_dict[key].data.copy_(target_dict[key].data * decay + source_dict[key].data * (1 - decay)) class WarmupLR: def __init__(self, warmup) -> None: self.warmup = warmup def __call__(self, step): return min(step, self.warmup) / self.warmup def is_time(num_samples, every, step_size): closest = (num_samples // every) * every return num_samples - closest < step_size def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'): print('conf:', conf.name) # assert not (conf.fp16 and conf.grad_clip > 0 # ), 'pytorch lightning has bug with amp + gradient clipping' model = LitModel(conf) if not os.path.exists(conf.logdir): os.makedirs(conf.logdir) checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}', save_last=True, save_top_k=-1, every_n_epochs=10) checkpoint_path = f'{conf.logdir}/last.ckpt' print('ckpt path:', checkpoint_path) if os.path.exists(checkpoint_path): resume = checkpoint_path print('resume!') else: if conf.continue_from is not None: # continue from a checkpoint resume = conf.continue_from.pathcd else: resume = None tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, name=None, version='') # from pytorch_lightning. plugins = [] if len(gpus) == 1 and nodes == 1: accelerator = None else: accelerator = 'ddp' from pytorch_lightning.plugins import DDPPlugin # important for working with gradient checkpoint plugins.append(DDPPlugin(find_unused_parameters=True)) trainer = pl.Trainer( max_steps=conf.total_samples // conf.batch_size_effective, resume_from_checkpoint=resume, gpus=gpus, num_nodes=nodes, accelerator=accelerator, precision=16 if conf.fp16 else 32, callbacks=[ checkpoint, LearningRateMonitor(), ], # clip in the model instead # gradient_clip_val=conf.grad_clip, replace_sampler_ddp=True, logger=tb_logger, accumulate_grad_batches=conf.accum_batches, plugins=plugins, ) trainer.fit(model)