Spaces:
Sleeping
Sleeping
add examples models
Browse files- app.py +51 -26
- demo.py +9 -2
- examples/002446/k1.color.jpg +0 -0
- examples/002446/k1.color.json +79 -0
- examples/002446/k1.obj_rend_mask.png +0 -0
- examples/002446/k1.person_mask.png +0 -0
- examples/053431/k1.color.jpg +0 -0
- examples/053431/k1.obj_rend_mask.png +0 -0
- examples/053431/k1.person_mask.png +0 -0
- examples/158107/k1.color.jpg +0 -0
- examples/158107/k1.obj_rend_mask.png +0 -0
- examples/158107/k1.person_mask.png +0 -0
app.py
CHANGED
@@ -22,6 +22,7 @@ import imageio
|
|
22 |
import gradio as gr
|
23 |
import plotly.graph_objs as go
|
24 |
import training_utils
|
|
|
25 |
|
26 |
from configs.structured import ProjectConfig
|
27 |
from demo import DemoRunner
|
@@ -91,7 +92,7 @@ def plot_points(colors, coords):
|
|
91 |
return fig
|
92 |
|
93 |
|
94 |
-
def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed):
|
95 |
"""
|
96 |
given user input, run inference
|
97 |
:param runner:
|
@@ -101,26 +102,38 @@ def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, s
|
|
101 |
:param mask_obj: (h, w, 3), np array
|
102 |
:param std_coverage: float value, used to estimate camera translation
|
103 |
:param input_seed: random seed
|
|
|
104 |
:return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
|
105 |
"""
|
106 |
-
|
107 |
-
|
|
|
|
|
108 |
|
109 |
-
|
110 |
std_coverage)
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
|
126 |
@hydra.main(config_path='configs', config_name='configs', version_base='1.1')
|
@@ -129,6 +142,8 @@ def main(cfg: ProjectConfig):
|
|
129 |
runner = DemoRunner(cfg)
|
130 |
|
131 |
# runner = None # without model initialization, it shows one line of thumbnail
|
|
|
|
|
132 |
|
133 |
# Setup interface
|
134 |
demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
|
@@ -147,33 +162,43 @@ def main(cfg: ProjectConfig):
|
|
147 |
# TODO: add hint for this value here
|
148 |
input_std = gr.Number(label='Gaussian std coverage', value=3.5)
|
149 |
input_seed = gr.Number(label='Random seed', value=42)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
# Output visualization
|
151 |
with gr.Row():
|
152 |
pc_plot = gr.Plot(label="Reconstructed point cloud")
|
153 |
out_pc_download = gr.File(label="3D reconstruction for download") # this allows downloading
|
|
|
|
|
|
|
154 |
|
155 |
gr.HTML("""<br/>""")
|
156 |
# Control
|
157 |
with gr.Row():
|
158 |
button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
|
159 |
button_recon.click(fn=partial(inference, runner, cfg),
|
160 |
-
inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],
|
161 |
-
outputs=[pc_plot, out_pc_download])
|
162 |
gr.HTML("""<br/>""")
|
163 |
# Example input
|
164 |
example_dir = cfg.run.code_dir_abs+"/examples"
|
165 |
rgb, ps, obj = 'k1.color.jpg', 'k1.person_mask.png', 'k1.obj_rend_mask.png'
|
166 |
example_images = gr.Examples([
|
167 |
-
[f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42],
|
168 |
-
[f"{example_dir}/002446/{rgb}", f"{example_dir}/002446/{ps}", f"{example_dir}/002446/{obj}", 3.0, 42],
|
169 |
-
[f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42],
|
170 |
-
[f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42],
|
171 |
|
172 |
-
], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],)
|
173 |
|
174 |
# demo.launch(share=True)
|
175 |
# Enabling queue for runtime>60s, see: https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
|
176 |
-
demo.queue(
|
177 |
|
178 |
if __name__ == '__main__':
|
179 |
-
main()
|
|
|
22 |
import gradio as gr
|
23 |
import plotly.graph_objs as go
|
24 |
import training_utils
|
25 |
+
import traceback
|
26 |
|
27 |
from configs.structured import ProjectConfig
|
28 |
from demo import DemoRunner
|
|
|
92 |
return fig
|
93 |
|
94 |
|
95 |
+
def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed, input_cls):
|
96 |
"""
|
97 |
given user input, run inference
|
98 |
:param runner:
|
|
|
102 |
:param mask_obj: (h, w, 3), np array
|
103 |
:param std_coverage: float value, used to estimate camera translation
|
104 |
:param input_seed: random seed
|
105 |
+
:param input_cls: the object category of the input image
|
106 |
:return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
|
107 |
"""
|
108 |
+
log = ""
|
109 |
+
try:
|
110 |
+
# Set random seed
|
111 |
+
training_utils.set_seed(int(input_seed))
|
112 |
|
113 |
+
data = DemoDataset([], (cfg.dataset.image_size, cfg.dataset.image_size),
|
114 |
std_coverage)
|
115 |
+
batch = data.image2batch(rgb, mask_hum, mask_obj)
|
116 |
+
|
117 |
+
if input_cls != 'general':
|
118 |
+
log += f"Reloading fine-tuned checkpoint of category {input_cls}\n"
|
119 |
+
runner.reload_checkpoint(input_cls)
|
120 |
+
|
121 |
+
out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
|
122 |
+
points = out_stage2.points_packed().cpu().numpy()
|
123 |
+
colors = out_stage2.features_packed().cpu().numpy()
|
124 |
+
fig = plot_points(colors, points)
|
125 |
+
# save tmp point cloud
|
126 |
+
outdir = './results'
|
127 |
+
os.makedirs(outdir, exist_ok=True)
|
128 |
+
trimesh.PointCloud(points, colors).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2_{input_cls}.ply")
|
129 |
+
trimesh.PointCloud(out_stage1.points_packed().cpu().numpy(),
|
130 |
+
out_stage1.features_packed().cpu().numpy()).export(
|
131 |
+
outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage1_{input_cls}.ply")
|
132 |
+
log += 'Successfully reconstructed the image.'
|
133 |
+
except Exception as e:
|
134 |
+
log = traceback.format_exc()
|
135 |
+
|
136 |
+
return fig, outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2_{input_cls}.ply", log
|
137 |
|
138 |
|
139 |
@hydra.main(config_path='configs', config_name='configs', version_base='1.1')
|
|
|
142 |
runner = DemoRunner(cfg)
|
143 |
|
144 |
# runner = None # without model initialization, it shows one line of thumbnail
|
145 |
+
# TODO: add instructions on how to get masks
|
146 |
+
# TODO: add instructions on how to use the demo, input output, example outputs etc.
|
147 |
|
148 |
# Setup interface
|
149 |
demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
|
|
|
162 |
# TODO: add hint for this value here
|
163 |
input_std = gr.Number(label='Gaussian std coverage', value=3.5)
|
164 |
input_seed = gr.Number(label='Random seed', value=42)
|
165 |
+
# TODO: add description outside label
|
166 |
+
input_cls = gr.Dropdown(label='Object category (we have fine tuned the model for specific categories, '
|
167 |
+
'reconstructing with these model should lead to better result '
|
168 |
+
'for specific categories.) ',
|
169 |
+
choices=['general', 'backpack', 'ball', 'bottle', 'box',
|
170 |
+
'chair', 'skateboard', 'suitcase', 'table'],
|
171 |
+
value='general')
|
172 |
# Output visualization
|
173 |
with gr.Row():
|
174 |
pc_plot = gr.Plot(label="Reconstructed point cloud")
|
175 |
out_pc_download = gr.File(label="3D reconstruction for download") # this allows downloading
|
176 |
+
with gr.Row():
|
177 |
+
out_log = gr.TextArea(label='Output log')
|
178 |
+
|
179 |
|
180 |
gr.HTML("""<br/>""")
|
181 |
# Control
|
182 |
with gr.Row():
|
183 |
button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
|
184 |
button_recon.click(fn=partial(inference, runner, cfg),
|
185 |
+
inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls],
|
186 |
+
outputs=[pc_plot, out_pc_download, out_log])
|
187 |
gr.HTML("""<br/>""")
|
188 |
# Example input
|
189 |
example_dir = cfg.run.code_dir_abs+"/examples"
|
190 |
rgb, ps, obj = 'k1.color.jpg', 'k1.person_mask.png', 'k1.obj_rend_mask.png'
|
191 |
example_images = gr.Examples([
|
192 |
+
[f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42, 'skateboard'],
|
193 |
+
[f"{example_dir}/002446/{rgb}", f"{example_dir}/002446/{ps}", f"{example_dir}/002446/{obj}", 3.0, 42, 'ball'],
|
194 |
+
[f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42, 'chair'],
|
195 |
+
[f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42, 'chair'],
|
196 |
|
197 |
+
], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls],)
|
198 |
|
199 |
# demo.launch(share=True)
|
200 |
# Enabling queue for runtime>60s, see: https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
|
201 |
+
demo.queue().launch(share=True)
|
202 |
|
203 |
if __name__ == '__main__':
|
204 |
+
main()
|
demo.py
CHANGED
@@ -65,8 +65,8 @@ class DemoRunner:
|
|
65 |
self.rend_size = cfg.dataset.image_size
|
66 |
self.device = 'cuda'
|
67 |
|
68 |
-
def load_checkpoint(self, ckpt_file1, model_stage1):
|
69 |
-
checkpoint = torch.load(ckpt_file1, map_location=
|
70 |
state_dict, key = checkpoint['model'], 'model'
|
71 |
if any(k.startswith('module.') for k in state_dict.keys()):
|
72 |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
@@ -78,6 +78,13 @@ class DemoRunner:
|
|
78 |
if len(unexpected_keys):
|
79 |
print(f' - Unexpected_keys: {unexpected_keys}')
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
@torch.no_grad()
|
82 |
def run(self):
|
83 |
"simply run the demo on given images, and save the results"
|
|
|
65 |
self.rend_size = cfg.dataset.image_size
|
66 |
self.device = 'cuda'
|
67 |
|
68 |
+
def load_checkpoint(self, ckpt_file1, model_stage1, device='cpu'):
|
69 |
+
checkpoint = torch.load(ckpt_file1, map_location=device)
|
70 |
state_dict, key = checkpoint['model'], 'model'
|
71 |
if any(k.startswith('module.') for k in state_dict.keys()):
|
72 |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
|
78 |
if len(unexpected_keys):
|
79 |
print(f' - Unexpected_keys: {unexpected_keys}')
|
80 |
|
81 |
+
def reload_checkpoint(self, cat_name):
|
82 |
+
"load checkpoint of models fine tuned on specific categories"
|
83 |
+
ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{self.cfg.run.stage1_name}-{cat_name}.pth')
|
84 |
+
self.load_checkpoint(ckpt_file1, self.model_stage1, device=self.device)
|
85 |
+
ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{self.cfg.run.stage2_name}-{cat_name}.pth')
|
86 |
+
self.load_checkpoint(ckpt_file2, self.model_stage2, device=self.device)
|
87 |
+
|
88 |
@torch.no_grad()
|
89 |
def run(self):
|
90 |
"simply run the demo on given images, and save the results"
|
examples/002446/k1.color.jpg
ADDED
examples/002446/k1.color.json
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"body_joints": [
|
3 |
+
362.91015625,
|
4 |
+
159.39576721191406,
|
5 |
+
0.9023686647415161,
|
6 |
+
373.57745361328125,
|
7 |
+
180.60316467285156,
|
8 |
+
0.8592674136161804,
|
9 |
+
333.528564453125,
|
10 |
+
179.45702362060547,
|
11 |
+
0.7867028713226318,
|
12 |
+
278.2209167480469,
|
13 |
+
207.63121032714844,
|
14 |
+
0.8840203285217285,
|
15 |
+
228.78005981445312,
|
16 |
+
234.69793701171875,
|
17 |
+
0.8324164152145386,
|
18 |
+
417.08209228515625,
|
19 |
+
181.77294921875,
|
20 |
+
0.7164953947067261,
|
21 |
+
477.138427734375,
|
22 |
+
199.3846893310547,
|
23 |
+
0.7733086347579956,
|
24 |
+
539.4710083007812,
|
25 |
+
219.44891357421875,
|
26 |
+
0.8321817517280579,
|
27 |
+
401.8182678222656,
|
28 |
+
288.8574676513672,
|
29 |
+
0.61277836561203,
|
30 |
+
382.9984436035156,
|
31 |
+
294.7460632324219,
|
32 |
+
0.5884051322937012,
|
33 |
+
388.8341979980469,
|
34 |
+
377.1164245605469,
|
35 |
+
0.8282020092010498,
|
36 |
+
488.86529541015625,
|
37 |
+
404.145751953125,
|
38 |
+
0.6257187724113464,
|
39 |
+
420.6218566894531,
|
40 |
+
282.9443664550781,
|
41 |
+
0.5774698257446289,
|
42 |
+
455.9610290527344,
|
43 |
+
361.8221130371094,
|
44 |
+
0.8058001399040222,
|
45 |
+
557.13916015625,
|
46 |
+
339.43017578125,
|
47 |
+
0.69627445936203,
|
48 |
+
352.3575134277344,
|
49 |
+
151.14682006835938,
|
50 |
+
0.9335765242576599,
|
51 |
+
371.185791015625,
|
52 |
+
146.48798370361328,
|
53 |
+
0.8626495003700256,
|
54 |
+
342.9620666503906,
|
55 |
+
150.00089263916016,
|
56 |
+
0.0641486719250679,
|
57 |
+
390.03204345703125,
|
58 |
+
135.8568878173828,
|
59 |
+
0.8869808316230774,
|
60 |
+
595.938720703125,
|
61 |
+
338.2825012207031,
|
62 |
+
0.25365617871284485,
|
63 |
+
594.7731323242188,
|
64 |
+
334.75506591796875,
|
65 |
+
0.23056654632091522,
|
66 |
+
561.8401489257812,
|
67 |
+
331.20794677734375,
|
68 |
+
0.29395991563796997,
|
69 |
+
484.1672058105469,
|
70 |
+
435.9705810546875,
|
71 |
+
0.6335450410842896,
|
72 |
+
479.44921875,
|
73 |
+
433.6032409667969,
|
74 |
+
0.5307492017745972,
|
75 |
+
501.7928466796875,
|
76 |
+
398.28533935546875,
|
77 |
+
0.5881072878837585
|
78 |
+
]
|
79 |
+
}
|
examples/002446/k1.obj_rend_mask.png
ADDED
examples/002446/k1.person_mask.png
ADDED
examples/053431/k1.color.jpg
ADDED
examples/053431/k1.obj_rend_mask.png
ADDED
examples/053431/k1.person_mask.png
ADDED
examples/158107/k1.color.jpg
ADDED
examples/158107/k1.obj_rend_mask.png
ADDED
examples/158107/k1.person_mask.png
ADDED