Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import numpy as np | |
import os.path as osp | |
from collections import deque | |
from dust3r.utils.image import imread_cv2 | |
from .base_many_view_dataset import BaseManyViewDataset | |
class BlendMVS(BaseManyViewDataset): | |
def __init__(self, num_seq=100, num_frames=5, | |
min_thresh=10, max_thresh=100, | |
test_id=None, full_video=False, | |
kf_every=1, *args, ROOT, **kwargs): | |
self.ROOT = ROOT | |
super().__init__(*args, **kwargs) | |
self.num_seq = num_seq | |
self.num_frames = num_frames | |
self.max_thresh = max_thresh | |
self.min_thresh = min_thresh | |
self.test_id = test_id | |
self.full_video = full_video | |
self.kf_every = kf_every | |
# load all scenes | |
self.load_all_scenes(ROOT) | |
def __len__(self): | |
return len(self.scene_list) * self.num_seq | |
def sample_pairs(self, pairs_path, rng, max_trials=10): | |
cluster_lines = open(pairs_path).read().splitlines() | |
image_num = int(cluster_lines[0]) | |
trials = 0 | |
while trials < max_trials: | |
trials += 1 | |
sample_idx = rng.choice(image_num) | |
ref_idx = int(cluster_lines[2 * sample_idx + 1]) | |
cluster_info = cluster_lines[2 * sample_idx + 2].split() | |
total_view_num = int(cluster_info[0]) | |
if total_view_num > self.num_frames-1: | |
list_idx = ['{:08d}.jpg'.format(ref_idx)] | |
sample_cidx = rng.choice(total_view_num, self.num_frames-1, replace=False) | |
for cidx in sample_cidx: | |
list_idx.append('{:08d}.jpg'.format(int(cluster_info[2 * cidx + 1]))) | |
if rng.choice([True, False]): | |
list_idx.reverse() | |
return list_idx | |
return None | |
def load_all_scenes(self, base_dir): | |
if self.test_id is None: | |
meta_split = osp.join(base_dir, f'{self.split}_list.txt') | |
if not osp.exists(meta_split): | |
raise FileNotFoundError(f"Split file {meta_split} not found") | |
with open(meta_split) as f: | |
self.scene_list = f.read().splitlines() | |
print(f"Found {len(self.scene_list)} scenes in split {self.split}") | |
else: | |
if isinstance(self.test_id, list): | |
self.scene_list = self.test_id | |
else: | |
self.scene_list = [self.test_id] | |
print(f"Test_id: {self.test_id}") | |
def load_cam_mvsnet(self, f, interval_scale=1): | |
""" read camera txt file """ | |
# f = open(file) | |
RT = np.loadtxt(f, skiprows=1, max_rows=4, dtype=np.float32) | |
assert RT.shape == (4, 4) | |
# RT = np.linalg.inv(RT) # world2cam to cam2world | |
K = np.loadtxt(f, skiprows=2, max_rows=3, dtype=np.float32) | |
assert K.shape == (3, 3) | |
return K, RT | |
def _get_views(self, idx, resolution, rng, attempts=0): | |
scene_id = self.scene_list[idx // self.num_seq] | |
image_path = osp.join(self.ROOT, scene_id, 'blended_images') | |
depth_path = osp.join(self.ROOT, scene_id, 'rendered_depth_maps') | |
cam_path = osp.join(self.ROOT, scene_id, 'cams') | |
pairs_path = osp.join(self.ROOT, scene_id, 'cams', 'pair.txt') | |
if not self.full_video: | |
img_idxs = self.sample_pairs(pairs_path, rng) | |
else: | |
img_idxs = sorted(os.listdir(image_path)) | |
img_idxs = img_idxs[::self.kf_every] | |
if img_idxs is None: | |
new_idx = rng.integers(0, self.__len__()-1) | |
return self._get_views(new_idx, resolution, rng) | |
imgs_idxs = deque(img_idxs) | |
views = [] | |
max_depth_min = 1e8 | |
max_depth_max = 0.0 | |
max_depth_first = None | |
while len(imgs_idxs) > 0: | |
im_idx = imgs_idxs.popleft() | |
impath = osp.join(image_path, im_idx) | |
depthpath = osp.join(depth_path, im_idx.replace('.jpg', '.pfm')) | |
campath = osp.join(cam_path, im_idx.replace('.jpg', '_cam.txt')) | |
rgb_image = imread_cv2(impath) | |
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) | |
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) | |
cur_intrinsics, camera_pose = self.load_cam_mvsnet(open(campath, 'r')) | |
intrinsics = cur_intrinsics[:3, :3] | |
camera_pose = np.linalg.inv(camera_pose) | |
H, W = rgb_image.shape[:2] | |
cx, cy = intrinsics[:2, 2].round().astype(int) | |
min_margin_x = min(cx, W-cx) | |
min_margin_y = min(cy, H-cy) | |
if min_margin_x <= W/5 or min_margin_y <= H/5: | |
new_idx = rng.integers(0, self.__len__()-1) | |
return self._get_views(new_idx, resolution, rng) | |
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( | |
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) | |
input_depth_max = depthmap.max() | |
if input_depth_max> max_depth_max: | |
max_depth_max = input_depth_max | |
if input_depth_max < max_depth_min: | |
max_depth_min = input_depth_max | |
if max_depth_first is None: | |
max_depth_first = input_depth_max | |
num_valid = (depthmap > 0.0).sum() | |
if num_valid == 0 or (not np.isfinite(camera_pose).all()): | |
if self.full_video: | |
print(f"Warning: No valid depthmap found for {impath}") | |
continue | |
else: | |
if attempts >= 5: | |
new_idx = rng.integers(0, self.__len__()-1) | |
return self._get_views(new_idx, resolution, rng) | |
return self._get_views(idx, resolution, rng, attempts+1) | |
views.append(dict( | |
img=rgb_image, | |
depthmap=depthmap, | |
camera_pose=camera_pose, | |
camera_intrinsics=intrinsics, | |
dataset='blendmvs', | |
label=osp.join(scene_id, im_idx), | |
instance=osp.split(impath)[1], | |
)) | |
if max_depth_max / max_depth_min > 100. or max_depth_max / max_depth_first > 10.: | |
print(f"Warning: Depthmap range too large: {max_depth_max} {max_depth_min} {max_depth_first}") | |
new_idx = rng.integers(0, self.__len__()-1) | |
return self._get_views(new_idx, resolution, rng) | |
return views | |