File size: 2,424 Bytes
d7aea57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import List

import pytorch_lightning as pl
import torch
from sacred import Ingredient
from torch import nn
from torchvision import transforms

from transformers import AdamW, AutoImageProcessor, AutoModel, BitImageProcessor

siglip_ingredient = Ingredient("siglip", save_git_info=False)


# pylint: disable=unused-variable
@siglip_ingredient.config
def config():
    hf_ckpt = "google/siglip-base-patch16-224"
    model_path = "./models/siglip.ckpt"
    ckpt = ""

    learning_rate = 1e-5


class SiglipClassifier(pl.LightningModule):
    @siglip_ingredient.capture
    def __init__(self, hf_ckpt: str):
        super().__init__()

        self.vision_model = AutoModel.from_pretrained(hf_ckpt).base_model.vision_model
        self.classifier = nn.Linear(768, 3)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        features = self.vision_model(x).pooler_output
        logits = self.classifier(features)
        return logits

    # pylint: disable=arguments-differ
    def training_step(self, batch):
        images, labels = batch

        logits = self.forward(images)
        loss = self.criterion(logits, labels)

        self.log("loss", loss.item(), prog_bar=True)

        return loss

    # pylint: disable=unused-argument
    def validation_step(self, batch, batch_id):
        images, labels = batch

        logits = self.forward(images)
        loss = self.criterion(logits, labels)
        self.log("test_loss", loss)

        return loss

    # pylint: disable=unused-argument
    def test_step(self, batch, batch_id):
        images, labels = batch

        logits = self.forward(images)
        loss = self.criterion(logits, labels)
        self.log("val_loss", loss)

        return loss

    # pylint: disable=arguments-differ
    @siglip_ingredient.capture
    def configure_optimizers(self, learning_rate):
        optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)
        return optimizer


@siglip_ingredient.capture
def get_siglip(model_path: str):
    if model_path:
        return SiglipClassifier.load_from_checkpoint(model_path)

    return SiglipClassifier()


@siglip_ingredient.capture
def get_siglip_preprocessor(hf_ckpt: str):
    return transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )