balthou's picture
initiate demo
cec5823
raw
history blame contribute delete
691 Bytes
import torch
from typing import Optional
from rstor.properties import LOSS_MSE
def compute_loss(
predic: torch.Tensor,
target: torch.Tensor,
mode: Optional[str] = LOSS_MSE
) -> torch.Tensor:
"""
Compute loss based on the predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values
target (torch.Tensor): [N, C, H, W] target values.
mode (Optional[str], optional): mode of loss computation.
Returns:
torch.Tensor: The computed loss.
"""
assert mode in [LOSS_MSE], f"Mode {mode} not supported"
if mode == LOSS_MSE:
loss = torch.nn.functional.mse_loss(predic, target)
return loss