File size: 1,405 Bytes
2076935 a0583df 2076935 411c19f 2076935 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from pathlib import Path
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import v2
from transformers import AutoImageProcessor, AutoModel
class TransformDino(v2.Transform):
def __init__(self, model_name="facebook/dinov2-base"):
super().__init__()
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def forward(self, batch):
model_inputs = self.processor(images=batch["features"], return_tensors="pt")
with torch.no_grad():
outputs = self.model(**model_inputs)
last_hidden_states = outputs.last_hidden_state
# extract the cls token
batch["features"] = last_hidden_states[:, 0]
return batch
class ImageDataset(Dataset):
def __init__(self, metadata_path, images_root_path):
self.metadata_path = metadata_path
self.metadata = pd.read_csv(metadata_path)
self.images_root_path = images_root_path
def __len__(self):
return len(self.metadata)
def __getitem__(self, idx):
row = self.metadata.iloc[idx]
image_path = Path(self.images_root_path) / row.filename
img = Image.open(image_path).convert("RGB")
img = v2.ToTensor()(img)
return {"features": img, "observation_id": row.observation_id}
|