GlyphControl / ldm /data /laion.py
yyk19's picture
first trial
0902a5f
raw
history blame
27.1 kB
import os #yaml, pickle, shutil, tarfile,
from glob import glob
import cv2
import albumentations
import PIL
import numpy as np
import torchvision.transforms.functional as TF
# from omegaconf import OmegaConf
from functools import partial
from PIL import Image
from torch.utils.data import Dataset #, Subset
import pandas as pd
from torchvision import transforms
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
from skimage import io
from tqdm import tqdm
import base64
from io import BytesIO
from ldm.data.base import Txt2ImgIterableBaseDataset
import multiprocessing as mp
from bisect import bisect_left, bisect_right
import omegaconf
import time
import json
from torch.utils.data.dataloader import _get_distributed_settings
class LAIONBase(Dataset):
def __init__(self, img_folder, caption_folder=None,
recollect_data_info = False,
# indices_file = None,
first_stage_key = "jpg", cond_stage_key = "txt", do_flip = False,
size=None, degradation=None,
downscale_f=4, min_crop_f=0.5, max_crop_f=1., flip_p=0.5,
random_crop=True):
"""
LAION Dataloader
Performs following ops in order:
1. crops a crop of size s from image either as random or center crop
2. resizes crop to size with cv2.area_interpolation
# 3. degrades resized crop with degradation_fn
:param size: resizing to size after cropping
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
:param downscale_f: Low Resolution Downsample factor
:param min_crop_f: determines crop size s,
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
:param max_crop_f: ""
:param data_root:
:param random_crop:
"""
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
self.root = img_folder.split("/laion")[0]
self.images = []
self.texts = []
self.paired_data = []
self.parquet_info = {}
# self.load_from_origin_data(caption_folder)
# self.load_data(img_folder, caption_folder)
# self.load_from_parquet(img_folder)
data_info_file = os.path.join(img_folder, "data_info.json")
# if os.path.exists(data_info_file):
collect_data_info = True
if not recollect_data_info:
try:
with open(data_info_file, "r") as f:
# f.write(json.dump(self.data_info))
self.data_info = json.loads(f.read())
collect_data_info = False
except:
print(
"fail to load data info from {}".format(data_info_file)
)
if collect_data_info:
print(
"start to collect data info to {}".format(data_info_file)
)
self.data_info = []
self.load_data_par(img_folder)
with open(data_info_file, "w") as f:
f.write(json.dumps(self.data_info))
# if indices_file is None or not os.path.exists(indices_file):
# self.data_info = self.data_info[:50000]
self.data_info = self.data_info[:5000]
self.indices = range(self.__len__())
# else:
# with open(indices_file, "r") as f:
# self.indices = [int(s.strip()) for s in f.readlines()]
# return
self.do_flip = do_flip
if self.do_flip:
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
# self.base = self.get_base()
assert size
self.size = size
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
# assert (size / downscale_f).is_integer()
# self.LR_size = int(size / downscale_f)
# self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
# if degradation == "bsrgan":
# self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
# elif degradation == "bsrgan_light":
# self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
# else:
# interpolation_fn = {
# "cv_nearest": cv2.INTER_NEAREST,
# "cv_bilinear": cv2.INTER_LINEAR,
# "cv_bicubic": cv2.INTER_CUBIC,
# "cv_area": cv2.INTER_AREA,
# "cv_lanczos": cv2.INTER_LANCZOS4,
# "pil_nearest": PIL.Image.NEAREST,
# "pil_bilinear": PIL.Image.BILINEAR,
# "pil_bicubic": PIL.Image.BICUBIC,
# "pil_box": PIL.Image.BOX,
# "pil_hamming": PIL.Image.HAMMING,
# "pil_lanczos": PIL.Image.LANCZOS,
# }[degradation]
# self.pil_interpolation = degradation.startswith("pil_")
# if self.pil_interpolation:
# self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
# else:
# self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
# interpolation=interpolation_fn)
def __len__(self):
return len(self.data_info)
# if len(self.images):
# return len(self.images)
# elif len(self.paired_data):
# self.ranges = []
# num = 0
# for imgs, _ in self.paired_data:
# num += len(imgs)
# self.ranges.append(num)
# return num
def load_from_origin_data(self, folder):
# folder = "/home/v-yukangyang/data/stable-diffusion-v2/v-yukangyang/data/laion_aesthetics/laion_aesthetics_6.25+"
# folder = "/home/v-yukangyang/data/stable-diffusion-v2/v-yukangyang/data/data/nl/output_part-00000/"
valid_num = 0
store_processed_folder = os.path.join("./data", os.path.basename(folder))
if not os.path.exists(store_processed_folder):
os.makedirs(store_processed_folder)
text_list = []
image_list = []
image_path = os.path.join(store_processed_folder, "images.npy")
text_path = os.path.join(store_processed_folder, "prompts.txt")
if not os.path.exists(image_path) or not os.path.exists(text_path):
for file_name in glob(folder + "/*"):
if file_name.endswith(".parquet"):
data = pd.read_parquet(file_name)
# elif file_name.endswith(".tsv"):
# data = pd.read_csv(file_name,sep='\t')
else:
continue
for idx, row in tqdm(data.iterrows()):
try:
img= row.URL #IMAGEPATH #assumes that the df has the column IMAGEPATH
txt = row.TEXT
image = io.imread(img)
except:
continue
if len(image.shape) == 2:
image = Image.fromarray(image)
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
image_list.append(image)
text_list.append(txt)
valid_num += 1
# if idx == 128:
# break
del data
self.images = np.array(image_list)
self.texts = text_list
with open(text_path, "w") as f:
f.writelines([text + "\n" for text in text_list])
np.save(image_path, self.images)
else:
self.images = np.load(image_path,allow_pickle=True)
with open(text_path, "r") as f:
self.texts = [line.rstrip() for line in f.readlines()]
def load_from_tsv(self, image_path, original_data):
idx_list = []
with open(image_path, "r") as f:
for line_ in tqdm(f.readlines()):
list_ = line_.split("\t")
# if not list_[1].startswith("/"):
# continue
img = list_[1]
idx = int(list_[0])
idx_list.append(idx)
code_ = base64.b64decode(img) #.decode()
image = Image.open(BytesIO(code_)).convert("RGB")
image = np.array(image).astype(np.uint8)
text = original_data.iloc[idx].TEXT
self.images.append(image)
self.texts.append(text)
def load_data(self, img_folder, caption_folder):
# par_data = pd.read_parquet(caption_folder) # faster
for subfolder in glob(img_folder + "/*"):
if os.path.isdir(subfolder):
caption_path = os.path.join(
caption_folder,
os.path.basename(subfolder).lstrip("output_") + ".parquet"
)
par_data = pd.read_parquet(caption_path) # faster
# num_items = par_data.num_rows
imgstr_list = []
# for img_file in glob(subfolder + "/*.tsv"):
# # self.load_from_tsv(img_file, par_data)
# with open(img_file, "r") as f:
# imgstr_list.extend(f.readlines())
tsv_paths = glob(subfolder + "/*.tsv")
def merge_(item):
imgstr_list.extend(item)
def load_(path):
with open(path, "r") as f:
return f.readlines()
p = mp.Pool(30)
p.map_async(load_, tsv_paths, callback=merge_)
p.close()
p.join()
self.paired_data.append([imgstr_list, par_data])
del par_data
def load_from_parquet(self, parquet_path):
df = pd.read_parquet(parquet_path)
# rows = df.num_rows
print(parquet_path + " is successfully loaded")
valid_inds = list(df[df.jpg.notnull()].index)
info_lists = list(zip(
[parquet_path] * len(valid_inds), # image path
[parquet_path] * len(valid_inds), # text path
valid_inds
))
# return (list(df.jpg), list(df.caption))
# return (parquet_path, len(df))
return info_lists
# self.images = []
# self.texts = []
# for idx in range(rows):
# img = df.iloc[idx].jpg
# if img:
# image = Image.open(BytesIO(img)).convert("RGB")
# image = np.array(image).astype(np.uint8)
# text = df.iloc[idx].caption
# self.images.append(image)
# self.texts.append(text)
# del df
def merge_data(self, items):
for item in items:
# self.images.extend(item[0])
# self.texts.extend(item[1])
self.data_info.extend(item)
# self.data_info.update(dict(items))
def load_data_par(self, folder):
# if isinstance(folder, list) or isinstance(folder, omegaconf.listconfig.ListConfig):
# parquet_paths = []
# for f_ in folder:
# parquet_paths.extend(glob(f_ + "/*.parquet"))
# else:
# parquet_paths = glob(folder + "/*.parquet")
parquet_paths = []
for root, _, files in os.walk(os.path.abspath(folder)):
for file in files:
if file.endswith(".parquet"):
parquet_paths.append(os.path.join(root, file))
parquet_paths = glob(folder + "/*/*.parquet")
# parquet_paths = parquet_paths[:40]
# for parquet_path in tqdm(parquet_paths):
# df = pd.read_parquet(parquet_path)
# # self.images.extend(list(df.jpg))
# # self.texts.extend(list(df.caption))
# del df
bs = 20
iterables = [
parquet_paths[i:i + bs] for i in range(0, len(parquet_paths), bs)
]
# results = p.map_async(read_imgs,
# ["/home/v-yukangyang/data/stable-diffusion-v2/v-yukangyang/data/data/000001.tsv", "/home/v-yukangyang/data/stable-diffusion-v2/v-yukangyang/data/data/000000.tsv"])
for iterable_ in tqdm(iterables):
p = mp.Pool(20)
p.map_async(self.load_from_parquet, iterable_, callback=self.merge_data)
# p.map_async(pd.read_parquet, iterable_)
# p.join()
p.close()
p.join()
# time.sleep(2)
# def collect_par_info():
def __getitem__(self, i):
example = dict()
# example[self.first_stage_key] = np.random.randn(self.size, self.size, 3)
# example[self.cond_stage_key] = "diffusion model"
# return example
# example = self.base[i]
# # open image file
# image = Image.open(example["file_path_"])
index_ = self.indices[i]
imgfile_name, textfile_name, file_idx = self.data_info[index_]
imgfile_name = imgfile_name.replace("/scratch", self.root)
textfile_name = textfile_name.replace("/scratch", self.root)
pre_t = time.time()
if imgfile_name.endswith(".parquet"):
df = pd.read_parquet(imgfile_name)
# print("get image byte", time.time() - pre_t)
img = df.jpg.iloc[file_idx]
# print("get image byte", time.time() - pre_t)
elif imgfile_name.endswith(".tsv"):
with open(imgfile_name, "r") as f:
line_ = f.readlines()[file_idx]
file_idx, img = line_.split("\t")
img = base64.b64decode(img)
file_idx = int(file_idx)
try:
image = Image.open(BytesIO(img)).convert("RGB")
image = np.array(image).astype(np.uint8)
# print("image load", time.time() - pre_t)
except:
return self.__getitem__(np.random.randint(0, len(self.indices)))
# if isinstance()
# if self.images:
# img = self.images[index_]
# try:
# image = Image.open(BytesIO(img)).convert("RGB")
# image = np.array(image).astype(np.uint8)
# except:
# return self.__getitem__(np.random.randint(0, len(self.indices)))
# example[self.cond_stage_key] = self.texts[index_]
# elif self.paired_data:
# sec_ = bisect_right(self.ranges, index_)
# imgs, texts = self.paired_data[sec_]
# index_sec = index_ - self.ranges[sec_-1] if sec_ != 0 else index_
# line_ = imgs[index_sec].strip()
# list_ = line_.split("\t")
# img = list_[1]
# idx_ = int(list_[0])
# try:
# code_ = base64.b64decode(img) #.decode()
# image = Image.open(BytesIO(code_)).convert("RGB")
# image = np.array(image).astype(np.uint8)
# except:
# return self.__getitem__(np.random.randint(0, len(self.indices)))
# example[self.cond_stage_key] = texts[idx_]
# if len(image.shape) == 2:
# image = Image.fromarray(image)
# # if not image.mode == "RGB":
# image = image.convert("RGB")
# image = np.array(image).astype(np.uint8)
if image.shape[0] < self.size or image.shape[1] < self.size:
return self.__getitem__(np.random.randint(0, len(self.indices)))
# random crop
min_side_len = min(image.shape[:2])
# if min_side_len == 0:
# return self.__getitem__(np.random.randint(0, len(self.indices)))
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
image = self.cropper(image=image)["image"] # ?
# if min(image.shape[:2]) == 0:
# aa = 1
# rescale
image = self.image_rescaler(image=image)["image"]
# flip
if self.do_flip:
image = self.flip(Image.fromarray(image))
image = np.array(image).astype(np.uint8)
# # degradation to get the low resolution images
# if self.pil_interpolation:
# image_pil = PIL.Image.fromarray(image)
# LR_image = self.degradation_process(image_pil)
# LR_image = np.array(LR_image).astype(np.uint8)
# else:
# LR_image = self.degradation_process(image=image)["image"]
# # store to example
# example["image"] = (image/127.5 - 1.0).astype(np.float32) #[-1, 1]
# example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) #[-1, 1]
example[self.first_stage_key] = (image/127.5 - 1.0).astype(np.float32)
# print("image", time.time() - pre_t)
pre_t = time.time()
if imgfile_name != textfile_name:
if textfile_name.endswith(".parquet"):
df = pd.read_parquet(textfile_name)
else:
print(
"the format {} of the text file is not supported".format(
os.path.splitext(imgfile_name)[1]
)
)
raise ValueError
try:
text = df.TEXT.iloc[file_idx]
except:
try:
text = df.caption.iloc[file_idx]
except:
raise ValueError
example[self.cond_stage_key] = text
# Sprint("text (text load)", time.time() - pre_t)
return example
class LAIONTrain(LAIONBase):
def __init__(self, store_folder, *args, ratio=0.7, **kwargs):
super().__init__(*args, **kwargs)
# store_folder = os.path.join("./data", os.path.basename(folder))
if not os.path.exists(os.path.join(store_folder, "train.txt")):
rand_inds = np.random.permutation(self.__len__())
self.indices = rand_inds[:int(len(rand_inds) * ratio)]
rand_inds = [str(i) + "\n" for i in rand_inds]
with open(os.path.join(store_folder, "train.txt"), "w") as f:
f.writelines(rand_inds[:int(len(rand_inds) * ratio)])
with open(os.path.join(store_folder, "val.txt"), "w") as f:
f.writelines(rand_inds[int(len(rand_inds) * ratio):])
else:
with open(os.path.join(store_folder, "train.txt"), "r") as f:
self.indices = [int(s.strip()) for s in f.readlines()]
# def get_base(self):
# with open("data/imagenet_train_hr_indices.p", "rb") as f:
# indices = pickle.load(f)
# dset = ImageNetTrain(process_images=False,)
# return Subset(dset, indices)
class LAIONValidation(LAIONBase):
def __init__(self, store_folder, *args, **kwargs):
super().__init__(*args, **kwargs)
# store_folder = os.path.join("./data", os.path.basename(folder))
if not os.path.exists(os.path.join(store_folder, "val.txt")):
raise ValueError
else:
with open(os.path.join(store_folder, "val.txt"), "r") as f:
self.indices = [int(s.strip()) for s in f.readlines()]
# def get_base(self):
# with open("data/imagenet_val_hr_indices.p", "rb") as f:
# indices = pickle.load(f)
# dset = ImageNetValidation(process_images=False,)
# return Subset(dset, indices)
class LAIONIterableBaseDataset(Txt2ImgIterableBaseDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, img_folder, caption_folder=None, size=256,
first_stage_key = "jpg", cond_stage_key = "txt", do_flip = False,
min_crop_f=0.5, max_crop_f=1., flip_p=0.5,
random_crop=True):
assert size
super().__init__(size=size)
self.caption_folder = caption_folder
if self.caption_folder:
# self.origin_folders = glob(img_folder + "/*/") # "output_part_000000"
self.valid_ids = glob(img_folder + "/*/")
self.origin_tsv_paths = {
subfolder: glob(subfolder + "/*.tsv") for subfolder in self.valid_ids #self.origin_folders
}
num = 0
self.tsv_folder_idx = {}
# self.tsv_nums = []
for key, value in self.origin_tsv_paths.items():
num += len(value)
self.tsv_folder_idx[num] = key
# self.tsv_nums.append(num)
# self.num_records = len(self.origin_folders)
# self.folders = self.origin_folders
self.num_records = len(self.valid_ids)
self.sample_ids = self.valid_ids
self.tsv_paths = self.origin_tsv_paths # to be deprecated
self.max_num = self.num_records * 100000
else:
parquet_paths = []
for root, _, files in os.walk(os.path.abspath(img_folder)):
for file in files:
if file.endswith(".parquet"):
parquet_paths.append(os.path.join(root, file))
parquet_paths = parquet_paths[:170]
# self.origin_parquet_paths = parquet_paths
# self.parquet_paths = self.origin_parquet_paths
# self.num_records = len(parquet_paths)
self.valid_ids = parquet_paths
self.sample_ids = self.valid_ids
self.num_records = len(self.valid_ids)
self.max_num = self.num_records * 1000
self.first_stage_key = first_stage_key
self.cond_stage_key = cond_stage_key
# self.num_records = len(self.folders)
# self.num_records = np.sum([
# len(value_) for value_ in self.tsv_paths.values()
# ])
self.do_flip = do_flip
if self.do_flip:
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
# self.base = self.get_base()
# self.size = size
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
# self.num_records = num_records
# self.valid_ids = valid_ids
# print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
# def __len__(self):
# # return self.num_records
# return self.max_num
def __iter__(self):
if self.caption_folder:
return self.parquet_tsv_iter()
else:
return self.parquet_iter()
def parquet_iter(self):
print("this shard on GPU {}: {}".format(_get_distributed_settings()[1], len(self.sample_ids)))
idx = 0
while idx >= 0:
for parqut_path in self.sample_ids: #parquet_paths:
df = pd.read_parquet(parqut_path)
for file_idx in range(len(df)):
img_code = df.jpg.iloc[file_idx]
if img_code:
try:
image = self.generate_img(img_code)
except:
# print("can' t open")
continue
if image is None:
continue
# except:
# continue
try:
text = df.caption.iloc[file_idx]
except:
try:
text = df.TEXT.iloc[file_idx]
except:
continue
if text is None:
continue
example = {}
example[self.first_stage_key] = image
example[self.cond_stage_key] = text
yield example
del df
print("has gone over the whole dataset, need to start next round")
idx += 1
def parquet_tsv_iter(self):
for subfolder in self.sample_ids: #folders:
caption_path = os.path.join(
self.caption_folder,
os.path.basename(subfolder).lstrip("output_") + ".parquet"
)
par_data = pd.read_parquet(caption_path) # faster
for image_path in self.tsv_paths[subfolder]:
with open(image_path, "r") as f:
for line_ in tqdm(f.readlines()): # shuffle could be done
# line_ = f.readline()
idx, img_code = line_.split("\t")
# if not list_[1].startswith("/"):
# continue
try:
img_code = base64.b64decode(img_code) #.decode()
image = self.generate_img(img_code)
if not image:
continue
except:
continue
example = dict()
example[self.first_stage_key] = image
idx = int(idx)
text = par_data.iloc[idx].TEXT
example[self.cond_stage_key] = text
yield example
del par_data
def generate_img(self, img_code):
image = Image.open(BytesIO(img_code)).convert("RGB")
image = np.array(image).astype(np.uint8)
if image.shape[0] < self.size or image.shape[1] < self.size:
return None
# crop
min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
image = self.cropper(image=image)["image"] # ?
# rescale
image = self.image_rescaler(image=image)["image"]
# flip
if self.do_flip:
image = self.flip(Image.fromarray(image))
image = np.array(image).astype(np.uint8)
return (image/127.5 - 1.0).astype(np.float32)
# pass