from __future__ import annotations import random from typing import Union import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from vwm.modules.diffusionmodules.util import fourier_filter from vwm.modules.encoders.modules import GeneralConditioner from vwm.util import append_dims, instantiate_from_config from .denoiser import Denoiser class StandardDiffusionLoss(nn.Module): def __init__( self, sigma_sampler_config: dict, loss_weighting_config: dict, loss_type: str = "l2", use_additional_loss: bool = False, offset_noise_level: float = 0.0, additional_loss_weight: float = 0.0, num_frames: int = 25, replace_cond_frames: bool = False, cond_frames_choices: Union[list, None] = None ): super().__init__() assert loss_type in ["l2", "l1"] self.loss_type = loss_type self.use_additional_loss = use_additional_loss self.sigma_sampler = instantiate_from_config(sigma_sampler_config) self.loss_weighting = instantiate_from_config(loss_weighting_config) self.offset_noise_level = offset_noise_level self.additional_loss_weight = additional_loss_weight self.num_frames = num_frames self.replace_cond_frames = replace_cond_frames self.cond_frames_choices = cond_frames_choices def get_noised_input( self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor ) -> torch.Tensor: noised_input = input + noise * sigmas_bc return noised_input def forward( self, network: nn.Module, denoiser: Denoiser, conditioner: GeneralConditioner, input: torch.Tensor, batch: dict ) -> torch.Tensor: cond = conditioner(batch) return self._forward(network, denoiser, cond, input) def _forward( self, network: nn.Module, denoiser: Denoiser, cond: dict, input: torch.Tensor ): sigmas = self.sigma_sampler(input.shape[0]).to(input) cond_mask = torch.zeros_like(sigmas) if self.replace_cond_frames: cond_mask = rearrange(cond_mask, "(b t) -> b t", t=self.num_frames) for each_cond_mask in cond_mask: assert len(self.cond_frames_choices[-1]) < self.num_frames weights = [2 ** n for n in range(len(self.cond_frames_choices))] cond_indices = random.choices(self.cond_frames_choices, weights=weights, k=1)[0] if cond_indices: each_cond_mask[cond_indices] = 1 cond_mask = rearrange(cond_mask, "b t -> (b t)") noise = torch.randn_like(input) if self.offset_noise_level > 0.0: # the entire channel is shifted together offset_shape = (input.shape[0], input.shape[1]) # offset_shape = (input.shape[0] // self.num_frames, 1, input.shape[1]) rand_init = torch.randn(offset_shape, device=input.device) # rand_init = repeat(rand_init, "b 1 c -> (b t) c", t=self.num_frames) noise = noise + self.offset_noise_level * append_dims(rand_init, input.ndim) if self.replace_cond_frames: sigmas_bc = append_dims((1 - cond_mask) * sigmas, input.ndim) else: sigmas_bc = append_dims(sigmas, input.ndim) noised_input = self.get_noised_input(sigmas_bc, noise, input) model_output = denoiser(network, noised_input, sigmas, cond, cond_mask) w = append_dims(self.loss_weighting(sigmas), input.ndim) if self.replace_cond_frames: # ignore mask predictions predict = model_output * append_dims(1 - cond_mask, input.ndim) + input * append_dims(cond_mask, input.ndim) else: predict = model_output return self.get_loss(predict, input, w) def get_loss(self, predict, target, w): if self.loss_type == "l2": if self.use_additional_loss: predict_seq = rearrange(predict, "(b t) ... -> b t ...", t=self.num_frames) target_seq = rearrange(target, "(b t) ... -> b t ...", t=self.num_frames) bs = target.shape[0] // self.num_frames aux_loss = ((target_seq[:, 1:] - target_seq[:, :-1]) - (predict_seq[:, 1:] - predict_seq[:, :-1])) ** 2 tmp_h, tmp_w = aux_loss.shape[-2], aux_loss.shape[-1] aux_loss = rearrange(aux_loss, "b t c h w -> b (t h w) c", c=4) aux_w = F.normalize(aux_loss, p=2) aux_w = rearrange(aux_w, "b (t h w) c -> b t c h w", t=self.num_frames - 1, h=tmp_h, w=tmp_w) aux_w = 1 + torch.cat((torch.zeros(bs, 1, *aux_w.shape[2:]).to(aux_w), aux_w), dim=1) aux_w = rearrange(aux_w, "b t ... -> (b t) ...").reshape(target.shape[0], -1) predict_hf = fourier_filter(predict, scale=0.) target_hf = fourier_filter(target, scale=0.) hf_loss = torch.mean((w * (predict_hf - target_hf) ** 2).reshape(target.shape[0], -1), 1).mean() return torch.mean( (w * (predict - target) ** 2).reshape(target.shape[0], -1) * aux_w.detach(), 1 ).mean() + self.additional_loss_weight * hf_loss else: return torch.mean( (w * (predict - target) ** 2).reshape(target.shape[0], -1), 1 ) elif self.loss_type == "l1": if self.use_additional_loss: predict_seq = rearrange(predict, "(b t) ... -> b t ...", t=self.num_frames) target_seq = rearrange(target, "(b t) ... -> b t ...", t=self.num_frames) bs = target.shape[0] // self.num_frames aux_loss = ((target_seq[:, 1:] - target_seq[:, :-1]) - (predict_seq[:, 1:] - predict_seq[:, :-1])).abs() tmp_h, tmp_w = aux_loss.shape[-2], aux_loss.shape[-1] aux_loss = rearrange(aux_loss, "b t c h w -> b (t h w) c", c=4) aux_w = F.normalize(aux_loss, p=1) aux_w = rearrange(aux_w, "b (t h w) c -> b t c h w", t=self.num_frames - 1, h=tmp_h, w=tmp_w) aux_w = 1 + torch.cat((torch.zeros(bs, 1, *aux_w.shape[2:]).to(aux_w), aux_w), dim=1) aux_w = rearrange(aux_w, "b t ... -> (b t) ...").reshape(target.shape[0], -1) predict_hf = fourier_filter(predict, scale=0.) target_hf = fourier_filter(target, scale=0.) hf_loss = torch.mean((w * (predict_hf - target_hf).abs()).reshape(target.shape[0], -1), 1).mean() return torch.mean( (w * (predict - target).abs()).reshape(target.shape[0], -1) * aux_w.detach(), 1 ).mean() + self.additional_loss_weight * hf_loss else: return torch.mean( (w * (predict - target).abs()).reshape(target.shape[0], -1), 1 ) else: raise NotImplementedError(f"Unknown loss type {self.loss_type}")