Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch | |
import numpy as np | |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score | |
class LabelWeightedBCELoss(nn.Module): | |
""" | |
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution. | |
Allows for the weighing of each probability distribution wrt loss. | |
""" | |
def __init__(self, label_weights:torch.Tensor, reduction="mean"): | |
super().__init__() | |
self.label_weights = label_weights | |
match reduction: | |
case "mean": | |
self.reduction = torch.mean | |
case "sum": | |
self.reduction = torch.sum | |
def _log(self,x:torch.Tensor) -> torch.Tensor: | |
return torch.clamp_min(torch.log(x), -100) | |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
losses = -self.label_weights * (target * self._log(input) + (1-target) * self._log(1-input)) | |
return self.reduction(losses) | |
# TODO: Code a onehot | |
def calculate_metrics(pred, target, threshold=0.5, prefix="", multi_label=True) -> dict[str, torch.Tensor]: | |
target = target.detach().cpu().numpy() | |
pred = pred.detach().cpu().numpy() | |
params = { | |
"y_true": target if multi_label else target.argmax(1) , | |
"y_pred": np.array(pred > threshold, dtype=float) if multi_label else pred.argmax(1), | |
"zero_division": 0, | |
"average":"macro" | |
} | |
metrics= { | |
'precision': precision_score(**params), | |
'recall': recall_score(**params), | |
'f1': f1_score(**params), | |
'accuracy': accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]), | |
} | |
return {prefix + k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()} | |
class EarlyStopping: | |
def __init__(self, patience=0): | |
self.patience = patience | |
self.last_measure = np.inf | |
self.consecutive_increase = 0 | |
def step(self, val) -> bool: | |
if self.last_measure <= val: | |
self.consecutive_increase +=1 | |
else: | |
self.consecutive_increase = 0 | |
self.last_measure = val | |
return self.patience < self.consecutive_increase |