|
import copy |
|
|
|
import numpy as np |
|
import torch |
|
from pytorch_lightning.callbacks import * |
|
from torch.optim.optimizer import Optimizer |
|
|
|
from transformers import PreTrainedModel |
|
|
|
from .DiffAEConfig import DiffAEConfig |
|
from .DiffAE_support import * |
|
|
|
class DiffAE(PreTrainedModel): |
|
config_class = DiffAEConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
conf = ukbb_autoenc(n_latents=config.latent_dim) |
|
conf.__dict__.update(**vars(config)) |
|
|
|
if config.test_with_TEval: |
|
conf.T_inv = conf.T_eval |
|
conf.T_step = conf.T_eval |
|
|
|
conf.fp16 = config.ampmode not in ["32", "32-true"] |
|
|
|
conf.refresh_values() |
|
conf.make_model_conf() |
|
|
|
self.config = config |
|
self.conf = conf |
|
|
|
self.net = conf.make_model_conf().make_model() |
|
self.ema_net = copy.deepcopy(self.net) |
|
self.ema_net.requires_grad_(False) |
|
self.ema_net.eval() |
|
|
|
model_size = sum(param.data.nelement() for param in self.net.parameters()) |
|
print('Model params: %.2f M' % (model_size / 1024 / 1024)) |
|
|
|
self.sampler = conf.make_diffusion_conf().make_sampler() |
|
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() |
|
|
|
|
|
self.T_sampler = conf.make_T_sampler() |
|
|
|
if conf.train_mode.use_latent_net(): |
|
self.latent_sampler = conf.make_latent_diffusion_conf( |
|
).make_sampler() |
|
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( |
|
).make_sampler() |
|
else: |
|
self.latent_sampler = None |
|
self.eval_latent_sampler = None |
|
|
|
|
|
self.register_buffer('x_T', torch.randn(conf.sample_size, conf.in_channels, *conf.input_shape)) |
|
|
|
if conf.pretrain is not None: |
|
print(f'loading pretrain ... {conf.pretrain.name}') |
|
state = torch.load(conf.pretrain.path, map_location='cpu') |
|
print('step:', state['global_step']) |
|
self.load_state_dict(state['state_dict'], strict=False) |
|
|
|
if conf.latent_infer_path is not None: |
|
print('loading latent stats ...') |
|
state = torch.load(conf.latent_infer_path) |
|
self.conds = state['conds'] |
|
self.register_buffer('conds_mean', state['conds_mean'][None, :]) |
|
self.register_buffer('conds_std', state['conds_std'][None, :]) |
|
else: |
|
self.conds_mean = None |
|
self.conds_std = None |
|
|
|
def normalise(self, cond): |
|
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( |
|
self.device) |
|
return cond |
|
|
|
def denormalise(self, cond): |
|
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( |
|
self.device) |
|
return cond |
|
|
|
def sample(self, N, device, T=None, T_latent=None): |
|
if T is None: |
|
sampler = self.eval_sampler |
|
latent_sampler = self.latent_sampler |
|
else: |
|
sampler = self.conf._make_diffusion_conf(T).make_sampler() |
|
latent_sampler = self.conf._make_latent_diffusion_conf(T_latent).make_sampler() |
|
|
|
noise = torch.randn(N, |
|
self.conf.in_channels, |
|
*self.conf.input_shape, |
|
device=device) |
|
pred_img = render_uncondition( |
|
self.conf, |
|
self.ema_net, |
|
noise, |
|
sampler=sampler, |
|
latent_sampler=latent_sampler, |
|
conds_mean=self.conds_mean, |
|
conds_std=self.conds_std, |
|
) |
|
pred_img = (pred_img + 1) / 2 |
|
return pred_img |
|
|
|
def render(self, noise, cond=None, T=None, use_ema=True): |
|
if T is None: |
|
sampler = self.eval_sampler |
|
else: |
|
sampler = self.conf._make_diffusion_conf(T).make_sampler() |
|
|
|
if cond is not None: |
|
pred_img = render_condition(self.conf, |
|
self.ema_net if use_ema else self.net, |
|
noise, |
|
sampler=sampler, |
|
cond=cond) |
|
else: |
|
pred_img = render_uncondition(self.conf, |
|
self.ema_net if use_ema else self.net, |
|
noise, |
|
sampler=sampler, |
|
latent_sampler=None) |
|
pred_img = (pred_img + 1) / 2 |
|
return pred_img |
|
|
|
def encode(self, x, use_ema=True): |
|
assert self.conf.model_type.has_autoenc() |
|
return self.ema_net.encoder.forward(x) if use_ema else self.net.encoder.forward(x) |
|
|
|
def encode_stochastic(self, x, cond, T=None, use_ema=True): |
|
if T is None: |
|
sampler = self.eval_sampler |
|
else: |
|
sampler = self.conf._make_diffusion_conf(T).make_sampler() |
|
out = sampler.ddim_reverse_sample_loop(self.ema_net if use_ema else self.net, |
|
x, |
|
model_kwargs={'cond': cond}) |
|
return out['sample'] |
|
|
|
def forward(self, x_start=None, noise=None, ema_model: bool = False): |
|
with amp.autocast(False): |
|
model = self.ema_net if ema_model else self.net |
|
return self.eval_sampler.sample( |
|
model=model, |
|
noise=noise, |
|
x_start=x_start, |
|
shape=noise.shape if noise is not None else x_start.shape, |
|
) |
|
|
|
def is_last_accum(self, batch_idx): |
|
""" |
|
is it the last gradient accumulation loop? |
|
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not |
|
""" |
|
return (batch_idx + 1) % self.conf.accum_batches == 0 |
|
|
|
def training_step(self, batch, batch_idx): |
|
""" |
|
given an input, calculate the loss function |
|
no optimization at this stage. |
|
""" |
|
with amp.autocast(False): |
|
|
|
if self.conf.train_mode.require_dataset_infer(): |
|
|
|
cond = batch[0] |
|
if self.conf.latent_znormalize: |
|
cond = (cond - self.conds_mean.to( |
|
self.device)) / self.conds_std.to(self.device) |
|
else: |
|
imgs, idxs = batch['inp']['data'], batch_idx |
|
|
|
x_start = imgs |
|
|
|
if self.conf.train_mode == TrainMode.diffusion: |
|
""" |
|
main training mode!!! |
|
""" |
|
|
|
t, weight = self.T_sampler.sample(len(x_start), x_start.device) |
|
losses = self.sampler.training_losses(model=self.net, |
|
x_start=x_start, |
|
t=t) |
|
elif self.conf.train_mode.is_latent_diffusion(): |
|
""" |
|
training the latent variables! |
|
""" |
|
|
|
t, weight = self.T_sampler.sample(len(cond), cond.device) |
|
latent_losses = self.latent_sampler.training_losses( |
|
model=self.net.latent_net, x_start=cond, t=t) |
|
|
|
losses = { |
|
'latent': latent_losses['loss'], |
|
'loss': latent_losses['loss'] |
|
} |
|
else: |
|
raise NotImplementedError() |
|
|
|
loss = losses['loss'].mean() |
|
loss_dict = {"train_loss": loss} |
|
for key in ['vae', 'latent', 'mmd', 'chamfer', 'arg_cnt']: |
|
if key in losses: |
|
loss_dict[f'train_{key}'] = losses[key].mean() |
|
self.log_dict(loss_dict, on_step=True, on_epoch=True, reduce_fx="mean", sync_dist=True, batch_size=batch['inp']['data'].shape[0]) |
|
|
|
return loss |
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None: |
|
""" |
|
after each training step ... |
|
""" |
|
if self.is_last_accum(batch_idx): |
|
|
|
|
|
if self.conf.train_mode == TrainMode.latent_diffusion: |
|
|
|
ema(self.net.latent_net, self.ema_net.latent_net, |
|
self.conf.ema_decay) |
|
else: |
|
ema(self.net, self.ema_net, self.conf.ema_decay) |
|
|
|
def on_before_optimizer_step(self, optimizer: Optimizer) -> None: |
|
|
|
|
|
if self.conf.grad_clip > 0: |
|
|
|
params = [ |
|
p for group in optimizer.param_groups for p in group['params'] |
|
] |
|
|
|
torch.nn.utils.clip_grad_norm_(params, |
|
max_norm=self.conf.grad_clip) |
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
_, prediction_ema = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_eval, T_step=self.conf.T_eval, use_ema=True) |
|
_, prediction_base = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_eval, T_step=self.conf.T_eval, use_ema=False) |
|
|
|
inp = batch['inp']['data'].cpu() |
|
inp = (inp + 1) / 2 |
|
|
|
_, val_ssim_ema = self._eval_prediction(inp, prediction_ema) |
|
_, val_ssim_base = self._eval_prediction(inp, prediction_base) |
|
|
|
self.log_dict({"val_ssim_ema": val_ssim_ema, "val_ssim_base": val_ssim_base, "val_loss": -val_ssim_ema}, on_step=True, on_epoch=True, reduce_fx="mean", sync_dist=True, batch_size=batch['inp']['data'].shape[0]) |
|
self.img_logger("val_ema", batch_idx, inp, prediction_ema) |
|
self.img_logger("val_base", batch_idx, inp, prediction_base) |
|
|
|
def _eval_prediction(self, inp, prediction): |
|
prediction = prediction.detach().cpu() |
|
prediction = prediction.numpy() if prediction.dtype not in {torch.bfloat16, torch.float16} else prediction.to(dtype=torch.float32).numpy() |
|
if self.config.grey2RGB in [0, 2]: |
|
inp = inp[:, 1, ...].unsqueeze(1) |
|
prediction = np.expand_dims(prediction[:, 1, ...], axis=1) |
|
val_ssim = getSSIM(inp.numpy(), prediction, data_range=1) |
|
return prediction, val_ssim |
|
|
|
def inference_pass(self, inp, T_inv, T_step, use_ema=True): |
|
semantic_latent = self.encode(inp, use_ema=use_ema) |
|
if self.config.test_emb_only: |
|
return semantic_latent, None |
|
stochastic_latent = self.encode_stochastic(inp, semantic_latent, T=T_inv) |
|
prediction = self.render(stochastic_latent, semantic_latent, T=T_step, use_ema=use_ema) |
|
return semantic_latent, prediction |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
emb, recon = self.inference_pass(batch['inp']['data'], T_inv=self.conf.T_inv, T_step=self.conf.T_step, use_ema=self.config.test_ema) |
|
|
|
emb = emb.detach().cpu() |
|
emb = emb.numpy() if emb.dtype not in {torch.bfloat16, torch.float16} else emb.to(dtype=torch.float32).numpy() |
|
|
|
return emb, recon |
|
|
|
|
|
def predict_step(self, batch, batch_idx): |
|
emb = self.encode(batch['inp']['data']).detach().cpu() |
|
return emb.numpy() if emb.dtype not in {torch.bfloat16, torch.float16} else emb.to(dtype=torch.float32).numpy() |
|
|
|
def configure_optimizers(self): |
|
if self.conf.optimizer == OptimizerType.adam: |
|
optim = torch.optim.Adam(self.net.parameters(), |
|
lr=self.conf.lr, |
|
weight_decay=self.conf.weight_decay) |
|
elif self.conf.optimizer == OptimizerType.adamw: |
|
optim = torch.optim.AdamW(self.net.parameters(), |
|
lr=self.conf.lr, |
|
weight_decay=self.conf.weight_decay) |
|
else: |
|
raise NotImplementedError() |
|
out = {'optimizer': optim} |
|
if self.conf.warmup > 0: |
|
sched = torch.optim.lr_scheduler.LambdaLR(optim, |
|
lr_lambda=WarmupLR( |
|
self.conf.warmup)) |
|
out['lr_scheduler'] = { |
|
'scheduler': sched, |
|
'interval': 'step', |
|
} |
|
return out |
|
|
|
def split_tensor(self, x): |
|
""" |
|
extract the tensor for a corresponding "worker" in the batch dimension |
|
|
|
Args: |
|
x: (n, c) |
|
|
|
Returns: x: (n_local, c) |
|
""" |
|
n = len(x) |
|
rank = self.global_rank |
|
world_size = get_world_size() |
|
|
|
per_rank = n // world_size |
|
return x[rank * per_rank:(rank + 1) * per_rank] |