StableRecon / spann3r /datasets /blendedmvs.py
Stable-X's picture
Upload folder using huggingface_hub
e4bf056 verified
raw
history blame
6.92 kB
import os
import cv2
import numpy as np
import os.path as osp
from collections import deque
from dust3r.utils.image import imread_cv2
from .base_many_view_dataset import BaseManyViewDataset
class BlendMVS(BaseManyViewDataset):
def __init__(self, num_seq=100, num_frames=5,
min_thresh=10, max_thresh=100,
test_id=None, full_video=False,
kf_every=1, *args, ROOT, **kwargs):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
self.num_seq = num_seq
self.num_frames = num_frames
self.max_thresh = max_thresh
self.min_thresh = min_thresh
self.test_id = test_id
self.full_video = full_video
self.kf_every = kf_every
# load all scenes
self.load_all_scenes(ROOT)
def __len__(self):
return len(self.scene_list) * self.num_seq
def sample_pairs(self, pairs_path, rng, max_trials=10):
cluster_lines = open(pairs_path).read().splitlines()
image_num = int(cluster_lines[0])
trials = 0
while trials < max_trials:
trials += 1
sample_idx = rng.choice(image_num)
ref_idx = int(cluster_lines[2 * sample_idx + 1])
cluster_info = cluster_lines[2 * sample_idx + 2].split()
total_view_num = int(cluster_info[0])
if total_view_num > self.num_frames-1:
list_idx = ['{:08d}.jpg'.format(ref_idx)]
sample_cidx = rng.choice(total_view_num, self.num_frames-1, replace=False)
for cidx in sample_cidx:
list_idx.append('{:08d}.jpg'.format(int(cluster_info[2 * cidx + 1])))
if rng.choice([True, False]):
list_idx.reverse()
return list_idx
return None
def load_all_scenes(self, base_dir):
if self.test_id is None:
meta_split = osp.join(base_dir, f'{self.split}_list.txt')
if not osp.exists(meta_split):
raise FileNotFoundError(f"Split file {meta_split} not found")
with open(meta_split) as f:
self.scene_list = f.read().splitlines()
print(f"Found {len(self.scene_list)} scenes in split {self.split}")
else:
if isinstance(self.test_id, list):
self.scene_list = self.test_id
else:
self.scene_list = [self.test_id]
print(f"Test_id: {self.test_id}")
def load_cam_mvsnet(self, f, interval_scale=1):
""" read camera txt file """
# f = open(file)
RT = np.loadtxt(f, skiprows=1, max_rows=4, dtype=np.float32)
assert RT.shape == (4, 4)
# RT = np.linalg.inv(RT) # world2cam to cam2world
K = np.loadtxt(f, skiprows=2, max_rows=3, dtype=np.float32)
assert K.shape == (3, 3)
return K, RT
def _get_views(self, idx, resolution, rng, attempts=0):
scene_id = self.scene_list[idx // self.num_seq]
image_path = osp.join(self.ROOT, scene_id, 'blended_images')
depth_path = osp.join(self.ROOT, scene_id, 'rendered_depth_maps')
cam_path = osp.join(self.ROOT, scene_id, 'cams')
pairs_path = osp.join(self.ROOT, scene_id, 'cams', 'pair.txt')
if not self.full_video:
img_idxs = self.sample_pairs(pairs_path, rng)
else:
img_idxs = sorted(os.listdir(image_path))
img_idxs = img_idxs[::self.kf_every]
if img_idxs is None:
new_idx = rng.integers(0, self.__len__()-1)
return self._get_views(new_idx, resolution, rng)
imgs_idxs = deque(img_idxs)
views = []
max_depth_min = 1e8
max_depth_max = 0.0
max_depth_first = None
while len(imgs_idxs) > 0:
im_idx = imgs_idxs.popleft()
impath = osp.join(image_path, im_idx)
depthpath = osp.join(depth_path, im_idx.replace('.jpg', '.pfm'))
campath = osp.join(cam_path, im_idx.replace('.jpg', '_cam.txt'))
rgb_image = imread_cv2(impath)
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0)
cur_intrinsics, camera_pose = self.load_cam_mvsnet(open(campath, 'r'))
intrinsics = cur_intrinsics[:3, :3]
camera_pose = np.linalg.inv(camera_pose)
H, W = rgb_image.shape[:2]
cx, cy = intrinsics[:2, 2].round().astype(int)
min_margin_x = min(cx, W-cx)
min_margin_y = min(cy, H-cy)
if min_margin_x <= W/5 or min_margin_y <= H/5:
new_idx = rng.integers(0, self.__len__()-1)
return self._get_views(new_idx, resolution, rng)
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
input_depth_max = depthmap.max()
if input_depth_max> max_depth_max:
max_depth_max = input_depth_max
if input_depth_max < max_depth_min:
max_depth_min = input_depth_max
if max_depth_first is None:
max_depth_first = input_depth_max
num_valid = (depthmap > 0.0).sum()
if num_valid == 0 or (not np.isfinite(camera_pose).all()):
if self.full_video:
print(f"Warning: No valid depthmap found for {impath}")
continue
else:
if attempts >= 5:
new_idx = rng.integers(0, self.__len__()-1)
return self._get_views(new_idx, resolution, rng)
return self._get_views(idx, resolution, rng, attempts+1)
views.append(dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset='blendmvs',
label=osp.join(scene_id, im_idx),
instance=osp.split(impath)[1],
))
if max_depth_max / max_depth_min > 100. or max_depth_max / max_depth_first > 10.:
print(f"Warning: Depthmap range too large: {max_depth_max} {max_depth_min} {max_depth_first}")
new_idx = rng.integers(0, self.__len__()-1)
return self._get_views(new_idx, resolution, rng)
return views