|
from pathlib import Path |
|
|
|
import pandas as pd |
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from torchvision.transforms import v2 |
|
from transformers import AutoImageProcessor, AutoModel |
|
|
|
|
|
class TransformDino(v2.Transform): |
|
def __init__(self, model_name="facebook/dinov2-base"): |
|
super().__init__() |
|
self.processor = AutoImageProcessor.from_pretrained(model_name) |
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
def forward(self, batch): |
|
model_inputs = self.processor(images=batch["features"], return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = self.model(**model_inputs) |
|
last_hidden_states = outputs.last_hidden_state |
|
|
|
batch["features"] = last_hidden_states[:, 0] |
|
return batch |
|
|
|
|
|
class ImageDataset(Dataset): |
|
def __init__(self, metadata_path, images_root_path): |
|
self.metadata_path = metadata_path |
|
self.metadata = pd.read_csv(metadata_path) |
|
self.images_root_path = images_root_path |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
def __getitem__(self, idx): |
|
row = self.metadata.iloc[idx] |
|
image_path = Path(self.images_root_path) / row.filename |
|
img = Image.open(image_path).convert("RGB") |
|
img = v2.ToTensor()(img) |
|
return {"features": img, "observation_id": row.observation_id} |
|
|