import sys sys.path.insert(1, '.') from typing import Dict import webdataset as wds import numpy as np from omegaconf import DictConfig, ListConfig import torch from torch.utils.data import Dataset from pathlib import Path import json from PIL import Image from torchvision import transforms import torchvision from einops import rearrange from ldm.util import instantiate_from_config from datasets import load_dataset import pytorch_lightning as pl import copy import csv import cv2 import random import matplotlib.pyplot as plt from torch.utils.data import DataLoader import json import os, sys import webdataset as wds import math from torch.utils.data.distributed import DistributedSampler import glob import pickle from ldm.data.objaverse_rendered import get_rendered_objaverse_list_v0 from ldm.data.decoder import ObjaverseDataDecoder, ObjaverseDecoerWDS, nodesplitter from loguru import logger from torch import distributed as dist from tqdm import tqdm from multiprocessing.pool import ThreadPool # Some hacky things to make experimentation easier def make_transform_multi_folder_data(paths, caption_files=None, **kwargs): ds = make_multi_folder_data(paths, caption_files, **kwargs) return TransformDataset(ds) def make_nfp_data(base_path): dirs = list(Path(base_path).glob("*/")) print(f"Found {len(dirs)} folders") print(dirs) tforms = [transforms.Resize(512), transforms.CenterCrop(512)] datasets = [NfpDataset(x, image_transforms=copy.copy(tforms), default_caption="A view from a train window") for x in dirs] return torch.utils.data.ConcatDataset(datasets) class VideoDataset(Dataset): def __init__(self, root_dir, image_transforms, caption_file, offset=8, n=2): self.root_dir = Path(root_dir) self.caption_file = caption_file self.n = n ext = "mp4" self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) self.offset = offset if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms with open(self.caption_file) as f: reader = csv.reader(f) rows = [row for row in reader] self.captions = dict(rows) def __len__(self): return len(self.paths) def __getitem__(self, index): for i in range(10): try: return self._load_sample(index) except Exception: # Not really good enough but... print("uh oh") def _load_sample(self, index): n = self.n filename = self.paths[index] min_frame = 2*self.offset + 2 vid = cv2.VideoCapture(str(filename)) max_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT)) curr_frame_n = random.randint(min_frame, max_frames) vid.set(cv2.CAP_PROP_POS_FRAMES,curr_frame_n) _, curr_frame = vid.read() prev_frames = [] for i in range(n): prev_frame_n = curr_frame_n - (i+1)*self.offset vid.set(cv2.CAP_PROP_POS_FRAMES,prev_frame_n) _, prev_frame = vid.read() prev_frame = self.tform(Image.fromarray(prev_frame[...,::-1])) prev_frames.append(prev_frame) vid.release() caption = self.captions[filename.name] data = { "image": self.tform(Image.fromarray(curr_frame[...,::-1])), "prev": torch.cat(prev_frames, dim=-1), "txt": caption } return data # end hacky things def make_tranforms(image_transforms): # if isinstance(image_transforms, ListConfig): # image_transforms = [instantiate_from_config(tt) for tt in image_transforms] image_transforms = [] image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) image_transforms = transforms.Compose(image_transforms) return image_transforms def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders Don't suport captions yet If paths is a list, that's ok, if it's a Dict interpret it as: k=folder v=n_times to repeat that """ list_of_paths = [] if isinstance(paths, (Dict, DictConfig)): assert caption_files is None, \ "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): list_of_paths.extend([folder_path]*repeats) paths = list_of_paths if caption_files is not None: datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] else: datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) class NfpDataset(Dataset): def __init__(self, root_dir, image_transforms=[], ext="jpg", default_caption="", ) -> None: """assume sequential frames and a deterministic transform""" self.root_dir = Path(root_dir) self.default_caption = default_caption self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) self.tform = make_tranforms(image_transforms) def __len__(self): return len(self.paths) - 1 def __getitem__(self, index): prev = self.paths[index] curr = self.paths[index+1] data = {} data["image"] = self._load_im(curr) data["prev"] = self._load_im(prev) data["txt"] = self.default_caption return data def _load_im(self, filename): im = Image.open(filename).convert("RGB") return self.tform(im) class ObjaverseDataModuleFromConfig(pl.LightningDataModule): def __init__(self, root_dir, batch_size, train=None, validation=None, test=None, num_workers=4, objaverse_data_list=None, ext="png", target_name="albedo", use_wds=True, tar_config=None, **kwargs): super().__init__(self) self.root_dir = root_dir self.batch_size = batch_size self.num_workers = num_workers self.kwargs = kwargs self.tar_config = tar_config self.use_wds = use_wds if train is not None: dataset_config = train if validation is not None: dataset_config = validation image_transforms = [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] image_transforms = torchvision.transforms.Compose(image_transforms) self.image_transforms = { "size": dataset_config.image_transforms.size, "totensor": image_transforms } self.target_name = target_name self.objaverse_data_list = objaverse_data_list self.ext = ext def naive_setup(self): # get object data list if self.objaverse_data_list is None or \ self.objaverse_data_list["image_list_cache_path"] == "None": # This is too slow.. self.paths = sorted(list(Path(self.root_dir).rglob(f"*{self.target_name}*.{self.ext}"))) if len(self.paths) == 0: # colmap format self.paths = sorted(list(Path(self.root_dir).rglob(f"*images_train/*.*"))) else: self.paths = get_rendered_objaverse_list_v0(self.root_dir, self.target_name, self.ext, **self.objaverse_data_list) random.shuffle(self.paths) # train val split total_objects = len(self.paths) self.paths_val = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation self.paths_train = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training if self.rank == 0: print('============= length of dataset %d =============' % len(self.paths)) print('============= length of training dataset %d =============' % len(self.paths_train)) print('============= length of Validation dataset %d =============' % len(self.paths_val)) # Split into each GPU self.paths_train = self._get_local_split(self.paths_train, self.world_size, self.rank) logger.info( f"[rank {self.rank}]: {len(self.paths_train)} images assigned." ) def _get_tar_length(self, tar_list, img_per_obj): dataset_size = 0 for _name in tar_list: num_obj = int(_name.rsplit("_num_")[1].rsplit(".")[0]) dataset_size += num_obj * img_per_obj return dataset_size def webdataset_setup(self, list_dir, tar_dir, img_per_obj, max_tars=None): # read data list and calculate size tar_name_list = sorted(os.listdir(list_dir)) if not max_tars is None: # for debugging on small scale data tar_name_list = tar_name_list[:max_tars] total_tars = len(tar_name_list) # random shuffle random.shuffle(tar_name_list) print(f"Rank {self.rank} shuffle: {tar_name_list}") # train test split self.test_tars = tar_name_list[math.floor(total_tars / 100. * 99.):] # make sure each node has one tar if len(self.test_tars) < self.world_size: self.test_tars += [self.test_tars[0]]*(self.world_size-len(self.test_tars)) self.train_tars = tar_name_list[:math.floor(total_tars / 100. * 99.)] # training tar truncation total_workers = self.num_workers * self.world_size num_tars_train = (len(self.train_tars) // total_workers) * total_workers if num_tars_train != len(self.train_tars): print(f"[WARNING] Total train tars: {len(self.train_tars)}, truncated: {len(self.train_tars)-num_tars_train}, remainnig: {num_tars_train}, total workers: {total_workers}") self.test_length = self._get_tar_length(self.test_tars, img_per_obj) self.train_length = self._get_tar_length(self.train_tars, img_per_obj) # name replace test_tars = [_name.rsplit("_num")[0]+".tar" for _name in self.test_tars] self.test_tars = [os.path.join(tar_dir, _name) for _name in test_tars] train_tars = [_name.rsplit("_num")[0]+".tar" for _name in self.train_tars] self.train_tars = [os.path.join(tar_dir, _name) for _name in train_tars] if self.rank == 0: print('============= length of dataset %d =============' % (self.test_length+self.train_length)) print('============= length of training dataset %d =============' % (self.train_length)) print('============= length of Validation dataset %d =============' % (self.test_length)) def setup(self, stage=None): try: self.world_size = dist.get_world_size() self.rank = dist.get_rank() except: self.world_size = 1 self.rank = 0 if self.rank == 0: print("#### Data ####") if self.use_wds: self.webdataset_setup(**self.tar_config) else: self.naive_setup() def _get_local_split(self, items: list, world_size: int, rank: int, seed: int = 6): """The local rank only loads a split of the dataset.""" n_items = len(items) items_permute = np.random.RandomState(seed).permutation(items) if n_items % world_size == 0: padded_items = items_permute else: padding = np.random.RandomState(seed).choice( items, world_size - (n_items % world_size), replace=True ) padded_items = np.concatenate([items_permute, padding]) assert ( len(padded_items) % world_size == 0 ), f"len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}" n_per_rank = len(padded_items) // world_size local_items = padded_items[n_per_rank * rank : n_per_rank * (rank + 1)] return local_items def train_dataloader(self): if self.use_wds: loader = self.train_dataloader_wds() else: loader = self.train_dataloader_naive() return loader def val_dataloader(self): if self.use_wds: loader = self.val_dataloader_wds() else: loader = self.val_dataloader_naive() return loader def train_dataloader_naive(self): dataset = ObjaverseData(root_dir=self.root_dir, \ image_transforms=self.image_transforms, image_list = self.paths_train, target_name=self.target_name, **self.kwargs) return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) def val_dataloader_naive(self): dataset = ObjaverseData(root_dir=self.root_dir, \ image_transforms=self.image_transforms, image_list = self.paths_val, target_name=self.target_name, **self.kwargs) return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) def train_dataloader_wds(self): decoder = ObjaverseDecoerWDS(root_dir=self.root_dir, \ image_transforms=self.image_transforms, image_list = None, target_name=self.target_name, **self.kwargs) worker_batch = self.batch_size epoch_length = self.train_length // worker_batch // self.num_workers // self.world_size dataset = (wds.WebDataset(self.train_tars, shardshuffle=min(1000, len(self.train_tars)), nodesplitter=wds.shardlists.split_by_node) .shuffle(5000, initial=1000) .map(decoder.process_sample) # .map(decoder.dict2tuple) .batched(worker_batch, partial=False) # .map(decoder.tuple2dict) .map(decoder.batch_reordering) .with_epoch(epoch_length) .with_length(epoch_length) ) loader = (wds.WebLoader(dataset, batch_size=None, num_workers=self.num_workers, shuffle=False) # .unbatched() # .shuffle(1000) # .batched(self.batch_size) # .map(decoder.tuple2dict) ) print(f"# Training loader length for single worker {epoch_length} with {self.num_workers} workers") return loader def val_dataloader_wds(self): decoder = ObjaverseDecoerWDS(root_dir=self.root_dir, \ image_transforms=self.image_transforms, image_list = None, target_name=self.target_name, **self.kwargs) # adjust worker number, as test has much much fewer tars val_workers = min(self.num_workers, len(self.test_tars) // self.world_size) epoch_length = max(self.test_length // self.batch_size // val_workers // self.world_size, 1) dataset = (wds.WebDataset(self.test_tars, shardshuffle=min(1000, len(self.test_tars)), handler=wds.ignore_and_continue, nodesplitter=wds.shardlists.split_by_node) .shuffle(1000) .map(decoder.process_sample) # .map(decoder.dict2tuple) .batched(self.batch_size, partial=False) .with_epoch(epoch_length) .with_length(epoch_length) ) loader = (wds.WebLoader(dataset, batch_size=None, num_workers=val_workers, shuffle=False) .unbatched() .shuffle(1000) .batched(self.batch_size) # .map(decoder.tuple2dict) .map(decoder.batch_reordering) ) print(f"# Validation loader length for single worker {epoch_length} with {val_workers} workers") return loader def test_dataloader(self): # testing will use all given data return wds.WebLoader(ObjaverseData(root_dir=self.root_dir, test=True, image_transforms=self.image_transforms, image_list = self.paths, target_name=self.target_name, **self.kwargs), batch_size=32, num_workers=self.num_workers, shuffle=False, ) class ObjaverseData(ObjaverseDataDecoder, Dataset): def __init__(self, root_dir='.objaverse/hf-objaverse-v1/views', image_list=None, threads=64, **kargs ) -> None: """Create a dataset from blender rendering results. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) """ self.paths = image_list self.root_dir = Path(root_dir) ObjaverseDataDecoder.__init__(self, **kargs) # pre-load data print(f"Data pre loading start with {threads}...") self.all_target_im = np.zeros((len(self.paths), self.img_size, self.img_size, 3), dtype=np.uint8) + 0 self.all_cond_im = np.zeros((len(self.paths), self.img_size, self.img_size, 3), dtype=np.uint8) + 0 self.all_filename = ["empty"] * len(self.paths) if self.condition_name == "normal": self.all_normal_img = np.zeros((len(self.paths), self.img_size, self.img_size, 3), dtype=np.uint8) + 0 self.all_crop_idx = np.zeros((len(self.paths), 6), dtype=int) + 0 print("Array allocated..") def parallel_load(index): pbar.update(1) self.preload_item(index) pbar = tqdm(total=len(self.paths)) with ThreadPool(threads) as pool: pool.map(parallel_load, range(len(self.paths))) pool.close() pool.join() print("Data pre loading done...") def __len__(self): return len(self.paths) def load_mask(self, mask_filename, cond_im): # auto image file extention glob_files = glob.glob(mask_filename.rsplit(".", 1)[0] + ".*") if len(glob_files) == 0: print("Warning: no mask image find") img_mask = np.ones_like(cond_im) if cond_im.shape[-1] == 4: print("Use image mask") img_mask = img_mask * cond_im[:, :, -1:] elif len(glob_files) == 1: img_mask = np.array(self.normalized_read(glob_files[0])) else: raise NotImplementedError("Too many mask images found! {}") return img_mask def preload_item(self, index): path = self.paths[index] filename = os.path.join(path) filename, condition_filename, \ mask_filename, normal_condition_filename, filename_targets = self.path_parsing(filename) # get file streams if filename_targets is None: filename_read = filename else: filename_read = filename_targets # image reading target_im, cond_im, normal_img = self.read_images(filename_read, condition_filename, normal_condition_filename) # mask reading img_mask = self.load_mask(mask_filename, cond_im) # post processing target_im, cond_im, normal_img, crop_idx = self.image_post_processing(img_mask, target_im, cond_im, normal_img) if self.test: # crop out valid_mask self.all_crop_idx[index] = crop_idx # put results self.all_target_im[index] = target_im self.all_cond_im[index] = cond_im self.all_filename[index] = filename if self.condition_name == "normal": self.all_normal_img[index] = normal_img def get_camera(self, input_filename): camera_file = input_filename.replace(f'{self.target_name}0001', \ 'camera').rsplit(".")[0] + ".pkl" cam_dir, cam_name = camera_file.rsplit("/", 1) cam_name = f"{cam_name:>15}" camera_file = os.path.join(cam_dir, cam_name) cam = pickle.load(open(camera_file, 'rb')) return cam def __getitem__(self, index): target_im = self.process_im(self.all_target_im[index]) cond_img = self.process_im(self.all_cond_im[index]) filename = self.all_filename[index] normal_img = self.process_im(self.all_normal_img[index]) \ if self.condition_name == "normal" \ else None sample = self.parse_item(target_im, cond_img, normal_img, filename) if self.test: sample["crop_idx"] = self.all_crop_idx[index] return sample if __name__ == "__main__": import pyhocon class DictAsMember(dict): def __getattr__(self, name): value = self[name] if isinstance(value, dict): value = DictAsMember(value) return value def ConfigAsMember(config): config_dict = DictAsMember(config) for key in config_dict.keys(): if isinstance(config_dict[key], pyhocon.config_tree.ConfigTree): config_dict[key] = ConfigAsMember(config_dict[key]) return config_dict train_config = DictAsMember({ "validation": False, "image_transforms": {"size": 256} }) val_config = DictAsMember({ "validation": True, "image_transforms": {"size": 256} }) objaverse_data_list = DictAsMember({ "image_list_cache_path": "image_lists/half_400000_image_list.npz", }) data_module = ObjaverseDataModuleFromConfig(root_dir='/mnt/volumes/perception/hujunkang/codes/renders/material-diffusion/data/objaverse_rendering', batch_size=4, train=train_config, validation=val_config, test=None, num_workers=1, objaverse_data_list=objaverse_data_list, ext="png", target_name="albedo", use_wds=False, tar_config=None) data_module.setup() train_dataloader_naive = data_module.train_dataloader_naive()