|
from torch.utils.data import Dataset |
|
import os |
|
from torchvision.datasets.folder import default_loader |
|
import torchvision.transforms as T |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
class CommonDataset(Dataset): |
|
def __init__(self, images_path, labels_path, x_transform, y_transform): |
|
self.imgs_path = images_path |
|
self.labels_path = labels_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.x_transform = x_transform |
|
self.y_transform = y_transform |
|
|
|
|
|
def __len__(self): |
|
return len(self.imgs_path) |
|
|
|
def __getitem__(self, idx): |
|
x_path = os.path.join(self.imgs_path[idx]) |
|
y_path = os.path.join(self.labels_path[idx]) |
|
|
|
x = default_loader(x_path) |
|
|
|
y = Image.open(y_path).convert('L') |
|
|
|
x = self.x_transform(x) |
|
y = self.y_transform(y) |
|
|
|
return x, y |
|
|