import copy import numpy as np import torch from pytorch_lightning.callbacks import * from torch.optim.optimizer import Optimizer from transformers import PreTrainedModel from .DiffAEConfig import DiffAEConfig from .DiffAE_support import * class DiffAE(PreTrainedModel): config_class = DiffAEConfig def __init__(self, config): super().__init__(config) conf = ukbb_autoenc(n_latents=config.latent_dim) conf.__dict__.update(**vars(config)) #update the supplied DiffAE params if config.test_with_TEval: conf.T_inv = conf.T_eval conf.T_step = conf.T_eval conf.fp16 = config.ampmode not in ["32", "32-true"] conf.refresh_values() conf.make_model_conf() self.config = config self.conf = conf self.net = conf.make_model_conf().make_model() self.ema_net = copy.deepcopy(self.net) self.ema_net.requires_grad_(False) self.ema_net.eval() model_size = sum(param.data.nelement() for param in self.net.parameters()) print('Model params: %.2f M' % (model_size / 1024 / 1024)) self.sampler = conf.make_diffusion_conf().make_sampler() self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() # this is shared for both model and latent self.T_sampler = conf.make_T_sampler() if conf.train_mode.use_latent_net(): self.latent_sampler = conf.make_latent_diffusion_conf( ).make_sampler() self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( ).make_sampler() else: self.latent_sampler = None self.eval_latent_sampler = None # initial variables for consistent sampling self.register_buffer('x_T', torch.randn(conf.sample_size, conf.in_channels, *conf.input_shape)) if conf.pretrain is not None: print(f'loading pretrain ... {conf.pretrain.name}') state = torch.load(conf.pretrain.path, map_location='cpu') print('step:', state['global_step']) self.load_state_dict(state['state_dict'], strict=False) if conf.latent_infer_path is not None: print('loading latent stats ...') state = torch.load(conf.latent_infer_path) self.conds = state['conds'] self.register_buffer('conds_mean', state['conds_mean'][None, :]) self.register_buffer('conds_std', state['conds_std'][None, :]) else: self.conds_mean = None self.conds_std = None def normalise(self, cond): cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( self.device) return cond def denormalise(self, cond): cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( self.device) return cond def sample(self, N, device, T=None, T_latent=None): if T is None: sampler = self.eval_sampler latent_sampler = self.latent_sampler else: sampler = self.conf._make_diffusion_conf(T).make_sampler() latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler() noise = torch.randn(N, self.conf.in_channels, *self.conf.input_shape, device=device) pred_img = render_uncondition( self.conf, self.ema_net, noise, sampler=sampler, latent_sampler=latent_sampler, conds_mean=self.conds_mean, conds_std=self.conds_std, ) pred_img = (pred_img + 1) / 2 return pred_img def render(self, noise, cond=None, T=None, use_ema=True): if T is None: sampler = self.eval_sampler else: sampler = self.conf._make_diffusion_conf(T).make_sampler() if cond is not None: pred_img = render_condition(self.conf, self.ema_net if use_ema else self.net, noise, sampler=sampler, cond=cond) else: pred_img = render_uncondition(self.conf, self.ema_net if use_ema else self.net, noise, sampler=sampler, latent_sampler=None) pred_img = (pred_img + 1) / 2 return pred_img def encode(self, x, use_ema=True): assert self.conf.model_type.has_autoenc() return self.ema_net.encoder.forward(x) if use_ema else self.net.encoder.forward(x) def encode_stochastic(self, x, cond, T=None, use_ema=True): if T is None: sampler = self.eval_sampler else: sampler = self.conf._make_diffusion_conf(T).make_sampler() out = sampler.ddim_reverse_sample_loop(self.ema_net if use_ema else self.net, x, model_kwargs={'cond': cond}) return out['sample'] def forward(self, x_start=None, noise=None, ema_model: bool = False): with amp.autocast(False): model = self.ema_net if ema_model else self.net return self.eval_sampler.sample( model=model, noise=noise, x_start=x_start, shape=noise.shape if noise is not None else x_start.shape, ) def is_last_accum(self, batch_idx): """ is it the last gradient accumulation loop? used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not """ return (batch_idx + 1) % self.conf.accum_batches == 0 def training_step(self, batch, batch_idx): """ given an input, calculate the loss function no optimization at this stage. """ with amp.autocast(False): # forward if self.conf.train_mode.require_dataset_infer(): # this mode as pre-calculated cond cond = batch[0] if self.conf.latent_znormalize: cond = (cond - self.conds_mean.to( self.device)) / self.conds_std.to(self.device) else: imgs, idxs = batch['inp']['data'], batch_idx # print(f'(rank {self.global_rank}) batch size:', len(imgs)) x_start = imgs if self.conf.train_mode == TrainMode.diffusion: """ main training mode!!! """ # with numpy seed we have the problem that the sample t's are related! t, weight = self.T_sampler.sample(len(x_start), x_start.device) losses = self.sampler.training_losses(model=self.net, x_start=x_start, t=t) elif self.conf.train_mode.is_latent_diffusion(): """ training the latent variables! """ # diffusion on the latent t, weight = self.T_sampler.sample(len(cond), cond.device) latent_losses = self.latent_sampler.training_losses( model=self.net.latent_net, x_start=cond, t=t) # train only do the latent diffusion losses = { 'latent': latent_losses['loss'], 'loss': latent_losses['loss'] } else: raise NotImplementedError() loss = losses['loss'].mean() loss_dict = {"train_loss": loss} for key in ['vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']: if key in losses: loss_dict[f'train_{key}'] = losses[key].mean() self.log_dict(loss_dict, on_step=True, on_epoch=True, reduce_fx="mean", sync_dist=True, batch_size=batch['inp']['data'].shape[0]) return loss def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None: """ after each training step ... """ if self.is_last_accum(batch_idx): # only apply ema on the last gradient accumulation step, # if it is the iteration that has optimizer.step() if self.conf.train_mode == TrainMode.latent_diffusion: # it trains only the latent hence change only the latent ema(self.net.latent_net, self.ema_net.latent_net, self.conf.ema_decay) else: ema(self.net, self.ema_net, self.conf.ema_decay) def on_before_optimizer_step(self, optimizer: Optimizer) -> None: # fix the fp16 + clip grad norm problem with pytorch lightinng # this is the currently correct way to do it if self.conf.grad_clip > 0: # from trainer.params_grads import grads_norm, iter_opt_params params = [ p for group in optimizer.param_groups for p in group['params'] ] # print('before:', grads_norm(iter_opt_params(optimizer))) torch.nn.utils.clip_grad_norm_(params, max_norm=self.conf.grad_clip) # print('after:', grads_norm(iter_opt_params(optimizer))) #Validation def validation_step(self, batch, batch_idx): _, prediction_ema = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_eval, T_step=self.conf.T_eval, use_ema=True) _, prediction_base = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_eval, T_step=self.conf.T_eval, use_ema=False) inp = batch['inp']['data'].cpu() inp = (inp + 1) / 2 _, val_ssim_ema = self._eval_prediction(inp, prediction_ema) _, val_ssim_base = self._eval_prediction(inp, prediction_base) self.log_dict({"val_ssim_ema": val_ssim_ema, "val_ssim_base": val_ssim_base, "val_loss": -val_ssim_ema}, on_step=True, on_epoch=True, reduce_fx="mean", sync_dist=True, batch_size=batch['inp']['data'].shape[0]) self.img_logger("val_ema", batch_idx, inp, prediction_ema) self.img_logger("val_base", batch_idx, inp, prediction_base) def _eval_prediction(self, inp, prediction): prediction = prediction.detach().cpu() prediction = prediction.numpy() if prediction.dtype not in {torch.bfloat16, torch.float16} else prediction.to(dtype=torch.float32).numpy() if self.config.grey2RGB in [0, 2]: inp = inp[:, 1, ...].unsqueeze(1) prediction = np.expand_dims(prediction[:, 1, ...], axis=1) val_ssim = getSSIM(inp.numpy(), prediction, data_range=1) return prediction, val_ssim def inference_pass(self, inp, T_inv, T_step, use_ema=True): semantic_latent = self.encode(inp, use_ema=use_ema) if self.config.test_emb_only: return semantic_latent, None stochastic_latent = self.encode_stochastic(inp, semantic_latent, T=T_inv) prediction = self.render(stochastic_latent, semantic_latent, T=T_step, use_ema=use_ema) return semantic_latent, prediction # Testing def test_step(self, batch, batch_idx): emb, recon = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_inv, T_step=self.conf.T_step, use_ema=self.config.test_ema) emb = emb.detach().cpu() emb = emb.numpy() if emb.dtype not in {torch.bfloat16, torch.float16} else emb.to(dtype=torch.float32).numpy() return emb, recon #Prediction def predict_step(self, batch, batch_idx): emb = self.encode(batch['inp']['data']).detach().cpu() return emb.numpy() if emb.dtype not in {torch.bfloat16, torch.float16} else emb.to(dtype=torch.float32).numpy() def configure_optimizers(self): if self.conf.optimizer == OptimizerType.adam: optim = torch.optim.Adam(self.net.parameters(), lr=self.conf.lr, weight_decay=self.conf.weight_decay) elif self.conf.optimizer == OptimizerType.adamw: optim = torch.optim.AdamW(self.net.parameters(), lr=self.conf.lr, weight_decay=self.conf.weight_decay) else: raise NotImplementedError() out = {'optimizer': optim} if self.conf.warmup > 0: sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=WarmupLR( self.conf.warmup)) out['lr_scheduler'] = { 'scheduler': sched, 'interval': 'step', } return out def split_tensor(self, x): """ extract the tensor for a corresponding "worker" in the batch dimension Args: x: (n, c) Returns: x: (n_local, c) """ n = len(x) rank = self.global_rank world_size = get_world_size() # print(f'rank: {rank}/{world_size}') per_rank = n // world_size return x[rank * per_rank:(rank + 1) * per_rank]