Spaces:
Runtime error
Runtime error
from typing import Dict | |
import numpy as np | |
from omegaconf import DictConfig, ListConfig | |
import torch | |
from torch.utils.data import Dataset | |
from pathlib import Path | |
import json | |
from PIL import Image | |
from torchvision import transforms | |
from einops import rearrange | |
from ldm.util import instantiate_from_config | |
# from datasets import load_dataset | |
import os | |
from collections import defaultdict | |
from glob import glob | |
import re | |
from bisect import bisect_left, bisect_right | |
import albumentations, cv2 | |
import time | |
class SynWhiteBoardDataset(Dataset): | |
def __init__(self, | |
img_folder, | |
caption_folder, | |
tsv_info_file, | |
corpus_type = "all_4gram", | |
image_transforms=[], | |
first_stage_key = "jpg", | |
cond_stage_key = "txt", | |
postprocess=None, | |
ext = "png", | |
img_class = "whiteboard", | |
caption_type = "regular", # "simple" or "regular" or "full" | |
lower_case = False, | |
max_num = None, | |
image_size = 512, | |
do_padding = True, | |
explict_arrangement = False, | |
) -> None: | |
self.root_dir = os.path.join(Path(img_folder), corpus_type) | |
self.caption_folder = caption_folder | |
assert os.path.exists(self.caption_folder) and os.path.exists(tsv_info_file) | |
with open(tsv_info_file, "r") as f: | |
tsv_info_dict = json.loads(f.read()) | |
total_num = 0 | |
rank_list = [] | |
for _, value in tsv_info_dict.items(): | |
total_num += len(value) | |
rank_list.append(total_num) | |
self.rank_list = rank_list | |
self.total_num = total_num if max_num is None else max_num | |
self.tsv_info_dict = tsv_info_dict | |
self.corpus_type = corpus_type | |
self.first_stage_key = first_stage_key | |
self.cond_stage_key = cond_stage_key | |
# postprocess | |
if isinstance(postprocess, DictConfig): | |
postprocess = instantiate_from_config(postprocess) | |
self.postprocess = postprocess | |
# image transform | |
if isinstance(image_transforms, ListConfig): | |
image_transforms = [instantiate_from_config(tt) for tt in image_transforms] | |
image_transforms.extend([transforms.ToTensor(), # to be checked | |
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) | |
image_transforms = transforms.Compose(image_transforms) | |
self.tform = image_transforms | |
self.ext = ext | |
self.num_rank = eval((list(tsv_info_dict.keys())[0]).split("_")[-1].split(".")[0]) | |
self.img_class = img_class | |
self.caption_type = caption_type | |
self.lower_case = lower_case | |
self.do_padding = do_padding | |
self.image_rescaler = albumentations.LongestMaxSize(max_size=image_size, interpolation=cv2.INTER_AREA) | |
self.image_size = image_size | |
self.pad = albumentations.PadIfNeeded(min_height= self.image_size, min_width=self.image_size, | |
border_mode=cv2.BORDER_CONSTANT, value= (255, 255, 255), | |
) | |
self.explict_arrangement = explict_arrangement | |
def __len__(self): | |
return self.total_num | |
def __getitem__(self, index): | |
pre = time.time() | |
data = {} | |
rank = bisect_right(self.rank_list, index) | |
index_in_tsv = index - ( self.rank_list[rank-1] if rank > 0 else 0 ) | |
# rank = index % self.num_rank | |
# index_in_tsv = index // self.num_rank | |
tsv_name = "{}_{}_{}.tsv".format( | |
self.corpus_type, rank, self.num_rank | |
) | |
with open(os.path.join(self.caption_folder, tsv_name), "r") as f: | |
f.seek( | |
self.tsv_info_dict[tsv_name][index_in_tsv] | |
) | |
caption_info = f.readline().strip() | |
# print("open caption file", time.time() - pre) | |
info_list = caption_info.split("\t") | |
assert len(info_list) == 5 | |
txt_content, font_file, arrange_, align, imagename= info_list | |
# imagename= str(index) + ".{}".format(self.ext) | |
filename = os.path.join(self.root_dir, imagename) | |
img_pret = time.time() | |
try: | |
im = Image.open(filename) | |
# print("open image time", time.time() - img_pret) | |
except: | |
return self.__getitem__(np.random.choice(self.__len__())) | |
im = self.process_im(im) | |
data[self.first_stage_key] = im | |
# print("img process time", time.time() - img_pret) | |
if self.caption_type == "simple": | |
caption = 'A {} that says {}'.format( | |
self.img_class, txt_content, | |
) | |
else: | |
# elif self.caption_type == "regular": | |
font_weight = "" | |
font_style = "" | |
font_width = "" | |
font_file = re.sub(u'\\[.*?\\]',"", font_file) # remove [] | |
font_list = font_file[:-4].split("-") | |
if len(font_list) > 2: | |
print("font file name outlier: {}".format(font_file)) | |
font_list = [ | |
"-".join(font_list[:-1]), | |
font_list[-1] | |
] | |
if len(font_list) == 2: | |
font_name, font_type = font_list | |
if font_type == "VF": | |
font_style = "VF" | |
else: | |
# font_type = re.sub(u'\\[.*?\\]',"", font_type) # remove [] | |
font_tlist = re.findall("[A-Z][a-z]*", font_type) | |
if "Regular" in font_tlist: | |
font_weight = "Regular" | |
font_style = "Regular" | |
else: | |
# style | |
if "Italic" in font_tlist: | |
font_style = "Italic" | |
font_tlist.remove("Italic") | |
elif "Oblique" in font_tlist: | |
font_style = "Oblique" | |
font_tlist.remove("Oblique") | |
elif "Cursive" in font_tlist: | |
font_style = "Cursive" | |
font_tlist.remove("Cursive") | |
elif "Book" in font_tlist: | |
font_style = "Book" | |
font_tlist.remove("Book") | |
# width | |
if "Condensed" in font_tlist: | |
font_width = "Condensed" | |
font_tlist.remove("Condensed") | |
# weight | |
if len(font_tlist): | |
font_weight = " ".join(font_tlist) | |
elif len(font_list) == 1: | |
font_name = font_list[0] | |
# font_name = re.sub(u'\\[.*?\\]',"", font_name) # remove [] | |
if "Italic" in font_name: | |
font_name = font_name.replace("Italic","") | |
font_style = "Italic" | |
if "Bold" in font_name: | |
font_name = font_name.replace("Bold", "") | |
font_weight = "Bold" | |
else: | |
print("Invalid font file name: {}".format(font_file)) | |
return self.__getitem__(np.random.choice(self.__len__())) | |
# Width | |
if "Condensed" in font_name: | |
if "Extra" in font_name or "Semi" in font_name or "Ultra" in font_name: | |
font_name_list = re.findall("[A-Z][a-z]*", font_name) | |
font_width = " ".join(font_name_list[-2:]) | |
font_name = "".join(font_name_list[:-2]) | |
else: | |
font_name = font_name.rstrip("Condensed") | |
font_width = "Condensed" | |
# if "ExtraCondensed" in font_name: | |
# font_width = "Extra Condensed" | |
# elif "SemiCondensed" in font_name: | |
# font_width = "Semi Condensed" | |
# elif "UltraCondensed" in font_name: | |
# font_width = "Ultra Condensed" | |
# else: | |
# font_width = "Condensed" | |
caption = 'A {} that says {} written in the font of {}'.format( | |
self.img_class, txt_content, font_name | |
) | |
addition_cond = 0 | |
if font_weight != "": | |
font_weight = font_weight.lower() if self.lower_case else font_weight | |
caption += " {} {} stroke weight".format( | |
"with" if addition_cond == 0 else "and", font_weight | |
) | |
addition_cond += 1 | |
if font_width != "": | |
font_width = font_width.lower() if self.lower_case else font_width | |
caption += " {} {} font width".format( | |
"with" if addition_cond == 0 else "and", font_width | |
) | |
addition_cond += 1 | |
if font_style != "": | |
font_style = font_style.lower() if self.lower_case else font_style | |
caption += " {} {} font style".format( | |
"with" if addition_cond == 0 else "and", font_style | |
) | |
addition_cond += 1 | |
if self.caption_type == "full": | |
words = txt_content.strip('"').split(" ") | |
assert len(words) == 4 | |
frn, srn = arrange_.split("_") | |
frn, srn = eval(frn), eval(srn) | |
assert (frn + srn == 4 ) | |
if frn == 0 or srn == 0: | |
caption += '. All the words are written in the same row.' | |
else: | |
if self.explict_arrangement: | |
caption += '. "{}" is written in the first row while "{}" is in the second row.'.format( | |
' '.join(words[:frn]), | |
' '.join(words[frn:]) | |
) | |
else: | |
caption += '. The first {} written in the first row while the {} in the second row.'.format( | |
"{} words are".format(frn) if frn >1 else "word is", | |
"other {} words are".format(srn) if srn >1 else "last word is", | |
) | |
# print(caption) | |
# print(caption) | |
data[self.cond_stage_key] = caption | |
# if self.captions is not None: | |
# data[self.cond_stage_key] = caption | |
# else: | |
# data[self.cond_stage_key] = self.default_caption | |
if self.postprocess is not None: | |
data = self.postprocess(data) | |
# print("total time", time.time() - pre) | |
return data | |
def process_im(self, im): | |
im = im.convert("RGB") | |
if self.do_padding: | |
# pre = time.time() | |
im = self.padding_image(im) | |
# print("padding time", time.time() - pre) | |
return self.tform(im) | |
def padding_image(self, im): | |
# resize | |
im = np.array(im).astype(np.uint8) | |
im_rescaled = self.image_rescaler(image=im)["image"] | |
# padding | |
im_padded = self.pad(image=im_rescaled)["image"] | |
return im_padded | |
# im_out = Image.fromarray(im_padded) | |
# return im_out |