Spaces:
Runtime error
Runtime error
File size: 1,399 Bytes
8044721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
#!/usr/bin/env python3
# coding=utf-8
import torch
import torch.nn.functional as F
def masked_sum(loss, mask, label_weight=1, eps=1e-8, reduction=True):
if mask is not None:
loss = loss.masked_fill(mask, 0.0)
if reduction:
return loss.sum() / (((1 - mask.long()) * label_weight).sum() + eps)
if reduction:
return loss.mean()
return loss
def cross_entropy(log_prob, target, mask, focal=False, label_weight=None, reduction=True):
target = target.unsqueeze(-1)
if focal:
focal_coeff = log_prob.exp().gather(-1, target).squeeze(-1)
focal_coeff = (1.0 - focal_coeff) ** 2
else:
focal_coeff = 1.0
loss = -focal_coeff * log_prob.gather(-1, target).squeeze(-1)
if label_weight is not None:
loss = loss * label_weight
return masked_sum(loss, mask, label_weight=label_weight, reduction=reduction)
else:
return masked_sum(loss, mask, reduction=reduction)
def binary_cross_entropy(logits, target, mask, focal=False, reduction=True):
if focal:
prob = logits.sigmoid()
focal_coeff = target * prob + (1.0 - target) * (1.0 - prob)
focal_coeff = (1.0 - focal_coeff) ** 2
else:
focal_coeff = 1.0
loss = focal_coeff * F.binary_cross_entropy_with_logits(logits, target, reduction="none")
return masked_sum(loss, mask, reduction=reduction)
|