fewshot_random_subspace / dataloader /simple_datamanager.py
darklord25's picture
Initial Commit
96ed3dd verified
import torch
from abc import abstractmethod
import os
from PIL import Image
import json
class DataManager:
@abstractmethod
def get_data_loader(self, data_file, aug):
pass
class SimpleDataset:
def __init__(self, data_file, transform):
with open(data_file, 'r') as f:
self.meta = json.load(f)
self.transform = transform
#self.target_transform = target_transform
def __getitem__(self,i):
image_path = os.path.join(self.meta['image_names'][i])
img = Image.open(image_path).convert('RGB')
img = self.transform(img)
target = self.target_transform(self.meta['image_labels'][i])
return img, target
def __len__(self):
return len(self.meta['image_names'])
class SimpleDataManager(DataManager):
def __init__(self, dataset, batch_size):
super(SimpleDataManager, self).__init__()
self.batch_size = batch_size
self.dataset = dataset
def get_data_loader(self): #parameters that would change on train/val set
dataset = self.dataset#SimpleDataset(data_file, transform)
data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 12, pin_memory = True)
data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
return data_loader