|
import sys |
|
sys.path.append('.') |
|
import os |
|
import torch |
|
import numpy as np |
|
import os.path as osp |
|
import glob |
|
import PIL.Image |
|
import torchvision.transforms as tvf |
|
|
|
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset |
|
from dust3r.utils.image import imread_cv2, crop_img |
|
from dust3r.utils.misc import get_stride_distribution |
|
|
|
np.random.seed(125) |
|
torch.multiprocessing.set_sharing_strategy('file_system') |
|
TAG_FLOAT = 202021.25 |
|
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
|
ToTensor = tvf.ToTensor() |
|
|
|
def depth_read(filename): |
|
""" Read depth data from file, return as numpy array. """ |
|
f = open(filename,'rb') |
|
check = np.fromfile(f,dtype=np.float32,count=1)[0] |
|
assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) |
|
width = np.fromfile(f,dtype=np.int32,count=1)[0] |
|
height = np.fromfile(f,dtype=np.int32,count=1)[0] |
|
size = width*height |
|
assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) |
|
depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) |
|
return depth |
|
|
|
def cam_read(filename): |
|
""" Read camera data, return (M,N) tuple. |
|
|
|
M is the intrinsic matrix, N is the extrinsic matrix, so that |
|
|
|
x = M*N*X, |
|
with x being a point in homogeneous image pixel coordinates, X being a |
|
point in homogeneous world coordinates. |
|
""" |
|
f = open(filename,'rb') |
|
check = np.fromfile(f,dtype=np.float32,count=1)[0] |
|
assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) |
|
M = np.fromfile(f,dtype='float64',count=9).reshape((3,3)) |
|
N = np.fromfile(f,dtype='float64',count=12).reshape((3,4)) |
|
return M,N |
|
|
|
class SintelDUSt3R(BaseStereoViewDataset): |
|
def __init__(self, |
|
dataset_location='data/sintel/training', |
|
dset='clean', |
|
use_augs=False, |
|
S=2, |
|
strides=[7], |
|
clip_step=2, |
|
quick=False, |
|
verbose=False, |
|
dist_type=None, |
|
clip_step_last_skip = 0, |
|
load_dynamic_mask=True, |
|
*args, |
|
**kwargs |
|
): |
|
|
|
print('loading sintel dataset...') |
|
super().__init__(*args, **kwargs) |
|
self.dataset_label = 'sintel' |
|
self.split = dset |
|
self.S = S |
|
self.verbose = verbose |
|
self.load_dynamic_mask = load_dynamic_mask |
|
|
|
self.use_augs = use_augs |
|
self.dset = dset |
|
|
|
self.rgb_paths = [] |
|
self.depth_paths = [] |
|
self.traj_paths = [] |
|
self.annotation_paths = [] |
|
self.dynamic_mask_paths = [] |
|
self.full_idxs = [] |
|
self.sample_stride = [] |
|
self.strides = strides |
|
|
|
self.subdirs = [] |
|
self.sequences = [] |
|
self.subdirs.append(os.path.join(dataset_location, dset)) |
|
|
|
for subdir in self.subdirs: |
|
for seq in glob.glob(os.path.join(subdir, "*/")): |
|
self.sequences.append(seq) |
|
|
|
self.sequences = sorted(self.sequences) |
|
if self.verbose: |
|
print(self.sequences) |
|
print('found %d unique videos in %s (dset=%s)' % (len(self.sequences), dataset_location, dset)) |
|
|
|
|
|
print('loading trajectories...') |
|
|
|
if quick: |
|
self.sequences = self.sequences[1:2] |
|
|
|
for seq in self.sequences: |
|
if self.verbose: |
|
print('seq', seq) |
|
|
|
rgb_path = seq |
|
depth_path = seq.replace(dset,'depth') |
|
caminfo_path = seq.replace(dset,'camdata_left') |
|
dynamic_mask_path = seq.replace(dset,'dynamic_label_perfect') |
|
|
|
for stride in strides: |
|
for ii in range(1,len(os.listdir(rgb_path))-self.S*max(stride,clip_step_last_skip)+1, clip_step): |
|
full_idx = ii + np.arange(self.S)*stride |
|
self.rgb_paths.append([os.path.join(rgb_path, 'frame_%04d.png' % idx) for idx in full_idx]) |
|
self.depth_paths.append([os.path.join(depth_path, 'frame_%04d.dpt' % idx) for idx in full_idx]) |
|
self.annotation_paths.append([os.path.join(caminfo_path, 'frame_%04d.cam' % idx) for idx in full_idx]) |
|
self.dynamic_mask_paths.append([os.path.join(dynamic_mask_path, 'frame_%04d.png' % idx) for idx in full_idx]) |
|
self.full_idxs.append(full_idx) |
|
self.sample_stride.append(stride) |
|
if self.verbose: |
|
sys.stdout.write('.') |
|
sys.stdout.flush() |
|
|
|
self.stride_counts = {} |
|
self.stride_idxs = {} |
|
for stride in strides: |
|
self.stride_counts[stride] = 0 |
|
self.stride_idxs[stride] = [] |
|
for i, stride in enumerate(self.sample_stride): |
|
self.stride_counts[stride] += 1 |
|
self.stride_idxs[stride].append(i) |
|
print('stride counts:', self.stride_counts) |
|
|
|
if len(strides) > 1 and dist_type is not None: |
|
self._resample_clips(strides, dist_type) |
|
|
|
print('collected %d clips of length %d in %s (dset=%s)' % ( |
|
len(self.rgb_paths), self.S, dataset_location, dset)) |
|
|
|
def _resample_clips(self, strides, dist_type): |
|
|
|
|
|
dist = get_stride_distribution(strides, dist_type=dist_type) |
|
dist = dist / np.max(dist) |
|
max_num_clips = self.stride_counts[strides[np.argmax(dist)]] |
|
num_clips_each_stride = [min(self.stride_counts[stride], int(dist[i]*max_num_clips)) for i, stride in enumerate(strides)] |
|
print('resampled_num_clips_each_stride:', num_clips_each_stride) |
|
resampled_idxs = [] |
|
for i, stride in enumerate(strides): |
|
resampled_idxs += np.random.choice(self.stride_idxs[stride], num_clips_each_stride[i], replace=False).tolist() |
|
|
|
self.rgb_paths = [self.rgb_paths[i] for i in resampled_idxs] |
|
self.depth_paths = [self.depth_paths[i] for i in resampled_idxs] |
|
self.annotation_paths = [self.annotation_paths[i] for i in resampled_idxs] |
|
self.dynamic_mask_paths = [self.dynamic_mask_paths[i] for i in resampled_idxs] |
|
self.full_idxs = [self.full_idxs[i] for i in resampled_idxs] |
|
self.sample_stride = [self.sample_stride[i] for i in resampled_idxs] |
|
|
|
def __len__(self): |
|
return len(self.rgb_paths) |
|
|
|
def _get_views(self, index, resolution, rng): |
|
|
|
rgb_paths = self.rgb_paths[index] |
|
depth_paths = self.depth_paths[index] |
|
full_idx = self.full_idxs[index] |
|
annotations_paths = self.annotation_paths[index] |
|
dynamic_mask_paths = self.dynamic_mask_paths[index] |
|
|
|
views = [] |
|
for i in range(2): |
|
impath = rgb_paths[i] |
|
depthpath = depth_paths[i] |
|
dynamic_mask_path = dynamic_mask_paths[i] |
|
|
|
|
|
intrinsics, extrinsics = cam_read(annotations_paths[i]) |
|
intrinsics, extrinsics = np.array(intrinsics, dtype=np.float32), np.array(extrinsics, dtype=np.float32) |
|
R = extrinsics[:3,:3] |
|
t = extrinsics[:3,3] |
|
camera_pose = np.eye(4, dtype=np.float32) |
|
camera_pose[:3,:3] = R.T |
|
camera_pose[:3,3] = -R.T @ t |
|
|
|
|
|
rgb_image = imread_cv2(impath) |
|
depthmap = depth_read(depthpath) |
|
|
|
|
|
if dynamic_mask_path is not None and os.path.exists(dynamic_mask_path): |
|
dynamic_mask = PIL.Image.open(dynamic_mask_path).convert('L') |
|
dynamic_mask = ToTensor(dynamic_mask).sum(0).numpy() |
|
_, dynamic_mask, _ = self._crop_resize_if_necessary( |
|
rgb_image, dynamic_mask, intrinsics, resolution, rng=rng, info=impath) |
|
dynamic_mask = dynamic_mask > 0.5 |
|
assert not np.all(dynamic_mask), f"Dynamic mask is all True for {impath}" |
|
else: |
|
dynamic_mask = np.ones((resolution[1],resolution[0]), dtype=bool) |
|
|
|
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( |
|
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) |
|
|
|
if self.load_dynamic_mask: |
|
views.append(dict( |
|
img=rgb_image, |
|
depthmap=depthmap, |
|
camera_pose=camera_pose, |
|
camera_intrinsics=intrinsics, |
|
dataset=self.dataset_label, |
|
label=rgb_paths[i].split('/')[-2], |
|
instance=osp.split(rgb_paths[i])[1], |
|
dynamic_mask=dynamic_mask, |
|
full_idx=full_idx, |
|
)) |
|
else: |
|
views.append(dict( |
|
img=rgb_image, |
|
depthmap=depthmap, |
|
camera_pose=camera_pose, |
|
camera_intrinsics=intrinsics, |
|
dataset=self.dataset_label, |
|
label=rgb_paths[i].split('/')[-2], |
|
instance=osp.split(rgb_paths[i])[1], |
|
full_idx=full_idx, |
|
)) |
|
return views |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
from dust3r.viz import SceneViz, auto_cam_size |
|
from dust3r.utils.image import rgb |
|
|
|
use_augs = False |
|
S = 2 |
|
strides = [1] |
|
clip_step = 1 |
|
quick = False |
|
|
|
|
|
def visualize_scene(idx): |
|
views = dataset[idx] |
|
assert len(views) == 2 |
|
viz = SceneViz() |
|
poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] |
|
cam_size = max(auto_cam_size(poses), 0.25) |
|
for view_idx in [0, 1]: |
|
pts3d = views[view_idx]['pts3d'] |
|
valid_mask = views[view_idx]['valid_mask'] |
|
colors = rgb(views[view_idx]['img']) |
|
viz.add_pointcloud(pts3d, colors, valid_mask) |
|
viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], |
|
focal=views[view_idx]['camera_intrinsics'][0, 0], |
|
color=(255, 0, 0), |
|
image=colors, |
|
cam_size=cam_size) |
|
path = f"./tmp/sintel_scene_{idx}.glb" |
|
return viz.save_glb(path) |
|
|
|
dataset = SintelDUSt3R( |
|
use_augs=use_augs, |
|
S=S, |
|
strides=strides, |
|
clip_step=clip_step, |
|
quick=quick, |
|
verbose=False, |
|
resolution=(512,224), |
|
seed = 777, |
|
clip_step_last_skip=0, |
|
aug_crop=16) |
|
|
|
idx = random.randint(0, len(dataset)-1) |
|
visualize_scene(idx) |