Spaces:
Sleeping
Sleeping
# custom_early_stopping.py | |
import pytorch_lightning as pl | |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping | |
class MultiMetricEarlyStopping(EarlyStopping): | |
def __init__(self, monitor_mood, monitor_va, patience, min_delta, mode="min"): | |
super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode) | |
self.monitor_mood = monitor_mood | |
self.monitor_va = monitor_va | |
self.patience = patience | |
self.min_delta = min_delta | |
self.mode = mode | |
# Initialize tracking variables | |
self.wait_mood = 0 | |
self.wait_va = 0 | |
self.best_mood = float('inf') if mode == "min" else -float('inf') | |
self.best_va = float('inf') if mode == "min" else -float('inf') | |
def _check_stop(self, current, best, wait): | |
if self.mode == "min" and current < best - self.min_delta: | |
return current, 0 | |
elif self.mode == "max" and current > best + self.min_delta: | |
return current, 0 | |
else: | |
return best, wait + 1 | |
def on_validation_epoch_end(self, trainer, pl_module): | |
logs = trainer.callback_metrics | |
if self.monitor_mood not in logs or self.monitor_va not in logs: | |
raise RuntimeError(f"Metrics {self.monitor_mood} or {self.monitor_va} not available.") | |
# Get current values for the monitored metrics | |
current_mood = logs[self.monitor_mood].item() | |
current_va = logs[self.monitor_va].item() | |
# Check stopping conditions for both metrics | |
self.best_mood, self.wait_mood = self._check_stop(current_mood, self.best_mood, self.wait_mood) | |
self.best_va, self.wait_va = self._check_stop(current_va, self.best_va, self.wait_va) | |
# Stop if patience exceeded for both metrics | |
if self.wait_mood > self.patience and self.wait_va > self.patience: | |
self.stopped_epoch = trainer.current_epoch | |
trainer.should_stop = True | |
# # custom_early_stopping.py | |
# import pytorch_lightning as pl | |
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping | |
# class MultiMetricEarlyStopping(EarlyStopping): | |
# def __init__(self, monitor_mood: str, monitor_va: str, patience: int = 10, min_delta: float = 0.0, mode: str = "min"): | |
# super().__init__(monitor=None, patience=patience, min_delta=min_delta, mode=mode) | |
# self.monitor_mood = monitor_mood | |
# self.monitor_va = monitor_va | |
# self.wait_mood = 0 | |
# self.wait_va = 0 | |
# self.best_mood_score = None | |
# self.best_va_score = None | |
# self.patience = patience | |
# self.stopped_epoch = 0 | |
# def on_validation_end(self, trainer, pl_module): | |
# current_mood = trainer.callback_metrics.get(self.monitor_mood) | |
# current_va = trainer.callback_metrics.get(self.monitor_va) | |
# # Check if current_mood improved | |
# if self.best_mood_score is None or self._compare(current_mood, self.best_mood_score): | |
# self.best_mood_score = current_mood | |
# self.wait_mood = 0 | |
# else: | |
# self.wait_mood += 1 | |
# # Check if current_va improved | |
# if self.best_va_score is None or self._compare(current_va, self.best_va_score): | |
# self.best_va_score = current_va | |
# self.wait_va = 0 | |
# else: | |
# self.wait_va += 1 | |
# # If both metrics are stagnant for patience epochs, stop training | |
# if self.wait_mood >= self.patience and self.wait_va >= self.patience: | |
# self.stopped_epoch = trainer.current_epoch | |
# trainer.should_stop = True | |
# def _compare(self, current, best): | |
# if self.mode == "min": | |
# return current < best - self.min_delta | |
# else: | |
# return current > best + self.min_delta |