File size: 10,925 Bytes
aab1f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import sys
import gradio as gr
import os
import shutil
import json
import argparse
from PIL import Image
import subprocess
from sparseags.dust3r_utils import infer_dust3r
from run import main
import functools

sys.path[0] = sys.path[0] + '/dust3r'
from dust3r.model import AsymmetricCroCo3DStereo


def info_fn():
    gr.Info("Data preprocessing done!")


def get_select_index(evt: gr.SelectData):
	index = evt.index
	cate_list = ['toy', 'butter', 'robot', 'jordan', 'eagle']
	args.num_views = len(examples_full[index][0])
	args.category = cate_list[index]

	return examples_full[index][0], examples_full[index][0]


# check if there is a picture uploaded or selected
def check_img_input(control_image):
	if control_image is None:
		raise gr.Error("Please select or upload an input image")


def preprocess(args, dust3r_model, image_block: list):
	if os.path.exists('data/demo/custom'):
		shutil.rmtree('data/demo/custom')

	if os.path.exists('output/demo/custom'):
		shutil.rmtree('output/demo/custom')

	os.makedirs('data/demo/custom/source')
	os.makedirs('data/demo/custom/processed')

	file_names = []

	for file_path in image_block:
		file_name = file_path.split("/")[-1]
		img_pil = Image.open(file_path)

		# save image to a designated path
		try:
			img_pil.save(os.path.join('data/demo/custom', file_name))
		except OSError:
			img_pil = img_pil.convert('RGB')
			img_pil.save(os.path.join('data/demo/custom', file_name))

		file_names.append(os.path.join('data/demo/custom/source', file_name.split('.')[0] + '.png'))

		# crop and resize image
		print(f"python process.py {os.path.join('data/demo/custom', file_name)}")
		subprocess.run(f"python process.py {os.path.join('data/demo/custom', file_name)}", shell=True)

	# predict initial camera poses from dust3r
	camera_data = infer_dust3r(dust3r_model, file_names)
	with open(os.path.join('data/demo/custom', 'cameras.json'), "w") as f:
		json.dump(camera_data, f)

	args.num_views = len(file_names)
	args.category = "custom"

	processed_image_block = []
	for file_path in image_block:
		out_base = os.path.basename(file_path).split('.')[0]
		out_rgba = os.path.join('data/demo/custom/processed', out_base + '_rgba.png')
		processed_image_block.append(out_rgba)

	return processed_image_block


def run_single_reconstruction(image_block: list):
	args.enable_loop = False
	main(args)

	return f'output/demo/{args.category}/round_0/{args.category}.glb'


def run_full_reconstruction(image_block: list):
	args.enable_loop = True
	main(args)

	if os.path.exists(f'output/demo/{args.category}/cameras_final_recovered.json'):
		return f'output/demo/{args.category}/check_recovered_poses/{args.category}.glb'
	elif os.path.exists(f'output/demo/{args.category}/cameras_final_init.json'):
		return f'output/demo/{args.category}/reconsider_init_poses/{args.category}.glb'
	else:
		return f'output/demo/{args.category}/round_1/{args.category}.glb'


