Spaces:
Running
on
Zero
Running
on
Zero
import io | |
import os | |
from time import time | |
from typing import Any, Dict, List, Tuple | |
import numpy as np | |
import tables | |
import torch | |
import torchvision | |
import torchvision.transforms.v2.functional as TF | |
from PIL import Image | |
from unik3d.datasets.base_dataset import BaseDataset | |
from unik3d.utils import is_main_process | |
from unik3d.utils.camera import BatchCamera, Pinhole | |
""" | |
Awful class for legacy reasons, we assume only pinhole cameras | |
And we "fake" sequences by setting sequence_fields to [(0, 0)] and cam2w as eye(4) | |
""" | |
class ImageDataset(BaseDataset): | |
def __init__( | |
self, | |
image_shape: Tuple[int, int], | |
split_file: str, | |
test_mode: bool, | |
normalize: bool, | |
augmentations_db: Dict[str, Any], | |
shape_constraints: Dict[str, Any], | |
resize_method: str, | |
mini: float, | |
benchmark: bool = False, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
image_shape=image_shape, | |
split_file=split_file, | |
test_mode=test_mode, | |
benchmark=benchmark, | |
normalize=normalize, | |
augmentations_db=augmentations_db, | |
shape_constraints=shape_constraints, | |
resize_method=resize_method, | |
mini=mini, | |
**kwargs, | |
) | |
self.mapper = self.get_mapper() | |
def get_single_item(self, idx, sample=None, mapper=None): | |
sample = self.dataset[idx] if sample is None else sample | |
mapper = self.mapper if mapper is None else mapper | |
results = { | |
(0, 0): dict( | |
gt_fields=set(), | |
image_fields=set(), | |
mask_fields=set(), | |
camera_fields=set(), | |
) | |
} | |
results = self.pre_pipeline(results) | |
results["sequence_fields"] = [(0, 0)] | |
chunk_idx = ( | |
int(sample[self.mapper["chunk_idx"]]) if "chunk_idx" in self.mapper else 0 | |
) | |
h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx]) | |
with tables.File( | |
h5_path, | |
mode="r", | |
libver="latest", | |
swmr=True, | |
) as h5file_chunk: | |
for key_mapper, idx_mapper in mapper.items(): | |
if "image" not in key_mapper and "depth" not in key_mapper: | |
continue | |
value = sample[idx_mapper] | |
results[(0, 0)][key_mapper] = value | |
name = key_mapper.replace("_filename", "") | |
value_root = "/" + value | |
if "image" in key_mapper: | |
results[(0, 0)]["filename"] = value | |
file = h5file_chunk.get_node(value_root).read() | |
image = ( | |
torchvision.io.decode_image(torch.from_numpy(file)) | |
.to(torch.uint8) | |
.squeeze() | |
) | |
results[(0, 0)]["image_fields"].add(name) | |
results[(0, 0)][f"image_ori_shape"] = image.shape[-2:] | |
results[(0, 0)][name] = image[None, ...] | |
# collect camera information for the given image | |
name = name.replace("image_", "") | |
results[(0, 0)]["camera_fields"].update({"camera", "cam2w"}) | |
K = self.get_intrinsics(idx, value) | |
if K is None: | |
K = torch.eye(3) | |
K[0, 0] = K[1, 1] = 0.7 * self.image_shape[1] | |
K[0, 2] = 0.5 * self.image_shape[1] | |
K[1, 2] = 0.5 * self.image_shape[0] | |
camera = Pinhole(K=K[None, ...].clone()) | |
results[(0, 0)]["camera"] = BatchCamera.from_camera(camera) | |
results[(0, 0)]["cam2w"] = self.get_extrinsics(idx, value)[ | |
None, ... | |
] | |
elif "depth" in key_mapper: | |
# start = time() | |
file = h5file_chunk.get_node(value_root).read() | |
depth = Image.open(io.BytesIO(file)) | |
depth = TF.pil_to_tensor(depth).squeeze().to(torch.float32) | |
if depth.ndim == 3: | |
depth = depth[2] + depth[1] * 255 + depth[0] * 255 * 255 | |
results[(0, 0)]["gt_fields"].add(name) | |
results[(0, 0)][f"depth_ori_shape"] = depth.shape | |
depth = ( | |
depth.view(1, 1, *depth.shape).contiguous() / self.depth_scale | |
) | |
results[(0, 0)][name] = depth | |
results = self.preprocess(results) | |
if not self.test_mode: | |
results = self.augment(results) | |
results = self.postprocess(results) | |
return results | |
def preprocess(self, results): | |
results = self.replicate(results) | |
for i, seq in enumerate(results["sequence_fields"]): | |
self.resizer.ctx = None | |
results[seq] = self.resizer(results[seq]) | |
num_pts = torch.count_nonzero(results[seq]["depth"] > 0) | |
if num_pts < 50: | |
raise IndexError(f"Too few points in depth map ({num_pts})") | |
for key in results[seq].get("image_fields", ["image"]): | |
results[seq][key] = results[seq][key].to(torch.float32) / 255 | |
# update fields common in sequence | |
for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]: | |
if key in results[(0, 0)]: | |
results[key] = results[(0, 0)][key] | |
results = self.pack_batch(results) | |
return results | |
def postprocess(self, results): | |
# normalize after because color aug requires [0,255]? | |
for key in results.get("image_fields", ["image"]): | |
results[key] = TF.normalize(results[key], **self.normalization_stats) | |
results = self.filler(results) | |
results = self.unpack_batch(results) | |
results = self.masker(results) | |
results = self.collecter(results) | |
return results | |
def __getitem__(self, idx): | |
try: | |
if isinstance(idx, (list, tuple)): | |
results = [self.get_single_item(i) for i in idx] | |
else: | |
results = self.get_single_item(idx) | |
except Exception as e: | |
print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}") | |
idx = np.random.randint(0, len(self.dataset)) | |
results = self[idx] | |
return results | |
def get_intrinsics(self, idx, image_name): | |
idx_sample = self.mapper.get("K", 1000) | |
sample = self.dataset[idx] | |
if idx_sample >= len(sample): | |
return None | |
return sample[idx_sample] | |
def get_extrinsics(self, idx, image_name): | |
idx_sample = self.mapper.get("cam2w", 1000) | |
sample = self.dataset[idx] | |
if idx_sample >= len(sample): | |
return torch.eye(4) | |
return sample[idx_sample] | |
def get_mapper(self): | |
return { | |
"image_filename": 0, | |
"depth_filename": 1, | |
"K": 2, | |
} | |