from typing import Dict import numpy as np import torch from torch.utils.data import Dataset from pathlib import Path import json from PIL import Image from torchvision import transforms from einops import rearrange, repeat from typing import Literal, Tuple, Optional, Any import cv2 import random import json import os, sys import math from PIL import Image, ImageOps from normal_utils import worldNormal2camNormal, plot_grid_images, img2normal, norm_normalize, deg2rad import pdb from icecream import ic def shift_list(lst, n): length = len(lst) n = n % length # Ensure n is within the range of the list length return lst[-n:] + lst[:-n] class ObjaverseDataset(Dataset): def __init__(self, root_dir: str, azi_interval: float, random_views: int, predict_relative_views: list, bg_color: Any, object_list: str, prompt_embeds_path: str, img_wh: Tuple[int, int], validation: bool = False, num_validation_samples: int = 64, num_samples: Optional[int] = None, invalid_list: Optional[str] = None, trans_norm_system: bool = True, # if True, transform all normals map into the cam system of front view # augment_data: bool = False, side_views_rate: float = 0., read_normal: bool = True, read_color: bool = False, read_depth: bool = False, mix_color_normal: bool = False, random_view_and_domain: bool = False, load_cache: bool = False, exten: str = '.png', elevation_list: Optional[str] = None, ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) """ self.root_dir = root_dir self.fixed_views = int(360 // azi_interval) self.bg_color = bg_color self.validation = validation self.num_samples = num_samples self.trans_norm_system = trans_norm_system # self.augment_data = augment_data self.invalid_list = invalid_list self.img_wh = img_wh self.read_normal = read_normal self.read_color = read_color self.read_depth = read_depth self.mix_color_normal = mix_color_normal # mix load color and normal maps self.random_view_and_domain = random_view_and_domain # load normal or rgb of a single view self.random_views = random_views self.load_cache = load_cache self.total_views = int(self.fixed_views * (self.random_views + 1)) self.predict_relative_views = predict_relative_views self.pred_view_nums = len(self.predict_relative_views) self.exten = exten self.side_views_rate = side_views_rate # ic(self.augment_data) ic(self.total_views) ic(self.fixed_views) ic(self.predict_relative_views) self.objects = [] if object_list is not None: for dataset_list in object_list: with open(dataset_list, 'r') as f: # objects = f.readlines() # objects = [o.strip() for o in objects] objects = json.load(f) self.objects.extend(objects) else: self.objects = os.listdir(self.root_dir) # load fixed camera poses self.trans_cv2gl_mat = np.linalg.inv(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])) self.fix_cam_poses = [] camera_path = os.path.join(self.root_dir, self.objects[0], 'camera') for vid in range(0, self.total_views, self.random_views+1): cam_info = np.load(f'{camera_path}/{vid:03d}.npy', allow_pickle=True).item() assert cam_info['camera'] == 'ortho', 'Only support predict ortho camera !!!' self.fix_cam_poses.append(cam_info['extrinsic']) random.shuffle(self.objects) # import pdb; pdb.set_trace() invalid_objects = [] if self.invalid_list is not None: for invalid_list in self.invalid_list: if invalid_list[-4:] == '.txt': with open(invalid_list, 'r') as f: sub_invalid = f.readlines() invalid_objects.extend([o.strip() for o in sub_invalid]) else: with open(invalid_list) as f: invalid_objects.extend(json.load(f)) self.invalid_objects = invalid_objects ic(len(self.invalid_objects)) if elevation_list: with open(elevation_list, 'r') as f: ele_list = [o.strip() for o in f.readlines()] self.objects = set(ele_list) & set(self.objects) self.all_objects = set(self.objects) - (set(self.invalid_objects) & set(self.objects)) self.all_objects = list(self.all_objects) self.validation = validation if not validation: self.all_objects = self.all_objects[:-num_validation_samples] # print('Warning: you are fitting in small-scale dataset') # self.all_objects = self.all_objects else: self.all_objects = self.all_objects[-num_validation_samples:] if num_samples is not None: self.all_objects = self.all_objects[:num_samples] ic(len(self.all_objects)) print("loading ", len(self.all_objects), " objects in the dataset") self.normal_prompt_embedding = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') self.color_prompt_embedding = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') if self.mix_color_normal: self.backup_data = self.__getitem_mix__(0, '8609cf7e67bf413487a7d94c73aeaa3e') else: self.backup_data = self.__getitem_norm__(0, '8609cf7e67bf413487a7d94c73aeaa3e') def trans_cv2gl(self, rt): r, t = rt[:3, :3], rt[:3, -1] r = np.matmul(self.trans_cv2gl_mat, r) t = np.matmul(self.trans_cv2gl_mat, t) return np.concatenate([r, t[:, None]], axis=-1) def get_bg_color(self): if self.bg_color == 'white': bg_color = np.array([1., 1., 1.], dtype=np.float32) elif self.bg_color == 'black': bg_color = np.array([0., 0., 0.], dtype=np.float32) elif self.bg_color == 'gray': bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) elif self.bg_color == 'random': bg_color = np.random.rand(3) elif self.bg_color == 'three_choices': white = np.array([1., 1., 1.], dtype=np.float32) black = np.array([0., 0., 0.], dtype=np.float32) gray = np.array([0.5, 0.5, 0.5], dtype=np.float32) bg_color = random.choice([white, black, gray]) elif isinstance(self.bg_color, float): bg_color = np.array([self.bg_color] * 3, dtype=np.float32) else: raise NotImplementedError return bg_color def load_image(self, img_path, bg_color, alpha=None, return_type='np'): # not using cv2 as may load in uint16 format # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255] # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC) # pil always returns uint8 rgba = np.array(Image.open(img_path).resize(self.img_wh)) rgba = rgba.astype(np.float32) / 255. # [0, 1] img = rgba[..., :3] if alpha is None: assert rgba.shape[-1] == 4 alpha = rgba[..., 3:4] assert alpha.sum() > 1e-8, 'w/o foreground' img = img[...,:3] * alpha + bg_color * (1 - alpha) if return_type == "np": pass elif return_type == "pt": img = torch.from_numpy(img) alpha = torch.from_numpy(alpha) else: raise NotImplementedError return img, alpha def load_depth(self, img_path, bg_color, alpha, input_type='png', return_type='np'): # not using cv2 as may load in uint16 format # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255] # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC) # pil always returns uint8 img = np.array(Image.open(img_path).resize(self.img_wh)) img = img.astype(np.float32) / 65535. # [0, 1] img[img > 0.4] = 0 img = img / 0.4 assert img.ndim == 2 # depth img = np.stack([img]*3, axis=-1) if alpha.shape[-1] != 1: alpha = alpha[:, :, None] # print(np.max(img[:, :, 0])) img = img[...,:3] * alpha + bg_color * (1 - alpha) if return_type == "np": pass elif return_type == "pt": img = torch.from_numpy(img) else: raise NotImplementedError return img def load_normal(self, img_path, bg_color, alpha, RT_w2c_cond=None, return_type='np'): normal_np = np.array(Image.open(img_path).resize(self.img_wh))[:, :, :3] assert np.var(normal_np) > 1e-8, 'pure normal' normal_cv = img2normal(normal_np) normal_relative_cv = worldNormal2camNormal(RT_w2c_cond[:3, :3], normal_cv) normal_relative_cv = norm_normalize(normal_relative_cv) # normal_relative_gl = normal_relative_cv[..., [ 0, 2, 1]] # normal_relative_gl[..., 2] = -normal_relative_gl[..., 2] normal_relative_gl = normal_relative_cv normal_relative_gl[..., 1:] = -normal_relative_gl[..., 1:] img = (normal_relative_cv*0.5 + 0.5).astype(np.float32) # [0, 1] if alpha.shape[-1] != 1: alpha = alpha[:, :, None] img = img[...,:3] * alpha + bg_color * (1 - alpha) if return_type == "np": pass elif return_type == "pt": img = torch.from_numpy(img) else: raise NotImplementedError return img def __len__(self): return len(self.all_objects) def __getitem_norm__(self, index, debug_object=None): # get the bg color bg_color = self.get_bg_color() if debug_object is not None: object_name = debug_object else: object_name = self.all_objects[index % len(self.all_objects)] if self.validation: cond_ele0_idx = 12 else: rand = random.random() if rand < self.side_views_rate: # 0.1 cond_ele0_idx = random.sample([8, 0], 1)[0] elif rand < 3 * self.side_views_rate: # 0.3 cond_ele0_idx = random.sample([10, 14], 1)[0] else: cond_ele0_idx = 12 # front view cond_random_idx = random.sample(range(self.random_views+1), 1)[0] # condition info cond_ele0_vid = cond_ele0_idx * (self.random_views + 1) cond_vid = cond_ele0_vid + cond_random_idx cond_ele0_w2c = self.fix_cam_poses[cond_ele0_idx] cond_info = np.load(f'{self.root_dir}/{object_name}/camera/{cond_vid:03d}.npy', allow_pickle=True).item() cond_type = cond_info['camera'] focal_len = cond_info['focal'] cond_eles = np.array([deg2rad(cond_info['elevation'])]) img_tensors_in = [ self.load_image(f"{self.root_dir}/{object_name}/image/{cond_vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1) ] * self.pred_view_nums # output info pred_vids = [(cond_ele0_vid + i * (self.random_views+1)) % self.total_views for i in self.predict_relative_views] # pred_w2cs = [self.fix_cam_poses[(cond_ele0_idx + i) % self.fixed_views] for i in self.predict_relative_views] img_tensors_out = [] normal_tensors_out = [] for i, vid in enumerate(pred_vids): try: img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image/{vid:03d}{self.exten}", bg_color, return_type='pt') except: img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image_relit/{vid:03d}{self.exten}", bg_color, return_type='pt') img_tensor = img_tensor.permute(2, 0, 1) # (3, H, W) img_tensors_out.append(img_tensor) normal_tensor = self.load_normal(f"{self.root_dir}/{object_name}/normal/{vid:03d}{self.exten}", bg_color, alpha_.numpy(), RT_w2c_cond=cond_ele0_w2c[:3, :], return_type="pt").permute(2, 0, 1) normal_tensors_out.append(normal_tensor) img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W) img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W) normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W) elevations_cond = torch.as_tensor(cond_eles).float() if cond_type == 'ortho': focal_embed = torch.tensor([0.]) else: focal_embed = torch.tensor([24./focal_len]) if not self.load_cache: return { 'elevations_cond': elevations_cond, 'focal_cond': focal_embed, 'id': object_name, 'vid':cond_vid, 'imgs_in': img_tensors_in, 'imgs_out': img_tensors_out, 'normals_out': normal_tensors_out, 'normal_prompt_embeddings': self.normal_prompt_embedding, 'color_prompt_embeddings': self.color_prompt_embedding } def __getitem__(self, index): try: return self.__getitem_norm__(index) except: print("load error ", self.all_objects[index%len(self.all_objects)] ) return self.backup_data