P-DFD / dataset /abstract_dataset.py
mrneuralnet's picture
Initial commit
982865f
import cv2
import torch
import numpy as np
from torchvision.datasets import VisionDataset
import albumentations
from albumentations import Compose
from albumentations.pytorch.transforms import ToTensorV2
class AbstractDataset(VisionDataset):
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None):
super(AbstractDataset, self).__init__(cfg['root'], transforms=transforms,
transform=transform, target_transform=target_transform)
# fix for re-production
np.random.seed(seed)
self.images = list()
self.targets = list()
self.split = cfg['split']
if self.transforms is None:
self.transforms = Compose(
[getattr(albumentations, _['name'])(**_['params']) for _ in cfg['transforms']] +
[ToTensorV2()]
)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
path = self.images[index]
tgt = self.targets[index]
return path, tgt
def load_item(self, items):
images = list()
for item in items:
img = cv2.imread(item)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = self.transforms(image=img)['image']
images.append(image)
return torch.stack(images, dim=0)