File size: 1,611 Bytes
2a13495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch
from torch.nn.modules.loss import _Loss
class MCCLoss(_Loss):
def __init__(self, eps: float = 1e-5):
"""Compute Matthews Correlation Coefficient Loss for image segmentation task.
It only supports binary mode.
Args:
eps (float): Small epsilon to handle situations where all the samples in the dataset belong to one class
Reference:
https://github.com/kakumarabhishek/MCC-Loss
"""
super().__init__()
self.eps = eps
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""Compute MCC loss
Args:
y_pred (torch.Tensor): model prediction of shape (N, H, W) or (N, 1, H, W)
y_true (torch.Tensor): ground truth labels of shape (N, H, W) or (N, 1, H, W)
Returns:
torch.Tensor: loss value (1 - mcc)
"""
bs = y_true.shape[0]
y_true = y_true.view(bs, 1, -1)
y_pred = y_pred.view(bs, 1, -1)
tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps
tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps
fp = torch.sum(torch.mul(y_pred, (1 - y_true))) + self.eps
fn = torch.sum(torch.mul((1 - y_pred), y_true)) + self.eps
numerator = torch.mul(tp, tn) - torch.mul(fp, fn)
denominator = torch.sqrt(
torch.add(tp, fp)
* torch.add(tp, fn)
* torch.add(tn, fp)
* torch.add(tn, fn)
)
mcc = torch.div(numerator.sum(), denominator.sum())
loss = 1.0 - mcc
return loss
|