Spaces:
Build error
Build error
File size: 7,230 Bytes
d323598 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from __future__ import annotations
import random
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vwm.modules.diffusionmodules.util import fourier_filter
from vwm.modules.encoders.modules import GeneralConditioner
from vwm.util import append_dims, instantiate_from_config
from .denoiser import Denoiser
class StandardDiffusionLoss(nn.Module):
def __init__(
self,
sigma_sampler_config: dict,
loss_weighting_config: dict,
loss_type: str = "l2",
use_additional_loss: bool = False,
offset_noise_level: float = 0.0,
additional_loss_weight: float = 0.0,
num_frames: int = 25,
replace_cond_frames: bool = False,
cond_frames_choices: Union[list, None] = None
):
super().__init__()
assert loss_type in ["l2", "l1"]
self.loss_type = loss_type
self.use_additional_loss = use_additional_loss
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
self.loss_weighting = instantiate_from_config(loss_weighting_config)
self.offset_noise_level = offset_noise_level
self.additional_loss_weight = additional_loss_weight
self.num_frames = num_frames
self.replace_cond_frames = replace_cond_frames
self.cond_frames_choices = cond_frames_choices
def get_noised_input(
self,
sigmas_bc: torch.Tensor,
noise: torch.Tensor,
input: torch.Tensor
) -> torch.Tensor:
noised_input = input + noise * sigmas_bc
return noised_input
def forward(
self,
network: nn.Module,
denoiser: Denoiser,
conditioner: GeneralConditioner,
input: torch.Tensor,
batch: dict
) -> torch.Tensor:
cond = conditioner(batch)
return self._forward(network, denoiser, cond, input)
def _forward(
self,
network: nn.Module,
denoiser: Denoiser,
cond: dict,
input: torch.Tensor
):
sigmas = self.sigma_sampler(input.shape[0]).to(input)
cond_mask = torch.zeros_like(sigmas)
if self.replace_cond_frames:
cond_mask = rearrange(cond_mask, "(b t) -> b t", t=self.num_frames)
for each_cond_mask in cond_mask:
assert len(self.cond_frames_choices[-1]) < self.num_frames
weights = [2 ** n for n in range(len(self.cond_frames_choices))]
cond_indices = random.choices(self.cond_frames_choices, weights=weights, k=1)[0]
if cond_indices:
each_cond_mask[cond_indices] = 1
cond_mask = rearrange(cond_mask, "b t -> (b t)")
noise = torch.randn_like(input)
if self.offset_noise_level > 0.0: # the entire channel is shifted together
offset_shape = (input.shape[0], input.shape[1])
# offset_shape = (input.shape[0] // self.num_frames, 1, input.shape[1])
rand_init = torch.randn(offset_shape, device=input.device)
# rand_init = repeat(rand_init, "b 1 c -> (b t) c", t=self.num_frames)
noise = noise + self.offset_noise_level * append_dims(rand_init, input.ndim)
if self.replace_cond_frames:
sigmas_bc = append_dims((1 - cond_mask) * sigmas, input.ndim)
else:
sigmas_bc = append_dims(sigmas, input.ndim)
noised_input = self.get_noised_input(sigmas_bc, noise, input)
model_output = denoiser(network, noised_input, sigmas, cond, cond_mask)
w = append_dims(self.loss_weighting(sigmas), input.ndim)
if self.replace_cond_frames: # ignore mask predictions
predict = model_output * append_dims(1 - cond_mask, input.ndim) + input * append_dims(cond_mask, input.ndim)
else:
predict = model_output
return self.get_loss(predict, input, w)
def get_loss(self, predict, target, w):
if self.loss_type == "l2":
if self.use_additional_loss:
predict_seq = rearrange(predict, "(b t) ... -> b t ...", t=self.num_frames)
target_seq = rearrange(target, "(b t) ... -> b t ...", t=self.num_frames)
bs = target.shape[0] // self.num_frames
aux_loss = ((target_seq[:, 1:] - target_seq[:, :-1]) - (predict_seq[:, 1:] - predict_seq[:, :-1])) ** 2
tmp_h, tmp_w = aux_loss.shape[-2], aux_loss.shape[-1]
aux_loss = rearrange(aux_loss, "b t c h w -> b (t h w) c", c=4)
aux_w = F.normalize(aux_loss, p=2)
aux_w = rearrange(aux_w, "b (t h w) c -> b t c h w", t=self.num_frames - 1, h=tmp_h, w=tmp_w)
aux_w = 1 + torch.cat((torch.zeros(bs, 1, *aux_w.shape[2:]).to(aux_w), aux_w), dim=1)
aux_w = rearrange(aux_w, "b t ... -> (b t) ...").reshape(target.shape[0], -1)
predict_hf = fourier_filter(predict, scale=0.)
target_hf = fourier_filter(target, scale=0.)
hf_loss = torch.mean((w * (predict_hf - target_hf) ** 2).reshape(target.shape[0], -1), 1).mean()
return torch.mean(
(w * (predict - target) ** 2).reshape(target.shape[0], -1) * aux_w.detach(), 1
).mean() + self.additional_loss_weight * hf_loss
else:
return torch.mean(
(w * (predict - target) ** 2).reshape(target.shape[0], -1), 1
)
elif self.loss_type == "l1":
if self.use_additional_loss:
predict_seq = rearrange(predict, "(b t) ... -> b t ...", t=self.num_frames)
target_seq = rearrange(target, "(b t) ... -> b t ...", t=self.num_frames)
bs = target.shape[0] // self.num_frames
aux_loss = ((target_seq[:, 1:] - target_seq[:, :-1]) - (predict_seq[:, 1:] - predict_seq[:, :-1])).abs()
tmp_h, tmp_w = aux_loss.shape[-2], aux_loss.shape[-1]
aux_loss = rearrange(aux_loss, "b t c h w -> b (t h w) c", c=4)
aux_w = F.normalize(aux_loss, p=1)
aux_w = rearrange(aux_w, "b (t h w) c -> b t c h w", t=self.num_frames - 1, h=tmp_h, w=tmp_w)
aux_w = 1 + torch.cat((torch.zeros(bs, 1, *aux_w.shape[2:]).to(aux_w), aux_w), dim=1)
aux_w = rearrange(aux_w, "b t ... -> (b t) ...").reshape(target.shape[0], -1)
predict_hf = fourier_filter(predict, scale=0.)
target_hf = fourier_filter(target, scale=0.)
hf_loss = torch.mean((w * (predict_hf - target_hf).abs()).reshape(target.shape[0], -1), 1).mean()
return torch.mean(
(w * (predict - target).abs()).reshape(target.shape[0], -1) * aux_w.detach(), 1
).mean() + self.additional_loss_weight * hf_loss
else:
return torch.mean(
(w * (predict - target).abs()).reshape(target.shape[0], -1), 1
)
else:
raise NotImplementedError(f"Unknown loss type {self.loss_type}")
|