P-DFD / dataset /abstract_dataset.py
mrneuralnet's picture
Initial commit
982865f
raw
history blame
1.39 kB
import cv2
import torch
import numpy as np
from torchvision.datasets import VisionDataset
import albumentations
from albumentations import Compose
from albumentations.pytorch.transforms import ToTensorV2
class AbstractDataset(VisionDataset):
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None):
super(AbstractDataset, self).__init__(cfg['root'], transforms=transforms,
transform=transform, target_transform=target_transform)
# fix for re-production
np.random.seed(seed)
self.images = list()
self.targets = list()
self.split = cfg['split']
if self.transforms is None:
self.transforms = Compose(
[getattr(albumentations, _['name'])(**_['params']) for _ in cfg['transforms']] +
[ToTensorV2()]
)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
path = self.images[index]
tgt = self.targets[index]
return path, tgt
def load_item(self, items):
images = list()
for item in items:
img = cv2.imread(item)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = self.transforms(image=img)['image']
images.append(image)
return torch.stack(images, dim=0)