if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument('--output', default='output/demo', type=str, help='Directory where obj files will be saved')
	parser.add_argument('--category', default='jordan', type=str, help='Directory where obj files will be saved')
	parser.add_argument('--num_pts', default=25000, type=int, help='Number of points at initialization')
	parser.add_argument('--num_views', default=8, type=int, help='Number of input images')
	parser.add_argument('--mesh_format', default='glb', type=str, help='Format of output mesh')
	parser.add_argument('--enable_loop', default=True, help='Enable the loop-based strategy to detect and correct outliers')
	parser.add_argument('--config', default='navi.yaml', type=str, help='Path to config file')
	args = parser.parse_args()

	_TITLE = '''Sparse-view Pose Estimation and Reconstruction via Analysis by Generative Synthesis'''

	_DESCRIPTION = '''
	<div>
	<a style="display:inline-block" href="https://qitaozhao.github.io/SparseAGS"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
	<a style="display:inline-block; margin-left: .5em" href="https://openreview.net/pdf?id=wgpmDyJgsg"><img src="https://img.shields.io/badge/2309.16653-f9f7f7?logo="></a>
	<a style="display:inline-block; margin-left: .5em" href='https://github.com/dreamgaussian/dreamgaussian'><img src='https://img.shields.io/github/stars/dreamgaussian/dreamgaussian?style=social'/></a>
	</div>
	Given a set of unposed input images, SparseAGS jointly infers the corresponding camera poses and underlying 3D, allowing high-fidelity 3D inference in the wild. 
	'''
	_IMG_USER_GUIDE = "Once you see the preprocessed images, you can click **Run Single 3D Reconstruction**. \
					   If the reconstructed 3D looks bad, you can try to click **Outlier Removal & Correction**	to run the full method to deal with outliers camera poses."

	# load images in 'data/demo' folder as examples
	examples_full = []

	for example in ['toy', 'butter', 'robot', 'jordan', 'eagle']:
		example_folder = os.path.join(os.path.dirname(__file__), 'data/demo', example, 'processed')
		example_fns = os.listdir(example_folder)
		example_fns.sort()
		examples = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]
		examples_full.append([examples])

	dust3r_model = AsymmetricCroCo3DStereo.from_pretrained('naver/DUSt3R_ViTLarge_BaseDecoder_224_linear').to('cuda')
	print("Loaded DUSt3R model!")

	preprocess = functools.partial(preprocess, args, dust3r_model)
	# get_select_index = functools.partial(get_select_index, args)

	# Compose demo layout & data flow
	with gr.Blocks(title=_TITLE, theme=gr.themes.Soft()) as demo:
		with gr.Row():
			with gr.Column(scale=1):
				gr.Markdown('# ' + _TITLE)
		gr.Markdown(_DESCRIPTION)

		# Image-to-3D
		with gr.Row(variant='panel'):
			with gr.Column(scale=5):
				# image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')#, tool=None)
				image_block = gr.File(file_count="multiple")

				preprocess_btn = gr.Button("Preprocess Images")

				# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
				gr.Markdown(
					"You have two options to run our model! (1) Upload your own images in the block above and then click **Preprocess Images** to initialize camera poses using \
					DUSt3R; (2) Choose one of the preprocessed examples below (no need to click **Preprocess Images**).")

				gallery = gr.Gallery(
					value=[example[0][0] for example in examples_full], label="Examples", show_label=True, elem_id="gallery"
				, columns=[5], rows=[1], object_fit="contain", height="256", preview=None, allow_preview=None)

				preprocessed_data = gr.Gallery(
 					label="Preprocessed images", show_label=True, elem_id="gallery"
				, columns=[4], rows=[2], object_fit="contain", height="256", preview=None, allow_preview=None)

				with gr.Row(variant='panel'):
					run_single_btn = gr.Button("Run Single 3D Reconstruction")
					outlier_detect_btn = gr.Button("Outlier Removal & Correction")
				img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)

			with gr.Column(scale=5):
				obj_single_recon = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model (Single Reconstruction)")
				obj_outlier_detect = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model (Full Method, w/ Outlier Removal & Correction)")

			# Select an preprocessed example
			gallery.select(get_select_index, None, outputs=[image_block, preprocessed_data])

			# Upload you own images and run preprocessing
			preprocess_btn.click(preprocess, inputs=[image_block], outputs=[preprocessed_data], queue=False, show_progress='full').success(info_fn, None, None)

			# Do single 3D reconstruction
			run_single_btn.click(check_img_input, inputs=[image_block], queue=False).success(run_single_reconstruction,
																						  inputs=[image_block],
																								  # preprocess_chk],
																								  # elevation_slider],
																						  outputs=[obj_single_recon])

			# Do loop-based outlier removal & correction                                                                                  
			outlier_detect_btn.click(check_img_input, inputs=[image_block], queue=False).success(run_full_reconstruction,
																						  inputs=[image_block],
																								  # preprocess_chk],
																								  # elevation_slider],
																						  outputs=[obj_outlier_detect])

	demo.queue().launch(share=True)