Spaces:
Runtime error
Runtime error
import os | |
from glob import glob | |
import cv2 | |
import albumentations | |
import numpy as np | |
from PIL import Image | |
import pandas as pd | |
from torchvision import transforms | |
# from skimage import io | |
from tqdm import tqdm | |
import base64 | |
from io import BytesIO | |
# from ldm.data.base import Txt2ImgIterableBaseDataset | |
from torch.utils.data.dataloader import _get_distributed_settings | |
# from abc import abstractmethod | |
# from torch.utils.data import IterableDataset | |
import clip | |
import subprocess | |
from ldm.data.base import Txt2ImgIterableBaseDataset | |
import tempfile | |
class LAIONIterableBaseDataset(Txt2ImgIterableBaseDataset): | |
''' | |
Load laion dataset into the IterableDatasets class | |
''' | |
def __init__(self, img_folder, caption_folder=None, img_txt_same_file = False, | |
blob_folder=None, sas_token =None, | |
max_num_records = 128, max_num_tsv_per_record = 182, tsv_patch_size = 10, start_tsv_idx=None, | |
do_azcopy=False, | |
remove_data_from_cluster=False, | |
size=256, | |
first_stage_key = "jpg", cond_stage_key = "txt", | |
clip_model = None, preprocess = None, | |
do_flip = False, min_crop_f=0.5, max_crop_f=1., flip_p=0.5, random_crop=True): | |
assert size | |
super().__init__(size=size) | |
self.img_folder = img_folder | |
self.caption_folder = caption_folder | |
self.img_txt_same_file = img_txt_same_file | |
if not self.img_txt_same_file: | |
# blob info | |
self.blob_folder = blob_folder | |
self.sas_token = sas_token | |
self.image_blob_name = os.path.basename(img_folder) | |
self.caption_blob_name = os.path.basename(caption_folder) | |
self.remove_data_from_cluster = remove_data_from_cluster if do_azcopy else False | |
self.do_azcopy = do_azcopy | |
self.max_num_tsv_per_record = max_num_tsv_per_record | |
self.tsv_patch_size = tsv_patch_size | |
self.start_tsv_idx = int(self.tsv_patch_size / 2) if start_tsv_idx is None else start_tsv_idx | |
if self.start_tsv_idx >= self.tsv_patch_size:# or self.start_tsv_idx < 1: | |
print("wrongly set the data download time") | |
raise ValueError | |
if self.caption_folder: | |
# try: | |
if self.do_azcopy: | |
# except: | |
self.valid_ids = [ | |
os.path.join(img_folder, "output_part-" + "{:0>5d}".format(i)) for i in range(max_num_records) | |
] | |
# self.valid_ids = [ | |
# os.path.join(img_folder, "output_part-" + "{:0>5d}".format(i)) for i in [4,5] #[4,5] | |
# ] | |
else: | |
self.valid_ids = [folder.rstrip("/") for folder in glob(img_folder + "/*/")] | |
self.num_records = len(self.valid_ids) | |
if not self.num_records: | |
print("zero data records, please check the data path") | |
raise ValueError | |
self.sample_ids = self.valid_ids | |
self.max_num = self.num_records * 100000 * self.max_num_tsv_per_record | |
else: | |
print("should provide caption folder") | |
raise ValueError | |
else: | |
parquet_paths = [] | |
for root, _, files in os.walk(os.path.abspath(img_folder)): | |
for file in files: | |
if file.endswith(".parquet"): | |
parquet_paths.append(os.path.join(root, file)) | |
# parquet_paths = parquet_paths[:32] | |
# self.origin_parquet_paths = parquet_paths | |
# self.parquet_paths = self.origin_parquet_paths | |
# self.num_records = len(parquet_paths) | |
self.valid_ids = parquet_paths | |
self.sample_ids = self.valid_ids | |
self.num_records = len(self.valid_ids) | |
self.max_num = self.num_records * 1000 | |
self.first_stage_key = first_stage_key | |
self.cond_stage_key = cond_stage_key | |
self.preprocess = None | |
if preprocess is not None: | |
self.preprocess = preprocess | |
else: | |
if clip_model is not None: # "ViT-L/14" | |
_, self.preprocess = clip.load(clip_model) #, device=self.device) # RN50x64 | |
self.do_flip = do_flip | |
if self.do_flip: | |
self.flip = transforms.RandomHorizontalFlip(p=flip_p) | |
self.min_crop_f = min_crop_f | |
self.max_crop_f = max_crop_f | |
assert(max_crop_f <= 1.) | |
self.center_crop = not random_crop | |
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) | |
# def __len__(self): | |
# # return self.num_records | |
# return self.max_num | |
def __iter__(self): | |
# if self.caption_folder: | |
if self.img_txt_same_file: | |
return self.parquet_iter() | |
else: | |
return self.parquet_tsv_iter() | |
# else: | |
# return self.parquet_iter() | |
def parquet_iter(self): | |
print("this shard on GPU {}: {}".format(_get_distributed_settings()[1], len(self.sample_ids))) | |
idx = 0 | |
while idx >= 0: | |
for parqut_path in self.sample_ids: #parquet_paths: | |
df = pd.read_parquet(parqut_path) | |
for file_idx in range(len(df)): | |
img_code = df.jpg.iloc[file_idx] | |
if img_code: | |
try: | |
image = self.generate_img(img_code) | |
except: | |
# print("can' t open") | |
continue | |
if image is None: | |
continue | |
# except: | |
# continue | |
try: | |
text = df.caption.iloc[file_idx] | |
except: | |
try: | |
text = df.TEXT.iloc[file_idx] | |
except: | |
continue | |
if text is None: | |
continue | |
example = {} | |
example[self.first_stage_key] = image | |
example[self.cond_stage_key] = text | |
yield example | |
del df | |
print("has gone over the whole dataset, need to start next round") | |
idx += 1 | |
def parquet_tsv_iter(self): | |
print("this shard on GPU {}: {}".format(_get_distributed_settings()[1], len(self.sample_ids))) | |
idx = 0 | |
# first_part = True | |
while idx >= 0: | |
for subfolder in self.sample_ids: #folders: | |
parquet_name = os.path.basename(subfolder).split("output_")[1] | |
caption_path = os.path.join( | |
self.caption_folder, | |
parquet_name + ".parquet" | |
) | |
if self.do_azcopy: | |
tsv_paths = [ | |
os.path.join(subfolder, "{:0>6d}.tsv".format(i)) for i in range(self.max_num_tsv_per_record) | |
] | |
# tsv_paths = self.check_and_download(caption_path, tsv_paths, subfolder, parquet_name, first_part = first_part) | |
self.download_data(caption_path, tsv_paths[:self.tsv_patch_size], subfolder, parquet_name, first_part=True) | |
download_time = 1 | |
else: | |
tsv_paths = glob(subfolder + "/*.tsv") | |
par_data = pd.read_parquet(caption_path) # faster | |
# for image_path in self.tsv_paths[subfolder]: | |
for rank, image_path in enumerate(tsv_paths): | |
print("start opening {}".format(image_path)) | |
with open(image_path, "r") as f: | |
# for line_ in tqdm(f.readlines()): | |
lines = f.readlines() | |
print("successfully open and read {}".format(image_path)) | |
if self.remove_data_from_cluster: | |
self.remove_data(image_path) | |
if self.do_azcopy and rank == self.start_tsv_idx + (download_time-1) * self.tsv_patch_size: | |
self.download_data( | |
caption_path, | |
tsv_paths[self.tsv_patch_size * download_time: self.tsv_patch_size * (download_time + 1)], | |
subfolder, parquet_name, | |
first_part=False | |
) | |
download_time += 1 | |
print("download time: {}".format(download_time)) | |
# for line_ in f.readlines(): | |
for i, line_ in enumerate(lines): | |
# print("the {}th line".format(i)) | |
# line_ = f.readline() | |
idx, img_code = [str_.strip() for str_ in line_.split("\t")] | |
# if not list_[1].startswith("/"): | |
# continue | |
try: | |
# img_code = base64.b64decode(img_code) #.decode() | |
# image = self.generate_img(img_code) | |
image = self.generate_img(base64.b64decode(img_code)) | |
if image is None: | |
continue | |
except: | |
continue | |
example = dict() | |
example[self.first_stage_key] = image | |
# idx = int(idx) | |
text = par_data.iloc[int(idx)].TEXT | |
example[self.cond_stage_key] = text | |
example["data"] = "\t".join([ | |
parquet_name, | |
idx, | |
img_code, | |
text | |
]) | |
yield example | |
# if i == 70000: | |
# break | |
del par_data | |
if self.remove_data_from_cluster: | |
self.remove_data(caption_path) | |
# if self.remove_data_from_cluster: | |
# self.remove_data(caption_path) | |
# self.remove_data(subfolder) | |
print("has gone over the whole dataset, need to start next round") | |
idx += 1 | |
def generate_img(self, img_code): | |
image = Image.open(BytesIO(img_code)) | |
if self.preprocess: | |
# pil_image = Image.open(img_path) | |
image = self.preprocess(image)#.unsqueeze(0)#.to(device) | |
return image | |
else: | |
image = image.convert("RGB") | |
image = np.array(image).astype(np.uint8) | |
if image.shape[0] < self.size or image.shape[1] < self.size: | |
return None | |
# crop | |
min_side_len = min(image.shape[:2]) | |
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) | |
crop_side_len = int(crop_side_len) | |
if self.center_crop: | |
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) | |
else: | |
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) | |
image = self.cropper(image=image)["image"] # ? | |
# rescale | |
image = self.image_rescaler(image=image)["image"] | |
# flip | |
if self.do_flip: | |
image = self.flip(Image.fromarray(image)) | |
image = np.array(image).astype(np.uint8) | |
return (image/127.5 - 1.0).astype(np.float32) | |
def check_and_download(self, caption_path, tsv_paths, subfolder, parquet_name): | |
if not os.path.exists(caption_path): | |
try: | |
os.makedirs(self.caption_folder, exist_ok=True) | |
self.azcopy_from_blob( | |
self.caption_blob_name, | |
parquet_name + ".parquet", | |
self.caption_folder, | |
) | |
except: | |
print("fail to download caption file from blob!") | |
raise ValueError | |
if not len(tsv_paths): | |
try: | |
os.makedirs(self.img_folder, exist_ok=True) | |
self.azcopy_from_blob( | |
self.image_blob_name, | |
os.path.basename(subfolder), | |
self.img_folder, | |
) | |
return glob(subfolder + "/*.tsv") | |
except: | |
print("fail to download image tsv file from blob!") | |
raise ValueError | |
return tsv_paths | |
def download_data(self, caption_path, tsv_paths, subfolder, parquet_name, first_part=True): | |
if not os.path.exists(caption_path) and first_part: | |
try: | |
os.makedirs(self.caption_folder, exist_ok=True) | |
self.azcopy_from_blob( | |
self.caption_blob_name, | |
parquet_name + ".parquet", | |
self.caption_folder, | |
first_part=first_part, | |
) | |
except: | |
print("fail to download caption file from blob!") | |
raise ValueError | |
os.makedirs(subfolder, exist_ok=True) | |
for tsv_path in tsv_paths: | |
if not os.path.exists(tsv_path): | |
try: | |
self.azcopy_from_blob( | |
self.image_blob_name, | |
os.path.join(os.path.basename(subfolder), os.path.basename(tsv_path)), | |
subfolder, | |
first_part=first_part, | |
) | |
# return glob(subfolder + "/*.tsv") | |
except: | |
print("fail to download image tsv file from blob to {}!".format(tsv_path)) | |
raise ValueError | |
# return tsv_paths | |
def azcopy_from_blob(self, subfolder = "laion-5b", name = "output_part-00005", destination = "/scratch", first_part=True): | |
command = 'sudo azcopy cp ' | |
if self.blob_folder is None: | |
print("The blob storage for laion data is not provided!") | |
raise ValueError | |
if self.sas_token is None: | |
print("The sas token for laion data is not provided!") | |
raise ValueError | |
file = self.blob_folder + "/" + subfolder + "/" + name | |
# file = "https://itpsea4data.blob.core.windows.net/v-yukangyang/data/data/laion-5b/output_part-00005" | |
# sas_token = "?sv=2021-08-06&st=2023-01-05T06%3A47%3A56Z&se=2023-01-11T06%3A47%3A00Z&sr=c&sp=racwl&sig=aAHHp4NhaVWuR7lnhT8GJqZicWvbQia%2FflKmoly4x0A%3D" | |
# sas_token = "?sv=2021-08-06&st=2023-01-05T06%3A17%3A31Z&se=2023-01-06T06%3A17%3A31Z&sr=c&sp=raccl&sig=0gRoqwgEqeDzZHchhduf9N9jVHLzAnX5iPC%2FOb%2F%2Bk9Q%3D" | |
# destination = "/scratch" | |
# sas_token = "?sv=2021-08-06&st=2023-01-05T06%3A17%3A31Z&se=2023-01-06T06%3A17%3A31Z&sr=c&sp=raccl&sig=0gRoqwgEqeDzZHchhduf9N9jVHLzAnX5iPC%2FOb%2F%2Bk9Q%3D" | |
# file_str = '"' + file + self.sas_token + '"' | |
file_str = file + self.sas_token | |
command_line = command + file_str + ' ' + destination + ' --recursive' | |
command_list = command_line.split(" ") | |
if first_part: | |
subprocess.call( | |
command_list | |
) | |
print("azcopy {} successfully!".format(file)) | |
else: | |
# os.popen(command_line) | |
# out_temp = tempfile.SpooledTemporaryFile(bufsize=10*1000) | |
# with tempfile.SpooledTemporaryFile() as out_temp: | |
# fileno = out_temp.fileno() | |
# p = subprocess.Popen(command_list, stdout=fileno, stderr=fileno, close_fds=True, shell=True) | |
# p.communicate() | |
# p = subprocess.Popen(command_list, close_fds=True, shell=True) | |
# p.communicate() | |
# p = subprocess.Popen(command_list, close_fds=True) | |
# p.communicate() | |
subprocess.Popen(command_list) | |
# p = subprocess.Popen(command_list, close_fds=True, stdout=subprocess.PIPE) | |
print("start downloading {}".format(file)) | |
# for line in iter(p.stdout.readline, b''): | |
# print(line) | |
# # print to stdout immediately | |
# p.stdout.close() | |
def remove_data(self, file = "/scratch/output_part-00005"): | |
command = "sudo rm -rf " | |
command_list = (command + file).split(" ") | |
subprocess.call( | |
command_list | |
) | |
print("remove {} from the cluster successfully!".format(file)) | |