File size: 1,793 Bytes
bd65e34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn.functional as F
import lightning as L
import torchmetrics
import timm


class LightningWrapper(L.LightningModule):
    def __init__(self, timm_model: str, num_classes: int, learning_rate: float = 1e-3):
        super().__init__()
        self.timm_model = timm_model
        self.lr = learning_rate
        self.model = timm.create_model(
            self.timm_model, pretrained=True, num_classes=num_classes
        )
        self.save_hyperparameters(ignore=["model"])

        metrics = torchmetrics.MetricCollection(
            {
                "accuracy": torchmetrics.Accuracy(
                    task="multiclass", num_classes=self.model.num_classes
                )
            }
        )

        self.train_metrics = metrics.clone(prefix="train_")
        self.val_metrics = metrics.clone(prefix="val_")

    def get_transforms(self, is_training: bool):
        data_config = timm.data.resolve_model_data_config(self.timm_model)
        return timm.data.create_transform(**data_config, is_training=is_training)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("train_loss", loss)
        self.train_metrics(logits, y)
        self.log_dict(self.train_metrics, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log("val_loss", loss)
        self.val_metrics(logits, y)
        self.log_dict(self.val_metrics, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer