File size: 1,405 Bytes
2076935
 
 
 
 
a0583df
2076935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411c19f
2076935
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from pathlib import Path

import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import v2
from transformers import AutoImageProcessor, AutoModel


class TransformDino(v2.Transform):
    def __init__(self, model_name="facebook/dinov2-base"):
        super().__init__()
        self.processor = AutoImageProcessor.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def forward(self, batch):
        model_inputs = self.processor(images=batch["features"], return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**model_inputs)
            last_hidden_states = outputs.last_hidden_state
        # extract the cls token
        batch["features"] = last_hidden_states[:, 0]
        return batch


class ImageDataset(Dataset):
    def __init__(self, metadata_path, images_root_path):
        self.metadata_path = metadata_path
        self.metadata = pd.read_csv(metadata_path)
        self.images_root_path = images_root_path

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        image_path = Path(self.images_root_path) / row.filename
        img = Image.open(image_path).convert("RGB")
        img = v2.ToTensor()(img)
        return {"features": img, "observation_id": row.observation_id}