Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
import numpy as np | |
from torchvision.datasets import VisionDataset | |
import albumentations | |
from albumentations import Compose | |
from albumentations.pytorch.transforms import ToTensorV2 | |
class AbstractDataset(VisionDataset): | |
def __init__(self, cfg, seed=2022, transforms=None, transform=None, target_transform=None): | |
super(AbstractDataset, self).__init__(cfg['root'], transforms=transforms, | |
transform=transform, target_transform=target_transform) | |
# fix for re-production | |
np.random.seed(seed) | |
self.images = list() | |
self.targets = list() | |
self.split = cfg['split'] | |
if self.transforms is None: | |
self.transforms = Compose( | |
[getattr(albumentations, _['name'])(**_['params']) for _ in cfg['transforms']] + | |
[ToTensorV2()] | |
) | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, index): | |
path = self.images[index] | |
tgt = self.targets[index] | |
return path, tgt | |
def load_item(self, items): | |
images = list() | |
for item in items: | |
img = cv2.imread(item) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
image = self.transforms(image=img)['image'] | |
images.append(image) | |
return torch.stack(images, dim=0) | |