Spaces:
Running
Running
File size: 3,823 Bytes
6ad6801 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# custom_early_stopping.py
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
class MultiMetricEarlyStopping(EarlyStopping):
def __init__(self, monitor_mood, monitor_va, patience, min_delta, mode="min"):
super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode)
self.monitor_mood = monitor_mood
self.monitor_va = monitor_va
self.patience = patience
self.min_delta = min_delta
self.mode = mode
# Initialize tracking variables
self.wait_mood = 0
self.wait_va = 0
self.best_mood = float('inf') if mode == "min" else -float('inf')
self.best_va = float('inf') if mode == "min" else -float('inf')
def _check_stop(self, current, best, wait):
if self.mode == "min" and current < best - self.min_delta:
return current, 0
elif self.mode == "max" and current > best + self.min_delta:
return current, 0
else:
return best, wait + 1
def on_validation_epoch_end(self, trainer, pl_module):
logs = trainer.callback_metrics
if self.monitor_mood not in logs or self.monitor_va not in logs:
raise RuntimeError(f"Metrics {self.monitor_mood} or {self.monitor_va} not available.")
# Get current values for the monitored metrics
current_mood = logs[self.monitor_mood].item()
current_va = logs[self.monitor_va].item()
# Check stopping conditions for both metrics
self.best_mood, self.wait_mood = self._check_stop(current_mood, self.best_mood, self.wait_mood)
self.best_va, self.wait_va = self._check_stop(current_va, self.best_va, self.wait_va)
# Stop if patience exceeded for both metrics
if self.wait_mood > self.patience and self.wait_va > self.patience:
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True
# # custom_early_stopping.py
# import pytorch_lightning as pl
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# class MultiMetricEarlyStopping(EarlyStopping):
# def __init__(self, monitor_mood: str, monitor_va: str, patience: int = 10, min_delta: float = 0.0, mode: str = "min"):
# super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode)
# self.monitor_mood = monitor_mood
# self.monitor_va = monitor_va
# self.wait_mood = 0
# self.wait_va = 0
# self.best_mood_score = None
# self.best_va_score = None
# self.patience = patience
# self.stopped_epoch = 0
# def on_validation_end(self, trainer, pl_module):
# current_mood = trainer.callback_metrics.get(self.monitor_mood)
# current_va = trainer.callback_metrics.get(self.monitor_va)
# # Check if current_mood improved
# if self.best_mood_score is None or self._compare(current_mood, self.best_mood_score):
# self.best_mood_score = current_mood
# self.wait_mood = 0
# else:
# self.wait_mood += 1
# # Check if current_va improved
# if self.best_va_score is None or self._compare(current_va, self.best_va_score):
# self.best_va_score = current_va
# self.wait_va = 0
# else:
# self.wait_va += 1
# # If both metrics are stagnant for patience epochs, stop training
# if self.wait_mood >= self.patience and self.wait_va >= self.patience:
# self.stopped_epoch = trainer.current_epoch
# trainer.should_stop = True
# def _compare(self, current, best):
# if self.mode == "min":
# return current < best - self.min_delta
# else:
# return current > best + self.min_delta |