Spaces:
Running
Running
import torch | |
from typing import Optional | |
from rstor.properties import LOSS_MSE | |
def compute_loss( | |
predic: torch.Tensor, | |
target: torch.Tensor, | |
mode: Optional[str] = LOSS_MSE | |
) -> torch.Tensor: | |
""" | |
Compute loss based on the predicted and true values. | |
Args: | |
predic (torch.Tensor): [N, C, H, W] predicted values | |
target (torch.Tensor): [N, C, H, W] target values. | |
mode (Optional[str], optional): mode of loss computation. | |
Returns: | |
torch.Tensor: The computed loss. | |
""" | |
assert mode in [LOSS_MSE], f"Mode {mode} not supported" | |
if mode == LOSS_MSE: | |
loss = torch.nn.functional.mse_loss(predic, target) | |
return loss | |