File size: 3,871 Bytes
797a86a
557fb53
 
797a86a
557fb53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42c4703
557fb53
 
42c4703
 
 
797a86a
 
 
557fb53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
797a86a
 
 
557fb53
 
 
 
 
 
 
 
 
 
 
42c4703
557fb53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
797a86a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import importlib
from models.utils import calculate_metrics

from abc import ABC, abstractmethod
import pytorch_lightning as pl
import torch
import torch.nn as nn


class TrainingEnvironment(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        criterion: nn.Module,
        config: dict,
        learning_rate=1e-4,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.learning_rate = config["training_environment"].get(
            "learning_rate", learning_rate
        )
        self.experiment_loggers = load_loggers(
            config["training_environment"].get("loggers", {})
        )
        self.config = config
        self.has_multi_label_predictions = (
            not type(criterion).__name__ == "CrossEntropyLoss"
        )
        self.save_hyperparameters(
            {
                "model": type(model).__name__,
                "loss": type(criterion).__name__,
                "config": config,
                **kwargs,
            }
        )

    def training_step(
        self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
    ) -> torch.Tensor:
        features, labels = batch
        outputs = self.model(features)
        loss = self.criterion(outputs, labels)
        metrics = calculate_metrics(
            outputs,
            labels,
            prefix="train/",
            multi_label=self.has_multi_label_predictions,
        )
        self.log_dict(metrics, prog_bar=True)
        experiment = self.logger.experiment
        for logger in self.experiment_loggers:
            logger.step(experiment, batch_index, features, labels)
        return loss

    def validation_step(
        self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int
    ):
        x, y = batch
        preds = self.model(x)
        metrics = calculate_metrics(
            preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
        )
        metrics["val/loss"] = self.criterion(preds, y)
        self.log_dict(metrics, prog_bar=True, sync_dist=True)

    def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
        x, y = batch
        preds = self.model(x)
        self.log_dict(
            calculate_metrics(
                preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
            ),
            prog_bar=True,
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val/loss",
        }


class ExperimentLogger(ABC):
    @abstractmethod
    def step(self, experiment, data):
        pass


class SpectrogramLogger(ExperimentLogger):
    def __init__(self, frequency=100) -> None:
        self.frequency = frequency
        self.counter = 0

    def step(self, experiment, batch_index, x, label):
        if self.counter == self.frequency:
            self.counter = 0
            img_index = torch.randint(0, len(x), (1,)).item()
            img = x[img_index][0]
            img = (img - img.min()) / (img.max() - img.min())
            experiment.add_image(
                f"batch: {batch_index}, element: {img_index}", img, 0, dataformats="HW"
            )
        self.counter += 1


def load_loggers(logger_config: dict) -> list[ExperimentLogger]:
    loggers = []
    for logger_path, kwargs in logger_config.items():
        module_name, class_name = logger_path.rsplit(".", 1)
        module = importlib.import_module(module_name)
        Logger = getattr(module, class_name)
        loggers.append(Logger(**kwargs))
    return loggers