import pytorch_lightning as pl from torch.utils.data import DataLoader from torchvision.transforms import v2 from .data import ImageDataset, TransformDino class InferenceDataModel(pl.LightningDataModule): def __init__( self, metadata_path, images_root_path, batch_size=32, ): super().__init__() self.metadata_path = metadata_path self.images_root_path = images_root_path self.batch_size = batch_size def setup(self, stage=None): self.dataloader = DataLoader( ImageDataset(self.metadata_path, self.images_root_path), batch_size=self.batch_size, shuffle=False, num_workers=4, ) def predict_dataloader(self): transform = v2.Compose([TransformDino("facebook/dinov2-base")]) for batch in self.dataloader: batch = transform(batch) yield batch