Spaces:
Running
on
Zero
Running
on
Zero
from abc import ABC, abstractmethod | |
import torch | |
class DenoisingLoss(ABC): | |
def __call__( | |
self, x: torch.Tensor, x_pred: torch.Tensor, | |
noise: torch.Tensor, noise_pred: torch.Tensor, | |
alphas_cumprod: torch.Tensor, | |
timestep: torch.Tensor, | |
**kwargs | |
) -> torch.Tensor: | |
""" | |
Base class for denoising loss. | |
Input: | |
- x: the clean data with shape [B, F, C, H, W] | |
- x_pred: the predicted clean data with shape [B, F, C, H, W] | |
- noise: the noise with shape [B, F, C, H, W] | |
- noise_pred: the predicted noise with shape [B, F, C, H, W] | |
- alphas_cumprod: the cumulative product of alphas (defining the noise schedule) with shape [T] | |
- timestep: the current timestep with shape [B, F] | |
""" | |
pass | |
class X0PredLoss(DenoisingLoss): | |
def __call__( | |
self, x: torch.Tensor, x_pred: torch.Tensor, | |
noise: torch.Tensor, noise_pred: torch.Tensor, | |
alphas_cumprod: torch.Tensor, | |
timestep: torch.Tensor, | |
**kwargs | |
) -> torch.Tensor: | |
return torch.mean((x - x_pred) ** 2) | |
class VPredLoss(DenoisingLoss): | |
def __call__( | |
self, x: torch.Tensor, x_pred: torch.Tensor, | |
noise: torch.Tensor, noise_pred: torch.Tensor, | |
alphas_cumprod: torch.Tensor, | |
timestep: torch.Tensor, | |
**kwargs | |
) -> torch.Tensor: | |
weights = 1 / (1 - alphas_cumprod[timestep].reshape(*timestep.shape, 1, 1, 1)) | |
return torch.mean(weights * (x - x_pred) ** 2) | |
class NoisePredLoss(DenoisingLoss): | |
def __call__( | |
self, x: torch.Tensor, x_pred: torch.Tensor, | |
noise: torch.Tensor, noise_pred: torch.Tensor, | |
alphas_cumprod: torch.Tensor, | |
timestep: torch.Tensor, | |
**kwargs | |
) -> torch.Tensor: | |
return torch.mean((noise - noise_pred) ** 2) | |
class FlowPredLoss(DenoisingLoss): | |
def __call__( | |
self, x: torch.Tensor, x_pred: torch.Tensor, | |
noise: torch.Tensor, noise_pred: torch.Tensor, | |
alphas_cumprod: torch.Tensor, | |
timestep: torch.Tensor, | |
**kwargs | |
) -> torch.Tensor: | |
return torch.mean((kwargs["flow_pred"] - (noise - x)) ** 2) | |
NAME_TO_CLASS = { | |
"x0": X0PredLoss, | |
"v": VPredLoss, | |
"noise": NoisePredLoss, | |
"flow": FlowPredLoss | |
} | |
def get_denoising_loss(loss_type: str) -> DenoisingLoss: | |
return NAME_TO_CLASS[loss_type] | |