import os import numpy as np import torch import torch.utils.data as data from torch.utils.data import Dataset from PIL import Image from copy import deepcopy import shutil import json def InfiniteSampler(n): """Data sampler""" i = n - 1 order = np.random.permutation(n) while True: yield order[i] i += 1 if i >= n: np.random.seed() order = np.random.permutation(n) i = 0 class InfiniteSamplerWrapper(data.sampler.Sampler): """Data sampler wrapper""" def __init__(self, data_source): self.num_samples = len(data_source) def __iter__(self): return iter(InfiniteSampler(self.num_samples)) def __len__(self): return 2 ** 31 def copy_G_params(model): flatten = deepcopy(list(p.data for p in model.parameters())) return flatten def load_params(model, new_param): for p, new_p in zip(model.parameters(), new_param): p.data.copy_(new_p) def get_dir(args): task_name = 'train_results/' + args.name saved_model_folder = os.path.join( task_name, 'models') saved_image_folder = os.path.join( task_name, 'images') os.makedirs(saved_model_folder, exist_ok=True) os.makedirs(saved_image_folder, exist_ok=True) for f in os.listdir('./'): if '.py' in f: shutil.copy(f, task_name+'/'+f) with open( os.path.join(saved_model_folder, '../args.txt'), 'w') as f: json.dump(args.__dict__, f, indent=2) return saved_model_folder, saved_image_folder class ImageFolder(Dataset): """docstring for ArtDataset""" def __init__(self, root, transform=None): super( ImageFolder, self).__init__() self.root = root self.frame = self._parse_frame() self.transform = transform def _parse_frame(self): frame = [] img_names = os.listdir(self.root) img_names.sort() for i in range(len(img_names)): image_path = os.path.join(self.root, img_names[i]) if image_path[-4:] == '.jpg' or image_path[-4:] == '.png' or image_path[-5:] == '.jpeg': frame.append(image_path) return frame def __len__(self): return len(self.frame) def __getitem__(self, idx): file = self.frame[idx] img = Image.open(file).convert('RGB') if self.transform: img = self.transform(img) return img from io import BytesIO import lmdb from torch.utils.data import Dataset class MultiResolutionDataset(Dataset): def __init__(self, path, transform, resolution=256): self.env = lmdb.open( path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False, ) if not self.env: raise IOError('Cannot open lmdb dataset', path) with self.env.begin(write=False) as txn: self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) self.resolution = resolution self.transform = transform def __len__(self): return self.length def __getitem__(self, index): with self.env.begin(write=False) as txn: key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') img_bytes = txn.get(key) #key_asp = f'aspect_ratio-{str(index).zfill(5)}'.encode('utf-8') #aspect_ratio = float(txn.get(key_asp).decode()) buffer = BytesIO(img_bytes) img = Image.open(buffer) img = self.transform(img) return img