File size: 691 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from typing import Optional
from rstor.properties import LOSS_MSE


def compute_loss(
    predic: torch.Tensor,
    target: torch.Tensor,
    mode: Optional[str] = LOSS_MSE
) -> torch.Tensor:
    """
    Compute loss based on the predicted and true values.

    Args:
        predic (torch.Tensor): [N, C, H, W] predicted values
        target (torch.Tensor): [N, C, H, W] target values.
        mode (Optional[str], optional): mode of loss computation.

    Returns:
        torch.Tensor: The computed loss.
    """
    assert mode in [LOSS_MSE], f"Mode {mode} not supported"
    if mode == LOSS_MSE:
        loss = torch.nn.functional.mse_loss(predic, target)
    return loss