|
from torch.utils import data |
|
from PIL import Image |
|
import os |
|
|
|
|
|
class Dataset(data.Dataset): |
|
'Characterizes a dataset for PyTorch' |
|
|
|
def __init__(self, path, transform=None): |
|
'Initialization' |
|
self.file_names = self.get_filenames(path) |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
'Denotes the total number of samples' |
|
return len(self.file_names) |
|
|
|
def __getitem__(self, index): |
|
'Generates one sample of data' |
|
img = Image.open(self.file_names[index]).convert('RGB') |
|
|
|
if self.transform is not None: |
|
img = self.transform(img) |
|
return img |
|
|
|
def get_filenames(self, data_path): |
|
images = [] |
|
for path, subdirs, files in os.walk(data_path): |
|
for name in files: |
|
if name.rfind('jpg') != -1 or name.rfind('png') != -1: |
|
filename = os.path.join(path, name) |
|
if os.path.isfile(filename): |
|
images.append(filename) |
|
return images |
|
|