File size: 2,732 Bytes
4b8361a
 
 
 
 
557fb53
4b8361a
0030bc6
 
 
 
557fb53
 
4b8361a
 
 
 
 
 
 
 
557fb53
 
4b8361a
 
 
557fb53
 
 
4b8361a
 
 
0030bc6
 
 
557fb53
 
 
4b8361a
1c22425
 
 
0030bc6
557fb53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0030bc6
 
 
 
 
 
557fb53
0030bc6
 
557fb53
0030bc6
 
 
 
557fb53
 
 
 
 
 
 
 
 
 
 
 
e748bc2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch.nn as nn
import torch
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score


class LabelWeightedBCELoss(nn.Module):
    """
    Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
    Allows for the weighing of each probability distribution wrt loss.
    """

    def __init__(self, label_weights: torch.Tensor, reduction="mean"):
        super().__init__()
        self.label_weights = label_weights

        match reduction:
            case "mean":
                self.reduction = torch.mean
            case "sum":
                self.reduction = torch.sum

    def _log(self, x: torch.Tensor) -> torch.Tensor:
        return torch.clamp_min(torch.log(x), -100)

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        losses = -self.label_weights * (
            target * self._log(input) + (1 - target) * self._log(1 - input)
        )
        return self.reduction(losses)


# TODO: Code a onehot


def calculate_metrics(
    pred, target, threshold=0.5, prefix="", multi_label=True
) -> dict[str, torch.Tensor]:
    target = target.detach().cpu().numpy()
    pred = pred.detach().cpu()
    pred = nn.functional.softmax(pred, dim=1)
    pred = pred.numpy()
    params = {
        "y_true": target if multi_label else target.argmax(1),
        "y_pred": np.array(pred > threshold, dtype=float)
        if multi_label
        else pred.argmax(1),
        "zero_division": 0,
        "average": "macro",
    }
    metrics = {
        "precision": precision_score(**params),
        "recall": recall_score(**params),
        "f1": f1_score(**params),
        "accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
    }
    return {
        prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
    }


class EarlyStopping:
    def __init__(self, patience=0):
        self.patience = patience
        self.last_measure = np.inf
        self.consecutive_increase = 0

    def step(self, val) -> bool:
        if self.last_measure <= val:
            self.consecutive_increase += 1
        else:
            self.consecutive_increase = 0
        self.last_measure = val

        return self.patience < self.consecutive_increase


def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
    id2label = {str(i): label for i, label in enumerate(labels)}
    label2id = {label: str(i) for i, label in enumerate(labels)}

    return id2label, label2id


def compute_hf_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)