Stable-X's picture
Upload folder using huggingface_hub
e4bf056 verified
raw
history blame
6.89 kB
import os
import cv2
import json
import itertools
import numpy as np
import os.path as osp
from collections import deque
from dust3r.utils.image import imread_cv2
from .base_many_view_dataset import BaseManyViewDataset
class Co3d(BaseManyViewDataset):
def __init__(self, mask_bg=True, use_comb=True,
scene_class=None, scene_id=None,
num_seq=100, num_frames=5,
min_thresh=10, max_thresh=100,
full_video=False, lb=0, ub=30,
kf_every=1, *args, ROOT, **kwargs):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
assert mask_bg in (True, False, 'rand')
self.mask_bg = mask_bg
self.num_seq = num_seq
self.num_frames = num_frames
self.max_thresh = max_thresh
self.min_thresh = min_thresh
self.full_video = full_video
self.kf_every = kf_every
self.use_comb = use_comb
self.scenes, self.scene_list = self.load_scene(scene_class, scene_id)
self.combinations, self.num_seq = self.get_combinations(use_comb, lb, ub)
self.invalidate = {scene: {} for scene in self.scene_list}
def get_combinations(self, use_comb, lb, ub):
if use_comb and not self.full_video:
print('Using combinations')
combinations = list(itertools.combinations(range(100), self.num_frames))
combinations = [combo for combo in combinations if all(lb < abs(x-y) <= ub and abs(x-y) % 5 == 0 for x, y in zip(combo, combo[1:]))]
num_seq = len(combinations)
print('Number of sequences:', num_seq)
else:
combinations = None
num_seq = self.num_seq
return combinations, num_seq
def load_scene(self, scene_class=None, scene_id=None):
print('Loading scenes')
with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
scenes = json.load(f)
if scene_class is not None:
scenes = {k: v for k, v in scenes.items() if k == scene_class}
else:
scenes = {k: v for k, v in scenes.items() if len(v) > 0} # k is class (apple), v is corresponding list
if scene_id is not None:
scenes = {(k, k2): v2 for k, v in scenes.items() for k2, v2 in v.items() if k2 == scene_id}
else:
scenes = {(k, k2): v2 for k, v in scenes.items()
for k2, v2 in v.items()} # k is class (apple), k2 is instance (110_13051_23361), v2 is list of image idx
scene_list = list(scenes.keys())
return scenes, scene_list
def __len__(self):
return len(self.scene_list) * self.num_seq
def _get_views(self, idx, resolution, rng, attempts=0):
obj, instance = self.scene_list[idx // self.num_seq]
image_pool = self.scenes[obj, instance]
if self.use_comb and not self.full_video:
frame_idx = self.combinations[idx % len(self.combinations)]
last = len(image_pool)-1
imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in frame_idx]
else:
img_idx = range(0, len(image_pool))
imgs_idxs = self.sample_frames(img_idx, rng)
if resolution not in self.invalidate[obj, instance]: # flag invalid images
self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
imgs_idxs = deque(imgs_idxs)
max_depth_min = 1e8
max_depth_max = 0.
max_depth_first = None
views = []
while len(imgs_idxs) > 0:
im_idx = imgs_idxs.popleft()
if self.invalidate[obj, instance][resolution][im_idx]:
# search for a valid image
random_direction = 2 * rng.choice(2) - 1
for offset in range(1, len(image_pool)):
tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
im_idx = tentative_im_idx
break
view_idx = image_pool[im_idx]
impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06d}.jpg')
# load camera params
input_metadata = np.load(impath.replace('jpg', 'npz'))
camera_pose = input_metadata['camera_pose'].astype(np.float32)
intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
rgb_image = imread_cv2(impath)
depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
if mask_bg:
# load object mask
maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06d}.png')
maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
maskmap = (maskmap / 255.0) > 0.1
# update the depthmap with mask
depthmap *= maskmap
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
num_valid = (depthmap > 0.0).sum()
if num_valid == 0:
# problem, invalidate image and retry
self.invalidate[obj, instance][resolution][im_idx] = True
imgs_idxs.appendleft(im_idx)
continue
if input_metadata['maximum_depth'] > max_depth_max:
max_depth_max = input_metadata['maximum_depth']
if input_metadata['maximum_depth'] < max_depth_min:
max_depth_min = input_metadata['maximum_depth']
if max_depth_first is None:
max_depth_first = input_metadata['maximum_depth']
views.append(dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset='Co3d_v2',
label=osp.join(obj, instance),
instance=osp.split(impath)[1]
))
if max_depth_max / max_depth_min > 100. or max_depth_max / max_depth_first > 10.:
new_idx = rng.integers(0, self.__len__()-1)
return self._get_views(new_idx, resolution, rng)
return views