File size: 2,254 Bytes
4b8361a
 
 
 
 
 
0030bc6
 
 
 
4b8361a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0030bc6
 
 
 
4b8361a
 
0030bc6
 
 
 
 
 
4b8361a
0030bc6
 
 
 
4b8361a
0030bc6
 
 
 
 
 
 
4b8361a
0030bc6
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch.nn as nn
import torch
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

class LabelWeightedBCELoss(nn.Module):
    """
    Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
    Allows for the weighing of each probability distribution wrt loss.
    """
    def __init__(self, label_weights:torch.Tensor, reduction="mean"):
        super().__init__()
        self.label_weights = label_weights

        match reduction:
            case "mean":
                self.reduction = torch.mean
            case "sum":
                self.reduction = torch.sum
    
    def _log(self,x:torch.Tensor) -> torch.Tensor:
        return torch.clamp_min(torch.log(x), -100)

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        losses = -self.label_weights * (target * self._log(input) + (1-target) * self._log(1-input))
        return self.reduction(losses)


# TODO: Code a onehot


def calculate_metrics(pred, target, threshold=0.5, prefix="", multi_label=True) -> dict[str, torch.Tensor]:
    target = target.detach().cpu().numpy()
    pred = pred.detach().cpu().numpy()
    params = {
            "y_true": target if multi_label else target.argmax(1) ,
            "y_pred": np.array(pred > threshold, dtype=float) if multi_label else pred.argmax(1), 
            "zero_division": 0,
            "average":"macro"
            }
    metrics= {
            'precision': precision_score(**params),
            'recall': recall_score(**params),
            'f1': f1_score(**params),
            'accuracy': accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
            }
    return {prefix + k: torch.tensor(v,dtype=torch.float32) for k,v in metrics.items()}

class EarlyStopping:
    def __init__(self, patience=0):
        self.patience = patience
        self.last_measure = np.inf
        self.consecutive_increase = 0
    
    def step(self, val) -> bool:
        if self.last_measure <= val:
            self.consecutive_increase +=1
        else:
            self.consecutive_increase = 0
        self.last_measure = val

        return self.patience < self.consecutive_increase