import json import logging import os import random from typing import Any, Callable, Dict, List, Tuple, Union import torch import torch.utils.checkpoint from PIL import Image from torch import nn from torch.utils.data import Dataset from torchvision import transforms logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) logger = logging.getLogger(__name__) __all__ = [ "Asset3dGenDataset", ] class Asset3dGenDataset(Dataset): def __init__( self, index_file: str, target_hw: Tuple[int, int], transform: Callable = None, control_transform: Callable = None, max_train_samples: int = None, sub_idxs: List[List[int]] = None, seed: int = 79, ) -> None: if not os.path.exists(index_file): raise FileNotFoundError(f"{index_file} index_file not found.") self.index_file = index_file self.target_hw = target_hw self.transform = transform self.control_transform = control_transform self.max_train_samples = max_train_samples self.meta_info = self.prepare_data_index(index_file) self.data_list = sorted(self.meta_info.keys()) self.sub_idxs = sub_idxs # sub_idxs [[0,1,2], [3,4,5], [...], ...] self.image_num = 6 # hardcode temp. random.seed(seed) logger.info(f"Trainset: {len(self)} asset3d instances.") def __len__(self) -> int: return len(self.meta_info) def prepare_data_index(self, index_file: str) -> Dict[str, Any]: with open(index_file, "r") as fin: meta_info = json.load(fin) meta_info_filtered = dict() for idx, uid in enumerate(meta_info): if "status" not in meta_info[uid]: continue if meta_info[uid]["status"] != "success": continue if self.max_train_samples and idx >= self.max_train_samples: break meta_info_filtered[uid] = meta_info[uid] logger.info( f"Load {len(meta_info)} assets, keep {len(meta_info_filtered)} valids." # noqa ) return meta_info_filtered def fetch_sample_images( self, uid: str, attrs: List[str], sub_index: int = None, transform: Callable = None, ) -> torch.Tensor: sample = self.meta_info[uid] images = [] for attr in attrs: item = sample[attr] if sub_index is not None: item = item[sub_index] mode = "L" if attr == "image_mask" else "RGB" image = Image.open(item).convert(mode) if transform is not None: image = transform(image) if len(image.shape) == 2: image = image[..., None] images.append(image) images = torch.cat(images, dim=0) return images def fetch_sample_grid_images( self, uid: str, attrs: List[str], sub_idxs: List[List[int]], transform: Callable = None, ) -> torch.Tensor: assert transform is not None grid_image = [] for row_idxs in sub_idxs: row_image = [] for row_idx in row_idxs: image = self.fetch_sample_images( uid, attrs, row_idx, transform ) row_image.append(image) row_image = torch.cat(row_image, dim=2) # (c h w) grid_image.append(row_image) grid_image = torch.cat(grid_image, dim=1) return grid_image def compute_text_embeddings( self, embed_path: str, original_size: Tuple[int, int] ) -> Dict[str, nn.Module]: data_dict = torch.load(embed_path) prompt_embeds = data_dict["prompt_embeds"][0] add_text_embeds = data_dict["pooled_prompt_embeds"][0] # Need changed if random crop, set as crop_top_left [y1, x1], center crop as [0, 0]. # noqa crops_coords_top_left = (0, 0) add_time_ids = list( original_size + crops_coords_top_left + self.target_hw ) add_time_ids = torch.tensor([add_time_ids]) # add_time_ids = add_time_ids.repeat((len(add_text_embeds), 1)) unet_added_cond_kwargs = { "text_embeds": add_text_embeds, "time_ids": add_time_ids, } return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} def visualize_item( self, control: torch.Tensor, color: torch.Tensor, save_dir: str = None, ) -> List[Image.Image]: to_pil = transforms.ToPILImage() color = (color + 1) / 2 color_pil = to_pil(color) normal_pil = to_pil(control[0:3]) position_pil = to_pil(control[3:6]) mask_pil = to_pil(control[6:]) if save_dir is not None: os.makedirs(save_dir, exist_ok=True) color_pil.save(f"{save_dir}/rgb.jpg") normal_pil.save(f"{save_dir}/normal.jpg") position_pil.save(f"{save_dir}/position.jpg") mask_pil.save(f"{save_dir}/mask.jpg") logger.info(f"Visualization in {save_dir}") return normal_pil, position_pil, mask_pil, color_pil def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: uid = self.data_list[index] sub_idxs = self.sub_idxs if sub_idxs is None: sub_idxs = [[random.randint(0, self.image_num - 1)]] input_image = self.fetch_sample_grid_images( uid, attrs=["image_view_normal", "image_position", "image_mask"], sub_idxs=sub_idxs, transform=self.control_transform, ) assert input_image.shape[1:] == self.target_hw output_image = self.fetch_sample_grid_images( uid, attrs=["image_color"], sub_idxs=sub_idxs, transform=self.transform, ) sample = self.meta_info[uid] text_feats = self.compute_text_embeddings( sample["text_feat"], tuple(sample["image_hw"]) ) data = dict( pixel_values=output_image, conditioning_pixel_values=input_image, prompt_embeds=text_feats["prompt_embeds"], text_embeds=text_feats["text_embeds"], time_ids=text_feats["time_ids"], ) return data if __name__ == "__main__": index_file = "/horizon-bucket/robot_lab/users/xinjie.wang/datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa target_hw = (512, 512) transform_list = [ transforms.Resize( target_hw, interpolation=transforms.InterpolationMode.BILINEAR ), transforms.CenterCrop(target_hw), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] image_transform = transforms.Compose(transform_list) control_transform = transforms.Compose(transform_list[:-1]) sub_idxs = [[0, 1, 2], [3, 4, 5]] # None if sub_idxs is not None: target_hw = ( target_hw[0] * len(sub_idxs), target_hw[1] * len(sub_idxs[0]), ) dataset = Asset3dGenDataset( index_file, target_hw, image_transform, control_transform, sub_idxs=sub_idxs, ) data = dataset[0] dataset.visualize_item( data["conditioning_pixel_values"], data["pixel_values"], save_dir="./" )