import io import os from time import time from typing import Any, Dict, List, Tuple import numpy as np import tables import torch import torchvision import torchvision.transforms.v2.functional as TF from PIL import Image from unik3d.datasets.base_dataset import BaseDataset from unik3d.utils import is_main_process from unik3d.utils.camera import BatchCamera, Pinhole """ Awful class for legacy reasons, we assume only pinhole cameras And we "fake" sequences by setting sequence_fields to [(0, 0)] and cam2w as eye(4) """ class ImageDataset(BaseDataset): def __init__( self, image_shape: Tuple[int, int], split_file: str, test_mode: bool, normalize: bool, augmentations_db: Dict[str, Any], shape_constraints: Dict[str, Any], resize_method: str, mini: float, benchmark: bool = False, **kwargs, ) -> None: super().__init__( image_shape=image_shape, split_file=split_file, test_mode=test_mode, benchmark=benchmark, normalize=normalize, augmentations_db=augmentations_db, shape_constraints=shape_constraints, resize_method=resize_method, mini=mini, **kwargs, ) self.mapper = self.get_mapper() def get_single_item(self, idx, sample=None, mapper=None): sample = self.dataset[idx] if sample is None else sample mapper = self.mapper if mapper is None else mapper results = { (0, 0): dict( gt_fields=set(), image_fields=set(), mask_fields=set(), camera_fields=set(), ) } results = self.pre_pipeline(results) results["sequence_fields"] = [(0, 0)] chunk_idx = ( int(sample[self.mapper["chunk_idx"]]) if "chunk_idx" in self.mapper else 0 ) h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx]) with tables.File( h5_path, mode="r", libver="latest", swmr=True, ) as h5file_chunk: for key_mapper, idx_mapper in mapper.items(): if "image" not in key_mapper and "depth" not in key_mapper: continue value = sample[idx_mapper] results[(0, 0)][key_mapper] = value name = key_mapper.replace("_filename", "") value_root = "/" + value if "image" in key_mapper: results[(0, 0)]["filename"] = value file = h5file_chunk.get_node(value_root).read() image = ( torchvision.io.decode_image(torch.from_numpy(file)) .to(torch.uint8) .squeeze() ) results[(0, 0)]["image_fields"].add(name) results[(0, 0)][f"image_ori_shape"] = image.shape[-2:] results[(0, 0)][name] = image[None, ...] # collect camera information for the given image name = name.replace("image_", "") results[(0, 0)]["camera_fields"].update({"camera", "cam2w"}) K = self.get_intrinsics(idx, value) if K is None: K = torch.eye(3) K[0, 0] = K[1, 1] = 0.7 * self.image_shape[1] K[0, 2] = 0.5 * self.image_shape[1] K[1, 2] = 0.5 * self.image_shape[0] camera = Pinhole(K=K[None, ...].clone()) results[(0, 0)]["camera"] = BatchCamera.from_camera(camera) results[(0, 0)]["cam2w"] = self.get_extrinsics(idx, value)[ None, ... ] elif "depth" in key_mapper: # start = time() file = h5file_chunk.get_node(value_root).read() depth = Image.open(io.BytesIO(file)) depth = TF.pil_to_tensor(depth).squeeze().to(torch.float32) if depth.ndim == 3: depth = depth[2] + depth[1] * 255 + depth[0] * 255 * 255 results[(0, 0)]["gt_fields"].add(name) results[(0, 0)][f"depth_ori_shape"] = depth.shape depth = ( depth.view(1, 1, *depth.shape).contiguous() / self.depth_scale ) results[(0, 0)][name] = depth results = self.preprocess(results) if not self.test_mode: results = self.augment(results) results = self.postprocess(results) return results def preprocess(self, results): results = self.replicate(results) for i, seq in enumerate(results["sequence_fields"]): self.resizer.ctx = None results[seq] = self.resizer(results[seq]) num_pts = torch.count_nonzero(results[seq]["depth"] > 0) if num_pts < 50: raise IndexError(f"Too few points in depth map ({num_pts})") for key in results[seq].get("image_fields", ["image"]): results[seq][key] = results[seq][key].to(torch.float32) / 255 # update fields common in sequence for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]: if key in results[(0, 0)]: results[key] = results[(0, 0)][key] results = self.pack_batch(results) return results def postprocess(self, results): # normalize after because color aug requires [0,255]? for key in results.get("image_fields", ["image"]): results[key] = TF.normalize(results[key], **self.normalization_stats) results = self.filler(results) results = self.unpack_batch(results) results = self.masker(results) results = self.collecter(results) return results def __getitem__(self, idx): try: if isinstance(idx, (list, tuple)): results = [self.get_single_item(i) for i in idx] else: results = self.get_single_item(idx) except Exception as e: print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}") idx = np.random.randint(0, len(self.dataset)) results = self[idx] return results def get_intrinsics(self, idx, image_name): idx_sample = self.mapper.get("K", 1000) sample = self.dataset[idx] if idx_sample >= len(sample): return None return sample[idx_sample] def get_extrinsics(self, idx, image_name): idx_sample = self.mapper.get("cam2w", 1000) sample = self.dataset[idx] if idx_sample >= len(sample): return torch.eye(4) return sample[idx_sample] def get_mapper(self): return { "image_filename": 0, "depth_filename": 1, "K": 2, }