waidhoferj's picture
updates
0030bc6
raw
history blame
2.25 kB
import torch.nn as nn
import torch
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
class LabelWeightedBCELoss(nn.Module):
"""
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
Allows for the weighing of each probability distribution wrt loss.
"""
def __init__(self, label_weights:torch.Tensor, reduction="mean"):
super().__init__()
self.label_weights = label_weights
match reduction:
case "mean":
self.reduction = torch.mean
case "sum":
self.reduction = torch.sum
def _log(self,x:torch.Tensor) -> torch.Tensor:
return torch.clamp_min(torch.log(x), -100)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
losses = -self.label_weights * (target * self._log(input) + (1-target) * self._log(1-input))
return self.reduction(losses)
# TODO: Code a onehot
def calculate_metrics(pred, target, threshold=0.5, prefix="", multi_label=True) -> dict[str, torch.Tensor]:
target = target.detach().cpu().numpy()
pred = pred.detach().cpu().numpy()
params = {
"y_true": target if multi_label else target.argmax(1) ,
"y_pred": np.array(pred > threshold, dtype=float) if multi_label else pred.argmax(1),
"zero_division": 0,
"average":"macro"
}
metrics= {
'precision': precision_score(**params),
'recall': recall_score(**params),
'f1': f1_score(**params),
'accuracy': accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
}
return {prefix + k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}
class EarlyStopping:
def __init__(self, patience=0):
self.patience = patience
self.last_measure = np.inf
self.consecutive_increase = 0
def step(self, val) -> bool:
if self.last_measure <= val:
self.consecutive_increase +=1
else:
self.consecutive_increase = 0
self.last_measure = val
return self.patience < self.consecutive_increase