import random |
from collections import defaultdict |
import torch |
from torch.utils.data import Dataset |
import torchvision.transforms as transforms |
import os |
import pickle |
import numpy as np |
from PIL import Image |
from pathlib import Path |
def get_dataset_path(dataset_name, height, file_suffix, datasets_path): |
if file_suffix is not None: |
filename = f'{dataset_name}-{height}-{file_suffix}.pickle' |
else: |
filename = f'{dataset_name}-{height}.pickle' |
return os.path.join(datasets_path, filename) |
def get_transform(grayscale=False, convert=True): |
transform_list = [] |
if grayscale: |
transform_list.append(transforms.Grayscale(1)) |
if convert: |
transform_list += [transforms.ToTensor()] |
if grayscale: |
transform_list += [transforms.Normalize((0.5,), (0.5,))] |
else: |
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
return transforms.Compose(transform_list) |
class TextDataset: |
def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, min_virtual_size=0, validation=False, debug=False): |
self.NUM_EXAMPLES = num_examples |
self.debug = debug |
self.min_virtual_size = min_virtual_size |
subset = 'test' if validation else 'train' |
file_to_store = open(base_path, "rb") |
self.IMG_DATA = pickle.load(file_to_store)[subset] |
self.IMG_DATA = dict(list(self.IMG_DATA.items())) |
if 'None' in self.IMG_DATA.keys(): |
del self.IMG_DATA['None'] |
self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), []))))) |
self.author_id = list(self.IMG_DATA.keys()) |
self.transform = get_transform(grayscale=True) |
self.target_transform = target_transform |
self.collate_fn = TextCollator(collator_resolution) |
def __len__(self): |
if self.debug: |
return 16 |
return max(len(self.author_id), self.min_virtual_size) |
@property |
def num_writers(self): |
return len(self.author_id) |
def __getitem__(self, index): |
index = index % len(self.author_id) |
author_id = self.author_id[index] |
self.IMG_DATA_AUTHOR = self.IMG_DATA[author_id] |
random_idxs = random.choices([i for i in range(len(self.IMG_DATA_AUTHOR))], k=self.NUM_EXAMPLES) |
word_data = random.choice(self.IMG_DATA_AUTHOR) |
real_img = self.transform(word_data['img'].convert('L')) |
real_labels = word_data['label'].encode() |
imgs = [np.array(self.IMG_DATA_AUTHOR[idx]['img'].convert('L')) for idx in random_idxs] |
slabels = [self.IMG_DATA_AUTHOR[idx]['label'].encode() for idx in random_idxs] |
max_width = 192 |
imgs_pad = [] |
imgs_wids = [] |
for img in imgs: |
img_height, img_width = img.shape[0], img.shape[1] |
output_img = np.ones((img_height, max_width), dtype='float32') * 255.0 |
output_img[:, :img_width] = img[:, :max_width] |
imgs_pad.append(self.transform(Image.fromarray(output_img.astype(np.uint8)))) |
imgs_wids.append(img_width) |
imgs_pad = torch.cat(imgs_pad, 0) |
item = { |
'simg': imgs_pad, |
'swids': imgs_wids, |
'img': real_img, |
'label': real_labels, |
'img_path': 'img_path', |
'idx': 'indexes', |
'wcl': index, |
'slabels': slabels, |
'author_id': author_id |
} |
return item |
def get_stats(self): |
char_counts = defaultdict(lambda: 0) |
total = 0 |
for author in self.IMG_DATA.keys(): |
for data in self.IMG_DATA[author]: |
for char in data['label']: |
char_counts[char] += 1 |
total += 1 |
char_counts = {k: 1.0 / (v / total) for k, v in char_counts.items()} |
return char_counts |
class TextCollator(object): |
def __init__(self, resolution): |
self.resolution = resolution |
def __call__(self, batch): |
if isinstance(batch[0], list): |
batch = sum(batch, []) |
img_path = [item['img_path'] for item in batch] |
width = [item['img'].shape[2] for item in batch] |
indexes = [item['idx'] for item in batch] |
simgs = torch.stack([item['simg'] for item in batch], 0) |
wcls = torch.Tensor([item['wcl'] for item in batch]) |
swids = torch.Tensor([item['swids'] for item in batch]) |
imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], max(width)], |
dtype=torch.float32) |
for idx, item in enumerate(batch): |
try: |
imgs[idx, :, :, 0:item['img'].shape[2]] = item['img'] |
except: |
print(imgs.shape) |
item = {'img': imgs, 'img_path': img_path, 'idx': indexes, 'simg': simgs, 'swids': swids, 'wcl': wcls} |
if 'label' in batch[0].keys(): |
labels = [item['label'] for item in batch] |
item['label'] = labels |
if 'slabels' in batch[0].keys(): |
slabels = [item['slabels'] for item in batch] |
item['slabels'] = np.array(slabels) |
if 'z' in batch[0].keys(): |
z = torch.stack([item['z'] for item in batch]) |
item['z'] = z |
return item |
class CollectionTextDataset(Dataset): |
def __init__(self, datasets, datasets_path, dataset_class, file_suffix=None, height=32, **kwargs): |
self.datasets = {} |
for dataset_name in sorted(datasets.split(',')): |
dataset_file = get_dataset_path(dataset_name, height, file_suffix, datasets_path) |
dataset = dataset_class(dataset_file, **kwargs) |
self.datasets[dataset_name] = dataset |
self.alphabet = ''.join(sorted(set(''.join(d.alphabet for d in self.datasets.values())))) |
def __len__(self): |
return sum(len(d) for d in self.datasets.values()) |
@property |
def num_writers(self): |
return sum(d.num_writers for d in self.datasets.values()) |
def __getitem__(self, index): |
for dataset in self.datasets.values(): |
if index < len(dataset): |
return dataset[index] |
index -= len(dataset) |
raise IndexError |
def get_dataset(self, index): |
for dataset_name, dataset in self.datasets.items(): |
if index < len(dataset): |
return dataset_name |
index -= len(dataset) |
raise IndexError |
def collate_fn(self, batch): |
return self.datasets[self.get_dataset(0)].collate_fn(batch) |
class FidDataset(Dataset): |
def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, mode='train', style_dataset=None): |
self.NUM_EXAMPLES = num_examples |
with open(base_path, "rb") as f: |
self.IMG_DATA = pickle.load(f) |
self.IMG_DATA = self.IMG_DATA[mode] |
if 'None' in self.IMG_DATA.keys(): |
del self.IMG_DATA['None'] |
self.STYLE_IMG_DATA = None |
if style_dataset is not None: |
with open(style_dataset, "rb") as f: |
self.STYLE_IMG_DATA = pickle.load(f) |
self.STYLE_IMG_DATA = self.STYLE_IMG_DATA[mode] |
if 'None' in self.STYLE_IMG_DATA.keys(): |
del self.STYLE_IMG_DATA['None'] |
self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), []))))) |
self.author_id = sorted(self.IMG_DATA.keys()) |
self.transform = get_transform(grayscale=True) |
self.target_transform = target_transform |
self.dataset_size = sum(len(samples) for samples in self.IMG_DATA.values()) |
self.collate_fn = TextCollator(collator_resolution) |
def __len__(self): |
return self.dataset_size |
@property |
def num_writers(self): |
return len(self.author_id) |
def __getitem__(self, index): |
sample, author_id = None, None |
for author_id, samples in self.IMG_DATA.items(): |
if index < len(samples): |
sample, author_id = samples[index], author_id |
break |
index -= len(samples) |
real_image = self.transform(sample['img'].convert('L')) |
real_label = sample['label'].encode() |
style_dataset = self.STYLE_IMG_DATA if self.STYLE_IMG_DATA is not None else self.IMG_DATA |
author_style_images = style_dataset[author_id] |
random_idxs = np.random.choice(len(author_style_images), NUM_SAMPLES, replace=True) |
style_images = [np.array(author_style_images[idx]['img'].convert('L')) for idx in random_idxs] |
max_width = 192 |
imgs_pad = [] |
imgs_wids = [] |
for img in style_images: |
img = 255 - img |
img_height, img_width = img.shape[0], img.shape[1] |
outImg = np.zeros((img_height, max_width), dtype='float32') |
outImg[:, :img_width] = img[:, :max_width] |
img = 255 - outImg |
imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8)))) |
imgs_wids.append(img_width) |
imgs_pad = torch.cat(imgs_pad, 0) |
item = { |
'simg': imgs_pad, |
'swids': imgs_wids, |
'img': real_image, |
'label': real_label, |
'img_path': 'img_path', |
'idx': sample['img_id'] if 'img_id' in sample.keys() else sample['image_id'], |
'wcl': int(author_id) |
} |
return item |
class FolderDataset: |
def __init__(self, folder_path, num_examples=15, word_lengths=None): |
folder_path = Path(folder_path) |
self.imgs = list([p for p in folder_path.iterdir() if not p.suffix == '.txt']) |
self.transform = get_transform(grayscale=True) |
self.num_examples = num_examples |
self.word_lengths = word_lengths |
def __len__(self): |
return len(self.imgs) |
def sample_style(self): |
random_idxs = np.random.choice(len(self.imgs), self.num_examples, replace=False) |
image_names = [self.imgs[idx].stem for idx in random_idxs] |
imgs = [Image.open(self.imgs[idx]).convert('L') for idx in random_idxs] |
if self.word_lengths is None: |
imgs = [img.resize((img.size[0] * 32 // img.size[1], 32), Image.BILINEAR) for img in imgs] |
else: |
imgs = [img.resize((self.word_lengths[name] * 16, 32), Image.BILINEAR) for img, name in zip(imgs, image_names)] |
imgs = [np.array(img) for img in imgs] |
max_width = 192 |
imgs_pad = [] |
imgs_wids = [] |
for img in imgs: |
img = 255 - img |
img_height, img_width = img.shape[0], img.shape[1] |
outImg = np.zeros((img_height, max_width), dtype='float32') |
outImg[:, :img_width] = img[:, :max_width] |
img = 255 - outImg |
imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8)))) |
imgs_wids.append(img_width) |
imgs_pad = torch.cat(imgs_pad, 0) |
item = { |
'simg': imgs_pad, |
'swids': imgs_wids, |
} |
return item |