Spaces:
Running
Running
File size: 691 Bytes
cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
from typing import Optional
from rstor.properties import LOSS_MSE
def compute_loss(
predic: torch.Tensor,
target: torch.Tensor,
mode: Optional[str] = LOSS_MSE
) -> torch.Tensor:
"""
Compute loss based on the predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values
target (torch.Tensor): [N, C, H, W] target values.
mode (Optional[str], optional): mode of loss computation.
Returns:
torch.Tensor: The computed loss.
"""
assert mode in [LOSS_MSE], f"Mode {mode} not supported"
if mode == LOSS_MSE:
loss = torch.nn.functional.mse_loss(predic, target)
return loss
|