Stable-X's picture
Upload folder using huggingface_hub
e4bf056 verified
raw
history blame
3.58 kB
import os
import cv2
import numpy as np
import os.path as osp
from collections import deque
from dust3r.utils.image import imread_cv2
from .base_many_view_dataset import BaseManyViewDataset
class Demo(BaseManyViewDataset):
def __init__(self, num_seq=1, num_frames=5,
min_thresh=10, max_thresh=100,
full_video=True, kf_every=1,
*args, ROOT, **kwargs):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
self.num_seq = num_seq
self.num_frames = num_frames
self.max_thresh = max_thresh
self.min_thresh = min_thresh
self.full_video = full_video
self.kf_every = kf_every
def __len__(self):
return self.num_seq
def _get_views(self, idx, resolution, rng):
img_idxs = sorted(os.listdir(self.ROOT))
valid_extensions = {'.jpg', '.jpeg', '.png', '.heic'}
img_idxs = [idx for idx in img_idxs
if idx.lower().endswith(tuple(valid_extensions)) and 'depth' not in idx.lower()]
img_idxs = self.sample_frame_idx(img_idxs, rng, full_video=self.full_video)
# pseudo intrinsics
fx, fy = 525, 525
views = []
imgs_idxs = deque(img_idxs)
while len(imgs_idxs) > 0:
im_idx = imgs_idxs.popleft()
impath = osp.join(self.ROOT, im_idx)
if not osp.exists(impath):
raise FileNotFoundError(f"Image not found: {impath}")
print(f'Loading image: {impath}')
if 'heic' in impath.lower():
from PIL import Image
rgb_image = Image.open(impath)
if rgb_image.mode != 'RGB':
rgb_image = rgb_image.convert('RGB')
rgb_image = np.array(rgb_image)
else:
rgb_image = imread_cv2(impath)
depth_path = impath.split('.')[0] + '_depth.png'
meta_data_path = impath.split('.')[0] + '.npz'
if osp.exists(meta_data_path):
input_metadata = np.load(meta_data_path)
camera_pose = input_metadata['camera_pose'].astype(np.float32)
intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
else:
cx, cy = rgb_image.shape[1]//2, rgb_image.shape[0]//2
intrinsics = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
# pseudo camera pose
camera_pose = np.eye(4).astype(np.float32)
if not osp.exists(depth_path):
depthmap = np.ones((rgb_image.shape[0], rgb_image.shape[1])).astype(np.float32)
else:
depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED)
depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
# resize rgb to the same size as depth
rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
views.append(dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset='demo',
label=osp.join(self.ROOT, im_idx),
instance=osp.split(impath)[1],
))
return views