self-forcing / utils /loss.py
multimodalart's picture
Upload 80 files
0fd2f06 verified
raw
history blame
2.47 kB
from abc import ABC, abstractmethod
import torch
class DenoisingLoss(ABC):
@abstractmethod
def __call__(
self, x: torch.Tensor, x_pred: torch.Tensor,
noise: torch.Tensor, noise_pred: torch.Tensor,
alphas_cumprod: torch.Tensor,
timestep: torch.Tensor,
**kwargs
) -> torch.Tensor:
"""
Base class for denoising loss.
Input:
- x: the clean data with shape [B, F, C, H, W]
- x_pred: the predicted clean data with shape [B, F, C, H, W]
- noise: the noise with shape [B, F, C, H, W]
- noise_pred: the predicted noise with shape [B, F, C, H, W]
- alphas_cumprod: the cumulative product of alphas (defining the noise schedule) with shape [T]
- timestep: the current timestep with shape [B, F]
"""
pass
class X0PredLoss(DenoisingLoss):
def __call__(
self, x: torch.Tensor, x_pred: torch.Tensor,
noise: torch.Tensor, noise_pred: torch.Tensor,
alphas_cumprod: torch.Tensor,
timestep: torch.Tensor,
**kwargs
) -> torch.Tensor:
return torch.mean((x - x_pred) ** 2)
class VPredLoss(DenoisingLoss):
def __call__(
self, x: torch.Tensor, x_pred: torch.Tensor,
noise: torch.Tensor, noise_pred: torch.Tensor,
alphas_cumprod: torch.Tensor,
timestep: torch.Tensor,
**kwargs
) -> torch.Tensor:
weights = 1 / (1 - alphas_cumprod[timestep].reshape(*timestep.shape, 1, 1, 1))
return torch.mean(weights * (x - x_pred) ** 2)
class NoisePredLoss(DenoisingLoss):
def __call__(
self, x: torch.Tensor, x_pred: torch.Tensor,
noise: torch.Tensor, noise_pred: torch.Tensor,
alphas_cumprod: torch.Tensor,
timestep: torch.Tensor,
**kwargs
) -> torch.Tensor:
return torch.mean((noise - noise_pred) ** 2)
class FlowPredLoss(DenoisingLoss):
def __call__(
self, x: torch.Tensor, x_pred: torch.Tensor,
noise: torch.Tensor, noise_pred: torch.Tensor,
alphas_cumprod: torch.Tensor,
timestep: torch.Tensor,
**kwargs
) -> torch.Tensor:
return torch.mean((kwargs["flow_pred"] - (noise - x)) ** 2)
NAME_TO_CLASS = {
"x0": X0PredLoss,
"v": VPredLoss,
"noise": NoisePredLoss,
"flow": FlowPredLoss
}
def get_denoising_loss(loss_type: str) -> DenoisingLoss:
return NAME_TO_CLASS[loss_type]