|
import torch |
|
from abc import abstractmethod |
|
import os |
|
from PIL import Image |
|
import json |
|
|
|
class DataManager: |
|
@abstractmethod |
|
def get_data_loader(self, data_file, aug): |
|
pass |
|
|
|
|
|
class SimpleDataset: |
|
def __init__(self, data_file, transform): |
|
with open(data_file, 'r') as f: |
|
self.meta = json.load(f) |
|
self.transform = transform |
|
|
|
|
|
|
|
def __getitem__(self,i): |
|
image_path = os.path.join(self.meta['image_names'][i]) |
|
img = Image.open(image_path).convert('RGB') |
|
img = self.transform(img) |
|
target = self.target_transform(self.meta['image_labels'][i]) |
|
return img, target |
|
|
|
def __len__(self): |
|
return len(self.meta['image_names']) |
|
|
|
|
|
class SimpleDataManager(DataManager): |
|
def __init__(self, dataset, batch_size): |
|
super(SimpleDataManager, self).__init__() |
|
self.batch_size = batch_size |
|
self.dataset = dataset |
|
|
|
def get_data_loader(self): |
|
dataset = self.dataset |
|
data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 12, pin_memory = True) |
|
data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) |
|
|
|
return data_loader |