Kai422kx's picture
init
4f6b78d
import sys
sys.path.append('.')
import os
import torch
import numpy as np
import os.path as osp
import glob
import PIL.Image
import torchvision.transforms as tvf
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from dust3r.utils.image import imread_cv2, crop_img
from dust3r.utils.misc import get_stride_distribution
np.random.seed(125)
torch.multiprocessing.set_sharing_strategy('file_system')
TAG_FLOAT = 202021.25
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
ToTensor = tvf.ToTensor()
def depth_read(filename):
""" Read depth data from file, return as numpy array. """
f = open(filename,'rb')
check = np.fromfile(f,dtype=np.float32,count=1)[0]
assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
width = np.fromfile(f,dtype=np.int32,count=1)[0]
height = np.fromfile(f,dtype=np.int32,count=1)[0]
size = width*height
assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width))
return depth
def cam_read(filename):
""" Read camera data, return (M,N) tuple.
M is the intrinsic matrix, N is the extrinsic matrix, so that
x = M*N*X,
with x being a point in homogeneous image pixel coordinates, X being a
point in homogeneous world coordinates.
"""
f = open(filename,'rb')
check = np.fromfile(f,dtype=np.float32,count=1)[0]
assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
M = np.fromfile(f,dtype='float64',count=9).reshape((3,3))
N = np.fromfile(f,dtype='float64',count=12).reshape((3,4))
return M,N
class SintelDUSt3R(BaseStereoViewDataset):
def __init__(self,
dataset_location='data/sintel/training',
dset='clean',
use_augs=False,
S=2,
strides=[7],
clip_step=2,
quick=False,
verbose=False,
dist_type=None,
clip_step_last_skip = 0,
load_dynamic_mask=True,
*args,
**kwargs
):
print('loading sintel dataset...')
super().__init__(*args, **kwargs)
self.dataset_label = 'sintel'
self.split = dset
self.S = S # stride
self.verbose = verbose
self.load_dynamic_mask = load_dynamic_mask
self.use_augs = use_augs
self.dset = dset
self.rgb_paths = []
self.depth_paths = []
self.traj_paths = []
self.annotation_paths = []
self.dynamic_mask_paths = []
self.full_idxs = []
self.sample_stride = []
self.strides = strides
self.subdirs = []
self.sequences = []
self.subdirs.append(os.path.join(dataset_location, dset))
for subdir in self.subdirs:
for seq in glob.glob(os.path.join(subdir, "*/")):
self.sequences.append(seq)
self.sequences = sorted(self.sequences)
if self.verbose:
print(self.sequences)
print('found %d unique videos in %s (dset=%s)' % (len(self.sequences), dataset_location, dset))
## load trajectories
print('loading trajectories...')
if quick:
self.sequences = self.sequences[1:2]
for seq in self.sequences:
if self.verbose:
print('seq', seq)
rgb_path = seq
depth_path = seq.replace(dset,'depth')
caminfo_path = seq.replace(dset,'camdata_left')
dynamic_mask_path = seq.replace(dset,'dynamic_label_perfect')
for stride in strides:
for ii in range(1,len(os.listdir(rgb_path))-self.S*max(stride,clip_step_last_skip)+1, clip_step):
full_idx = ii + np.arange(self.S)*stride
self.rgb_paths.append([os.path.join(rgb_path, 'frame_%04d.png' % idx) for idx in full_idx])
self.depth_paths.append([os.path.join(depth_path, 'frame_%04d.dpt' % idx) for idx in full_idx])
self.annotation_paths.append([os.path.join(caminfo_path, 'frame_%04d.cam' % idx) for idx in full_idx])
self.dynamic_mask_paths.append([os.path.join(dynamic_mask_path, 'frame_%04d.png' % idx) for idx in full_idx])
self.full_idxs.append(full_idx)
self.sample_stride.append(stride)
if self.verbose:
sys.stdout.write('.')
sys.stdout.flush()
self.stride_counts = {}
self.stride_idxs = {}
for stride in strides:
self.stride_counts[stride] = 0
self.stride_idxs[stride] = []
for i, stride in enumerate(self.sample_stride):
self.stride_counts[stride] += 1
self.stride_idxs[stride].append(i)
print('stride counts:', self.stride_counts)
if len(strides) > 1 and dist_type is not None:
self._resample_clips(strides, dist_type)
print('collected %d clips of length %d in %s (dset=%s)' % (
len(self.rgb_paths), self.S, dataset_location, dset))
def _resample_clips(self, strides, dist_type):
# Get distribution of strides, and sample based on that
dist = get_stride_distribution(strides, dist_type=dist_type)
dist = dist / np.max(dist)
max_num_clips = self.stride_counts[strides[np.argmax(dist)]]
num_clips_each_stride = [min(self.stride_counts[stride], int(dist[i]*max_num_clips)) for i, stride in enumerate(strides)]
print('resampled_num_clips_each_stride:', num_clips_each_stride)
resampled_idxs = []
for i, stride in enumerate(strides):
resampled_idxs += np.random.choice(self.stride_idxs[stride], num_clips_each_stride[i], replace=False).tolist()
self.rgb_paths = [self.rgb_paths[i] for i in resampled_idxs]
self.depth_paths = [self.depth_paths[i] for i in resampled_idxs]
self.annotation_paths = [self.annotation_paths[i] for i in resampled_idxs]
self.dynamic_mask_paths = [self.dynamic_mask_paths[i] for i in resampled_idxs]
self.full_idxs = [self.full_idxs[i] for i in resampled_idxs]
self.sample_stride = [self.sample_stride[i] for i in resampled_idxs]
def __len__(self):
return len(self.rgb_paths)
def _get_views(self, index, resolution, rng):
rgb_paths = self.rgb_paths[index]
depth_paths = self.depth_paths[index]
full_idx = self.full_idxs[index]
annotations_paths = self.annotation_paths[index]
dynamic_mask_paths = self.dynamic_mask_paths[index]
views = []
for i in range(2):
impath = rgb_paths[i]
depthpath = depth_paths[i]
dynamic_mask_path = dynamic_mask_paths[i]
# load camera params
intrinsics, extrinsics = cam_read(annotations_paths[i])
intrinsics, extrinsics = np.array(intrinsics, dtype=np.float32), np.array(extrinsics, dtype=np.float32)
R = extrinsics[:3,:3]
t = extrinsics[:3,3]
camera_pose = np.eye(4, dtype=np.float32)
camera_pose[:3,:3] = R.T
camera_pose[:3,3] = -R.T @ t
# load image and depth
rgb_image = imread_cv2(impath)
depthmap = depth_read(depthpath)
# load dynamic mask
if dynamic_mask_path is not None and os.path.exists(dynamic_mask_path):
dynamic_mask = PIL.Image.open(dynamic_mask_path).convert('L')
dynamic_mask = ToTensor(dynamic_mask).sum(0).numpy()
_, dynamic_mask, _ = self._crop_resize_if_necessary(
rgb_image, dynamic_mask, intrinsics, resolution, rng=rng, info=impath)
dynamic_mask = dynamic_mask > 0.5
assert not np.all(dynamic_mask), f"Dynamic mask is all True for {impath}"
else:
dynamic_mask = np.ones((resolution[1],resolution[0]), dtype=bool)
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
if self.load_dynamic_mask:
views.append(dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset=self.dataset_label,
label=rgb_paths[i].split('/')[-2],
instance=osp.split(rgb_paths[i])[1],
dynamic_mask=dynamic_mask,
full_idx=full_idx,
))
else:
views.append(dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset=self.dataset_label,
label=rgb_paths[i].split('/')[-2],
instance=osp.split(rgb_paths[i])[1],
full_idx=full_idx,
))
return views
if __name__ == "__main__":
from dust3r.viz import SceneViz, auto_cam_size
from dust3r.utils.image import rgb
use_augs = False
S = 2
strides = [1]
clip_step = 1
quick = False # Set to True for quick testing
def visualize_scene(idx):
views = dataset[idx]
assert len(views) == 2
viz = SceneViz()
poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]
cam_size = max(auto_cam_size(poses), 0.25)
for view_idx in [0, 1]:
pts3d = views[view_idx]['pts3d']
valid_mask = views[view_idx]['valid_mask']
colors = rgb(views[view_idx]['img'])
viz.add_pointcloud(pts3d, colors, valid_mask)
viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
focal=views[view_idx]['camera_intrinsics'][0, 0],
color=(255, 0, 0),
image=colors,
cam_size=cam_size)
path = f"./tmp/sintel_scene_{idx}.glb"
return viz.save_glb(path)
dataset = SintelDUSt3R(
use_augs=use_augs,
S=S,
strides=strides,
clip_step=clip_step,
quick=quick,
verbose=False,
resolution=(512,224),
seed = 777,
clip_step_last_skip=0,
aug_crop=16)
idx = random.randint(0, len(dataset)-1)
visualize_scene(idx)