File size: 6,897 Bytes
5ed9923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import json
import logging
import os
import sys

import cv2
import numpy as np

# Add dust3r to the sys.path
sys.path.append('src/dust3r_src')
from data.data import crop_resize_if_necessary, DUST3RSplattingDataset, DUST3RSplattingTestDataset
from src.mast3r_src.dust3r.dust3r.utils.image import imread_cv2

logger = logging.getLogger(__name__)


class ScanNetPPData():

    def __init__(self, root, stage):

        self.root = root
        self.stage = stage
        self.png_depth_scale = 1000.0

        # Dictionaries to store the data for each scene
        self.color_paths = {}
        self.depth_paths = {}
        self.intrinsics = {}
        self.c2ws = {}

        # Fetch the sequences to use
        if stage == "train":
            sequence_file = os.path.join(self.root, "raw", "splits", "nvs_sem_train.txt")
            bad_scenes = ['303745abc7']
        elif stage == "val" or stage == "test":
            sequence_file = os.path.join(self.root, "raw", "splits", "nvs_sem_val.txt")
            bad_scenes = ['cc5237fd77']
        with open(sequence_file, "r") as f:
            self.sequences = f.read().splitlines()

        # Remove scenes that have frames with no valid depths
        logger.info(f"Removing scenes that have frames with no valid depths: {bad_scenes}")
        self.sequences = [s for s in self.sequences if s not in bad_scenes]

        P = np.array([
            [1, 0, 0, 0],
            [0, -1, 0, 0],
            [0, 0, -1, 0],
            [0, 0, 0, 1]]
        ).astype(np.float32)

        # Collect information for every sequence
        scenes_with_no_good_frames = []
        for sequence in self.sequences:

            input_raw_folder = os.path.join(self.root, 'raw', 'data', sequence)
            input_processed_folder = os.path.join(self.root, 'processed', sequence)

            # Load Train & Test Splits
            frame_file = os.path.join(input_raw_folder, "dslr", "train_test_lists.json")
            with open(frame_file, "r") as f:
                train_test_list = json.load(f)

            # Camera Metadata
            cams_metadata_path = f"{input_processed_folder}/dslr/nerfstudio/transforms_undistorted.json"
            with open(cams_metadata_path, "r") as f:
                cams_metadata = json.load(f)

            # Load the nerfstudio/transforms.json file to check whether each image is blurry
            nerfstudio_transforms_path = f"{input_raw_folder}/dslr/nerfstudio/transforms.json"
            with open(nerfstudio_transforms_path, "r") as f:
                nerfstudio_transforms = json.load(f)

            # Create a reverse mapping from image name to the frame information and nerfstudio transform
            # (as transforms_undistorted.json does not store the frames in the same order as train_test_lists.json)
            file_path_to_frame_metadata = {}
            file_path_to_nerfstudio_transform = {}
            for frame in cams_metadata["frames"]:
                file_path_to_frame_metadata[frame["file_path"]] = frame
            for frame in nerfstudio_transforms["frames"]:
                file_path_to_nerfstudio_transform[frame["file_path"]] = frame

            # Fetch the pose for every frame
            sequence_color_paths = []
            sequence_depth_paths = []
            sequence_c2ws = []
            for train_file_name in train_test_list["train"]:
                is_bad = file_path_to_nerfstudio_transform[train_file_name]["is_bad"]
                if is_bad:
                    continue
                sequence_color_paths.append(f"{input_processed_folder}/dslr/undistorted_images/{train_file_name}")
                sequence_depth_paths.append(f"{input_processed_folder}/dslr/undistorted_depths/{train_file_name.replace('.JPG', '.png')}")
                frame_metadata = file_path_to_frame_metadata[train_file_name]
                c2w = np.array(frame_metadata["transform_matrix"], dtype=np.float32)
                c2w = P @ c2w @ P.T
                sequence_c2ws.append(c2w)

            if len(sequence_color_paths) == 0:
                logger.info(f"No good frames for sequence: {sequence}")
                scenes_with_no_good_frames.append(sequence)
                continue

            # Get the intrinsics data for the frame
            K = np.eye(4, dtype=np.float32)
            K[0, 0] = cams_metadata["fl_x"]
            K[1, 1] = cams_metadata["fl_y"]
            K[0, 2] = cams_metadata["cx"]
            K[1, 2] = cams_metadata["cy"]

            self.color_paths[sequence] = sequence_color_paths
            self.depth_paths[sequence] = sequence_depth_paths
            self.c2ws[sequence] = sequence_c2ws
            self.intrinsics[sequence] = K

        # Remove scenes with no good frames
        self.sequences = [s for s in self.sequences if s not in scenes_with_no_good_frames]

    def get_view(self, sequence, view_idx, resolution):

        # RGB Image
        rgb_path = self.color_paths[sequence][view_idx]
        rgb_image = imread_cv2(rgb_path)

        # Depthmap
        depth_path = self.depth_paths[sequence][view_idx]
        depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED)
        depthmap = depthmap.astype(np.float32)
        depthmap = depthmap / self.png_depth_scale

        # C2W Pose
        c2w = self.c2ws[sequence][view_idx]

        # Camera Intrinsics
        intrinsics = self.intrinsics[sequence]

        # Resize
        rgb_image, depthmap, intrinsics = crop_resize_if_necessary(
            rgb_image, depthmap, intrinsics, resolution
        )

        view = {
            'original_img': rgb_image,
            'depthmap': depthmap,
            'camera_pose': c2w,
            'camera_intrinsics': intrinsics,
            'dataset': 'scannet++',
            'label': f"scannet++/{sequence}",
            'instance': f'{view_idx}',
            'is_metric_scale': True,
            'sky_mask': depthmap <= 0.0,
        }
        return view


def get_scannet_dataset(root, stage, resolution, num_epochs_per_epoch=1):

    data = ScanNetPPData(root, stage)

    coverage = {}
    for sequence in data.sequences:
        with open(f'./data/scannetpp/coverage/{sequence}.json', 'r') as f:
            sequence_coverage = json.load(f)
        coverage[sequence] = sequence_coverage[sequence]

    dataset = DUST3RSplattingDataset(
        data,
        coverage,
        resolution,
        num_epochs_per_epoch=num_epochs_per_epoch,
    )

    return dataset


def get_scannet_test_dataset(root, alpha, beta, resolution, use_every_n_sample=100):

    data = ScanNetPPData(root, 'val')

    samples_file = f'data/scannetpp/test_set_{alpha}_{beta}.json'
    print(f"Loading samples from: {samples_file}")
    with open(samples_file, 'r') as f:
        samples = json.load(f)
    samples = samples[::use_every_n_sample]

    dataset = DUST3RSplattingTestDataset(data, samples, resolution)

    return dataset