File size: 3,823 Bytes
6ad6801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# custom_early_stopping.py

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


class MultiMetricEarlyStopping(EarlyStopping):
    def __init__(self, monitor_mood, monitor_va, patience, min_delta, mode="min"):
        super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode)
        self.monitor_mood = monitor_mood
        self.monitor_va = monitor_va
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode

        # Initialize tracking variables
        self.wait_mood = 0
        self.wait_va = 0
        self.best_mood = float('inf') if mode == "min" else -float('inf')
        self.best_va = float('inf') if mode == "min" else -float('inf')

    def _check_stop(self, current, best, wait):
        if self.mode == "min" and current < best - self.min_delta:
            return current, 0
        elif self.mode == "max" and current > best + self.min_delta:
            return current, 0
        else:
            return best, wait + 1

    def on_validation_epoch_end(self, trainer, pl_module):
        logs = trainer.callback_metrics

        if self.monitor_mood not in logs or self.monitor_va not in logs:
            raise RuntimeError(f"Metrics {self.monitor_mood} or {self.monitor_va} not available.")

        # Get current values for the monitored metrics
        current_mood = logs[self.monitor_mood].item()
        current_va = logs[self.monitor_va].item()

        # Check stopping conditions for both metrics
        self.best_mood, self.wait_mood = self._check_stop(current_mood, self.best_mood, self.wait_mood)
        self.best_va, self.wait_va = self._check_stop(current_va, self.best_va, self.wait_va)

        # Stop if patience exceeded for both metrics
        if self.wait_mood > self.patience and self.wait_va > self.patience:
            self.stopped_epoch = trainer.current_epoch
            trainer.should_stop = True

# # custom_early_stopping.py

# import pytorch_lightning as pl
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# class MultiMetricEarlyStopping(EarlyStopping):
#     def __init__(self, monitor_mood: str, monitor_va: str, patience: int = 10, min_delta: float = 0.0, mode: str = "min"):
#         super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode)
#         self.monitor_mood = monitor_mood
#         self.monitor_va = monitor_va
#         self.wait_mood = 0
#         self.wait_va = 0
#         self.best_mood_score = None
#         self.best_va_score = None
#         self.patience = patience
#         self.stopped_epoch = 0

#     def on_validation_end(self, trainer, pl_module):
#         current_mood = trainer.callback_metrics.get(self.monitor_mood)
#         current_va = trainer.callback_metrics.get(self.monitor_va)

#         # Check if current_mood improved
#         if self.best_mood_score is None or self._compare(current_mood, self.best_mood_score):
#             self.best_mood_score = current_mood
#             self.wait_mood = 0
#         else:
#             self.wait_mood += 1

#         # Check if current_va improved
#         if self.best_va_score is None or self._compare(current_va, self.best_va_score):
#             self.best_va_score = current_va
#             self.wait_va = 0
#         else:
#             self.wait_va += 1

#         # If both metrics are stagnant for patience epochs, stop training
#         if self.wait_mood >= self.patience and self.wait_va >= self.patience:
#             self.stopped_epoch = trainer.current_epoch
#             trainer.should_stop = True

#     def _compare(self, current, best):
#         if self.mode == "min":
#             return current < best - self.min_delta
#         else:
#             return current > best + self.min_delta