Anthony Miyaguchi
Remove lightning dependency from submission
a0583df
from pathlib import Path
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import v2
from transformers import AutoImageProcessor, AutoModel
class TransformDino(v2.Transform):
def __init__(self, model_name="facebook/dinov2-base"):
super().__init__()
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def forward(self, batch):
model_inputs = self.processor(images=batch["features"], return_tensors="pt")
with torch.no_grad():
outputs = self.model(**model_inputs)
last_hidden_states = outputs.last_hidden_state
# extract the cls token
batch["features"] = last_hidden_states[:, 0]
return batch
class ImageDataset(Dataset):
def __init__(self, metadata_path, images_root_path):
self.metadata_path = metadata_path
self.metadata = pd.read_csv(metadata_path)
self.images_root_path = images_root_path
def __len__(self):
return len(self.metadata)
def __getitem__(self, idx):
row = self.metadata.iloc[idx]
image_path = Path(self.images_root_path) / row.filename
img = Image.open(image_path).convert("RGB")
img = v2.ToTensor()(img)
return {"features": img, "observation_id": row.observation_id}