geyongtao commited on
Commit
ec2cf52
·
verified ·
1 Parent(s): 2c3438e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import spaces
4
+ import torch
5
+ from gradio_rerun import Rerun
6
+ import rerun as rr
7
+ import rerun.blueprint as rrb
8
+ from pathlib import Path
9
+ import uuid
10
+
11
+ from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
12
+ from mini_dust3r.model import AsymmetricCroCo3DStereo
13
+ from mini_dust3r.utils.misc import (
14
+ fill_default_args,
15
+ freeze_all_params,
16
+ is_symmetrized,
17
+ interleave,
18
+ transpose_to_landscape,
19
+ )
20
+
21
+ from .head import Cat_MLP_LocalFeatures_DPT_Pts3d
22
+
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "CPU"
24
+
25
+ # model = AsymmetricCroCo3DStereo.from_pretrained(
26
+ # "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
27
+ # ).to(DEVICE)
28
+
29
+
30
+
31
+ from .linear_head import LinearPts3d
32
+ from .dpt_head import create_dpt_head
33
+
34
+ def head_factory(head_type, output_mode, net, has_conf=False):
35
+ """" build a prediction head for the decoder
36
+ """
37
+ if head_type == 'linear' and output_mode == 'pts3d':
38
+ return LinearPts3d(net, has_conf)
39
+ elif head_type == 'dpt' and output_mode == 'pts3d':
40
+ return create_dpt_head(net, has_conf=has_conf)
41
+ if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
42
+ local_feat_dim = int(output_mode[10:])
43
+ assert net.dec_depth > 9
44
+ l2 = net.dec_depth
45
+ feature_dim = 256
46
+ last_dim = feature_dim // 2
47
+ out_nchan = 3
48
+ ed = net.enc_embed_dim
49
+ dd = net.dec_embed_dim
50
+ return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
51
+ num_channels=out_nchan + has_conf,
52
+ feature_dim=feature_dim,
53
+ last_dim=last_dim,
54
+ hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
55
+ dim_tokens=[ed, dd, dd, dd],
56
+ postprocess=postprocess,
57
+ depth_mode=net.depth_mode,
58
+ conf_mode=net.conf_mode,
59
+ head_type='regression')
60
+ else:
61
+ raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
62
+
63
+
64
+ class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
65
+ def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
66
+ self.desc_mode = desc_mode
67
+ self.two_confs = two_confs
68
+ self.desc_conf_mode = desc_conf_mode
69
+ super().__init__(**kwargs)
70
+
71
+ @classmethod
72
+ def from_pretrained(cls, pretrained_model_name_or_path, **kw):
73
+ if os.path.isfile(pretrained_model_name_or_path):
74
+ return load_model(pretrained_model_name_or_path, device='cpu')
75
+ else:
76
+ return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
77
+
78
+ def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
79
+ assert img_size[0] % patch_size == 0 and img_size[
80
+ 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
81
+ self.output_mode = output_mode
82
+ self.head_type = head_type
83
+ self.depth_mode = depth_mode
84
+ self.conf_mode = conf_mode
85
+ if self.desc_conf_mode is None:
86
+ self.desc_conf_mode = conf_mode
87
+ # allocate heads
88
+ self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
89
+ self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
90
+ # magic wrapper
91
+ self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
92
+ self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
93
+
94
+
95
+
96
+ model = AsymmetricMASt3R.from_pretrained(
97
+ "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to(DEVICE)
98
+
99
+
100
+ def create_blueprint(image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
101
+ # dont show 2d views if there are more than 4 images as to not clutter the view
102
+ if len(image_name_list) > 4:
103
+ blueprint = rrb.Blueprint(
104
+ rrb.Horizontal(
105
+ rrb.Spatial3DView(origin=f"{log_path}"),
106
+ ),
107
+ collapse_panels=True,
108
+ )
109
+ else:
110
+ blueprint = rrb.Blueprint(
111
+ rrb.Horizontal(
112
+ contents=[
113
+ rrb.Spatial3DView(origin=f"{log_path}"),
114
+ rrb.Vertical(
115
+ contents=[
116
+ rrb.Spatial2DView(
117
+ origin=f"{log_path}/camera_{i}/pinhole/",
118
+ contents=[
119
+ "+ $origin/**",
120
+ ],
121
+ )
122
+ for i in range(len(image_name_list))
123
+ ]
124
+ ),
125
+ ],
126
+ column_shares=[3, 1],
127
+ ),
128
+ collapse_panels=True,
129
+ )
130
+ return blueprint
131
+
132
+
133
+ @spaces.GPU
134
+ def predict(image_name_list: list[str] | str):
135
+ # check if is list or string and if not raise error
136
+ if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
137
+ raise gr.Error(
138
+ f"Input must be a list of strings or a string, got: {type(image_name_list)}"
139
+ )
140
+ uuid_str = str(uuid.uuid4())
141
+ filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
142
+ rr.init(f"{uuid_str}")
143
+ log_path = Path("world")
144
+
145
+ if isinstance(image_name_list, str):
146
+ image_name_list = [image_name_list]
147
+
148
+ optimized_results: OptimizedResult = inferece_dust3r(
149
+ image_dir_or_list=image_name_list,
150
+ model=model,
151
+ device=DEVICE,
152
+ batch_size=1,
153
+ )
154
+
155
+ blueprint: rrb.Blueprint = create_blueprint(image_name_list, log_path)
156
+ rr.send_blueprint(blueprint)
157
+
158
+ rr.set_time_sequence("sequence", 0)
159
+ log_optimized_result(optimized_results, log_path)
160
+ rr.save(filename.as_posix())
161
+ return filename.as_posix()
162
+
163
+
164
+ with gr.Blocks(
165
+ css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
166
+ title="Mini-DUSt3R Demo",
167
+ ) as demo:
168
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
169
+ gr.HTML('<h2 style="text-align: center;">Mini-DUSt3R Demo</h2>')
170
+ gr.HTML(
171
+ '<p style="text-align: center;">Unofficial DUSt3R demo using the mini-dust3r pip package</p>'
172
+ )
173
+ gr.HTML(
174
+ '<p style="text-align: center;">More info <a href="https://github.com/pablovela5620/mini-dust3r">here</a></p>'
175
+ )
176
+ with gr.Tab(label="Single Image"):
177
+ with gr.Column():
178
+ single_image = gr.Image(type="filepath", height=300)
179
+ run_btn_single = gr.Button("Run")
180
+ rerun_viewer_single = Rerun(height=900)
181
+ run_btn_single.click(
182
+ fn=predict, inputs=[single_image], outputs=[rerun_viewer_single]
183
+ )
184
+
185
+ example_single_dir = Path("examples/single_image")
186
+ example_single_files = sorted(example_single_dir.glob("*.png"))
187
+
188
+ examples_single = gr.Examples(
189
+ examples=example_single_files,
190
+ inputs=[single_image],
191
+ outputs=[rerun_viewer_single],
192
+ fn=predict,
193
+ cache_examples="lazy",
194
+ )
195
+ with gr.Tab(label="Multi Image"):
196
+ with gr.Column():
197
+ multi_files = gr.File(file_count="multiple")
198
+ run_btn_multi = gr.Button("Run")
199
+ rerun_viewer_multi = Rerun(height=900)
200
+ run_btn_multi.click(
201
+ fn=predict, inputs=[multi_files], outputs=[rerun_viewer_multi]
202
+ )
203
+
204
+
205
+ demo.launch()