from typing import List import pytorch_lightning as pl import torch from sacred import Ingredient from torch import nn from torchvision import transforms from transformers import AdamW, AutoImageProcessor, AutoModel, BitImageProcessor siglip_ingredient = Ingredient("siglip", save_git_info=False) # pylint: disable=unused-variable @siglip_ingredient.config def config(): hf_ckpt = "google/siglip-base-patch16-224" model_path = "./models/siglip.ckpt" ckpt = "" learning_rate = 1e-5 class SiglipClassifier(pl.LightningModule): @siglip_ingredient.capture def __init__(self, hf_ckpt: str): super().__init__() self.vision_model = AutoModel.from_pretrained(hf_ckpt).base_model.vision_model self.classifier = nn.Linear(768, 3) self.criterion = nn.CrossEntropyLoss() def forward(self, x): features = self.vision_model(x).pooler_output logits = self.classifier(features) return logits # pylint: disable=arguments-differ def training_step(self, batch): images, labels = batch logits = self.forward(images) loss = self.criterion(logits, labels) self.log("loss", loss.item(), prog_bar=True) return loss # pylint: disable=unused-argument def validation_step(self, batch, batch_id): images, labels = batch logits = self.forward(images) loss = self.criterion(logits, labels) self.log("test_loss", loss) return loss # pylint: disable=unused-argument def test_step(self, batch, batch_id): images, labels = batch logits = self.forward(images) loss = self.criterion(logits, labels) self.log("val_loss", loss) return loss # pylint: disable=arguments-differ @siglip_ingredient.capture def configure_optimizers(self, learning_rate): optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate) return optimizer @siglip_ingredient.capture def get_siglip(model_path: str): if model_path: return SiglipClassifier.load_from_checkpoint(model_path) return SiglipClassifier() @siglip_ingredient.capture def get_siglip_preprocessor(hf_ckpt: str): return transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] )