Spaces:
Build error
Build error
File size: 1,387 Bytes
1865436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import transforms
from src.ss.datasets_signboard_detection.dataset import PoIDataset
import src.ss.datasets_signboard_detection.utils as utils
class POIDataModule(pl.LightningDataModule):
def __init__(self,
data_path: str,
train_batch_size=8,
test_batch_size=8,
seed=28):
super().__init__()
self.data_path = data_path
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.seed = seed
def prepare_data(self):
pass
def setup(self, stage="fit"):
transform = [transforms.ToTensor()]
test_transform = transforms.Compose(transform)
if stage == "predict" or stage is None:
self.test_dataset = PoIDataset(self.data_path,
transforms=test_transform)
def predict_dataloader(self):
if self.test_dataset is not None:
return DataLoader(self.test_dataset,
batch_size=self.test_batch_size,
shuffle=False,
num_workers=16,
collate_fn=utils.collate_fn)
def _get_name(filepath):
images = filepath
return images
|