File size: 472 Bytes
2680cbd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
from torch import nn
class SiLogLoss(nn.Module):
def __init__(self, lambd=0.5):
super().__init__()
self.lambd = lambd
def forward(self, pred, target, valid_mask):
valid_mask = valid_mask.detach()
diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
self.lambd * torch.pow(diff_log.mean(), 2))
return loss
|