Spaces:
Running
Running
from typing import List | |
import pytorch_lightning as pl | |
import torch | |
from sacred import Ingredient | |
from torch import nn | |
from torchvision import transforms | |
from transformers import AdamW, AutoImageProcessor, AutoModel, BitImageProcessor | |
siglip_ingredient = Ingredient("siglip", save_git_info=False) | |
# pylint: disable=unused-variable | |
def config(): | |
hf_ckpt = "google/siglip-base-patch16-224" | |
model_path = "./models/siglip.ckpt" | |
ckpt = "" | |
learning_rate = 1e-5 | |
class SiglipClassifier(pl.LightningModule): | |
def __init__(self, hf_ckpt: str): | |
super().__init__() | |
self.vision_model = AutoModel.from_pretrained(hf_ckpt).base_model.vision_model | |
self.classifier = nn.Linear(768, 3) | |
self.criterion = nn.CrossEntropyLoss() | |
def forward(self, x): | |
features = self.vision_model(x).pooler_output | |
logits = self.classifier(features) | |
return logits | |
# pylint: disable=arguments-differ | |
def training_step(self, batch): | |
images, labels = batch | |
logits = self.forward(images) | |
loss = self.criterion(logits, labels) | |
self.log("loss", loss.item(), prog_bar=True) | |
return loss | |
# pylint: disable=unused-argument | |
def validation_step(self, batch, batch_id): | |
images, labels = batch | |
logits = self.forward(images) | |
loss = self.criterion(logits, labels) | |
self.log("test_loss", loss) | |
return loss | |
# pylint: disable=unused-argument | |
def test_step(self, batch, batch_id): | |
images, labels = batch | |
logits = self.forward(images) | |
loss = self.criterion(logits, labels) | |
self.log("val_loss", loss) | |
return loss | |
# pylint: disable=arguments-differ | |
def configure_optimizers(self, learning_rate): | |
optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate) | |
return optimizer | |
def get_siglip(model_path: str): | |
if model_path: | |
return SiglipClassifier.load_from_checkpoint(model_path) | |
return SiglipClassifier() | |
def get_siglip_preprocessor(hf_ckpt: str): | |
return transforms.Compose( | |
[ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |