nathbotbol's picture
Upload folder using huggingface_hub
d7aea57 verified
from typing import List
import pytorch_lightning as pl
import torch
from sacred import Ingredient
from torch import nn
from torchvision import transforms
from transformers import AdamW, AutoImageProcessor, AutoModel, BitImageProcessor
siglip_ingredient = Ingredient("siglip", save_git_info=False)
# pylint: disable=unused-variable
@siglip_ingredient.config
def config():
hf_ckpt = "google/siglip-base-patch16-224"
model_path = "./models/siglip.ckpt"
ckpt = ""
learning_rate = 1e-5
class SiglipClassifier(pl.LightningModule):
@siglip_ingredient.capture
def __init__(self, hf_ckpt: str):
super().__init__()
self.vision_model = AutoModel.from_pretrained(hf_ckpt).base_model.vision_model
self.classifier = nn.Linear(768, 3)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
features = self.vision_model(x).pooler_output
logits = self.classifier(features)
return logits
# pylint: disable=arguments-differ
def training_step(self, batch):
images, labels = batch
logits = self.forward(images)
loss = self.criterion(logits, labels)
self.log("loss", loss.item(), prog_bar=True)
return loss
# pylint: disable=unused-argument
def validation_step(self, batch, batch_id):
images, labels = batch
logits = self.forward(images)
loss = self.criterion(logits, labels)
self.log("test_loss", loss)
return loss
# pylint: disable=unused-argument
def test_step(self, batch, batch_id):
images, labels = batch
logits = self.forward(images)
loss = self.criterion(logits, labels)
self.log("val_loss", loss)
return loss
# pylint: disable=arguments-differ
@siglip_ingredient.capture
def configure_optimizers(self, learning_rate):
optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)
return optimizer
@siglip_ingredient.capture
def get_siglip(model_path: str):
if model_path:
return SiglipClassifier.load_from_checkpoint(model_path)
return SiglipClassifier()
@siglip_ingredient.capture
def get_siglip_preprocessor(hf_ckpt: str):
return transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)