File size: 11,157 Bytes
26ce2a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import os
import json
import copy
import shutil
import argparse
import numpy as np

import sys 
import torch

from sparseags.mesh_utils.mesh_renderer import Renderer
from sparseags.render_utils.util import render_and_compare, align_to_mesh
from sparseags.visual_utils import vis_output


def seed_everything(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def do_reconstruction(cur_dir, out_dir, args, force_all_views=False):
    # first stage
    if not os.path.exists(os.path.join(cur_dir, f'{args.category}_mesh.{args.mesh_format}')):
        os.system(f'python sparseags/main_stage1.py '
                  f'--config configs/{args.config} '
                  f'camera_path={cur_dir}/cameras.json ' 
                  f'outdir=./ '
                  f'save_path={cur_dir}/{args.category} '
                  f'opt_cam=1 sh_degree=3 '
                  f'num_pts={args.num_pts} '
                  f'all_views={force_all_views} ' 
                  f'mesh_format={args.mesh_format}')

    # second stage
    if not os.path.exists(os.path.join(cur_dir, f'{args.category}.{args.mesh_format}')):
        os.system(f'python sparseags/main_stage2.py '
                  f'--config configs/{args.config} '
                  f'camera_path={cur_dir}/cameras_updated.json ' 
                  f'outdir=./ '
                  f'save_path={cur_dir}/{args.category} '
                  f'all_views={force_all_views} '
                  f'mesh_format={args.mesh_format}')

    # export video
    mesh_path = os.path.join(cur_dir, f'{args.category}.{args.mesh_format}')
    if not os.path.exists(os.path.join(cur_dir, f'{args.category}.mp4')):
        os.system(f'python -m kiui.render {mesh_path} '
                  f'--save_video {cur_dir}/{args.category}.mp4 '
                  f'--wogui '
                  f'--elevation -30')


def read_updated_cameras(cur_dir): 
    camera_path = os.path.join(cur_dir, 'cameras_updated.json')
    with open(camera_path, 'r') as f:
        camera_data = json.load(f)

    return camera_data


def save_updated_cameras(camera_data, cur_dir): 
    camera_path = os.path.join(cur_dir, 'cameras.json')
    with open(camera_path, 'w') as f:
        json.dump(camera_data, f, indent=4)


def save_metrics(metrics_track, out_dir): 
    metrics_path = os.path.join(out_dir, 'metrics.json')
    with open(metrics_path, 'w') as f:
        json.dump(metrics_track, f, indent=4)


def main(args):
    CATEGORY = args.category
    NUM_VIEWS = args.num_views
    src_dir = os.path.join('data/demo', CATEGORY)
    out_dir = os.path.join(args.output, CATEGORY)
    os.makedirs(out_dir, exist_ok=True)
    source_camera_path = os.path.join(src_dir, "cameras.json")

    print(f'======== processing {src_dir} ========')
    if not os.path.exists(source_camera_path):
        print(f'{source_camera_path} is missing!')
        sys.exit()

    stop = 0  # the flag for stopping the loop
    stop_with_lower_quality = 0  # the flag for stopping as the reconstruction quality drops
    stop_with_full_iters = 0  # the flag for stopping as we reach the maximum loop number

    cnt = 0
    # pretty empirical choice of maximum iteration number
    if NUM_VIEWS <= 5:
        MAX_CNT = 1
    elif NUM_VIEWS < 8:
        MAX_CNT = NUM_VIEWS - 4
    else:
        MAX_CNT = 4

    if not args.enable_loop:
        MAX_CNT = 0

    THRESHOLD_LPIPS = 0.05
    metrics_track = {}
    cur_dir = os.path.join(out_dir, f'round_{cnt}')
    os.makedirs(cur_dir, exist_ok=True)
    shutil.copy2(source_camera_path, os.path.join(cur_dir, 'cameras.json'))

    while not stop:
        metrics_track[cnt] = []

        do_reconstruction(cur_dir, out_dir, args)
        camera_data = read_updated_cameras(cur_dir)

        lpips_losses, mse_losses = vis_output(
            camera_data, 
            mesh_path=os.path.join(cur_dir, f'{CATEGORY}.{args.mesh_format}'), 
            save_path=os.path.join(cur_dir, 'vis.png'), 
            num_views=NUM_VIEWS
        )
        flags = np.array([int(v["flag"]) for k, v in camera_data.items()])

        mean_lpips = np.sum(lpips_losses * flags) / flags.sum()
        mean_mse = np.sum(mse_losses * flags) / flags.sum()
        metrics_track[cnt].append(mean_lpips)

        # 0: go to the next iter
        # 1-MAX_CNT: stop if no improvement compared to the last iter
        if cnt != 0 and args.enable_loop:
            last_lpips = mean_lpips_wo_max
            diff_lpips = abs(last_lpips - mean_lpips)
            if mean_lpips > last_lpips or diff_lpips < THRESHOLD_LPIPS: 
                stop_with_lower_quality = 1

            if cnt >= MAX_CNT:
                stop_with_full_iters = 1

        if stop_with_full_iters or stop_with_lower_quality:
            stop = 1
            cnt_to_stop = cnt - 1 if stop_with_lower_quality else cnt
            camera_path_to_be_copied = os.path.join(out_dir, f'round_{cnt_to_stop}', 'cameras_updated.json')
            shutil.copy2(camera_path_to_be_copied, os.path.join(out_dir, 'cameras_outlier_removal.json'))
            save_metrics(metrics_track, out_dir)

        elif not args.enable_loop:
            stop = 1
            cnt_to_stop = 0
            save_metrics(metrics_track, out_dir)

        # should not stop, go to the next round
        else:
            max_lpips_value = -float('inf')
            max_index = -1

            for i in range(NUM_VIEWS):
                if flags[i] == 1 and lpips_losses[i] > max_lpips_value:
                    max_lpips_value = lpips_losses[i]
                    max_index = i

            flags[max_index] = 0
            mean_lpips_wo_max = np.sum(lpips_losses * flags) / flags.sum()
            metrics_track[cnt].append(mean_lpips_wo_max)

            assert camera_data[list(camera_data.keys())[max_index]]["flag"] == 1
            camera_data[list(camera_data.keys())[max_index]]["flag"] = 0

            cnt += 1
            cur_dir = os.path.join(out_dir, f'round_{cnt}')
            os.makedirs(cur_dir, exist_ok=True)

            # copy-paste the camera poses for next iter
            save_updated_cameras(camera_data, cur_dir)

    if cnt_to_stop == 0:
        pass

    else:
        """If we identified outliers, do render-and-compare to correct them"""
        camera_path_outlier_removal = os.path.join(out_dir, 'cameras_outlier_removal.json')
        assert os.path.exists(camera_path_outlier_removal)
        with open(camera_path_outlier_removal, 'r') as f:
            camera_data_outlier_removal = json.load(f)

        camera_path_render_and_compare = os.path.join(out_dir, 'cameras_render_and_compare.json')
        if not os.path.exists(camera_path_render_and_compare):
            mesh_path = os.path.join(out_dir, f'round_{cnt_to_stop}', f'{CATEGORY}.{args.mesh_format}')
            camera_data_render_and_compare = render_and_compare(copy.deepcopy(camera_data_outlier_removal), mesh_path, out_dir, num_views=NUM_VIEWS)

            with open(camera_path_render_and_compare, 'w') as f:
                json.dump(camera_data_render_and_compare, f, indent=4)

        # (1) check the recovered cameras from render-and-compare: we do reconstruction to align everything together
        cur_dir = os.path.join(out_dir, f'check_recovered_poses')
        os.makedirs(cur_dir, exist_ok=True)
        shutil.copy2(camera_path_render_and_compare, os.path.join(cur_dir, 'cameras.json'))
        do_reconstruction(cur_dir, out_dir, args, force_all_views=True)

        camera_data = read_updated_cameras(cur_dir)

        lpips_losses, mse_losses = vis_output(
            camera_data, 
            mesh_path=os.path.join(cur_dir, f'{CATEGORY}.{args.mesh_format}'), 
            save_path=os.path.join(cur_dir, 'vis.png'), 
            num_views=args.num_views
        )

        # (2) re-consider initial cameras: we fix the updated inliers while aligning the outliers to the 3D from inliers
        cur_dir = os.path.join(out_dir, f'reconsider_init_poses')
        os.makedirs(cur_dir, exist_ok=True)
        mesh_path = os.path.join(out_dir, f'round_{cnt_to_stop}', f'{CATEGORY}.{args.mesh_format}')
        camera_data_aligned = align_to_mesh(camera_data_outlier_removal, mesh_path, cur_dir, num_views=NUM_VIEWS)
        save_updated_cameras(camera_data_aligned, cur_dir)

        # we do reconstruction to align everything together
        do_reconstruction(cur_dir, out_dir, args, force_all_views=True)
        camera_data_init = read_updated_cameras(cur_dir)

        lpips_losses_init, mse_losses_init = vis_output(
            camera_data_init, 
            mesh_path=os.path.join(cur_dir, f'{CATEGORY}.{args.mesh_format}'), 
            save_path=os.path.join(cur_dir, 'vis.png'), 
            num_views=NUM_VIEWS
        )

        flags_sum = np.array([int(v["flag"]) for k, v in camera_data_outlier_removal.items()]).sum()
        cnt_valid_cameras = 0
        keep_init_poses = False
        if lpips_losses.mean() > lpips_losses_init.mean():
            keep_init_poses = True  # Keep optimized initial poses

        else:
            for idx, (k, v) in enumerate(camera_data_outlier_removal.items()):
                if int(v["flag"]) == 1:
                    continue

                if lpips_losses[idx] < lpips_losses_init[idx] and mse_losses[idx] < mse_losses_init[idx]:
                    cnt_valid_cameras += 1
                    # camera_data_init_updated[k] = camera_data[k] # replace the initial poses with recovered poses

            if cnt_valid_cameras + flags_sum == NUM_VIEWS:
                keep_init_poses = False
            else:
                keep_init_poses = True

        output_path = os.path.join(out_dir, 'cameras_final.json') 
        if keep_init_poses:
            print("Keep the (optimized) initial camera poses.")
            with open(output_path.replace(".json", "_init.json"), 'w') as f:
                json.dump(camera_data_init, f, indent=4)
        else:
            print("Replace the initial cameras with the recovered ones!")
            with open(output_path.replace(".json", "_recovered.json"), 'w') as f:
                json.dump(camera_data, f, indent=4)


if __name__ == "__main__":

    seed_everything(0)
    parser = argparse.ArgumentParser()
    parser.add_argument('--output', default='output/demo', type=str, help='Directory where obj files will be saved')
    parser.add_argument('--category', default='jordan', type=str, help='Directory where obj files will be saved')
    parser.add_argument('--num_pts', default=25000, type=int, help='Number of points at initialization')
    parser.add_argument('--num_views', default=8, type=int, help='Number of input images')
    parser.add_argument('--mesh_format', default='obj', type=str, help='Format of output mesh')
    parser.add_argument('--enable_loop', action='store_true', help='Enable the loop-based strategy to detect and correct outliers')
    parser.add_argument('--config', default='navi.yaml', type=str, help='Path to config file')
    args = parser.parse_args()

    main(args)