File size: 925 Bytes
a0583df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from .data import ImageDataset, TransformDino


class InferenceDataModel(pl.LightningDataModule):
    def __init__(
        self,
        metadata_path,
        images_root_path,
        batch_size=32,
    ):
        super().__init__()
        self.metadata_path = metadata_path
        self.images_root_path = images_root_path
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.dataloader = DataLoader(
            ImageDataset(self.metadata_path, self.images_root_path),
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
        )

    def predict_dataloader(self):
        transform = v2.Compose([TransformDino("facebook/dinov2-base")])
        for batch in self.dataloader:
            batch = transform(batch)
            yield batch