Yuliang commited on
Commit
162943d
·
1 Parent(s): d08f752
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +17 -0
  2. LICENSE +53 -0
  3. README.md +7 -8
  4. apps/ICON.py +762 -0
  5. apps/Normal.py +213 -0
  6. apps/__pycache__/app.cpython-38.pyc +0 -0
  7. apps/app.py +21 -0
  8. apps/infer.py +616 -0
  9. assets/garment_teaser.png +0 -0
  10. assets/intermediate_results.png +0 -0
  11. assets/teaser.gif +0 -0
  12. configs/icon-filter.yaml +25 -0
  13. configs/icon-nofilter.yaml +25 -0
  14. configs/pamir.yaml +24 -0
  15. configs/pifu.yaml +24 -0
  16. environment.yaml +16 -0
  17. examples/22097467bffc92d4a5c4246f7d4edb75.png +0 -0
  18. examples/44c0f84c957b6b9bdf77662af5bb7078.png +0 -0
  19. examples/5a6a25963db2f667441d5076972c207c.png +0 -0
  20. examples/8da7ceb94669c2f65cbd28022e1f9876.png +0 -0
  21. examples/923d65f767c85a42212cae13fba3750b.png +0 -0
  22. examples/959c4c726a69901ce71b93a9242ed900.png +0 -0
  23. examples/c9856a2bc31846d684cbb965457fad59.png +0 -0
  24. examples/e1e7622af7074a022f5d96dc16672517.png +0 -0
  25. examples/fb9d20fdb93750584390599478ecf86e.png +0 -0
  26. examples/segmentation/003883.jpg +0 -0
  27. examples/segmentation/003883.json +136 -0
  28. examples/segmentation/028009.jpg +0 -0
  29. examples/segmentation/028009.json +191 -0
  30. examples/slack_trial2-000150.png +0 -0
  31. fetch_data.sh +60 -0
  32. install.sh +16 -0
  33. lib/__init__.py +0 -0
  34. lib/common/__init__.py +0 -0
  35. lib/common/cloth_extraction.py +170 -0
  36. lib/common/config.py +218 -0
  37. lib/common/render.py +387 -0
  38. lib/common/render_utils.py +221 -0
  39. lib/common/seg3d_lossless.py +604 -0
  40. lib/common/seg3d_utils.py +392 -0
  41. lib/common/smpl_vert_segmentation.json +0 -0
  42. lib/common/train_util.py +597 -0
  43. lib/dataloader_demo.py +58 -0
  44. lib/dataset/Evaluator.py +264 -0
  45. lib/dataset/NormalDataset.py +212 -0
  46. lib/dataset/NormalModule.py +94 -0
  47. lib/dataset/PIFuDataModule.py +71 -0
  48. lib/dataset/PIFuDataset.py +662 -0
  49. lib/dataset/TestDataset.py +342 -0
  50. lib/dataset/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/*/*
2
+ data/thuman*
3
+ !data/tbfo.ttf
4
+ __pycache__
5
+ debug/
6
+ log/
7
+ results/*
8
+ .vscode
9
+ !.gitignore
10
+ force_push.sh
11
+ .idea
12
+ smplx/
13
+ human_det/
14
+ kaolin/
15
+ neural_voxelization_layer/
16
+ pytorch3d/
17
+ force_push.sh
LICENSE ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ License
2
+
3
+ Software Copyright License for non-commercial scientific research purposes
4
+ Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the ICON model, data and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
5
+
6
+ Ownership / Licensees
7
+ The Software and the associated materials has been developed at the Max Planck Institute for Intelligent Systems (hereinafter "MPI"). Any copyright or patent right is owned by and proprietary material of the Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) hereinafter the “Licensor”.
8
+
9
+ License Grant
10
+ Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
11
+
12
+ • To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization;
13
+ • To use the Model & Software for the sole purpose of performing peaceful non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
14
+ • To modify, adapt, translate or create derivative works based upon the Model & Software.
15
+
16
+ Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
17
+
18
+ The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
19
+
20
+ No Distribution
21
+ The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
22
+
23
+ Disclaimer of Representations and Warranties
24
+ You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.
25
+
26
+ Limitation of Liability
27
+ Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
28
+ Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
29
+ Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders.
30
+ The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.
31
+
32
+ No Maintenance Services
33
+ You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.
34
+
35
+ Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
36
+
37
+ Publications using the Model & Software
38
+ You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.
39
+
40
+ Citation:
41
+
42
+ @inproceedings{xiu2022icon,
43
+ title={{ICON}: {I}mplicit {C}lothed humans {O}btained from {N}ormals},
44
+ author={Xiu, Yuliang and Yang, Jinlong and Tzionas, Dimitrios and Black, Michael J.},
45
+ booktitle={IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR)},
46
+ month = jun,
47
+ year={2022}
48
+ }
49
+
50
+ Commercial licensing opportunities
51
+ For commercial uses of the Model & Software, please send email to [email protected]
52
+
53
+ This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: ICON
3
- emoji: 🌖
 
4
  colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.1.1
8
- app_file: app.py
9
- pinned: false
10
- license: other
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: ICON
3
+ metaTitle: "Image2Human by Yuliang Xiu"
4
+ emoji: 🤼
5
  colorFrom: indigo
6
+ colorTo: yellow
7
  sdk: gradio
8
  sdk_version: 3.1.1
9
+ app_file: ./apps/app.py
10
+ pinned: true
11
+ python_version: 3.8
12
+ ---
 
 
apps/ICON.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ from lib.common.seg3d_lossless import Seg3dLossless
18
+ from lib.dataset.Evaluator import Evaluator
19
+ from lib.net import HGPIFuNet
20
+ from lib.common.train_util import *
21
+ from lib.renderer.gl.init_gl import initialize_GL_context
22
+ from lib.common.render import Render
23
+ from lib.dataset.mesh_util import SMPLX, update_mesh_shape_prior_losses, get_visibility
24
+ import warnings
25
+ import logging
26
+ import torch
27
+ import smplx
28
+ import numpy as np
29
+ from torch import nn
30
+ from skimage.transform import resize
31
+ import pytorch_lightning as pl
32
+
33
+ torch.backends.cudnn.benchmark = True
34
+
35
+ logging.getLogger("lightning").setLevel(logging.ERROR)
36
+
37
+ warnings.filterwarnings("ignore")
38
+
39
+
40
+ class ICON(pl.LightningModule):
41
+ def __init__(self, cfg):
42
+ super(ICON, self).__init__()
43
+
44
+ self.cfg = cfg
45
+ self.batch_size = self.cfg.batch_size
46
+ self.lr_G = self.cfg.lr_G
47
+
48
+ self.use_sdf = cfg.sdf
49
+ self.prior_type = cfg.net.prior_type
50
+ self.mcube_res = cfg.mcube_res
51
+ self.clean_mesh_flag = cfg.clean_mesh
52
+
53
+ self.netG = HGPIFuNet(
54
+ self.cfg,
55
+ self.cfg.projection_mode,
56
+ error_term=nn.SmoothL1Loss() if self.use_sdf else nn.MSELoss(),
57
+ )
58
+
59
+ # TODO: replace the renderer from opengl to pytorch3d
60
+ self.evaluator = Evaluator(
61
+ device=torch.device(f"cuda:{self.cfg.gpus[0]}"))
62
+
63
+ self.resolutions = (
64
+ np.logspace(
65
+ start=5,
66
+ stop=np.log2(self.mcube_res),
67
+ base=2,
68
+ num=int(np.log2(self.mcube_res) - 4),
69
+ endpoint=True,
70
+ )
71
+ + 1.0
72
+ )
73
+ self.resolutions = self.resolutions.astype(np.int16).tolist()
74
+
75
+ self.icon_keys = ["smpl_verts", "smpl_faces", "smpl_vis", "smpl_cmap"]
76
+ self.pamir_keys = ["voxel_verts",
77
+ "voxel_faces", "pad_v_num", "pad_f_num"]
78
+
79
+ self.reconEngine = Seg3dLossless(
80
+ query_func=query_func,
81
+ b_min=[[-1.0, 1.0, -1.0]],
82
+ b_max=[[1.0, -1.0, 1.0]],
83
+ resolutions=self.resolutions,
84
+ align_corners=True,
85
+ balance_value=0.50,
86
+ device=torch.device(f"cuda:{self.cfg.test_gpus[0]}"),
87
+ visualize=False,
88
+ debug=False,
89
+ use_cuda_impl=False,
90
+ faster=True,
91
+ )
92
+
93
+ self.render = Render(
94
+ size=512, device=torch.device(f"cuda:{self.cfg.test_gpus[0]}")
95
+ )
96
+ self.smpl_data = SMPLX()
97
+
98
+ self.get_smpl_model = lambda smpl_type, gender, age, v_template: smplx.create(
99
+ self.smpl_data.model_dir,
100
+ kid_template_path=osp.join(
101
+ osp.realpath(self.smpl_data.model_dir),
102
+ f"{smpl_type}/{smpl_type}_kid_template.npy",
103
+ ),
104
+ model_type=smpl_type,
105
+ gender=gender,
106
+ age=age,
107
+ v_template=v_template,
108
+ use_face_contour=False,
109
+ ext="pkl",
110
+ )
111
+
112
+ self.in_geo = [item[0] for item in cfg.net.in_geo]
113
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
114
+ self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
115
+ self.in_total = self.in_geo + self.in_nml
116
+ self.smpl_dim = cfg.net.smpl_dim
117
+
118
+ self.export_dir = None
119
+ self.result_eval = {}
120
+
121
+ def get_progress_bar_dict(self):
122
+ tqdm_dict = super().get_progress_bar_dict()
123
+ if "v_num" in tqdm_dict:
124
+ del tqdm_dict["v_num"]
125
+ return tqdm_dict
126
+
127
+ # Training related
128
+ def configure_optimizers(self):
129
+
130
+ # set optimizer
131
+ weight_decay = self.cfg.weight_decay
132
+ momentum = self.cfg.momentum
133
+
134
+ optim_params_G = [
135
+ {"params": self.netG.if_regressor.parameters(), "lr": self.lr_G}
136
+ ]
137
+
138
+ if self.cfg.net.use_filter:
139
+ optim_params_G.append(
140
+ {"params": self.netG.F_filter.parameters(), "lr": self.lr_G}
141
+ )
142
+
143
+ if self.cfg.net.prior_type == "pamir":
144
+ optim_params_G.append(
145
+ {"params": self.netG.ve.parameters(), "lr": self.lr_G}
146
+ )
147
+
148
+ if self.cfg.optim == "Adadelta":
149
+
150
+ optimizer_G = torch.optim.Adadelta(
151
+ optim_params_G, lr=self.lr_G, weight_decay=weight_decay
152
+ )
153
+
154
+ elif self.cfg.optim == "Adam":
155
+
156
+ optimizer_G = torch.optim.Adam(
157
+ optim_params_G, lr=self.lr_G, weight_decay=weight_decay
158
+ )
159
+
160
+ elif self.cfg.optim == "RMSprop":
161
+
162
+ optimizer_G = torch.optim.RMSprop(
163
+ optim_params_G,
164
+ lr=self.lr_G,
165
+ weight_decay=weight_decay,
166
+ momentum=momentum,
167
+ )
168
+
169
+ else:
170
+ raise NotImplementedError
171
+
172
+ # set scheduler
173
+ scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
174
+ optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
175
+ )
176
+
177
+ return [optimizer_G], [scheduler_G]
178
+
179
+ def training_step(self, batch, batch_idx):
180
+
181
+ if not self.cfg.fast_dev:
182
+ export_cfg(self.logger, self.cfg)
183
+
184
+ self.netG.train()
185
+
186
+ in_tensor_dict = {
187
+ "sample": batch["samples_geo"].permute(0, 2, 1),
188
+ "calib": batch["calib"],
189
+ "label": batch["labels_geo"].unsqueeze(1),
190
+ }
191
+
192
+ for name in self.in_total:
193
+ in_tensor_dict.update({name: batch[name]})
194
+
195
+ if self.prior_type == "icon":
196
+ for key in self.icon_keys:
197
+ in_tensor_dict.update({key: batch[key]})
198
+ elif self.prior_type == "pamir":
199
+ for key in self.pamir_keys:
200
+ in_tensor_dict.update({key: batch[key]})
201
+ else:
202
+ pass
203
+
204
+ preds_G, error_G = self.netG(in_tensor_dict)
205
+
206
+ acc, iou, prec, recall = self.evaluator.calc_acc(
207
+ preds_G.flatten(),
208
+ in_tensor_dict["label"].flatten(),
209
+ 0.5,
210
+ use_sdf=self.cfg.sdf,
211
+ )
212
+
213
+ # metrics processing
214
+ metrics_log = {
215
+ "train_loss": error_G.item(),
216
+ "train_acc": acc.item(),
217
+ "train_iou": iou.item(),
218
+ "train_prec": prec.item(),
219
+ "train_recall": recall.item(),
220
+ }
221
+
222
+ tf_log = tf_log_convert(metrics_log)
223
+ bar_log = bar_log_convert(metrics_log)
224
+
225
+ if batch_idx % int(self.cfg.freq_show_train) == 0:
226
+
227
+ with torch.no_grad():
228
+ self.render_func(in_tensor_dict, dataset="train")
229
+
230
+ metrics_return = {
231
+ k.replace("train_", ""): torch.tensor(v) for k, v in metrics_log.items()
232
+ }
233
+
234
+ metrics_return.update(
235
+ {"loss": error_G, "log": tf_log, "progress_bar": bar_log})
236
+
237
+ return metrics_return
238
+
239
+ def training_epoch_end(self, outputs):
240
+
241
+ if [] in outputs:
242
+ outputs = outputs[0]
243
+
244
+ # metrics processing
245
+ metrics_log = {
246
+ "train_avgloss": batch_mean(outputs, "loss"),
247
+ "train_avgiou": batch_mean(outputs, "iou"),
248
+ "train_avgprec": batch_mean(outputs, "prec"),
249
+ "train_avgrecall": batch_mean(outputs, "recall"),
250
+ "train_avgacc": batch_mean(outputs, "acc"),
251
+ }
252
+
253
+ tf_log = tf_log_convert(metrics_log)
254
+
255
+ return {"log": tf_log}
256
+
257
+ def validation_step(self, batch, batch_idx):
258
+
259
+ self.netG.eval()
260
+ self.netG.training = False
261
+
262
+ in_tensor_dict = {
263
+ "sample": batch["samples_geo"].permute(0, 2, 1),
264
+ "calib": batch["calib"],
265
+ "label": batch["labels_geo"].unsqueeze(1),
266
+ }
267
+
268
+ for name in self.in_total:
269
+ in_tensor_dict.update({name: batch[name]})
270
+
271
+ if self.prior_type == "icon":
272
+ for key in self.icon_keys:
273
+ in_tensor_dict.update({key: batch[key]})
274
+ elif self.prior_type == "pamir":
275
+ for key in self.pamir_keys:
276
+ in_tensor_dict.update({key: batch[key]})
277
+ else:
278
+ pass
279
+
280
+ preds_G, error_G = self.netG(in_tensor_dict)
281
+
282
+ acc, iou, prec, recall = self.evaluator.calc_acc(
283
+ preds_G.flatten(),
284
+ in_tensor_dict["label"].flatten(),
285
+ 0.5,
286
+ use_sdf=self.cfg.sdf,
287
+ )
288
+
289
+ if batch_idx % int(self.cfg.freq_show_val) == 0:
290
+ with torch.no_grad():
291
+ self.render_func(in_tensor_dict, dataset="val", idx=batch_idx)
292
+
293
+ metrics_return = {
294
+ "val_loss": error_G,
295
+ "val_acc": acc,
296
+ "val_iou": iou,
297
+ "val_prec": prec,
298
+ "val_recall": recall,
299
+ }
300
+
301
+ return metrics_return
302
+
303
+ def validation_epoch_end(self, outputs):
304
+
305
+ # metrics processing
306
+ metrics_log = {
307
+ "val_avgloss": batch_mean(outputs, "val_loss"),
308
+ "val_avgacc": batch_mean(outputs, "val_acc"),
309
+ "val_avgiou": batch_mean(outputs, "val_iou"),
310
+ "val_avgprec": batch_mean(outputs, "val_prec"),
311
+ "val_avgrecall": batch_mean(outputs, "val_recall"),
312
+ }
313
+
314
+ tf_log = tf_log_convert(metrics_log)
315
+
316
+ return {"log": tf_log}
317
+
318
+ def compute_vis_cmap(self, smpl_type, smpl_verts, smpl_faces):
319
+
320
+ (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
321
+ smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
322
+ if smpl_type == "smpl":
323
+ smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
324
+ else:
325
+ smplx_ind = np.arange(smpl_vis.shape[0])
326
+ smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
327
+
328
+ return {
329
+ "smpl_vis": smpl_vis.unsqueeze(0).to(self.device),
330
+ "smpl_cmap": smpl_cmap.unsqueeze(0).to(self.device),
331
+ "smpl_verts": smpl_verts.unsqueeze(0),
332
+ }
333
+
334
+ @torch.enable_grad()
335
+ def optim_body(self, in_tensor_dict, batch):
336
+
337
+ smpl_model = self.get_smpl_model(
338
+ batch["type"][0], batch["gender"][0], batch["age"][0], None
339
+ ).to(self.device)
340
+ in_tensor_dict["smpl_faces"] = (
341
+ torch.tensor(smpl_model.faces.astype(np.int))
342
+ .long()
343
+ .unsqueeze(0)
344
+ .to(self.device)
345
+ )
346
+
347
+ # The optimizer and variables
348
+ optimed_pose = torch.tensor(
349
+ batch["body_pose"][0], device=self.device, requires_grad=True
350
+ ) # [1,23,3,3]
351
+ optimed_trans = torch.tensor(
352
+ batch["transl"][0], device=self.device, requires_grad=True
353
+ ) # [3]
354
+ optimed_betas = torch.tensor(
355
+ batch["betas"][0], device=self.device, requires_grad=True
356
+ ) # [1,10]
357
+ optimed_orient = torch.tensor(
358
+ batch["global_orient"][0], device=self.device, requires_grad=True
359
+ ) # [1,1,3,3]
360
+
361
+ optimizer_smpl = torch.optim.SGD(
362
+ [optimed_pose, optimed_trans, optimed_betas, optimed_orient],
363
+ lr=1e-3,
364
+ momentum=0.9,
365
+ )
366
+ scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
367
+ optimizer_smpl, mode="min", factor=0.5, verbose=0, min_lr=1e-5, patience=5
368
+ )
369
+ loop_smpl = range(50)
370
+ for i in loop_smpl:
371
+
372
+ optimizer_smpl.zero_grad()
373
+
374
+ # prior_loss, optimed_pose = dataset.vposer_prior(optimed_pose)
375
+ smpl_out = smpl_model(
376
+ betas=optimed_betas,
377
+ body_pose=optimed_pose,
378
+ global_orient=optimed_orient,
379
+ transl=optimed_trans,
380
+ return_verts=True,
381
+ )
382
+
383
+ smpl_verts = smpl_out.vertices[0] * 100.0
384
+ smpl_verts = projection(
385
+ smpl_verts, batch["calib"][0], format="tensor")
386
+ smpl_verts[:, 1] *= -1
387
+ # render optimized mesh (normal, T_normal, image [-1,1])
388
+ self.render.load_meshes(
389
+ smpl_verts, in_tensor_dict["smpl_faces"])
390
+ (
391
+ in_tensor_dict["T_normal_F"],
392
+ in_tensor_dict["T_normal_B"],
393
+ ) = self.render.get_rgb_image()
394
+
395
+ T_mask_F, T_mask_B = self.render.get_silhouette_image()
396
+
397
+ with torch.no_grad():
398
+ (
399
+ in_tensor_dict["normal_F"],
400
+ in_tensor_dict["normal_B"],
401
+ ) = self.netG.normal_filter(in_tensor_dict)
402
+
403
+ # mask = torch.abs(in_tensor['T_normal_F']).sum(dim=0, keepdims=True) > 0.0
404
+ diff_F_smpl = torch.abs(
405
+ in_tensor_dict["T_normal_F"] - in_tensor_dict["normal_F"]
406
+ )
407
+ diff_B_smpl = torch.abs(
408
+ in_tensor_dict["T_normal_B"] - in_tensor_dict["normal_B"]
409
+ )
410
+ loss = (diff_F_smpl + diff_B_smpl).mean()
411
+
412
+ # silhouette loss
413
+ smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
414
+ gt_arr = torch.cat(
415
+ [in_tensor_dict["normal_F"][0], in_tensor_dict["normal_B"][0]], dim=2
416
+ ).permute(1, 2, 0)
417
+ gt_arr = ((gt_arr + 1.0) * 0.5).to(self.device)
418
+ bg_color = (
419
+ torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
420
+ 0).unsqueeze(0).to(self.device)
421
+ )
422
+ gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
423
+ loss += torch.abs(smpl_arr - gt_arr).mean()
424
+
425
+ # Image.fromarray(((in_tensor_dict['T_normal_F'][0].permute(1,2,0)+1.0)*0.5*255.0).detach().cpu().numpy().astype(np.uint8)).show()
426
+
427
+ # loop_smpl.set_description(f"smpl = {loss:.3f}")
428
+
429
+ loss.backward(retain_graph=True)
430
+ optimizer_smpl.step()
431
+ scheduler_smpl.step(loss)
432
+ in_tensor_dict["smpl_verts"] = smpl_verts.unsqueeze(0)
433
+
434
+ in_tensor_dict.update(
435
+ self.compute_vis_cmap(
436
+ batch["type"][0],
437
+ in_tensor_dict["smpl_verts"][0],
438
+ in_tensor_dict["smpl_faces"][0],
439
+ )
440
+ )
441
+
442
+ features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
443
+
444
+ return features, inter, in_tensor_dict
445
+
446
+ @torch.enable_grad()
447
+ def optim_cloth(self, verts_pr, faces_pr, inter):
448
+
449
+ # convert from GT to SDF
450
+ verts_pr -= (self.resolutions[-1] - 1) / 2.0
451
+ verts_pr /= (self.resolutions[-1] - 1) / 2.0
452
+
453
+ losses = {
454
+ "cloth": {"weight": 5.0, "value": 0.0},
455
+ "edge": {"weight": 100.0, "value": 0.0},
456
+ "normal": {"weight": 0.2, "value": 0.0},
457
+ "laplacian": {"weight": 100.0, "value": 0.0},
458
+ "smpl": {"weight": 1.0, "value": 0.0},
459
+ "deform": {"weight": 20.0, "value": 0.0},
460
+ }
461
+
462
+ deform_verts = torch.full(
463
+ verts_pr.shape, 0.0, device=self.device, requires_grad=True
464
+ )
465
+ optimizer_cloth = torch.optim.SGD(
466
+ [deform_verts], lr=1e-1, momentum=0.9)
467
+ scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
468
+ optimizer_cloth, mode="min", factor=0.1, verbose=0, min_lr=1e-3, patience=5
469
+ )
470
+ # cloth optimization
471
+ loop_cloth = range(100)
472
+
473
+ for i in loop_cloth:
474
+
475
+ optimizer_cloth.zero_grad()
476
+
477
+ self.render.load_meshes(
478
+ verts_pr.unsqueeze(0).to(self.device),
479
+ faces_pr.unsqueeze(0).to(self.device).long(),
480
+ deform_verts,
481
+ )
482
+ P_normal_F, P_normal_B = self.render.get_rgb_image()
483
+
484
+ update_mesh_shape_prior_losses(self.render.mesh, losses)
485
+ diff_F_cloth = torch.abs(P_normal_F[0] - inter[:3])
486
+ diff_B_cloth = torch.abs(P_normal_B[0] - inter[3:])
487
+ losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
488
+ losses["deform"]["value"] = torch.topk(
489
+ torch.abs(deform_verts.flatten()), 30
490
+ )[0].mean()
491
+
492
+ # Weighted sum of the losses
493
+ cloth_loss = torch.tensor(0.0, device=self.device)
494
+ pbar_desc = ""
495
+
496
+ for k in losses.keys():
497
+ if k != "smpl":
498
+ cloth_loss_per_cls = losses[k]["value"] * \
499
+ losses[k]["weight"]
500
+ pbar_desc += f"{k}: {cloth_loss_per_cls:.3f} | "
501
+ cloth_loss += cloth_loss_per_cls
502
+
503
+ # loop_cloth.set_description(pbar_desc)
504
+ cloth_loss.backward(retain_graph=True)
505
+ optimizer_cloth.step()
506
+ scheduler_cloth.step(cloth_loss)
507
+
508
+ # convert from GT to SDF
509
+ deform_verts = deform_verts.flatten().detach()
510
+ deform_verts[torch.topk(torch.abs(deform_verts), 30)[
511
+ 1]] = deform_verts.mean()
512
+ deform_verts = deform_verts.view(-1, 3).cpu()
513
+
514
+ verts_pr += deform_verts
515
+ verts_pr *= (self.resolutions[-1] - 1) / 2.0
516
+ verts_pr += (self.resolutions[-1] - 1) / 2.0
517
+
518
+ return verts_pr
519
+
520
+ def test_step(self, batch, batch_idx):
521
+
522
+ # dict_keys(['dataset', 'subject', 'rotation', 'scale', 'calib',
523
+ # 'normal_F', 'normal_B', 'image', 'T_normal_F', 'T_normal_B',
524
+ # 'z-trans', 'verts', 'faces', 'samples_geo', 'labels_geo',
525
+ # 'smpl_verts', 'smpl_faces', 'smpl_vis', 'smpl_cmap', 'pts_signs',
526
+ # 'type', 'gender', 'age', 'body_pose', 'global_orient', 'betas', 'transl'])
527
+
528
+ if self.evaluator._normal_render is None:
529
+ self.evaluator.init_gl()
530
+
531
+ self.netG.eval()
532
+ self.netG.training = False
533
+ in_tensor_dict = {}
534
+
535
+ # export paths
536
+ mesh_name = batch["subject"][0]
537
+ mesh_rot = batch["rotation"][0].item()
538
+ ckpt_dir = self.cfg.name
539
+
540
+ for kid, key in enumerate(self.cfg.dataset.noise_type):
541
+ ckpt_dir += f"_{key}_{self.cfg.dataset.noise_scale[kid]}"
542
+
543
+ if self.cfg.optim_cloth:
544
+ ckpt_dir += "_optim_cloth"
545
+ if self.cfg.optim_body:
546
+ ckpt_dir += "_optim_body"
547
+
548
+ self.export_dir = osp.join(self.cfg.results_path, ckpt_dir, mesh_name)
549
+ os.makedirs(self.export_dir, exist_ok=True)
550
+
551
+ for name in self.in_total:
552
+ if name in batch.keys():
553
+ in_tensor_dict.update({name: batch[name]})
554
+
555
+ # update the new T_normal_F/B
556
+ in_tensor_dict.update(
557
+ self.evaluator.render_normal(
558
+ batch["smpl_verts"], batch["smpl_faces"])
559
+ )
560
+
561
+ # update the new smpl_vis
562
+ (xy, z) = batch["smpl_verts"][0].split([2, 1], dim=1)
563
+ smpl_vis = get_visibility(
564
+ xy,
565
+ z,
566
+ torch.as_tensor(self.smpl_data.faces).type_as(
567
+ batch["smpl_verts"]).long(),
568
+ )
569
+ in_tensor_dict.update({"smpl_vis": smpl_vis.unsqueeze(0)})
570
+
571
+ if self.prior_type == "icon":
572
+ for key in self.icon_keys:
573
+ in_tensor_dict.update({key: batch[key]})
574
+ elif self.prior_type == "pamir":
575
+ for key in self.pamir_keys:
576
+ in_tensor_dict.update({key: batch[key]})
577
+ else:
578
+ pass
579
+
580
+ with torch.no_grad():
581
+ if self.cfg.optim_body:
582
+ features, inter, in_tensor_dict = self.optim_body(
583
+ in_tensor_dict, batch)
584
+ else:
585
+ features, inter = self.netG.filter(
586
+ in_tensor_dict, return_inter=True)
587
+ sdf = self.reconEngine(
588
+ opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
589
+ )
590
+
591
+ # save inter results
592
+ image = (
593
+ in_tensor_dict["image"][0].permute(
594
+ 1, 2, 0).detach().cpu().numpy() + 1.0
595
+ ) * 0.5
596
+ smpl_F = (
597
+ in_tensor_dict["T_normal_F"][0].permute(
598
+ 1, 2, 0).detach().cpu().numpy()
599
+ + 1.0
600
+ ) * 0.5
601
+ smpl_B = (
602
+ in_tensor_dict["T_normal_B"][0].permute(
603
+ 1, 2, 0).detach().cpu().numpy()
604
+ + 1.0
605
+ ) * 0.5
606
+ image_inter = np.concatenate(
607
+ self.tensor2image(512, inter[0]) + [smpl_F, smpl_B, image], axis=1
608
+ )
609
+ Image.fromarray((image_inter * 255.0).astype(np.uint8)).save(
610
+ osp.join(self.export_dir, f"{mesh_rot}_inter.png")
611
+ )
612
+
613
+ verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
614
+
615
+ if self.clean_mesh_flag:
616
+ verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
617
+
618
+ if self.cfg.optim_cloth:
619
+ verts_pr = self.optim_cloth(verts_pr, faces_pr, inter[0].detach())
620
+
621
+ verts_gt = batch["verts"][0]
622
+ faces_gt = batch["faces"][0]
623
+
624
+ self.result_eval.update(
625
+ {
626
+ "verts_gt": verts_gt,
627
+ "faces_gt": faces_gt,
628
+ "verts_pr": verts_pr,
629
+ "faces_pr": faces_pr,
630
+ "recon_size": (self.resolutions[-1] - 1.0),
631
+ "calib": batch["calib"][0],
632
+ }
633
+ )
634
+
635
+ self.evaluator.set_mesh(self.result_eval, scale_factor=1.0)
636
+ self.evaluator.space_transfer()
637
+
638
+ chamfer, p2s = self.evaluator.calculate_chamfer_p2s(
639
+ sampled_points=1000)
640
+ normal_consist = self.evaluator.calculate_normal_consist(
641
+ save_demo_img=osp.join(self.export_dir, f"{mesh_rot}_nc.png")
642
+ )
643
+
644
+ test_log = {"chamfer": chamfer, "p2s": p2s, "NC": normal_consist}
645
+
646
+ return test_log
647
+
648
+ def test_epoch_end(self, outputs):
649
+
650
+ # make_test_gif("/".join(self.export_dir.split("/")[:-2]))
651
+
652
+ accu_outputs = accumulate(
653
+ outputs,
654
+ rot_num=3,
655
+ split={
656
+ "thuman2": (0, 5),
657
+ },
658
+ )
659
+
660
+ print(colored(self.cfg.name, "green"))
661
+ print(colored(self.cfg.dataset.noise_scale, "green"))
662
+
663
+ self.logger.experiment.add_hparams(
664
+ hparam_dict={"lr_G": self.lr_G, "bsize": self.batch_size},
665
+ metric_dict=accu_outputs,
666
+ )
667
+
668
+ np.save(
669
+ osp.join(self.export_dir, "../test_results.npy"),
670
+ accu_outputs,
671
+ allow_pickle=True,
672
+ )
673
+
674
+ return accu_outputs
675
+
676
+ def tensor2image(self, height, inter):
677
+
678
+ all = []
679
+ for dim in self.in_geo_dim:
680
+ img = resize(
681
+ np.tile(
682
+ ((inter[:dim].cpu().numpy() + 1.0) /
683
+ 2.0).transpose(1, 2, 0),
684
+ (1, 1, int(3 / dim)),
685
+ ),
686
+ (height, height),
687
+ anti_aliasing=True,
688
+ )
689
+
690
+ all.append(img)
691
+ inter = inter[dim:]
692
+
693
+ return all
694
+
695
+ def render_func(self, in_tensor_dict, dataset="title", idx=0):
696
+
697
+ for name in in_tensor_dict.keys():
698
+ in_tensor_dict[name] = in_tensor_dict[name][0:1]
699
+
700
+ self.netG.eval()
701
+ features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
702
+ sdf = self.reconEngine(
703
+ opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
704
+ )
705
+
706
+ if sdf is not None:
707
+ render = self.reconEngine.display(sdf)
708
+
709
+ image_pred = np.flip(render[:, :, ::-1], axis=0)
710
+ height = image_pred.shape[0]
711
+
712
+ image_gt = resize(
713
+ ((in_tensor_dict["image"].cpu().numpy()[0] + 1.0) / 2.0).transpose(
714
+ 1, 2, 0
715
+ ),
716
+ (height, height),
717
+ anti_aliasing=True,
718
+ )
719
+ image_inter = self.tensor2image(height, inter[0])
720
+ image = np.concatenate(
721
+ [image_pred, image_gt] + image_inter, axis=1)
722
+
723
+ step_id = self.global_step if dataset == "train" else self.global_step + idx
724
+ self.logger.experiment.add_image(
725
+ tag=f"Occupancy-{dataset}/{step_id}",
726
+ img_tensor=image.transpose(2, 0, 1),
727
+ global_step=step_id,
728
+ )
729
+
730
+ def test_single(self, batch):
731
+
732
+ self.netG.eval()
733
+ self.netG.training = False
734
+ in_tensor_dict = {}
735
+
736
+ for name in self.in_total:
737
+ if name in batch.keys():
738
+ in_tensor_dict.update({name: batch[name]})
739
+
740
+ if self.prior_type == "icon":
741
+ for key in self.icon_keys:
742
+ in_tensor_dict.update({key: batch[key]})
743
+ elif self.prior_type == "pamir":
744
+ for key in self.pamir_keys:
745
+ in_tensor_dict.update({key: batch[key]})
746
+ else:
747
+ pass
748
+
749
+ features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
750
+ sdf = self.reconEngine(
751
+ opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
752
+ )
753
+
754
+ verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
755
+
756
+ if self.clean_mesh_flag:
757
+ verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
758
+
759
+ verts_pr -= (self.resolutions[-1] - 1) / 2.0
760
+ verts_pr /= (self.resolutions[-1] - 1) / 2.0
761
+
762
+ return verts_pr, faces_pr, inter
apps/Normal.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib.net import NormalNet
2
+ from lib.common.train_util import *
3
+ import logging
4
+ import torch
5
+ import numpy as np
6
+ from torch import nn
7
+ from skimage.transform import resize
8
+ import pytorch_lightning as pl
9
+
10
+ torch.backends.cudnn.benchmark = True
11
+
12
+ logging.getLogger("lightning").setLevel(logging.ERROR)
13
+
14
+
15
+ class Normal(pl.LightningModule):
16
+ def __init__(self, cfg):
17
+ super(Normal, self).__init__()
18
+ self.cfg = cfg
19
+ self.batch_size = self.cfg.batch_size
20
+ self.lr_N = self.cfg.lr_N
21
+
22
+ self.schedulers = []
23
+
24
+ self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss())
25
+
26
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
27
+
28
+ def get_progress_bar_dict(self):
29
+ tqdm_dict = super().get_progress_bar_dict()
30
+ if "v_num" in tqdm_dict:
31
+ del tqdm_dict["v_num"]
32
+ return tqdm_dict
33
+
34
+ # Training related
35
+ def configure_optimizers(self):
36
+
37
+ # set optimizer
38
+ weight_decay = self.cfg.weight_decay
39
+ momentum = self.cfg.momentum
40
+
41
+ optim_params_N_F = [
42
+ {"params": self.netG.netF.parameters(), "lr": self.lr_N}]
43
+ optim_params_N_B = [
44
+ {"params": self.netG.netB.parameters(), "lr": self.lr_N}]
45
+
46
+ optimizer_N_F = torch.optim.Adam(
47
+ optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay
48
+ )
49
+
50
+ optimizer_N_B = torch.optim.Adam(
51
+ optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay
52
+ )
53
+
54
+ scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
55
+ optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
56
+ )
57
+
58
+ scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
59
+ optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
60
+ )
61
+
62
+ self.schedulers = [scheduler_N_F, scheduler_N_B]
63
+ optims = [optimizer_N_F, optimizer_N_B]
64
+
65
+ return optims, self.schedulers
66
+
67
+ def render_func(self, render_tensor):
68
+
69
+ height = render_tensor["image"].shape[2]
70
+ result_list = []
71
+
72
+ for name in render_tensor.keys():
73
+ result_list.append(
74
+ resize(
75
+ ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(
76
+ 1, 2, 0
77
+ ),
78
+ (height, height),
79
+ anti_aliasing=True,
80
+ )
81
+ )
82
+ result_array = np.concatenate(result_list, axis=1)
83
+
84
+ return result_array
85
+
86
+ def training_step(self, batch, batch_idx, optimizer_idx):
87
+
88
+ export_cfg(self.logger, self.cfg)
89
+
90
+ # retrieve the data
91
+ in_tensor = {}
92
+ for name in self.in_nml:
93
+ in_tensor[name] = batch[name]
94
+
95
+ FB_tensor = {"normal_F": batch["normal_F"],
96
+ "normal_B": batch["normal_B"]}
97
+
98
+ self.netG.train()
99
+
100
+ preds_F, preds_B = self.netG(in_tensor)
101
+ error_NF, error_NB = self.netG.get_norm_error(
102
+ preds_F, preds_B, FB_tensor)
103
+
104
+ (opt_nf, opt_nb) = self.optimizers()
105
+
106
+ opt_nf.zero_grad()
107
+ opt_nb.zero_grad()
108
+
109
+ self.manual_backward(error_NF, opt_nf)
110
+ self.manual_backward(error_NB, opt_nb)
111
+
112
+ opt_nf.step()
113
+ opt_nb.step()
114
+
115
+ if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0:
116
+
117
+ self.netG.eval()
118
+ with torch.no_grad():
119
+ nmlF, nmlB = self.netG(in_tensor)
120
+ in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
121
+ result_array = self.render_func(in_tensor)
122
+
123
+ self.logger.experiment.add_image(
124
+ tag=f"Normal-train/{self.global_step}",
125
+ img_tensor=result_array.transpose(2, 0, 1),
126
+ global_step=self.global_step,
127
+ )
128
+
129
+ # metrics processing
130
+ metrics_log = {
131
+ "train_loss-NF": error_NF.item(),
132
+ "train_loss-NB": error_NB.item(),
133
+ }
134
+
135
+ tf_log = tf_log_convert(metrics_log)
136
+ bar_log = bar_log_convert(metrics_log)
137
+
138
+ return {
139
+ "loss": error_NF + error_NB,
140
+ "loss-NF": error_NF,
141
+ "loss-NB": error_NB,
142
+ "log": tf_log,
143
+ "progress_bar": bar_log,
144
+ }
145
+
146
+ def training_epoch_end(self, outputs):
147
+
148
+ if [] in outputs:
149
+ outputs = outputs[0]
150
+
151
+ # metrics processing
152
+ metrics_log = {
153
+ "train_avgloss": batch_mean(outputs, "loss"),
154
+ "train_avgloss-NF": batch_mean(outputs, "loss-NF"),
155
+ "train_avgloss-NB": batch_mean(outputs, "loss-NB"),
156
+ }
157
+
158
+ tf_log = tf_log_convert(metrics_log)
159
+
160
+ tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0]
161
+ tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0]
162
+
163
+ return {"log": tf_log}
164
+
165
+ def validation_step(self, batch, batch_idx):
166
+
167
+ # retrieve the data
168
+ in_tensor = {}
169
+ for name in self.in_nml:
170
+ in_tensor[name] = batch[name]
171
+
172
+ FB_tensor = {"normal_F": batch["normal_F"],
173
+ "normal_B": batch["normal_B"]}
174
+
175
+ self.netG.train()
176
+
177
+ preds_F, preds_B = self.netG(in_tensor)
178
+ error_NF, error_NB = self.netG.get_norm_error(
179
+ preds_F, preds_B, FB_tensor)
180
+
181
+ if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or (
182
+ batch_idx == 0
183
+ ):
184
+
185
+ with torch.no_grad():
186
+ nmlF, nmlB = self.netG(in_tensor)
187
+ in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
188
+ result_array = self.render_func(in_tensor)
189
+
190
+ self.logger.experiment.add_image(
191
+ tag=f"Normal-val/{self.global_step}",
192
+ img_tensor=result_array.transpose(2, 0, 1),
193
+ global_step=self.global_step,
194
+ )
195
+
196
+ return {
197
+ "val_loss": error_NF + error_NB,
198
+ "val_loss-NF": error_NF,
199
+ "val_loss-NB": error_NB,
200
+ }
201
+
202
+ def validation_epoch_end(self, outputs):
203
+
204
+ # metrics processing
205
+ metrics_log = {
206
+ "val_avgloss": batch_mean(outputs, "val_loss"),
207
+ "val_avgloss-NF": batch_mean(outputs, "val_loss-NF"),
208
+ "val_avgloss-NB": batch_mean(outputs, "val_loss-NB"),
209
+ }
210
+
211
+ tf_log = tf_log_convert(metrics_log)
212
+
213
+ return {"log": tf_log}
apps/__pycache__/app.cpython-38.pyc ADDED
Binary file (555 Bytes). View file
 
apps/app.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # install
2
+
3
+ import os
4
+ os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
5
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
6
+ try:
7
+ os.system("bash install.sh")
8
+ except Exception as e:
9
+ print(e)
10
+
11
+
12
+ # running
13
+
14
+ import gradio as gr
15
+
16
+ def image_classifier(inp):
17
+ return {'cat': 0.3, 'dog': 0.7}
18
+
19
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
20
+ demo.launch(auth=("[email protected]", "icon_2022"),
21
+ auth_message="Register at icon.is.tue.mpg.de/download to get the username and password.")
apps/infer.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ import logging
18
+ from lib.common.render import query_color, image2vid
19
+ from lib.common.config import cfg
20
+ from lib.common.cloth_extraction import extract_cloth
21
+ from lib.dataset.mesh_util import (
22
+ load_checkpoint,
23
+ update_mesh_shape_prior_losses,
24
+ get_optim_grid_image,
25
+ blend_rgb_norm,
26
+ unwrap,
27
+ remesh,
28
+ tensor2variable,
29
+ normal_loss
30
+ )
31
+
32
+ from lib.dataset.TestDataset import TestDataset
33
+ from lib.net.local_affine import LocalAffine
34
+ from pytorch3d.structures import Meshes
35
+ from apps.ICON import ICON
36
+
37
+ import os
38
+ from termcolor import colored
39
+ import argparse
40
+ import numpy as np
41
+ from PIL import Image
42
+ import trimesh
43
+ import pickle
44
+ import numpy as np
45
+
46
+ import torch
47
+ torch.backends.cudnn.benchmark = True
48
+
49
+ logging.getLogger("trimesh").setLevel(logging.ERROR)
50
+
51
+
52
+ if __name__ == "__main__":
53
+
54
+ # loading cfg file
55
+ parser = argparse.ArgumentParser()
56
+
57
+ parser.add_argument("-gpu", "--gpu_device", type=int, default=0)
58
+ parser.add_argument("-colab", action="store_true")
59
+ parser.add_argument("-loop_smpl", "--loop_smpl", type=int, default=100)
60
+ parser.add_argument("-patience", "--patience", type=int, default=5)
61
+ parser.add_argument("-vis_freq", "--vis_freq", type=int, default=10)
62
+ parser.add_argument("-loop_cloth", "--loop_cloth", type=int, default=200)
63
+ parser.add_argument("-hps_type", "--hps_type", type=str, default="pymaf")
64
+ parser.add_argument("-export_video", action="store_true")
65
+ parser.add_argument("-in_dir", "--in_dir", type=str, default="./examples")
66
+ parser.add_argument("-out_dir", "--out_dir",
67
+ type=str, default="./results")
68
+ parser.add_argument('-seg_dir', '--seg_dir', type=str, default=None)
69
+ parser.add_argument(
70
+ "-cfg", "--config", type=str, default="./configs/icon-filter.yaml"
71
+ )
72
+
73
+ args = parser.parse_args()
74
+
75
+ # cfg read and merge
76
+ cfg.merge_from_file(args.config)
77
+ cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml")
78
+
79
+ cfg_show_list = [
80
+ "test_gpus",
81
+ [args.gpu_device],
82
+ "mcube_res",
83
+ 256,
84
+ "clean_mesh",
85
+ True,
86
+ ]
87
+
88
+ cfg.merge_from_list(cfg_show_list)
89
+ cfg.freeze()
90
+
91
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
92
+ device = torch.device(f"cuda:{args.gpu_device}")
93
+
94
+ if args.colab:
95
+ print(colored("colab environment...", "red"))
96
+ from tqdm.notebook import tqdm
97
+ else:
98
+ print(colored("normal environment...", "red"))
99
+ from tqdm import tqdm
100
+
101
+ # load model and dataloader
102
+ model = ICON(cfg)
103
+ model = load_checkpoint(model, cfg)
104
+
105
+ dataset_param = {
106
+ 'image_dir': args.in_dir,
107
+ 'seg_dir': args.seg_dir,
108
+ 'has_det': True, # w/ or w/o detection
109
+ 'hps_type': args.hps_type # pymaf/pare/pixie
110
+ }
111
+
112
+ if args.hps_type == "pixie" and "pamir" in args.config:
113
+ print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red"))
114
+ dataset_param["hps_type"] = "pymaf"
115
+
116
+ dataset = TestDataset(dataset_param, device)
117
+
118
+ print(colored(f"Dataset Size: {len(dataset)}", "green"))
119
+
120
+ pbar = tqdm(dataset)
121
+
122
+ for data in pbar:
123
+
124
+ pbar.set_description(f"{data['name']}")
125
+
126
+ in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]}
127
+
128
+ # The optimizer and variables
129
+ optimed_pose = torch.tensor(
130
+ data["body_pose"], device=device, requires_grad=True
131
+ ) # [1,23,3,3]
132
+ optimed_trans = torch.tensor(
133
+ data["trans"], device=device, requires_grad=True
134
+ ) # [3]
135
+ optimed_betas = torch.tensor(
136
+ data["betas"], device=device, requires_grad=True
137
+ ) # [1,10]
138
+ optimed_orient = torch.tensor(
139
+ data["global_orient"], device=device, requires_grad=True
140
+ ) # [1,1,3,3]
141
+
142
+ optimizer_smpl = torch.optim.SGD(
143
+ [optimed_pose, optimed_trans, optimed_betas, optimed_orient],
144
+ lr=1e-3,
145
+ momentum=0.9,
146
+ )
147
+ scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
148
+ optimizer_smpl,
149
+ mode="min",
150
+ factor=0.5,
151
+ verbose=0,
152
+ min_lr=1e-5,
153
+ patience=args.patience,
154
+ )
155
+
156
+ losses = {
157
+ "cloth": {"weight": 1e1, "value": 0.0}, # Cloth: Normal_recon - Normal_pred
158
+ "stiffness": {"weight": 1e5, "value": 0.0}, # Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
159
+ "rigid": {"weight": 1e5, "value": 0.0}, # Cloth: det(R) = 1
160
+ "edge": {"weight": 0, "value": 0.0}, # Cloth: edge length
161
+ "nc": {"weight": 0, "value": 0.0}, # Cloth: normal consistency
162
+ "laplacian": {"weight": 1e2, "value": 0.0}, # Cloth: laplacian smoonth
163
+ "normal": {"weight": 1e0, "value": 0.0}, # Body: Normal_pred - Normal_smpl
164
+ "silhouette": {"weight": 1e1, "value": 0.0}, # Body: Silhouette_pred - Silhouette_smpl
165
+ }
166
+
167
+ # smpl optimization
168
+
169
+ loop_smpl = tqdm(
170
+ range(args.loop_smpl if cfg.net.prior_type != "pifu" else 1))
171
+
172
+ per_data_lst = []
173
+
174
+ for i in loop_smpl:
175
+
176
+ per_loop_lst = []
177
+
178
+ optimizer_smpl.zero_grad()
179
+
180
+ if dataset_param["hps_type"] != "pixie":
181
+ smpl_out = dataset.smpl_model(
182
+ betas=optimed_betas,
183
+ body_pose=optimed_pose,
184
+ global_orient=optimed_orient,
185
+ pose2rot=False,
186
+ )
187
+
188
+ smpl_verts = ((smpl_out.vertices) +
189
+ optimed_trans) * data["scale"]
190
+ else:
191
+ smpl_verts, _, _ = dataset.smpl_model(
192
+ shape_params=optimed_betas,
193
+ expression_params=tensor2variable(data["exp"], device),
194
+ body_pose=optimed_pose,
195
+ global_pose=optimed_orient,
196
+ jaw_pose=tensor2variable(data["jaw_pose"], device),
197
+ left_hand_pose=tensor2variable(
198
+ data["left_hand_pose"], device),
199
+ right_hand_pose=tensor2variable(
200
+ data["right_hand_pose"], device),
201
+ )
202
+
203
+ smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
204
+
205
+ # render optimized mesh (normal, T_normal, image [-1,1])
206
+ in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
207
+ smpl_verts *
208
+ torch.tensor([1.0, -1.0, -1.0]
209
+ ).to(device), in_tensor["smpl_faces"]
210
+ )
211
+ T_mask_F, T_mask_B = dataset.render.get_silhouette_image()
212
+
213
+ with torch.no_grad():
214
+ in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter(
215
+ in_tensor
216
+ )
217
+
218
+ diff_F_smpl = torch.abs(
219
+ in_tensor["T_normal_F"] - in_tensor["normal_F"])
220
+ diff_B_smpl = torch.abs(
221
+ in_tensor["T_normal_B"] - in_tensor["normal_B"])
222
+
223
+ loss_F_smpl = normal_loss(
224
+ in_tensor["T_normal_F"], in_tensor["normal_F"])
225
+ loss_B_smpl = normal_loss(
226
+ in_tensor["T_normal_B"], in_tensor["normal_B"])
227
+
228
+ losses["normal"]["value"] = (loss_F_smpl + loss_B_smpl).mean()
229
+
230
+ # silhouette loss
231
+ smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
232
+ gt_arr = torch.cat(
233
+ [in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2
234
+ ).permute(1, 2, 0)
235
+ gt_arr = ((gt_arr + 1.0) * 0.5).to(device)
236
+ bg_color = (
237
+ torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
238
+ 0).unsqueeze(0).to(device)
239
+ )
240
+ gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
241
+ diff_S = torch.abs(smpl_arr - gt_arr)
242
+ losses["silhouette"]["value"] = diff_S.mean()
243
+
244
+ # Weighted sum of the losses
245
+ smpl_loss = 0.0
246
+ pbar_desc = "Body Fitting --- "
247
+ for k in ["normal", "silhouette"]:
248
+ pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | "
249
+ smpl_loss += losses[k]["value"] * losses[k]["weight"]
250
+ pbar_desc += f"Total: {smpl_loss:.3f}"
251
+ loop_smpl.set_description(pbar_desc)
252
+
253
+ if i % args.vis_freq == 0:
254
+
255
+ per_loop_lst.extend(
256
+ [
257
+ in_tensor["image"],
258
+ in_tensor["T_normal_F"],
259
+ in_tensor["normal_F"],
260
+ diff_F_smpl / 2.0,
261
+ diff_S[:, :512].unsqueeze(
262
+ 0).unsqueeze(0).repeat(1, 3, 1, 1),
263
+ ]
264
+ )
265
+ per_loop_lst.extend(
266
+ [
267
+ in_tensor["image"],
268
+ in_tensor["T_normal_B"],
269
+ in_tensor["normal_B"],
270
+ diff_B_smpl / 2.0,
271
+ diff_S[:, 512:].unsqueeze(
272
+ 0).unsqueeze(0).repeat(1, 3, 1, 1),
273
+ ]
274
+ )
275
+ per_data_lst.append(
276
+ get_optim_grid_image(
277
+ per_loop_lst, None, nrow=5, type="smpl")
278
+ )
279
+
280
+ smpl_loss.backward()
281
+ optimizer_smpl.step()
282
+ scheduler_smpl.step(smpl_loss)
283
+ in_tensor["smpl_verts"] = smpl_verts * \
284
+ torch.tensor([1.0, 1.0, -1.0]).to(device)
285
+
286
+ # visualize the optimization process
287
+ # 1. SMPL Fitting
288
+ # 2. Clothes Refinement
289
+
290
+ os.makedirs(os.path.join(args.out_dir, cfg.name,
291
+ "refinement"), exist_ok=True)
292
+
293
+ # visualize the final results in self-rotation mode
294
+ os.makedirs(os.path.join(args.out_dir, cfg.name, "vid"), exist_ok=True)
295
+
296
+ # final results rendered as image
297
+ # 1. Render the final fitted SMPL (xxx_smpl.png)
298
+ # 2. Render the final reconstructed clothed human (xxx_cloth.png)
299
+ # 3. Blend the original image with predicted cloth normal (xxx_overlap.png)
300
+
301
+ os.makedirs(os.path.join(args.out_dir, cfg.name, "png"), exist_ok=True)
302
+
303
+ # final reconstruction meshes
304
+ # 1. SMPL mesh (xxx_smpl.obj)
305
+ # 2. SMPL params (xxx_smpl.npy)
306
+ # 3. clohted mesh (xxx_recon.obj)
307
+ # 4. remeshed clothed mesh (xxx_remesh.obj)
308
+ # 5. refined clothed mesh (xxx_refine.obj)
309
+
310
+ os.makedirs(os.path.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
311
+
312
+ if cfg.net.prior_type != "pifu":
313
+
314
+ per_data_lst[0].save(
315
+ os.path.join(
316
+ args.out_dir, cfg.name, f"refinement/{data['name']}_smpl.gif"
317
+ ),
318
+ save_all=True,
319
+ append_images=per_data_lst[1:],
320
+ duration=500,
321
+ loop=0,
322
+ )
323
+
324
+ if args.vis_freq == 1:
325
+ image2vid(
326
+ per_data_lst,
327
+ os.path.join(
328
+ args.out_dir, cfg.name, f"refinement/{data['name']}_smpl.avi"
329
+ ),
330
+ )
331
+
332
+ per_data_lst[-1].save(
333
+ os.path.join(args.out_dir, cfg.name,
334
+ f"png/{data['name']}_smpl.png")
335
+ )
336
+
337
+ norm_pred = (
338
+ ((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0)
339
+ .detach()
340
+ .cpu()
341
+ .numpy()
342
+ .astype(np.uint8)
343
+ )
344
+
345
+ norm_orig = unwrap(norm_pred, data)
346
+ mask_orig = unwrap(
347
+ np.repeat(
348
+ data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2
349
+ ).astype(np.uint8),
350
+ data,
351
+ )
352
+ rgb_norm = blend_rgb_norm(data["ori_image"], norm_orig, mask_orig)
353
+
354
+ Image.fromarray(
355
+ np.concatenate(
356
+ [data["ori_image"].astype(np.uint8), rgb_norm], axis=1)
357
+ ).save(os.path.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png"))
358
+
359
+ smpl_obj = trimesh.Trimesh(
360
+ in_tensor["smpl_verts"].detach().cpu()[0] *
361
+ torch.tensor([1.0, -1.0, 1.0]),
362
+ in_tensor['smpl_faces'].detach().cpu()[0],
363
+ process=False,
364
+ maintains_order=True
365
+ )
366
+ smpl_obj.export(
367
+ f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl.obj")
368
+
369
+ smpl_info = {'betas': optimed_betas,
370
+ 'pose': optimed_pose,
371
+ 'orient': optimed_orient,
372
+ 'trans': optimed_trans}
373
+
374
+ np.save(
375
+ f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True)
376
+
377
+ # ------------------------------------------------------------------------------------------------------------------
378
+
379
+ # cloth optimization
380
+
381
+ per_data_lst = []
382
+
383
+ # cloth recon
384
+ in_tensor.update(
385
+ dataset.compute_vis_cmap(
386
+ in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0]
387
+ )
388
+ )
389
+
390
+ if cfg.net.prior_type == "pamir":
391
+ in_tensor.update(
392
+ dataset.compute_voxel_verts(
393
+ optimed_pose,
394
+ optimed_orient,
395
+ optimed_betas,
396
+ optimed_trans,
397
+ data["scale"],
398
+ )
399
+ )
400
+
401
+ with torch.no_grad():
402
+ verts_pr, faces_pr, _ = model.test_single(in_tensor)
403
+
404
+ recon_obj = trimesh.Trimesh(
405
+ verts_pr, faces_pr, process=False, maintains_order=True
406
+ )
407
+ recon_obj.export(
408
+ os.path.join(args.out_dir, cfg.name,
409
+ f"obj/{data['name']}_recon.obj")
410
+ )
411
+
412
+ # Isotropic Explicit Remeshing for better geometry topology
413
+ verts_refine, faces_refine = remesh(os.path.join(args.out_dir, cfg.name,
414
+ f"obj/{data['name']}_recon.obj"), 0.5, device)
415
+
416
+ # define local_affine deform verts
417
+ mesh_pr = Meshes(verts_refine, faces_refine).to(device)
418
+ local_affine_model = LocalAffine(
419
+ mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device)
420
+ optimizer_cloth = torch.optim.Adam(
421
+ [{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True)
422
+
423
+ scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
424
+ optimizer_cloth,
425
+ mode="min",
426
+ factor=0.1,
427
+ verbose=0,
428
+ min_lr=1e-5,
429
+ patience=args.patience,
430
+ )
431
+
432
+ with torch.no_grad():
433
+ per_loop_lst = []
434
+ rotate_recon_lst = dataset.render.get_rgb_image(cam_ids=[
435
+ 0, 1, 2, 3])
436
+ per_loop_lst.extend(rotate_recon_lst)
437
+ per_data_lst.append(get_optim_grid_image(
438
+ per_loop_lst, None, type="cloth"))
439
+
440
+ final = None
441
+
442
+ if args.loop_cloth > 0:
443
+
444
+ loop_cloth = tqdm(range(args.loop_cloth))
445
+
446
+ for i in loop_cloth:
447
+
448
+ per_loop_lst = []
449
+
450
+ optimizer_cloth.zero_grad()
451
+
452
+ deformed_verts, stiffness, rigid = local_affine_model(
453
+ verts_refine.to(device), return_stiff=True)
454
+ mesh_pr = mesh_pr.update_padded(deformed_verts)
455
+
456
+ # losses for laplacian, edge, normal consistency
457
+ update_mesh_shape_prior_losses(mesh_pr, losses)
458
+
459
+ in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal(
460
+ mesh_pr.verts_padded(), mesh_pr.faces_padded())
461
+
462
+ diff_F_cloth = torch.abs(
463
+ in_tensor["P_normal_F"] - in_tensor["normal_F"])
464
+ diff_B_cloth = torch.abs(
465
+ in_tensor["P_normal_B"] - in_tensor["normal_B"])
466
+
467
+ losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
468
+ losses["stiffness"]["value"] = torch.mean(stiffness)
469
+ losses["rigid"]["value"] = torch.mean(rigid)
470
+
471
+ # Weighted sum of the losses
472
+ cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
473
+ pbar_desc = "Cloth Refinement --- "
474
+
475
+ for k in losses.keys():
476
+ if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0:
477
+ cloth_loss = cloth_loss + \
478
+ losses[k]["value"] * losses[k]["weight"]
479
+ pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | "
480
+
481
+ pbar_desc += f"Total: {cloth_loss:.5f}"
482
+ loop_cloth.set_description(pbar_desc)
483
+
484
+ # update params
485
+ cloth_loss.backward(retain_graph=True)
486
+ optimizer_cloth.step()
487
+ scheduler_cloth.step(cloth_loss)
488
+
489
+ # for vis
490
+ with torch.no_grad():
491
+ if i % args.vis_freq == 0:
492
+
493
+ rotate_recon_lst = dataset.render.get_rgb_image(cam_ids=[
494
+ 0, 1, 2, 3])
495
+
496
+ per_loop_lst.extend(
497
+ [
498
+ in_tensor["image"],
499
+ in_tensor["P_normal_F"],
500
+ in_tensor["normal_F"],
501
+ diff_F_cloth / 2.0,
502
+ ]
503
+ )
504
+ per_loop_lst.extend(
505
+ [
506
+ in_tensor["image"],
507
+ in_tensor["P_normal_B"],
508
+ in_tensor["normal_B"],
509
+ diff_B_cloth / 2.0,
510
+ ]
511
+ )
512
+ per_loop_lst.extend(rotate_recon_lst)
513
+ per_data_lst.append(
514
+ get_optim_grid_image(
515
+ per_loop_lst, None, type="cloth")
516
+ )
517
+
518
+ # gif for optimization
519
+ per_data_lst[1].save(
520
+ os.path.join(
521
+ args.out_dir, cfg.name, f"refinement/{data['name']}_cloth.gif"
522
+ ),
523
+ save_all=True,
524
+ append_images=per_data_lst[2:],
525
+ duration=500,
526
+ loop=0,
527
+ )
528
+
529
+ if args.vis_freq == 1:
530
+ image2vid(
531
+ per_data_lst,
532
+ os.path.join(
533
+ args.out_dir, cfg.name, f"refinement/{data['name']}_cloth.avi"
534
+ ),
535
+ )
536
+
537
+ final = trimesh.Trimesh(
538
+ mesh_pr.verts_packed().detach().squeeze(0).cpu(),
539
+ mesh_pr.faces_packed().detach().squeeze(0).cpu(),
540
+ process=False, maintains_order=True
541
+ )
542
+ final_colors = query_color(
543
+ mesh_pr.verts_packed().detach().squeeze(0).cpu(),
544
+ mesh_pr.faces_packed().detach().squeeze(0).cpu(),
545
+ in_tensor["image"],
546
+ device=device,
547
+ )
548
+ final.visual.vertex_colors = final_colors
549
+ final.export(
550
+ f"{args.out_dir}/{cfg.name}/obj/{data['name']}_refine.obj")
551
+
552
+ # always export visualized png regardless of the cloth refinment
553
+ per_data_lst[-1].save(
554
+ os.path.join(args.out_dir, cfg.name,
555
+ f"png/{data['name']}_cloth.png")
556
+ )
557
+
558
+ # always export visualized video regardless of the cloth refinment
559
+ if args.export_video:
560
+ if final is not None:
561
+ verts_lst = [verts_pr, final.vertices]
562
+ faces_lst = [faces_pr, final.faces]
563
+ else:
564
+ verts_lst = [verts_pr]
565
+ faces_lst = [faces_pr]
566
+
567
+ # self-rotated video
568
+ dataset.render.load_meshes(
569
+ verts_lst, faces_lst)
570
+ dataset.render.get_rendered_video(
571
+ [data["ori_image"], rgb_norm],
572
+ os.path.join(args.out_dir, cfg.name,
573
+ f"vid/{data['name']}_cloth.mp4"),
574
+ )
575
+
576
+ # garment extraction from deepfashion images
577
+ if not (args.seg_dir is None):
578
+ if final is not None:
579
+ recon_obj = final.copy()
580
+
581
+ os.makedirs(os.path.join(
582
+ args.out_dir, cfg.name, "clothes"), exist_ok=True)
583
+ os.makedirs(os.path.join(args.out_dir, cfg.name,
584
+ "clothes", "info"), exist_ok=True)
585
+ for seg in data['segmentations']:
586
+ # These matrices work for PyMaf, not sure about the other hps type
587
+ K = np.array([[1.0000, 0.0000, 0.0000, 0.0000],
588
+ [0.0000, 1.0000, 0.0000, 0.0000],
589
+ [0.0000, 0.0000, -0.5000, 0.0000],
590
+ [-0.0000, -0.0000, 0.5000, 1.0000]]).T
591
+
592
+ R = np.array([[-1., 0., 0.],
593
+ [0., 1., 0.],
594
+ [0., 0., -1.]])
595
+
596
+ t = np.array([[-0., -0., 100.]])
597
+ clothing_obj = extract_cloth(recon_obj, seg, K, R, t, smpl_obj)
598
+ if clothing_obj is not None:
599
+ cloth_type = seg['type'].replace(' ', '_')
600
+ cloth_info = {
601
+ 'betas': optimed_betas,
602
+ 'body_pose': optimed_pose,
603
+ 'global_orient': optimed_orient,
604
+ 'pose2rot': False,
605
+ 'clothing_type': cloth_type,
606
+ }
607
+
608
+ file_id = f"{data['name']}_{cloth_type}"
609
+ with open(os.path.join(args.out_dir, cfg.name, "clothes", "info", f"{file_id}_info.pkl"), 'wb') as fp:
610
+ pickle.dump(cloth_info, fp)
611
+
612
+ clothing_obj.export(os.path.join(
613
+ args.out_dir, cfg.name, "clothes", f"{file_id}.obj"))
614
+ else:
615
+ print(
616
+ f"Unable to extract clothing of type {seg['type']} from image {data['name']}")
assets/garment_teaser.png ADDED
assets/intermediate_results.png ADDED
assets/teaser.gif ADDED
configs/icon-filter.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: icon-filter
2
+ ckpt_dir: "./data/ckpt/"
3
+ resume_path: "./data/ckpt/icon-filter.ckpt"
4
+ normal_path: "./data/ckpt/normal.ckpt"
5
+
6
+ test_mode: True
7
+ batch_size: 1
8
+
9
+ net:
10
+ mlp_dim: [256, 512, 256, 128, 1]
11
+ res_layers: [2,3,4]
12
+ num_stack: 2
13
+ prior_type: "icon" # icon/pamir/icon
14
+ use_filter: True
15
+ in_geo: (('normal_F',3), ('normal_B',3))
16
+ in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
17
+ smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
18
+ gtype: 'HGPIFuNet'
19
+ norm_mlp: 'batch'
20
+ hourglass_dim: 6
21
+ smpl_dim: 7
22
+
23
+ # user defined
24
+ mcube_res: 512 # occupancy field resolution, higher --> more details
25
+ clean_mesh: False # if True, will remove floating pieces
configs/icon-nofilter.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: icon-nofilter
2
+ ckpt_dir: "./data/ckpt/"
3
+ resume_path: "./data/ckpt/icon-nofilter.ckpt"
4
+ normal_path: "./data/ckpt/normal.ckpt"
5
+
6
+ test_mode: True
7
+ batch_size: 1
8
+
9
+ net:
10
+ mlp_dim: [256, 512, 256, 128, 1]
11
+ res_layers: [2,3,4]
12
+ num_stack: 2
13
+ prior_type: "icon" # icon/pamir/icon
14
+ use_filter: False
15
+ in_geo: (('normal_F',3), ('normal_B',3))
16
+ in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
17
+ smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
18
+ gtype: 'HGPIFuNet'
19
+ norm_mlp: 'batch'
20
+ hourglass_dim: 6
21
+ smpl_dim: 7
22
+
23
+ # user defined
24
+ mcube_res: 512 # occupancy field resolution, higher --> more details
25
+ clean_mesh: False # if True, will remove floating pieces
configs/pamir.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pamir
2
+ ckpt_dir: "./data/ckpt/"
3
+ resume_path: "./data/ckpt/pamir.ckpt"
4
+ normal_path: "./data/ckpt/normal.ckpt"
5
+
6
+ test_mode: True
7
+ batch_size: 1
8
+
9
+ net:
10
+ mlp_dim: [256, 512, 256, 128, 1]
11
+ res_layers: [2,3,4]
12
+ num_stack: 2
13
+ prior_type: "pamir" # icon/pamir/icon
14
+ use_filter: True
15
+ in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
16
+ in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
17
+ gtype: 'HGPIFuNet'
18
+ norm_mlp: 'batch'
19
+ hourglass_dim: 6
20
+ voxel_dim: 7
21
+
22
+ # user defined
23
+ mcube_res: 512 # occupancy field resolution, higher --> more details
24
+ clean_mesh: False # if True, will remove floating pieces
configs/pifu.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pifu
2
+ ckpt_dir: "./data/ckpt/"
3
+ resume_path: "./data/ckpt/pifu.ckpt"
4
+ normal_path: "./data/ckpt/normal.ckpt"
5
+
6
+ test_mode: True
7
+ batch_size: 1
8
+
9
+ net:
10
+ mlp_dim: [256, 512, 256, 128, 1]
11
+ res_layers: [2,3,4]
12
+ num_stack: 2
13
+ prior_type: "pifu" # icon/pamir/icon
14
+ use_filter: True
15
+ in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
16
+ in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
17
+ gtype: 'HGPIFuNet'
18
+ norm_mlp: 'batch'
19
+ hourglass_dim: 12
20
+
21
+
22
+ # user defined
23
+ mcube_res: 512 # occupancy field resolution, higher --> more details
24
+ clean_mesh: False # if True, will remove floating pieces
environment.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: icon
2
+ channels:
3
+ - pytorch-lts
4
+ - nvidia
5
+ - conda-forge
6
+ - fvcore
7
+ - iopath
8
+ - bottler
9
+ - defaults
10
+ dependencies:
11
+ - pytorch
12
+ - torchvision
13
+ - fvcore
14
+ - iopath
15
+ - nvidiacub
16
+ - pyembree
examples/22097467bffc92d4a5c4246f7d4edb75.png ADDED
examples/44c0f84c957b6b9bdf77662af5bb7078.png ADDED
examples/5a6a25963db2f667441d5076972c207c.png ADDED
examples/8da7ceb94669c2f65cbd28022e1f9876.png ADDED
examples/923d65f767c85a42212cae13fba3750b.png ADDED
examples/959c4c726a69901ce71b93a9242ed900.png ADDED
examples/c9856a2bc31846d684cbb965457fad59.png ADDED
examples/e1e7622af7074a022f5d96dc16672517.png ADDED
examples/fb9d20fdb93750584390599478ecf86e.png ADDED
examples/segmentation/003883.jpg ADDED
examples/segmentation/003883.json ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "item2": {
3
+ "segmentation": [
4
+ [
5
+ 232.29572649572654, 34.447388414055126, 237.0364672364673,
6
+ 40.57084520417861, 244.9377018043686, 47.089363722697165,
7
+ 252.04881291547974, 49.65726495726508, 262.5179487179489,
8
+ 51.43504273504287, 269.233998100665, 50.447388414055204,
9
+ 277.5446343779678, 49.12725546058881, 285.64339981006657,
10
+ 46.16429249762584, 294.9273504273506, 41.22602089268754,
11
+ 299.9377967711301, 36.514245014245084, 304.67853751187084,
12
+ 30.588319088319132, 306.0612535612536, 25.65004748338083,
13
+ 307.64150047483383, 23.477207977207982, 311.19705603038943,
14
+ 24.859924026590704, 317.12298195631536, 28.020417853751216,
15
+ 323.04890788224134, 29.008072174738874, 331.34520417853764,
16
+ 30.193257359924065, 339.4439696106365, 34.7364672364673,
17
+ 346.75261158594515, 39.279677113010536, 350.11063627730323,
18
+ 44.61301044634389, 355.00541310541314, 61.422317188983875,
19
+ 358.9560303893638, 77.6198480531815, 362.1165242165243,
20
+ 90.26182336182353, 364.88195631528976, 103.29886039886063,
21
+ 367.6473884140552, 118.11367521367552, 369.42516619183294,
22
+ 129.37293447293484, 369.2324786324788, 132.60550807217476,
23
+ 365.6769230769232, 134.77834757834762, 359.15840455840464,
24
+ 138.3339031339032, 353.43000949667623, 140.70427350427357,
25
+ 351.4547008547009, 141.4943969610637, 351.25716999050337,
26
+ 138.5314339981007, 351.05963912630585, 136.75365622032294,
27
+ 345.7263057929725, 137.34624881291552, 337.8250712250712,
28
+ 139.51908831908838, 331.5040835707502, 141.09933523266864,
29
+ 324.7880341880341, 143.66723646723653, 322.2201329534662,
30
+ 146.43266856600198, 322.2201329534662, 151.5684710351378,
31
+ 323.0102564102563, 160.6548907882243, 324.95185185185176,
32
+ 173.44615384615395, 325.34691358024685, 190.23627730294416,
33
+ 325.93950617283946, 205.64368471035164, 325.93950617283946,
34
+ 215.71775878442577, 325.93950617283946, 220.06343779677147,
35
+ 322.7790123456789, 223.22393162393197, 315.0753086419752,
36
+ 228.55726495726532, 309.34691358024673, 230.53257359924066,
37
+ 290.1866096866098, 230.87929724596398, 263.91500474833805,
38
+ 229.6941120607788, 236.45821462488112, 229.29905033238373,
39
+ 218.48290598290572, 226.73114909781583, 202.65650522317188,
40
+ 224.82811016144353, 197.71823361823357, 221.07502374169044,
41
+ 195.15033238366567, 214.55650522317188, 195.74292497625825,
42
+ 200.53181386514711, 197.125641025641, 180.5811965811964,
43
+ 197.33285849952523, 164.68736942070285, 198.51804368471042,
44
+ 154.21823361823365, 198.51804368471042, 138.61329534662863,
45
+ 193.5797720797721, 136.4404558404558, 185.08594491927823,
46
+ 133.08243114909774, 177.77730294396957, 128.73675213675205,
47
+ 174.41927825261152, 128.53922127255453, 173.82668566001894,
48
+ 133.2799620132953, 174.02421652421646, 136.24292497625825,
49
+ 172.83903133903127, 137.03304843304838, 167.11063627730283,
50
+ 134.86020892687554, 159.9995251661917, 130.51452991452985,
51
+ 159.01187084520404, 129.1318138651471, 159.60446343779662,
52
+ 123.60094966761622, 162.6012345679013, 111.57578347578357,
53
+ 165.95925925925934, 98.53874643874646, 170.30493827160504,
54
+ 82.7362773029439, 173.92307692307693, 70.05584045584048,
55
+ 177.08357075023744, 54.84596391263053, 180.58129154795822,
56
+ 41.73190883190885, 183.14919278252614, 34.423266856600165,
57
+ 188.51623931623936, 30.279962013295354, 195.6273504273505,
58
+ 25.539221272554588, 201.75080721747398, 22.971320037986676,
59
+ 211.23228869895553, 22.37872744539408, 221.10883190883212,
60
+ 20.996011396011355, 224.8619183285852, 20.996011396011355,
61
+ 226.04710351377042, 23.56391263057927, 229.01006647673339,
62
+ 30.279962013295354
63
+ ]
64
+ ],
65
+ "category_id": 1,
66
+ "category_name": "short sleeve top"
67
+ },
68
+ "item1": {
69
+ "segmentation": [
70
+ [
71
+ 201.51804815682925, 224.7401022799914, 218.41555508203712,
72
+ 227.23317707223518, 236.42109524824218, 228.89522693373104,
73
+ 256.91971020669104, 229.44924355422967, 280.188408267633,
74
+ 230.2802684849776, 296.53189857234224, 230.2802684849776,
75
+ 313.7064138077994, 229.72625186447897, 315.32667803111013,
76
+ 236.8076070743661, 317.8197528233539, 240.96273172810572,
77
+ 318.65077775410185, 246.2258896228426, 321.4208608565949,
78
+ 253.15109737907534, 322.8059024078415, 265.0624547197956,
79
+ 324.74496057958663, 273.6497123375242, 325.9612827615598,
80
+ 284.4076070743661, 325.40726614106114, 299.9200724483274,
81
+ 324.29923290006394, 316.8175793735353, 322.0831664180694,
82
+ 325.9588536117625, 320.16803750266354, 336.5366716386107,
83
+ 316.0129128489239, 344.01589601534204, 315.18188791817596,
84
+ 357.86631152780745, 312.4118048156829, 368.1156190070319,
85
+ 308.5336884721926, 378.64193479650567, 306.31762199019806,
86
+ 385.29013424248905, 305.76360536969946, 398.3095248242066,
87
+ 305.48659705945016, 409.6668655444283, 304.94393777967184,
88
+ 419.3418708715109, 302.7278712976774, 427.0981035584915,
89
+ 301.3428297464308, 433.74630300447495, 301.3428297464308,
90
+ 445.3806520349459, 300.5118048156829, 461.72414233965515,
91
+ 299.89735776688684, 467.352311953974, 297.9582995951417,
92
+ 477.60161943319844, 295.1882164926486, 491.7290432559132,
93
+ 293.52616663115276, 497.2692094608994, 291.8641167696569,
94
+ 503.36339228638417, 291.3101001491583, 510.8426166631155,
95
+ 289.37104197741314, 513.8897080758579, 287.4433411463882,
96
+ 519.2043682079693, 283.0112081823993, 519.7583848284679,
97
+ 275.5319838056679, 519.4813765182186, 270.26882591093107,
98
+ 518.096334966972, 265.8366929469421, 513.6642020029831,
99
+ 263.62062646494763, 509.78608565949276, 264.7286597059449,
100
+ 498.9827615597697, 265.2826763264435, 478.76115491157015,
101
+ 266.1137012571914, 467.1268058810992, 266.1137012571914,
102
+ 454.6614319198803, 264.17464308544623, 441.64204133816276,
103
+ 263.06660984444903, 424.19051779245626, 261.5834221180482,
104
+ 407.2581504368212, 259.92137225655233, 396.45482633709815,
105
+ 257.1512891540592, 380.1113360323889, 257.42829746430857,
106
+ 359.05870445344146, 256.8742808438099, 338.56008949499255,
107
+ 256.8742808438099, 321.3855742595354, 254.10419774131685,
108
+ 320.5545493287875, 251.05710632857443, 326.6487321542723,
109
+ 249.39505646707858, 339.1141061154912, 249.11804815682927,
110
+ 356.28862135094835, 248.28702322608135, 372.3551033454083,
111
+ 245.23993181333896, 387.59056040912026, 243.5766673769444,
112
+ 409.1404219049649, 241.91461751544855, 424.92989558917554,
113
+ 240.52957596420202, 440.4423609631369, 238.86752610270617,
114
+ 455.40080971659955, 238.86752610270617, 470.91327509056083,
115
+ 238.31350948220754, 486.42574046452216, 238.81966759002768,
116
+ 501.19639889196685, 239.6506925207756, 511.168698060942,
117
+ 236.0495844875346, 515.6008310249309, 229.40138504155118,
118
+ 519.4789473684212, 221.6451523545705, 520.3099722991692,
119
+ 216.65900277008296, 517.2628808864267, 213.33490304709125,
120
+ 509.50664819944615, 208.3487534626037, 491.50110803324105,
121
+ 205.8556786703599, 475.1576177285318, 203.63961218836545,
122
+ 460.75318559556774, 203.63961218836545, 443.3016620498613,
123
+ 203.63961218836545, 421.9720221606645, 200.59252077562303,
124
+ 415.60083102493036, 197.5052844662264, 406.9847858512679,
125
+ 195.28921798423193, 392.0263370978052, 193.35015981248677,
126
+ 370.97370551885774, 190.857085020243, 343.82689111442545,
127
+ 187.8099936075006, 322.77425953547794, 187.0028979330919,
128
+ 309.89237161730256, 186.17187300234397, 291.33281483059886,
129
+ 188.11093117408916, 266.67907521841033, 191.15802258683155,
130
+ 250.3355849137011, 196.69818879181773, 234.82311953973982
131
+ ]
132
+ ],
133
+ "category_id": 8,
134
+ "category_name": "trousers"
135
+ }
136
+ }
examples/segmentation/028009.jpg ADDED
examples/segmentation/028009.json ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "item2": {
3
+ "segmentation": [
4
+ [
5
+ 314.7474747474744, 204.84848484848482, 328.9696969696967,
6
+ 209.7373737373737, 342.74747474747454, 211.95959595959593,
7
+ 360.0808080808079, 211.07070707070704, 375.19191919191906,
8
+ 210.18181818181816, 384.5252525252524, 207.07070707070704,
9
+ 390.30303030303025, 204.84848484848482, 396.080808080808,
10
+ 201.29292929292924, 402.3030303030303, 204.40404040404036,
11
+ 412.969696969697, 203.9595959595959, 425.8585858585859,
12
+ 206.18181818181813, 434.3030303030304, 211.95959595959593,
13
+ 439.63636363636374, 223.0707070707071, 444.0808080808082,
14
+ 234.18181818181824, 448.52525252525265, 250.62626262626276,
15
+ 449.41414141414157, 260.848484848485, 452.08080808080825,
16
+ 279.0707070707073, 456.08080808080825, 300.84848484848516,
17
+ 457.858585858586, 308.40404040404076, 460.5252525252526,
18
+ 315.7575757575756, 460.96969696969705, 329.97979797979787,
19
+ 460.5252525252526, 345.9797979797979, 456.969696969697,
20
+ 363.75757575757575, 453.41414141414145, 373.5353535353536,
21
+ 450.3030303030303, 385.97979797979804, 447.1919191919192,
22
+ 393.9797979797981, 443.6363636363636, 401.9797979797981,
23
+ 438.3030303030303, 403.7575757575759, 433.85858585858585,
24
+ 401.09090909090924, 430.7474747474747, 393.0909090909092,
25
+ 426.7474747474747, 383.3131313131314, 424.9696969696969,
26
+ 374.8686868686869, 424.9696969696969, 369.0909090909091,
27
+ 423.63636363636357, 363.3131313131313, 423.63636363636357,
28
+ 359.3131313131313, 423.63636363636357, 352.6464646464646,
29
+ 420.9696969696969, 350.86868686868684, 422.74747474747466,
30
+ 345.53535353535347, 422.74747474747466, 340.64646464646455,
31
+ 422.74747474747466, 332.2020202020201, 421.8585858585858,
32
+ 321.53535353535335, 418.74747474747466, 313.0909090909089,
33
+ 416.5252525252524, 306.4242424242422, 412.9696969696969,
34
+ 314.8686868686867, 410.3030303030302, 320.20202020202004,
35
+ 411.6363636363635, 327.3131313131312, 414.74747474747466,
36
+ 336.2020202020201, 418.74747474747466, 351.7575757575757,
37
+ 420.9696969696969, 365.0909090909091, 423.1919191919191,
38
+ 377.0909090909091, 423.1919191919191, 385.0909090909092,
39
+ 424.5252525252525, 398.42424242424255, 396.0808080808079,
40
+ 398.42424242424255, 374.7474747474745, 400.6464646464648,
41
+ 354.7474747474744, 400.6464646464648, 331.6363636363632,
42
+ 400.6464646464648, 313.41414141414094, 400.6464646464648,
43
+ 305.4141414141409, 399.3131313131314, 297.4141414141409,
44
+ 396.6464646464648, 284.525252525252, 396.2020202020203,
45
+ 282.8686868686866, 391.59595959595964, 282.42424242424215,
46
+ 373.81818181818176, 282.42424242424215, 358.26262626262616,
47
+ 281.09090909090884, 334.70707070707056, 281.5353535353533,
48
+ 313.37373737373713, 283.31313131313107, 297.3737373737371,
49
+ 282.8686868686866, 283.1515151515148, 280.6464646464644,
50
+ 266.7070707070703, 271.313131313131, 253.3737373737369,
51
+ 264.6464646464643, 246.70707070707022, 257.5353535353532,
52
+ 239.59595959595907, 249.9797979797976, 228.9292929292924,
53
+ 242.42424242424204, 220.92929292929236, 233.17171717171723,
54
+ 209.01010101010093, 225.1717171717172, 194.78787878787867,
55
+ 222.06060606060606, 185.4545454545453, 224.2828282828283,
56
+ 179.6767676767675, 230.0606060606061, 171.67676767676747,
57
+ 232.72727272727278, 169.89898989898967, 243.83838383838392,
58
+ 167.67676767676744, 256.2828282828284, 165.4545454545452,
59
+ 274.06060606060623, 165.4545454545452, 291.8383838383841,
60
+ 167.67676767676744, 302.5050505050508, 168.1212121212119,
61
+ 310.94949494949526, 177.0101010101008, 314.0606060606064,
62
+ 181.45454545454527, 314.94949494949526, 187.2323232323231,
63
+ 312.7272727272731, 193.01010101010087, 307.8383838383842,
64
+ 191.2323232323231, 302.94949494949526, 193.45454545454533,
65
+ 292.727272727273, 193.45454545454533, 290.50505050505075,
66
+ 195.67676767676755, 287.39393939393966, 197.45454545454533,
67
+ 285.61616161616183, 197.45454545454533, 283.3939393939396,
68
+ 193.89898989898978, 278.94949494949515, 197.45454545454533,
69
+ 274.94949494949515, 199.67676767676755, 279.83838383838406,
70
+ 201.45454545454535, 286.50505050505075, 201.45454545454535,
71
+ 291.8383838383841, 201.8989898989898, 296.2828282828286,
72
+ 202.7878787878787, 303.3939393939397, 202.34343434343424
73
+ ]
74
+ ],
75
+ "category_id": 2,
76
+ "category_name": "long sleeve top"
77
+ },
78
+ "item1": {
79
+ "segmentation": [
80
+ [
81
+ 346.9494949494949, 660.6868686868687, 397.6161616161618,
82
+ 661.5757575757576, 398.06060606060623, 674.0202020202021,
83
+ 398.94949494949515, 691.3535353535356, 397.6161616161618,
84
+ 710.0202020202022, 395.838383838384, 726.0202020202023,
85
+ 393.1717171717173, 742.0202020202023, 346.9494949494949,
86
+ 738.9090909090912, 346.50505050505046, 724.2424242424245,
87
+ 347.3939393939394, 713.5757575757578, 348.72727272727275,
88
+ 706.0202020202022, 349.17171717171715, 686.0202020202022,
89
+ 348.72727272727275, 675.7979797979799, 347.3939393939394,
90
+ 667.7979797979799
91
+ ],
92
+ [
93
+ 283.71717171717165, 396.68686868686876, 289.9393939393939,
94
+ 396.68686868686876, 303.27272727272725, 397.1313131313132,
95
+ 312.16161616161617, 399.7979797979799, 334.3838383838385,
96
+ 400.68686868686876, 351.7171717171719, 400.68686868686876,
97
+ 361.93939393939417, 401.5757575757577, 376.60606060606085,
98
+ 401.5757575757577, 390.82828282828314, 398.46464646464653,
99
+ 410.3838383838388, 397.5757575757577, 425.0505050505055,
100
+ 394.46464646464653, 431.71717171717216, 422.9090909090911,
101
+ 434.38383838383885, 447.79797979798, 430.38383838383885,
102
+ 478.0202020202024, 423.2727272727277, 507.79797979798025,
103
+ 418.3838383838388, 530.0202020202025, 411.8787878787878,
104
+ 557.3333333333333, 403.43434343434336, 590.6666666666666,
105
+ 400.7676767676767, 611.5555555555557, 399.8787878787878,
106
+ 619.1111111111112, 399.8787878787878, 630.6666666666669,
107
+ 398.10101010101, 635.1111111111113, 399.43434343434336,
108
+ 641.7777777777779, 399.43434343434336, 656.4444444444447,
109
+ 398.10101010101, 662.666666666667, 347.4343434343432, 662.666666666667,
110
+ 346.1010101010098, 637.7777777777779, 347.4343434343432,
111
+ 610.6666666666667, 349.21212121212096, 576.4444444444445,
112
+ 350.98989898989873, 556.4444444444443, 349.6565656565654,
113
+ 541.3333333333331, 348.32323232323205, 535.9999999999998,
114
+ 348.32323232323205, 523.5555555555553, 349.21212121212096,
115
+ 505.33333333333303, 342.5454545454543, 511.5555555555553,
116
+ 338.9898989898987, 516.8888888888887, 334.5454545454542,
117
+ 523.5555555555553, 325.6565656565653, 543.111111111111,
118
+ 319.87878787878753, 556.4444444444443, 314.1010101010097,
119
+ 568.4444444444443, 307.8787878787875, 583.1111111111111,
120
+ 300.3232323232319, 608.0000000000001, 298.10101010100965,
121
+ 617.7777777777778, 298.5454545454541, 624.0000000000001,
122
+ 295.43434343434296, 628.0000000000001, 293.2121212121208,
123
+ 628.0000000000001, 293.6565656565652, 632.4444444444446,
124
+ 291.43434343434296, 638.6666666666669, 290.54545454545405,
125
+ 644.4444444444447, 292.3232323232319, 648.8888888888891,
126
+ 303.8787878787875, 667.1111111111114, 313.65656565656525,
127
+ 684.0000000000003, 319.87878787878753, 700.8888888888893,
128
+ 322.54545454545416, 712.8888888888894, 324.323232323232,
129
+ 720.0000000000005, 327.87878787878753, 731.5555555555561,
130
+ 330.9898989898987, 738.6666666666672, 331.87878787878753,
131
+ 743.1111111111117, 334.5454545454542, 745.7777777777783,
132
+ 336.3232323232325, 749.1313131313133, 338.54545454545473,
133
+ 754.0202020202022, 338.54545454545473, 757.5757575757577,
134
+ 341.6565656565658, 760.6868686868688, 344.76767676767696,
135
+ 767.3535353535356, 345.2121212121214, 770.9090909090911,
136
+ 346.9898989898992, 754.0202020202022, 347.43434343434365,
137
+ 738.909090909091, 393.2121212121216, 740.6868686868687,
138
+ 389.65656565656604, 764.6868686868688, 386.5454545454549,
139
+ 784.2424242424245, 384.3232323232327, 806.9090909090912,
140
+ 382.54545454545485, 812.686868686869, 381.13131313131316,
141
+ 818.7070707070708, 378.020202020202, 828.4848484848485,
142
+ 375.35353535353534, 839.5959595959597, 374.9090909090909,
143
+ 854.2626262626264, 373.1313131313131, 856.9292929292931,
144
+ 376.24242424242425, 864.9292929292931, 372.24242424242425,
145
+ 874.2626262626264, 366.4646464646464, 880.9292929292932,
146
+ 357.13131313131305, 872.9292929292932, 345.13131313131305,
147
+ 868.0404040404043, 337.131313131313, 867.1515151515154,
148
+ 337.131313131313, 856.0404040404042, 338.4646464646463,
149
+ 850.7070707070709, 336.2424242424241, 846.2626262626264,
150
+ 335.3535353535352, 841.3737373737375, 338.4646464646463,
151
+ 827.5959595959597, 342.0202020202019, 815.5959595959596,
152
+ 344.6868686868686, 809.3737373737374, 344.6868686868686,
153
+ 796.4848484848484, 344.6868686868686, 786.7070707070707,
154
+ 346.0202020202019, 779.151515151515, 344.24242424242414,
155
+ 776.0404040404039, 343.3535353535352, 786.2626262626262,
156
+ 342.0202020202019, 796.0404040404039, 338.90909090909076,
157
+ 801.8181818181818, 333.57575757575745, 809.3737373737374,
158
+ 326.02020202020185, 813.8181818181819, 320.242424242424,
159
+ 812.4848484848485, 318.02020202020185, 810.7070707070707,
160
+ 317.13131313131294, 807.1515151515151, 315.79797979797956,
161
+ 803.5959595959596, 313.57575757575734, 799.5959595959596,
162
+ 311.3535353535351, 793.8181818181818, 306.90909090909065,
163
+ 791.1515151515151, 305.57575757575734, 787.5959595959595,
164
+ 304.242424242424, 782.7070707070706, 302.02020202020174,
165
+ 776.4848484848484, 298.90909090909065, 773.8181818181816,
166
+ 294.90909090909065, 771.151515151515, 290.34343434343435,
167
+ 758.909090909091, 284.5656565656566, 742.020202020202,
168
+ 278.78787878787875, 729.5757575757575, 270.3434343434343,
169
+ 713.131313131313, 257.8989898989898, 689.1313131313129,
170
+ 247.2323232323231, 669.1313131313128, 239.23232323232307,
171
+ 657.5757575757573, 233.89898989898973, 642.9090909090905,
172
+ 233.0101010101008, 634.0202020202016, 233.45454545454527,
173
+ 630.0202020202016, 235.23232323232304, 611.7979797979793,
174
+ 241.93939393939402, 583.0707070707073, 245.93939393939405,
175
+ 567.5151515151516, 251.2727272727274, 540.4040404040404,
176
+ 256.1616161616163, 518.6262626262626, 260.60606060606074,
177
+ 501.2929292929292, 263.7171717171719, 493.7373737373736,
178
+ 268.16161616161634, 481.73737373737356, 270.38383838383857,
179
+ 469.73737373737356, 272.6060606060608, 462.18181818181796,
180
+ 276.1616161616164, 457.7373737373735, 276.1616161616164,
181
+ 454.1818181818179, 277.05050505050525, 450.1818181818179,
182
+ 278.828282828283, 433.292929292929, 278.3838383838386,
183
+ 419.0707070707067, 278.828282828283, 417.29292929292893,
184
+ 281.0505050505053, 414.1818181818178, 281.93939393939417,
185
+ 404.8484848484844, 283.71717171717194, 401.2929292929289
186
+ ]
187
+ ],
188
+ "category_id": 8,
189
+ "category_name": "trousers"
190
+ }
191
+ }
examples/slack_trial2-000150.png ADDED
fetch_data.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
3
+
4
+ mkdir -p data/smpl_related/models
5
+
6
+ # username and password input
7
+ echo -e "\nYou need to register at https://icon.is.tue.mpg.de/, according to Installation Instruction."
8
+ read -p "Username (ICON):" username
9
+ read -p "Password (ICON):" password
10
+ username=$(urle $username)
11
+ password=$(urle $password)
12
+
13
+ # SMPL (Male, Female)
14
+ echo -e "\nDownloading SMPL..."
15
+ wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.0.0.zip&resume=1' -O './data/smpl_related/models/SMPL_python_v.1.0.0.zip' --no-check-certificate --continue
16
+ unzip data/smpl_related/models/SMPL_python_v.1.0.0.zip -d data/smpl_related/models
17
+ mv data/smpl_related/models/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl data/smpl_related/models/smpl/SMPL_FEMALE.pkl
18
+ mv data/smpl_related/models/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl data/smpl_related/models/smpl/SMPL_MALE.pkl
19
+ cd data/smpl_related/models
20
+ rm -rf *.zip __MACOSX smpl/models smpl/smpl_webuser
21
+ cd ../../..
22
+
23
+ # SMPL (Neutral, from SMPLIFY)
24
+ echo -e "\nDownloading SMPLify..."
25
+ wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplify&sfile=mpips_smplify_public_v2.zip&resume=1' -O './data/smpl_related/models/mpips_smplify_public_v2.zip' --no-check-certificate --continue
26
+ unzip data/smpl_related/models/mpips_smplify_public_v2.zip -d data/smpl_related/models
27
+ mv data/smpl_related/models/smplify_public/code/models/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl data/smpl_related/models/smpl/SMPL_NEUTRAL.pkl
28
+ cd data/smpl_related/models
29
+ rm -rf *.zip smplify_public
30
+ cd ../../..
31
+
32
+ # ICON
33
+ echo -e "\nDownloading ICON..."
34
+ wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=icon&sfile=icon_data.zip&resume=1' -O './data/icon_data.zip' --no-check-certificate --continue
35
+ cd data && unzip icon_data.zip
36
+ mv smpl_data smpl_related/
37
+ rm -f icon_data.zip
38
+ cd ..
39
+
40
+ function download_for_training () {
41
+
42
+ # SMPL-X (optional)
43
+ echo -e "\nDownloading SMPL-X..."
44
+ wget --post-data "username=$1&password=$2" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=models_smplx_v1_1.zip&resume=1' -O './data/smpl_related/models/models_smplx_v1_1.zip' --no-check-certificate --continue
45
+ unzip data/smpl_related/models/models_smplx_v1_1.zip -d data/smpl_related
46
+ rm -f data/smpl_related/models/models_smplx_v1_1.zip
47
+
48
+ # SMIL (optional)
49
+ echo -e "\nDownloading SMIL..."
50
+ wget --post-data "username=$1&password=$2" 'https://download.is.tue.mpg.de/download.php?domain=agora&sfile=smpl_kid_template.npy&resume=1' -O './data/smpl_related/models/smpl/smpl_kid_template.npy' --no-check-certificate --continue
51
+ wget --post-data "username=$1&password=$2" 'https://download.is.tue.mpg.de/download.php?domain=agora&sfile=smplx_kid_template.npy&resume=1' -O './data/smpl_related/models/smplx/smplx_kid_template.npy' --no-check-certificate --continue
52
+ }
53
+
54
+
55
+ read -p "(optional) Download models used for training (y/n)?" choice
56
+ case "$choice" in
57
+ y|Y ) download_for_training $username $password;;
58
+ n|N ) echo "Great job! Try the demo for now!";;
59
+ * ) echo "Invalid input! Please use y|Y or n|N";;
60
+ esac
install.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # conda installation
2
+ # wget https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.3-Linux-x86_64.sh
3
+ # chmod +x Miniconda3-py38_4.10.3-Linux-x86_64.sh
4
+ # bash Miniconda3-py38_4.10.3-Linux-x86_64.sh -b -f -p /home/user/.local
5
+ # rm Miniconda3-py38_4.10.3-Linux-x86_64.sh
6
+ # conda config --env --set always_yes true
7
+ # conda update -n base -c defaults conda -y
8
+
9
+ # # conda environment setup
10
+ # conda env create -f environment.yaml
11
+ # conda init bash
12
+ # source /home/user/.bashrc
13
+ # source activate icon
14
+ nvidia-smi
15
+ pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111
16
+ pip install -r requirement.txt
lib/__init__.py ADDED
File without changes
lib/common/__init__.py ADDED
File without changes
lib/common/cloth_extraction.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import os
4
+ import itertools
5
+ import trimesh
6
+ from matplotlib.path import Path
7
+ from collections import Counter
8
+ from sklearn.neighbors import KNeighborsClassifier
9
+
10
+
11
+ def load_segmentation(path, shape):
12
+ """
13
+ Get a segmentation mask for a given image
14
+ Arguments:
15
+ path: path to the segmentation json file
16
+ shape: shape of the output mask
17
+ Returns:
18
+ Returns a segmentation mask
19
+ """
20
+ with open(path) as json_file:
21
+ dict = json.load(json_file)
22
+ segmentations = []
23
+ for key, val in dict.items():
24
+ if not key.startswith('item'):
25
+ continue
26
+
27
+ # Each item can have multiple polygons. Combine them to one
28
+ # segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
29
+ # segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
30
+
31
+ coordinates = []
32
+ for segmentation_coord in val['segmentation']:
33
+ # The format before is [x1,y1, x2, y2, ....]
34
+ x = segmentation_coord[::2]
35
+ y = segmentation_coord[1::2]
36
+ xy = np.vstack((x, y)).T
37
+ coordinates.append(xy)
38
+
39
+ segmentations.append(
40
+ {'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates})
41
+
42
+ return segmentations
43
+
44
+
45
+ def smpl_to_recon_labels(recon, smpl, k=1):
46
+ """
47
+ Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
48
+ Arguments:
49
+ recon: trimesh object (fully clothed model)
50
+ shape: trimesh object (smpl model)
51
+ k: number of nearest neighbours to use
52
+ Returns:
53
+ Returns a dictionary containing the bodypart and the corresponding indices
54
+ """
55
+ smpl_vert_segmentation = json.load(
56
+ open(os.path.join(os.path.dirname(__file__), 'smpl_vert_segmentation.json')))
57
+ n = smpl.vertices.shape[0]
58
+ y = np.array([None] * n)
59
+ for key, val in smpl_vert_segmentation.items():
60
+ y[val] = key
61
+
62
+ classifier = KNeighborsClassifier(n_neighbors=1)
63
+ classifier.fit(smpl.vertices, y)
64
+
65
+ y_pred = classifier.predict(recon.vertices)
66
+
67
+ recon_labels = {}
68
+ for key in smpl_vert_segmentation.keys():
69
+ recon_labels[key] = list(np.argwhere(
70
+ y_pred == key).flatten().astype(int))
71
+
72
+ return recon_labels
73
+
74
+
75
+ def extract_cloth(recon, segmentation, K, R, t, smpl=None):
76
+ """
77
+ Extract a portion of a mesh using 2d segmentation coordinates
78
+ Arguments:
79
+ recon: fully clothed mesh
80
+ seg_coord: segmentation coordinates in 2D (NDC)
81
+ K: intrinsic matrix of the projection
82
+ R: rotation matrix of the projection
83
+ t: translation vector of the projection
84
+ Returns:
85
+ Returns a submesh using the segmentation coordinates
86
+ """
87
+ seg_coord = segmentation['coord_normalized']
88
+ mesh = trimesh.Trimesh(recon.vertices, recon.faces)
89
+ extrinsic = np.zeros((3, 4))
90
+ extrinsic[:3, :3] = R
91
+ extrinsic[:, 3] = t
92
+ P = K[:3, :3] @ extrinsic
93
+
94
+ P_inv = np.linalg.pinv(P)
95
+
96
+ # Each segmentation can contain multiple polygons
97
+ # We need to check them separately
98
+ points_so_far = []
99
+ faces = recon.faces
100
+ for polygon in seg_coord:
101
+ n = len(polygon)
102
+ coords_h = np.hstack((polygon, np.ones((n, 1))))
103
+ # Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
104
+ XYZ = P_inv @ coords_h[:, :, None]
105
+ XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
106
+ XYZ = XYZ[:, :3] / XYZ[:, 3, None]
107
+
108
+ p = Path(XYZ[:, :2])
109
+
110
+ grid = p.contains_points(recon.vertices[:, :2])
111
+ indeces = np.argwhere(grid == True)
112
+ points_so_far += list(indeces.flatten())
113
+
114
+ if smpl is not None:
115
+ num_verts = recon.vertices.shape[0]
116
+ recon_labels = smpl_to_recon_labels(recon, smpl)
117
+ body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
118
+ 'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand']
119
+ type = segmentation['type_id']
120
+
121
+ # Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
122
+ # https://github.com/switchablenorms/DeepFashion2
123
+ # Short sleeve clothes
124
+ if type == 1 or type == 3 or type == 10:
125
+ body_parts_to_remove += ['leftForeArm', 'rightForeArm']
126
+ # No sleeves at all or lower body clothes
127
+ elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
128
+ body_parts_to_remove += ['leftForeArm',
129
+ 'rightForeArm', 'leftArm', 'rightArm']
130
+ # Shorts
131
+ elif type == 7:
132
+ body_parts_to_remove += ['leftLeg', 'rightLeg',
133
+ 'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']
134
+
135
+ verts_to_remove = list(itertools.chain.from_iterable(
136
+ [recon_labels[part] for part in body_parts_to_remove]))
137
+
138
+ label_mask = np.zeros(num_verts, dtype=bool)
139
+ label_mask[verts_to_remove] = True
140
+
141
+ seg_mask = np.zeros(num_verts, dtype=bool)
142
+ seg_mask[points_so_far] = True
143
+
144
+ # Remove points that belong to other bodyparts
145
+ # If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
146
+ extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
147
+
148
+ combine_mask = np.zeros(num_verts, dtype=bool)
149
+ combine_mask[points_so_far] = True
150
+ combine_mask[extra_verts_to_remove] = False
151
+
152
+ all_indices = np.argwhere(combine_mask == True).flatten()
153
+
154
+ i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
155
+ i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
156
+ i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
157
+
158
+ faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
159
+ mask = np.zeros(len(recon.faces), dtype=bool)
160
+ if len(faces_to_keep) > 0:
161
+ mask[faces_to_keep] = True
162
+
163
+ mesh.update_faces(mask)
164
+ mesh.remove_unreferenced_vertices()
165
+
166
+ # mesh.rezero()
167
+
168
+ return mesh
169
+
170
+ return None
lib/common/config.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ from yacs.config import CfgNode as CN
19
+ import os
20
+
21
+ _C = CN(new_allowed=True)
22
+
23
+ # needed by trainer
24
+ _C.name = 'default'
25
+ _C.gpus = [0]
26
+ _C.test_gpus = [1]
27
+ _C.root = "./data/"
28
+ _C.ckpt_dir = './data/ckpt/'
29
+ _C.resume_path = ''
30
+ _C.normal_path = ''
31
+ _C.corr_path = ''
32
+ _C.results_path = './data/results/'
33
+ _C.projection_mode = 'orthogonal'
34
+ _C.num_views = 1
35
+ _C.sdf = False
36
+ _C.sdf_clip = 5.0
37
+
38
+ _C.lr_G = 1e-3
39
+ _C.lr_C = 1e-3
40
+ _C.lr_N = 2e-4
41
+ _C.weight_decay = 0.0
42
+ _C.momentum = 0.0
43
+ _C.optim = 'RMSprop'
44
+ _C.schedule = [5, 10, 15]
45
+ _C.gamma = 0.1
46
+
47
+ _C.overfit = False
48
+ _C.resume = False
49
+ _C.test_mode = False
50
+ _C.test_uv = False
51
+ _C.draw_geo_thres = 0.60
52
+ _C.num_sanity_val_steps = 2
53
+ _C.fast_dev = 0
54
+ _C.get_fit = False
55
+ _C.agora = False
56
+ _C.optim_cloth = False
57
+ _C.optim_body = False
58
+ _C.mcube_res = 256
59
+ _C.clean_mesh = True
60
+ _C.remesh = False
61
+
62
+ _C.batch_size = 4
63
+ _C.num_threads = 8
64
+
65
+ _C.num_epoch = 10
66
+ _C.freq_plot = 0.01
67
+ _C.freq_show_train = 0.1
68
+ _C.freq_show_val = 0.2
69
+ _C.freq_eval = 0.5
70
+ _C.accu_grad_batch = 4
71
+
72
+ _C.test_items = ['sv', 'mv', 'mv-fusion', 'hybrid', 'dc-pred', 'gt']
73
+
74
+ _C.net = CN()
75
+ _C.net.gtype = 'HGPIFuNet'
76
+ _C.net.ctype = 'resnet18'
77
+ _C.net.classifierIMF = 'MultiSegClassifier'
78
+ _C.net.netIMF = 'resnet18'
79
+ _C.net.norm = 'group'
80
+ _C.net.norm_mlp = 'group'
81
+ _C.net.norm_color = 'group'
82
+ _C.net.hg_down = 'ave_pool'
83
+ _C.net.num_views = 1
84
+
85
+ # kernel_size, stride, dilation, padding
86
+
87
+ _C.net.conv1 = [7, 2, 1, 3]
88
+ _C.net.conv3x3 = [3, 1, 1, 1]
89
+
90
+ _C.net.num_stack = 4
91
+ _C.net.num_hourglass = 2
92
+ _C.net.hourglass_dim = 256
93
+ _C.net.voxel_dim = 32
94
+ _C.net.resnet_dim = 120
95
+ _C.net.mlp_dim = [320, 1024, 512, 256, 128, 1]
96
+ _C.net.mlp_dim_knn = [320, 1024, 512, 256, 128, 3]
97
+ _C.net.mlp_dim_color = [513, 1024, 512, 256, 128, 3]
98
+ _C.net.mlp_dim_multiseg = [1088, 2048, 1024, 500]
99
+ _C.net.res_layers = [2, 3, 4]
100
+ _C.net.filter_dim = 256
101
+ _C.net.smpl_dim = 3
102
+
103
+ _C.net.cly_dim = 3
104
+ _C.net.soft_dim = 64
105
+ _C.net.z_size = 200.0
106
+ _C.net.N_freqs = 10
107
+ _C.net.geo_w = 0.1
108
+ _C.net.norm_w = 0.1
109
+ _C.net.dc_w = 0.1
110
+ _C.net.C_cat_to_G = False
111
+
112
+ _C.net.skip_hourglass = True
113
+ _C.net.use_tanh = True
114
+ _C.net.soft_onehot = True
115
+ _C.net.no_residual = True
116
+ _C.net.use_attention = False
117
+
118
+ _C.net.prior_type = "sdf"
119
+ _C.net.smpl_feats = ['sdf', 'cmap', 'norm', 'vis']
120
+ _C.net.use_filter = True
121
+ _C.net.use_cc = False
122
+ _C.net.use_PE = False
123
+ _C.net.use_IGR = False
124
+ _C.net.in_geo = ()
125
+ _C.net.in_nml = ()
126
+
127
+ _C.dataset = CN()
128
+ _C.dataset.root = ''
129
+ _C.dataset.set_splits = [0.95, 0.04]
130
+ _C.dataset.types = [
131
+ "3dpeople", "axyz", "renderpeople", "renderpeople_p27", "humanalloy"
132
+ ]
133
+ _C.dataset.scales = [1.0, 100.0, 1.0, 1.0, 100.0 / 39.37]
134
+ _C.dataset.rp_type = "pifu900"
135
+ _C.dataset.th_type = 'train'
136
+ _C.dataset.input_size = 512
137
+ _C.dataset.rotation_num = 3
138
+ _C.dataset.num_precomp = 10 # Number of segmentation classifiers
139
+ _C.dataset.num_multiseg = 500 # Number of categories per classifier
140
+ _C.dataset.num_knn = 10 # for loss/error
141
+ _C.dataset.num_knn_dis = 20 # for accuracy
142
+ _C.dataset.num_verts_max = 20000
143
+ _C.dataset.zray_type = False
144
+ _C.dataset.online_smpl = False
145
+ _C.dataset.noise_type = ['z-trans', 'pose', 'beta']
146
+ _C.dataset.noise_scale = [0.0, 0.0, 0.0]
147
+ _C.dataset.num_sample_geo = 10000
148
+ _C.dataset.num_sample_color = 0
149
+ _C.dataset.num_sample_seg = 0
150
+ _C.dataset.num_sample_knn = 10000
151
+
152
+ _C.dataset.sigma_geo = 5.0
153
+ _C.dataset.sigma_color = 0.10
154
+ _C.dataset.sigma_seg = 0.10
155
+ _C.dataset.thickness_threshold = 20.0
156
+ _C.dataset.ray_sample_num = 2
157
+ _C.dataset.semantic_p = False
158
+ _C.dataset.remove_outlier = False
159
+
160
+ _C.dataset.train_bsize = 1.0
161
+ _C.dataset.val_bsize = 1.0
162
+ _C.dataset.test_bsize = 1.0
163
+
164
+
165
+ def get_cfg_defaults():
166
+ """Get a yacs CfgNode object with default values for my_project."""
167
+ # Return a clone so that the defaults will not be altered
168
+ # This is for the "local variable" use pattern
169
+ return _C.clone()
170
+
171
+
172
+ # Alternatively, provide a way to import the defaults as
173
+ # a global singleton:
174
+ cfg = _C # users can `from config import cfg`
175
+
176
+ # cfg = get_cfg_defaults()
177
+ # cfg.merge_from_file('./configs/example.yaml')
178
+
179
+ # # Now override from a list (opts could come from the command line)
180
+ # opts = ['dataset.root', './data/XXXX', 'learning_rate', '1e-2']
181
+ # cfg.merge_from_list(opts)
182
+
183
+
184
+ def update_cfg(cfg_file):
185
+ # cfg = get_cfg_defaults()
186
+ _C.merge_from_file(cfg_file)
187
+ # return cfg.clone()
188
+ return _C
189
+
190
+
191
+ def parse_args(args):
192
+ cfg_file = args.cfg_file
193
+ if args.cfg_file is not None:
194
+ cfg = update_cfg(args.cfg_file)
195
+ else:
196
+ cfg = get_cfg_defaults()
197
+
198
+ # if args.misc is not None:
199
+ # cfg.merge_from_list(args.misc)
200
+
201
+ return cfg
202
+
203
+
204
+ def parse_args_extend(args):
205
+ if args.resume:
206
+ if not os.path.exists(args.log_dir):
207
+ raise ValueError(
208
+ 'Experiment are set to resume mode, but log directory does not exist.'
209
+ )
210
+
211
+ # load log's cfg
212
+ cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
213
+ cfg = update_cfg(cfg_file)
214
+
215
+ if args.misc is not None:
216
+ cfg.merge_from_list(args.misc)
217
+ else:
218
+ parse_args(args)
lib/common/render.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+ from pytorch3d.renderer import (
18
+ BlendParams,
19
+ blending,
20
+ look_at_view_transform,
21
+ FoVOrthographicCameras,
22
+ PointLights,
23
+ RasterizationSettings,
24
+ PointsRasterizationSettings,
25
+ PointsRenderer,
26
+ AlphaCompositor,
27
+ PointsRasterizer,
28
+ MeshRenderer,
29
+ MeshRasterizer,
30
+ SoftPhongShader,
31
+ SoftSilhouetteShader,
32
+ TexturesVertex,
33
+ )
34
+ from pytorch3d.renderer.mesh import TexturesVertex
35
+ from pytorch3d.structures import Meshes
36
+ from lib.dataset.mesh_util import SMPLX, get_visibility
37
+
38
+ import lib.common.render_utils as util
39
+ import torch
40
+ import numpy as np
41
+ from PIL import Image
42
+ from tqdm import tqdm
43
+ import os
44
+ import cv2
45
+ import math
46
+ from termcolor import colored
47
+
48
+
49
+ def image2vid(images, vid_path):
50
+
51
+ w, h = images[0].size
52
+ videodims = (w, h)
53
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
54
+ video = cv2.VideoWriter(vid_path, fourcc, 30, videodims)
55
+ for image in images:
56
+ video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
57
+ video.release()
58
+
59
+
60
+ def query_color(verts, faces, image, device):
61
+ """query colors from points and image
62
+
63
+ Args:
64
+ verts ([B, 3]): [query verts]
65
+ faces ([M, 3]): [query faces]
66
+ image ([B, 3, H, W]): [full image]
67
+
68
+ Returns:
69
+ [np.float]: [return colors]
70
+ """
71
+
72
+ verts = verts.float().to(device)
73
+ faces = faces.long().to(device)
74
+
75
+ (xy, z) = verts.split([2, 1], dim=1)
76
+ visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
77
+ uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
78
+ uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
79
+ colors = (torch.nn.functional.grid_sample(image, uv, align_corners=True)[
80
+ 0, :, :, 0].permute(1, 0) + 1.0) * 0.5 * 255.0
81
+ colors[visibility == 0.0] = ((Meshes(verts.unsqueeze(0), faces.unsqueeze(
82
+ 0)).verts_normals_padded().squeeze(0) + 1.0) * 0.5 * 255.0)[visibility == 0.0]
83
+
84
+ return colors.detach().cpu()
85
+
86
+
87
+ class cleanShader(torch.nn.Module):
88
+ def __init__(self, device="cpu", cameras=None, blend_params=None):
89
+ super().__init__()
90
+ self.cameras = cameras
91
+ self.blend_params = blend_params if blend_params is not None else BlendParams()
92
+
93
+ def forward(self, fragments, meshes, **kwargs):
94
+ cameras = kwargs.get("cameras", self.cameras)
95
+ if cameras is None:
96
+ msg = "Cameras must be specified either at initialization \
97
+ or in the forward pass of TexturedSoftPhongShader"
98
+
99
+ raise ValueError(msg)
100
+
101
+ # get renderer output
102
+ blend_params = kwargs.get("blend_params", self.blend_params)
103
+ texels = meshes.sample_textures(fragments)
104
+ images = blending.softmax_rgb_blend(
105
+ texels, fragments, blend_params, znear=-256, zfar=256
106
+ )
107
+
108
+ return images
109
+
110
+
111
+ class Render:
112
+ def __init__(self, size=512, device=torch.device("cuda:0")):
113
+ self.device = device
114
+ self.mesh_y_center = 100.0
115
+ self.dis = 100.0
116
+ self.scale = 1.0
117
+ self.size = size
118
+ self.cam_pos = [(0, 100, 100)]
119
+
120
+ self.mesh = None
121
+ self.deform_mesh = None
122
+ self.pcd = None
123
+ self.renderer = None
124
+ self.meshRas = None
125
+ self.type = None
126
+ self.knn = None
127
+ self.knn_inverse = None
128
+
129
+ self.smpl_seg = None
130
+ self.smpl_cmap = None
131
+
132
+ self.smplx = SMPLX()
133
+
134
+ self.uv_rasterizer = util.Pytorch3dRasterizer(self.size)
135
+
136
+ def get_camera(self, cam_id):
137
+
138
+ R, T = look_at_view_transform(
139
+ eye=[self.cam_pos[cam_id]],
140
+ at=((0, self.mesh_y_center, 0),),
141
+ up=((0, 1, 0),),
142
+ )
143
+
144
+ camera = FoVOrthographicCameras(
145
+ device=self.device,
146
+ R=R,
147
+ T=T,
148
+ znear=100.0,
149
+ zfar=-100.0,
150
+ max_y=100.0,
151
+ min_y=-100.0,
152
+ max_x=100.0,
153
+ min_x=-100.0,
154
+ scale_xyz=(self.scale * np.ones(3),),
155
+ )
156
+
157
+ return camera
158
+
159
+ def init_renderer(self, camera, type="clean_mesh", bg="gray"):
160
+
161
+ if "mesh" in type:
162
+
163
+ # rasterizer
164
+ self.raster_settings_mesh = RasterizationSettings(
165
+ image_size=self.size,
166
+ blur_radius=np.log(1.0 / 1e-4) * 1e-7,
167
+ faces_per_pixel=30,
168
+ )
169
+ self.meshRas = MeshRasterizer(
170
+ cameras=camera, raster_settings=self.raster_settings_mesh
171
+ )
172
+
173
+ if bg == "black":
174
+ blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0))
175
+ elif bg == "white":
176
+ blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0))
177
+ elif bg == "gray":
178
+ blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5))
179
+
180
+ if type == "ori_mesh":
181
+
182
+ lights = PointLights(
183
+ device=self.device,
184
+ ambient_color=((0.8, 0.8, 0.8),),
185
+ diffuse_color=((0.2, 0.2, 0.2),),
186
+ specular_color=((0.0, 0.0, 0.0),),
187
+ location=[[0.0, 200.0, 0.0]],
188
+ )
189
+
190
+ self.renderer = MeshRenderer(
191
+ rasterizer=self.meshRas,
192
+ shader=SoftPhongShader(
193
+ device=self.device,
194
+ cameras=camera,
195
+ lights=lights,
196
+ blend_params=blendparam,
197
+ ),
198
+ )
199
+
200
+ if type == "silhouette":
201
+ self.raster_settings_silhouette = RasterizationSettings(
202
+ image_size=self.size,
203
+ blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5,
204
+ faces_per_pixel=50,
205
+ cull_backfaces=True,
206
+ )
207
+
208
+ self.silhouetteRas = MeshRasterizer(
209
+ cameras=camera, raster_settings=self.raster_settings_silhouette
210
+ )
211
+ self.renderer = MeshRenderer(
212
+ rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
213
+ )
214
+
215
+ if type == "pointcloud":
216
+ self.raster_settings_pcd = PointsRasterizationSettings(
217
+ image_size=self.size, radius=0.006, points_per_pixel=10
218
+ )
219
+
220
+ self.pcdRas = PointsRasterizer(
221
+ cameras=camera, raster_settings=self.raster_settings_pcd
222
+ )
223
+ self.renderer = PointsRenderer(
224
+ rasterizer=self.pcdRas,
225
+ compositor=AlphaCompositor(background_color=(0, 0, 0)),
226
+ )
227
+
228
+ if type == "clean_mesh":
229
+
230
+ self.renderer = MeshRenderer(
231
+ rasterizer=self.meshRas,
232
+ shader=cleanShader(
233
+ device=self.device, cameras=camera, blend_params=blendparam
234
+ ),
235
+ )
236
+
237
+ def VF2Mesh(self, verts, faces):
238
+
239
+ if not torch.is_tensor(verts):
240
+ verts = torch.tensor(verts)
241
+ if not torch.is_tensor(faces):
242
+ faces = torch.tensor(faces)
243
+
244
+ if verts.ndimension() == 2:
245
+ verts = verts.unsqueeze(0).float()
246
+ if faces.ndimension() == 2:
247
+ faces = faces.unsqueeze(0).long()
248
+
249
+ verts = verts.to(self.device)
250
+ faces = faces.to(self.device)
251
+
252
+ mesh = Meshes(verts, faces).to(self.device)
253
+
254
+ mesh.textures = TexturesVertex(
255
+ verts_features=(mesh.verts_normals_padded() + 1.0) * 0.5
256
+ )
257
+
258
+ return mesh
259
+
260
+ def load_meshes(self, verts, faces):
261
+ """load mesh into the pytorch3d renderer
262
+
263
+ Args:
264
+ verts ([N,3]): verts
265
+ faces ([N,3]): faces
266
+ offset ([N,3]): offset
267
+ """
268
+
269
+ # camera setting
270
+ self.scale = 100.0
271
+ self.mesh_y_center = 0.0
272
+
273
+ self.cam_pos = [
274
+ (0, self.mesh_y_center, 100.0),
275
+ (100.0, self.mesh_y_center, 0),
276
+ (0, self.mesh_y_center, -100.0),
277
+ (-100.0, self.mesh_y_center, 0),
278
+ ]
279
+
280
+ self.type = "color"
281
+
282
+ if isinstance(verts, list):
283
+ self.meshes = []
284
+ for V, F in zip(verts, faces):
285
+ self.meshes.append(self.VF2Mesh(V, F))
286
+ else:
287
+ self.meshes = [self.VF2Mesh(verts, faces)]
288
+
289
+ def get_depth_map(self, cam_ids=[0, 2]):
290
+
291
+ depth_maps = []
292
+ for cam_id in cam_ids:
293
+ self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
294
+ fragments = self.meshRas(self.meshes[0])
295
+ depth_map = fragments.zbuf[..., 0].squeeze(0)
296
+ if cam_id == 2:
297
+ depth_map = torch.fliplr(depth_map)
298
+ depth_maps.append(depth_map)
299
+
300
+ return depth_maps
301
+
302
+ def get_rgb_image(self, cam_ids=[0, 2]):
303
+
304
+ images = []
305
+ for cam_id in range(len(self.cam_pos)):
306
+ if cam_id in cam_ids:
307
+ self.init_renderer(self.get_camera(
308
+ cam_id), "clean_mesh", "gray")
309
+ if len(cam_ids) == 4:
310
+ rendered_img = (
311
+ self.renderer(self.meshes[0])[
312
+ 0:1, :, :, :3].permute(0, 3, 1, 2)
313
+ - 0.5
314
+ ) * 2.0
315
+ else:
316
+ rendered_img = (
317
+ self.renderer(self.meshes[0])[
318
+ 0:1, :, :, :3].permute(0, 3, 1, 2)
319
+ - 0.5
320
+ ) * 2.0
321
+ if cam_id == 2 and len(cam_ids) == 2:
322
+ rendered_img = torch.flip(rendered_img, dims=[3])
323
+ images.append(rendered_img)
324
+
325
+ return images
326
+
327
+ def get_rendered_video(self, images, save_path):
328
+
329
+ self.cam_pos = []
330
+ for angle in range(360):
331
+ self.cam_pos.append(
332
+ (
333
+ 100.0 * math.cos(np.pi / 180 * angle),
334
+ self.mesh_y_center,
335
+ 100.0 * math.sin(np.pi / 180 * angle),
336
+ )
337
+ )
338
+
339
+ old_shape = np.array(images[0].shape[:2])
340
+ new_shape = np.around(
341
+ (self.size / old_shape[0]) * old_shape).astype(np.int)
342
+
343
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
344
+ video = cv2.VideoWriter(
345
+ save_path, fourcc, 30, (self.size * len(self.meshes) +
346
+ new_shape[1] * len(images), self.size)
347
+ )
348
+
349
+ pbar = tqdm(range(len(self.cam_pos)))
350
+ pbar.set_description(colored(f"exporting video {os.path.basename(save_path)}...", "blue"))
351
+ for cam_id in pbar:
352
+ self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
353
+
354
+ img_lst = [
355
+ np.array(Image.fromarray(img).resize(new_shape[::-1])).astype(np.uint8)[
356
+ :, :, [2, 1, 0]
357
+ ]
358
+ for img in images
359
+ ]
360
+
361
+ for mesh in self.meshes:
362
+ rendered_img = (
363
+ (self.renderer(mesh)[0, :, :, :3] * 255.0)
364
+ .detach()
365
+ .cpu()
366
+ .numpy()
367
+ .astype(np.uint8)
368
+ )
369
+
370
+ img_lst.append(rendered_img)
371
+ final_img = np.concatenate(img_lst, axis=1)
372
+ video.write(final_img)
373
+
374
+ video.release()
375
+
376
+ def get_silhouette_image(self, cam_ids=[0, 2]):
377
+
378
+ images = []
379
+ for cam_id in range(len(self.cam_pos)):
380
+ if cam_id in cam_ids:
381
+ self.init_renderer(self.get_camera(cam_id), "silhouette")
382
+ rendered_img = self.renderer(self.meshes[0])[0:1, :, :, 3]
383
+ if cam_id == 2 and len(cam_ids) == 2:
384
+ rendered_img = torch.flip(rendered_img, dims=[2])
385
+ images.append(rendered_img)
386
+
387
+ return images
lib/common/render_utils.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import torch
19
+ from torch import nn
20
+ import trimesh
21
+ import math
22
+ from typing import NewType
23
+ from pytorch3d.structures import Meshes
24
+ from pytorch3d.renderer.mesh import rasterize_meshes
25
+
26
+ Tensor = NewType('Tensor', torch.Tensor)
27
+
28
+
29
+ def solid_angles(points: Tensor,
30
+ triangles: Tensor,
31
+ thresh: float = 1e-8) -> Tensor:
32
+ ''' Compute solid angle between the input points and triangles
33
+ Follows the method described in:
34
+ The Solid Angle of a Plane Triangle
35
+ A. VAN OOSTEROM AND J. STRACKEE
36
+ IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING,
37
+ VOL. BME-30, NO. 2, FEBRUARY 1983
38
+ Parameters
39
+ -----------
40
+ points: BxQx3
41
+ Tensor of input query points
42
+ triangles: BxFx3x3
43
+ Target triangles
44
+ thresh: float
45
+ float threshold
46
+ Returns
47
+ -------
48
+ solid_angles: BxQxF
49
+ A tensor containing the solid angle between all query points
50
+ and input triangles
51
+ '''
52
+ # Center the triangles on the query points. Size should be BxQxFx3x3
53
+ centered_tris = triangles[:, None] - points[:, :, None, None]
54
+
55
+ # BxQxFx3
56
+ norms = torch.norm(centered_tris, dim=-1)
57
+
58
+ # Should be BxQxFx3
59
+ cross_prod = torch.cross(centered_tris[:, :, :, 1],
60
+ centered_tris[:, :, :, 2],
61
+ dim=-1)
62
+ # Should be BxQxF
63
+ numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
64
+ del cross_prod
65
+
66
+ dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1)
67
+ dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1)
68
+ dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
69
+ del centered_tris
70
+
71
+ denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] +
72
+ dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0])
73
+ del dot01, dot12, dot02, norms
74
+
75
+ # Should be BxQ
76
+ solid_angle = torch.atan2(numerator, denominator)
77
+ del numerator, denominator
78
+
79
+ torch.cuda.empty_cache()
80
+
81
+ return 2 * solid_angle
82
+
83
+
84
+ def winding_numbers(points: Tensor,
85
+ triangles: Tensor,
86
+ thresh: float = 1e-8) -> Tensor:
87
+ ''' Uses winding_numbers to compute inside/outside
88
+ Robust inside-outside segmentation using generalized winding numbers
89
+ Alec Jacobson,
90
+ Ladislav Kavan,
91
+ Olga Sorkine-Hornung
92
+ Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018
93
+ Gavin Barill
94
+ NEIL G. Dickson
95
+ Ryan Schmidt
96
+ David I.W. Levin
97
+ and Alec Jacobson
98
+ Parameters
99
+ -----------
100
+ points: BxQx3
101
+ Tensor of input query points
102
+ triangles: BxFx3x3
103
+ Target triangles
104
+ thresh: float
105
+ float threshold
106
+ Returns
107
+ -------
108
+ winding_numbers: BxQ
109
+ A tensor containing the Generalized winding numbers
110
+ '''
111
+ # The generalized winding number is the sum of solid angles of the point
112
+ # with respect to all triangles.
113
+ return 1 / (4 * math.pi) * solid_angles(points, triangles,
114
+ thresh=thresh).sum(dim=-1)
115
+
116
+
117
+ def batch_contains(verts, faces, points):
118
+
119
+ B = verts.shape[0]
120
+ N = points.shape[1]
121
+
122
+ verts = verts.detach().cpu()
123
+ faces = faces.detach().cpu()
124
+ points = points.detach().cpu()
125
+ contains = torch.zeros(B, N)
126
+
127
+ for i in range(B):
128
+ contains[i] = torch.as_tensor(
129
+ trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
130
+
131
+ return 2.0 * (contains - 0.5)
132
+
133
+
134
+ def dict2obj(d):
135
+ # if isinstance(d, list):
136
+ # d = [dict2obj(x) for x in d]
137
+ if not isinstance(d, dict):
138
+ return d
139
+
140
+ class C(object):
141
+ pass
142
+
143
+ o = C()
144
+ for k in d:
145
+ o.__dict__[k] = dict2obj(d[k])
146
+ return o
147
+
148
+
149
+ def face_vertices(vertices, faces):
150
+ """
151
+ :param vertices: [batch size, number of vertices, 3]
152
+ :param faces: [batch size, number of faces, 3]
153
+ :return: [batch size, number of faces, 3, 3]
154
+ """
155
+
156
+ bs, nv = vertices.shape[:2]
157
+ bs, nf = faces.shape[:2]
158
+ device = vertices.device
159
+ faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
160
+ nv)[:, None, None]
161
+ vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
162
+
163
+ return vertices[faces.long()]
164
+
165
+
166
+ class Pytorch3dRasterizer(nn.Module):
167
+ """ Borrowed from https://github.com/facebookresearch/pytorch3d
168
+ Notice:
169
+ x,y,z are in image space, normalized
170
+ can only render squared image now
171
+ """
172
+
173
+ def __init__(self, image_size=224):
174
+ """
175
+ use fixed raster_settings for rendering faces
176
+ """
177
+ super().__init__()
178
+ raster_settings = {
179
+ 'image_size': image_size,
180
+ 'blur_radius': 0.0,
181
+ 'faces_per_pixel': 1,
182
+ 'bin_size': None,
183
+ 'max_faces_per_bin': None,
184
+ 'perspective_correct': True,
185
+ 'cull_backfaces': True,
186
+ }
187
+ raster_settings = dict2obj(raster_settings)
188
+ self.raster_settings = raster_settings
189
+
190
+ def forward(self, vertices, faces, attributes=None):
191
+ fixed_vertices = vertices.clone()
192
+ fixed_vertices[..., :2] = -fixed_vertices[..., :2]
193
+ meshes_screen = Meshes(verts=fixed_vertices.float(),
194
+ faces=faces.long())
195
+ raster_settings = self.raster_settings
196
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
197
+ meshes_screen,
198
+ image_size=raster_settings.image_size,
199
+ blur_radius=raster_settings.blur_radius,
200
+ faces_per_pixel=raster_settings.faces_per_pixel,
201
+ bin_size=raster_settings.bin_size,
202
+ max_faces_per_bin=raster_settings.max_faces_per_bin,
203
+ perspective_correct=raster_settings.perspective_correct,
204
+ )
205
+ vismask = (pix_to_face > -1).float()
206
+ D = attributes.shape[-1]
207
+ attributes = attributes.clone()
208
+ attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
209
+ 3, attributes.shape[-1])
210
+ N, H, W, K, _ = bary_coords.shape
211
+ mask = pix_to_face == -1
212
+ pix_to_face = pix_to_face.clone()
213
+ pix_to_face[mask] = 0
214
+ idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
215
+ pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
216
+ pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
217
+ pixel_vals[mask] = 0 # Replace masked values in output.
218
+ pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
219
+ pixel_vals = torch.cat(
220
+ [pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
221
+ return pixel_vals
lib/common/seg3d_lossless.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+
19
+ from .seg3d_utils import (
20
+ create_grid3D,
21
+ plot_mask3D,
22
+ SmoothConv3D,
23
+ )
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import numpy as np
28
+ import torch.nn.functional as F
29
+ import mcubes
30
+ from kaolin.ops.conversions import voxelgrids_to_trianglemeshes
31
+ import logging
32
+
33
+ logging.getLogger("lightning").setLevel(logging.ERROR)
34
+
35
+
36
+ class Seg3dLossless(nn.Module):
37
+ def __init__(self,
38
+ query_func,
39
+ b_min,
40
+ b_max,
41
+ resolutions,
42
+ channels=1,
43
+ balance_value=0.5,
44
+ align_corners=False,
45
+ visualize=False,
46
+ debug=False,
47
+ use_cuda_impl=False,
48
+ faster=False,
49
+ use_shadow=False,
50
+ **kwargs):
51
+ """
52
+ align_corners: same with how you process gt. (grid_sample / interpolate)
53
+ """
54
+ super().__init__()
55
+ self.query_func = query_func
56
+ self.register_buffer(
57
+ 'b_min',
58
+ torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
59
+ self.register_buffer(
60
+ 'b_max',
61
+ torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
62
+
63
+ # ti.init(arch=ti.cuda)
64
+ # self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
65
+
66
+ if type(resolutions[0]) is int:
67
+ resolutions = torch.tensor([(res, res, res)
68
+ for res in resolutions])
69
+ else:
70
+ resolutions = torch.tensor(resolutions)
71
+ self.register_buffer('resolutions', resolutions)
72
+ self.batchsize = self.b_min.size(0)
73
+ assert self.batchsize == 1
74
+ self.balance_value = balance_value
75
+ self.channels = channels
76
+ assert self.channels == 1
77
+ self.align_corners = align_corners
78
+ self.visualize = visualize
79
+ self.debug = debug
80
+ self.use_cuda_impl = use_cuda_impl
81
+ self.faster = faster
82
+ self.use_shadow = use_shadow
83
+
84
+ for resolution in resolutions:
85
+ assert resolution[0] % 2 == 1 and resolution[1] % 2 == 1, \
86
+ f"resolution {resolution} need to be odd becuase of align_corner."
87
+
88
+ # init first resolution
89
+ init_coords = create_grid3D(0,
90
+ resolutions[-1] - 1,
91
+ steps=resolutions[0]) # [N, 3]
92
+ init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1,
93
+ 1) # [bz, N, 3]
94
+ self.register_buffer('init_coords', init_coords)
95
+
96
+ # some useful tensors
97
+ calculated = torch.zeros(
98
+ (self.resolutions[-1][2], self.resolutions[-1][1],
99
+ self.resolutions[-1][0]),
100
+ dtype=torch.bool)
101
+ self.register_buffer('calculated', calculated)
102
+
103
+ gird8_offsets = torch.stack(
104
+ torch.meshgrid([
105
+ torch.tensor([-1, 0, 1]),
106
+ torch.tensor([-1, 0, 1]),
107
+ torch.tensor([-1, 0, 1])
108
+ ])).int().view(3, -1).t() # [27, 3]
109
+ self.register_buffer('gird8_offsets', gird8_offsets)
110
+
111
+ # smooth convs
112
+ self.smooth_conv3x3 = SmoothConv3D(in_channels=1,
113
+ out_channels=1,
114
+ kernel_size=3)
115
+ self.smooth_conv5x5 = SmoothConv3D(in_channels=1,
116
+ out_channels=1,
117
+ kernel_size=5)
118
+ self.smooth_conv7x7 = SmoothConv3D(in_channels=1,
119
+ out_channels=1,
120
+ kernel_size=7)
121
+ self.smooth_conv9x9 = SmoothConv3D(in_channels=1,
122
+ out_channels=1,
123
+ kernel_size=9)
124
+
125
+ def batch_eval(self, coords, **kwargs):
126
+ """
127
+ coords: in the coordinates of last resolution
128
+ **kwargs: for query_func
129
+ """
130
+ coords = coords.detach()
131
+ # normalize coords to fit in [b_min, b_max]
132
+ if self.align_corners:
133
+ coords2D = coords.float() / (self.resolutions[-1] - 1)
134
+ else:
135
+ step = 1.0 / self.resolutions[-1].float()
136
+ coords2D = coords.float() / self.resolutions[-1] + step / 2
137
+ coords2D = coords2D * (self.b_max - self.b_min) + self.b_min
138
+ # query function
139
+ occupancys = self.query_func(**kwargs, points=coords2D)
140
+ if type(occupancys) is list:
141
+ occupancys = torch.stack(occupancys) # [bz, C, N]
142
+ assert len(occupancys.size()) == 3, \
143
+ "query_func should return a occupancy with shape of [bz, C, N]"
144
+ return occupancys
145
+
146
+ def forward(self, **kwargs):
147
+ if self.faster:
148
+ return self._forward_faster(**kwargs)
149
+ else:
150
+ return self._forward(**kwargs)
151
+
152
+ def _forward_faster(self, **kwargs):
153
+ """
154
+ In faster mode, we make following changes to exchange accuracy for speed:
155
+ 1. no conflict checking: 4.88 fps -> 6.56 fps
156
+ 2. smooth_conv9x9 ~ smooth_conv3x3 for different resolution
157
+ 3. last step no examine
158
+ """
159
+ final_W = self.resolutions[-1][0]
160
+ final_H = self.resolutions[-1][1]
161
+ final_D = self.resolutions[-1][2]
162
+
163
+ for resolution in self.resolutions:
164
+ W, H, D = resolution
165
+ stride = (self.resolutions[-1] - 1) / (resolution - 1)
166
+
167
+ # first step
168
+ if torch.equal(resolution, self.resolutions[0]):
169
+ coords = self.init_coords.clone() # torch.long
170
+ occupancys = self.batch_eval(coords, **kwargs)
171
+ occupancys = occupancys.view(self.batchsize, self.channels, D,
172
+ H, W)
173
+ if (occupancys > 0.5).sum() == 0:
174
+ # return F.interpolate(
175
+ # occupancys, size=(final_D, final_H, final_W),
176
+ # mode="linear", align_corners=True)
177
+ return None
178
+
179
+ if self.visualize:
180
+ self.plot(occupancys, coords, final_D, final_H, final_W)
181
+
182
+ with torch.no_grad():
183
+ coords_accum = coords / stride
184
+
185
+ # last step
186
+ elif torch.equal(resolution, self.resolutions[-1]):
187
+
188
+ with torch.no_grad():
189
+ # here true is correct!
190
+ valid = F.interpolate(
191
+ (occupancys > self.balance_value).float(),
192
+ size=(D, H, W),
193
+ mode="trilinear",
194
+ align_corners=True)
195
+
196
+ # here true is correct!
197
+ occupancys = F.interpolate(occupancys.float(),
198
+ size=(D, H, W),
199
+ mode="trilinear",
200
+ align_corners=True)
201
+
202
+ # is_boundary = (valid > 0.0) & (valid < 1.0)
203
+ is_boundary = valid == 0.5
204
+
205
+ # next steps
206
+ else:
207
+ coords_accum *= 2
208
+
209
+ with torch.no_grad():
210
+ # here true is correct!
211
+ valid = F.interpolate(
212
+ (occupancys > self.balance_value).float(),
213
+ size=(D, H, W),
214
+ mode="trilinear",
215
+ align_corners=True)
216
+
217
+ # here true is correct!
218
+ occupancys = F.interpolate(occupancys.float(),
219
+ size=(D, H, W),
220
+ mode="trilinear",
221
+ align_corners=True)
222
+
223
+ is_boundary = (valid > 0.0) & (valid < 1.0)
224
+
225
+ with torch.no_grad():
226
+ if torch.equal(resolution, self.resolutions[1]):
227
+ is_boundary = (self.smooth_conv9x9(is_boundary.float())
228
+ > 0)[0, 0]
229
+ elif torch.equal(resolution, self.resolutions[2]):
230
+ is_boundary = (self.smooth_conv7x7(is_boundary.float())
231
+ > 0)[0, 0]
232
+ else:
233
+ is_boundary = (self.smooth_conv3x3(is_boundary.float())
234
+ > 0)[0, 0]
235
+
236
+ coords_accum = coords_accum.long()
237
+ is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
238
+ coords_accum[0, :, 0]] = False
239
+ point_coords = is_boundary.permute(
240
+ 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
241
+ point_indices = (point_coords[:, :, 2] * H * W +
242
+ point_coords[:, :, 1] * W +
243
+ point_coords[:, :, 0])
244
+
245
+ R, C, D, H, W = occupancys.shape
246
+
247
+ # inferred value
248
+ coords = point_coords * stride
249
+
250
+ if coords.size(1) == 0:
251
+ continue
252
+ occupancys_topk = self.batch_eval(coords, **kwargs)
253
+
254
+ # put mask point predictions to the right places on the upsampled grid.
255
+ R, C, D, H, W = occupancys.shape
256
+ point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
257
+ occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
258
+ 2, point_indices, occupancys_topk).view(R, C, D, H, W))
259
+
260
+ with torch.no_grad():
261
+ voxels = coords / stride
262
+ coords_accum = torch.cat([voxels, coords_accum],
263
+ dim=1).unique(dim=1)
264
+
265
+ return occupancys[0, 0]
266
+
267
+ def _forward(self, **kwargs):
268
+ """
269
+ output occupancy field would be:
270
+ (bz, C, res, res)
271
+ """
272
+ final_W = self.resolutions[-1][0]
273
+ final_H = self.resolutions[-1][1]
274
+ final_D = self.resolutions[-1][2]
275
+
276
+ calculated = self.calculated.clone()
277
+
278
+ for resolution in self.resolutions:
279
+ W, H, D = resolution
280
+ stride = (self.resolutions[-1] - 1) / (resolution - 1)
281
+
282
+ if self.visualize:
283
+ this_stage_coords = []
284
+
285
+ # first step
286
+ if torch.equal(resolution, self.resolutions[0]):
287
+ coords = self.init_coords.clone() # torch.long
288
+ occupancys = self.batch_eval(coords, **kwargs)
289
+ occupancys = occupancys.view(self.batchsize, self.channels, D,
290
+ H, W)
291
+
292
+ if self.visualize:
293
+ self.plot(occupancys, coords, final_D, final_H, final_W)
294
+
295
+ with torch.no_grad():
296
+ coords_accum = coords / stride
297
+ calculated[coords[0, :, 2], coords[0, :, 1],
298
+ coords[0, :, 0]] = True
299
+
300
+ # next steps
301
+ else:
302
+ coords_accum *= 2
303
+
304
+ with torch.no_grad():
305
+ # here true is correct!
306
+ valid = F.interpolate(
307
+ (occupancys > self.balance_value).float(),
308
+ size=(D, H, W),
309
+ mode="trilinear",
310
+ align_corners=True)
311
+
312
+ # here true is correct!
313
+ occupancys = F.interpolate(occupancys.float(),
314
+ size=(D, H, W),
315
+ mode="trilinear",
316
+ align_corners=True)
317
+
318
+ is_boundary = (valid > 0.0) & (valid < 1.0)
319
+
320
+ with torch.no_grad():
321
+ # TODO
322
+ if self.use_shadow and torch.equal(resolution,
323
+ self.resolutions[-1]):
324
+ # larger z means smaller depth here
325
+ depth_res = resolution[2].item()
326
+ depth_index = torch.linspace(0,
327
+ depth_res - 1,
328
+ steps=depth_res).type_as(
329
+ occupancys.device)
330
+ depth_index_max = torch.max(
331
+ (occupancys > self.balance_value) *
332
+ (depth_index + 1),
333
+ dim=-1,
334
+ keepdim=True)[0] - 1
335
+ shadow = depth_index < depth_index_max
336
+ is_boundary[shadow] = False
337
+ is_boundary = is_boundary[0, 0]
338
+ else:
339
+ is_boundary = (self.smooth_conv3x3(is_boundary.float())
340
+ > 0)[0, 0]
341
+ # is_boundary = is_boundary[0, 0]
342
+
343
+ is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
344
+ coords_accum[0, :, 0]] = False
345
+ point_coords = is_boundary.permute(
346
+ 2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
347
+ point_indices = (point_coords[:, :, 2] * H * W +
348
+ point_coords[:, :, 1] * W +
349
+ point_coords[:, :, 0])
350
+
351
+ R, C, D, H, W = occupancys.shape
352
+ # interpolated value
353
+ occupancys_interp = torch.gather(
354
+ occupancys.reshape(R, C, D * H * W), 2,
355
+ point_indices.unsqueeze(1))
356
+
357
+ # inferred value
358
+ coords = point_coords * stride
359
+
360
+ if coords.size(1) == 0:
361
+ continue
362
+ occupancys_topk = self.batch_eval(coords, **kwargs)
363
+ if self.visualize:
364
+ this_stage_coords.append(coords)
365
+
366
+ # put mask point predictions to the right places on the upsampled grid.
367
+ R, C, D, H, W = occupancys.shape
368
+ point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
369
+ occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
370
+ 2, point_indices, occupancys_topk).view(R, C, D, H, W))
371
+
372
+ with torch.no_grad():
373
+ # conflicts
374
+ conflicts = ((occupancys_interp - self.balance_value) *
375
+ (occupancys_topk - self.balance_value) < 0)[0,
376
+ 0]
377
+
378
+ if self.visualize:
379
+ self.plot(occupancys, coords, final_D, final_H,
380
+ final_W)
381
+
382
+ voxels = coords / stride
383
+ coords_accum = torch.cat([voxels, coords_accum],
384
+ dim=1).unique(dim=1)
385
+ calculated[coords[0, :, 2], coords[0, :, 1],
386
+ coords[0, :, 0]] = True
387
+
388
+ while conflicts.sum() > 0:
389
+ if self.use_shadow and torch.equal(resolution,
390
+ self.resolutions[-1]):
391
+ break
392
+
393
+ with torch.no_grad():
394
+ conflicts_coords = coords[0, conflicts, :]
395
+
396
+ if self.debug:
397
+ self.plot(occupancys,
398
+ conflicts_coords.unsqueeze(0),
399
+ final_D,
400
+ final_H,
401
+ final_W,
402
+ title='conflicts')
403
+
404
+ conflicts_boundary = (conflicts_coords.int() +
405
+ self.gird8_offsets.unsqueeze(1) *
406
+ stride.int()).reshape(
407
+ -1, 3).long().unique(dim=0)
408
+ conflicts_boundary[:, 0] = (
409
+ conflicts_boundary[:, 0].clamp(
410
+ 0,
411
+ calculated.size(2) - 1))
412
+ conflicts_boundary[:, 1] = (
413
+ conflicts_boundary[:, 1].clamp(
414
+ 0,
415
+ calculated.size(1) - 1))
416
+ conflicts_boundary[:, 2] = (
417
+ conflicts_boundary[:, 2].clamp(
418
+ 0,
419
+ calculated.size(0) - 1))
420
+
421
+ coords = conflicts_boundary[calculated[
422
+ conflicts_boundary[:, 2], conflicts_boundary[:, 1],
423
+ conflicts_boundary[:, 0]] == False]
424
+
425
+ if self.debug:
426
+ self.plot(occupancys,
427
+ coords.unsqueeze(0),
428
+ final_D,
429
+ final_H,
430
+ final_W,
431
+ title='coords')
432
+
433
+ coords = coords.unsqueeze(0)
434
+ point_coords = coords / stride
435
+ point_indices = (point_coords[:, :, 2] * H * W +
436
+ point_coords[:, :, 1] * W +
437
+ point_coords[:, :, 0])
438
+
439
+ R, C, D, H, W = occupancys.shape
440
+ # interpolated value
441
+ occupancys_interp = torch.gather(
442
+ occupancys.reshape(R, C, D * H * W), 2,
443
+ point_indices.unsqueeze(1))
444
+
445
+ # inferred value
446
+ coords = point_coords * stride
447
+
448
+ if coords.size(1) == 0:
449
+ break
450
+ occupancys_topk = self.batch_eval(coords, **kwargs)
451
+ if self.visualize:
452
+ this_stage_coords.append(coords)
453
+
454
+ with torch.no_grad():
455
+ # conflicts
456
+ conflicts = ((occupancys_interp - self.balance_value) *
457
+ (occupancys_topk - self.balance_value) <
458
+ 0)[0, 0]
459
+
460
+ # put mask point predictions to the right places on the upsampled grid.
461
+ point_indices = point_indices.unsqueeze(1).expand(
462
+ -1, C, -1)
463
+ occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
464
+ 2, point_indices, occupancys_topk).view(R, C, D, H, W))
465
+
466
+ with torch.no_grad():
467
+ voxels = coords / stride
468
+ coords_accum = torch.cat([voxels, coords_accum],
469
+ dim=1).unique(dim=1)
470
+ calculated[coords[0, :, 2], coords[0, :, 1],
471
+ coords[0, :, 0]] = True
472
+
473
+ if self.visualize:
474
+ this_stage_coords = torch.cat(this_stage_coords, dim=1)
475
+ self.plot(occupancys, this_stage_coords, final_D, final_H,
476
+ final_W)
477
+
478
+ return occupancys[0, 0]
479
+
480
+ def plot(self,
481
+ occupancys,
482
+ coords,
483
+ final_D,
484
+ final_H,
485
+ final_W,
486
+ title='',
487
+ **kwargs):
488
+ final = F.interpolate(occupancys.float(),
489
+ size=(final_D, final_H, final_W),
490
+ mode="trilinear",
491
+ align_corners=True) # here true is correct!
492
+ x = coords[0, :, 0].to("cpu")
493
+ y = coords[0, :, 1].to("cpu")
494
+ z = coords[0, :, 2].to("cpu")
495
+
496
+ plot_mask3D(final[0, 0].to("cpu"), title, (x, y, z), **kwargs)
497
+
498
+ def find_vertices(self, sdf, direction="front"):
499
+ '''
500
+ - direction: "front" | "back" | "left" | "right"
501
+ '''
502
+ resolution = sdf.size(2)
503
+ if direction == "front":
504
+ pass
505
+ elif direction == "left":
506
+ sdf = sdf.permute(2, 1, 0)
507
+ elif direction == "back":
508
+ inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
509
+ sdf = sdf[inv_idx, :, :]
510
+ elif direction == "right":
511
+ inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
512
+ sdf = sdf[:, :, inv_idx]
513
+ sdf = sdf.permute(2, 1, 0)
514
+
515
+ inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
516
+ sdf = sdf[inv_idx, :, :]
517
+ sdf_all = sdf.permute(2, 1, 0)
518
+
519
+ # shadow
520
+ grad_v = (sdf_all > 0.5) * torch.linspace(
521
+ resolution, 1, steps=resolution).to(sdf.device)
522
+ grad_c = torch.ones_like(sdf_all) * torch.linspace(
523
+ 0, resolution - 1, steps=resolution).to(sdf.device)
524
+ max_v, max_c = grad_v.max(dim=2)
525
+ shadow = grad_c > max_c.view(resolution, resolution, 1)
526
+ keep = (sdf_all > 0.5) & (~shadow)
527
+
528
+ p1 = keep.nonzero(as_tuple=False).t() # [3, N]
529
+ p2 = p1.clone() # z
530
+ p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
531
+ p3 = p1.clone() # y
532
+ p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
533
+ p4 = p1.clone() # x
534
+ p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
535
+
536
+ v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
537
+ v2 = sdf_all[p2[0, :], p2[1, :], p2[2, :]]
538
+ v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
539
+ v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
540
+
541
+ X = p1[0, :].long() # [N,]
542
+ Y = p1[1, :].long() # [N,]
543
+ Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + \
544
+ p1[2, :].float() * (v2 - 0.5) / (v2 - v1) # [N,]
545
+ Z = Z.clamp(0, resolution)
546
+
547
+ # normal
548
+ norm_z = v2 - v1
549
+ norm_y = v3 - v1
550
+ norm_x = v4 - v1
551
+ # print (v2.min(dim=0)[0], v2.max(dim=0)[0], v3.min(dim=0)[0], v3.max(dim=0)[0])
552
+
553
+ norm = torch.stack([norm_x, norm_y, norm_z], dim=1)
554
+ norm = norm / torch.norm(norm, p=2, dim=1, keepdim=True)
555
+
556
+ return X, Y, Z, norm
557
+
558
+ def render_normal(self, resolution, X, Y, Z, norm):
559
+ image = torch.ones((1, 3, resolution, resolution),
560
+ dtype=torch.float32).to(norm.device)
561
+ color = (norm + 1) / 2.0
562
+ color = color.clamp(0, 1)
563
+ image[0, :, Y, X] = color.t()
564
+ return image
565
+
566
+ def display(self, sdf):
567
+
568
+ # render
569
+ X, Y, Z, norm = self.find_vertices(sdf, direction="front")
570
+ image1 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
571
+ X, Y, Z, norm = self.find_vertices(sdf, direction="left")
572
+ image2 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
573
+ X, Y, Z, norm = self.find_vertices(sdf, direction="right")
574
+ image3 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
575
+ X, Y, Z, norm = self.find_vertices(sdf, direction="back")
576
+ image4 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
577
+
578
+ image = torch.cat([image1, image2, image3, image4], axis=3)
579
+ image = image.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.0
580
+
581
+ return np.uint8(image)
582
+
583
+ def export_mesh(self, occupancys):
584
+
585
+ final = occupancys[1:, 1:, 1:].contiguous()
586
+
587
+ if final.shape[0] > 256:
588
+ # for voxelgrid larger than 256^3, the required GPU memory will be > 9GB
589
+ # thus we use CPU marching_cube to avoid "CUDA out of memory"
590
+ occu_arr = final.detach().cpu().numpy() # non-smooth surface
591
+ # occu_arr = mcubes.smooth(final.detach().cpu().numpy()) # smooth surface
592
+ vertices, triangles = mcubes.marching_cubes(
593
+ occu_arr, self.balance_value)
594
+ verts = torch.as_tensor(vertices[:, [2, 1, 0]])
595
+ faces = torch.as_tensor(triangles.astype(
596
+ np.long), dtype=torch.long)[:, [0, 2, 1]]
597
+ else:
598
+ torch.cuda.empty_cache()
599
+ vertices, triangles = voxelgrids_to_trianglemeshes(
600
+ final.unsqueeze(0))
601
+ verts = vertices[0][:, [2, 1, 0]].cpu()
602
+ faces = triangles[0][:, [0, 2, 1]].cpu()
603
+
604
+ return verts, faces
lib/common/seg3d_utils.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import matplotlib.pyplot as plt
22
+
23
+
24
+ def plot_mask2D(mask,
25
+ title="",
26
+ point_coords=None,
27
+ figsize=10,
28
+ point_marker_size=5):
29
+ '''
30
+ Simple plotting tool to show intermediate mask predictions and points
31
+ where PointRend is applied.
32
+
33
+ Args:
34
+ mask (Tensor): mask prediction of shape HxW
35
+ title (str): title for the plot
36
+ point_coords ((Tensor, Tensor)): x and y point coordinates
37
+ figsize (int): size of the figure to plot
38
+ point_marker_size (int): marker size for points
39
+ '''
40
+
41
+ H, W = mask.shape
42
+ plt.figure(figsize=(figsize, figsize))
43
+ if title:
44
+ title += ", "
45
+ plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30)
46
+ plt.ylabel(H, fontsize=30)
47
+ plt.xlabel(W, fontsize=30)
48
+ plt.xticks([], [])
49
+ plt.yticks([], [])
50
+ plt.imshow(mask.detach(),
51
+ interpolation="nearest",
52
+ cmap=plt.get_cmap('gray'))
53
+ if point_coords is not None:
54
+ plt.scatter(x=point_coords[0],
55
+ y=point_coords[1],
56
+ color="red",
57
+ s=point_marker_size,
58
+ clip_on=True)
59
+ plt.xlim(-0.5, W - 0.5)
60
+ plt.ylim(H - 0.5, -0.5)
61
+ plt.show()
62
+
63
+
64
+ def plot_mask3D(mask=None,
65
+ title="",
66
+ point_coords=None,
67
+ figsize=1500,
68
+ point_marker_size=8,
69
+ interactive=True):
70
+ '''
71
+ Simple plotting tool to show intermediate mask predictions and points
72
+ where PointRend is applied.
73
+
74
+ Args:
75
+ mask (Tensor): mask prediction of shape DxHxW
76
+ title (str): title for the plot
77
+ point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates
78
+ figsize (int): size of the figure to plot
79
+ point_marker_size (int): marker size for points
80
+ '''
81
+ import trimesh
82
+ import vtkplotter
83
+ from skimage import measure
84
+
85
+ vp = vtkplotter.Plotter(title=title, size=(figsize, figsize))
86
+ vis_list = []
87
+
88
+ if mask is not None:
89
+ mask = mask.detach().to("cpu").numpy()
90
+ mask = mask.transpose(2, 1, 0)
91
+
92
+ # marching cube to find surface
93
+ verts, faces, normals, values = measure.marching_cubes_lewiner(
94
+ mask, 0.5, gradient_direction='ascent')
95
+
96
+ # create a mesh
97
+ mesh = trimesh.Trimesh(verts, faces)
98
+ mesh.visual.face_colors = [200, 200, 250, 100]
99
+ vis_list.append(mesh)
100
+
101
+ if point_coords is not None:
102
+ point_coords = torch.stack(point_coords, 1).to("cpu").numpy()
103
+
104
+ # import numpy as np
105
+ # select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112)
106
+ # select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272)
107
+ # select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112)
108
+ # select = np.logical_and(np.logical_and(select_x, select_y), select_z)
109
+ # point_coords = point_coords[select, :]
110
+
111
+ pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
112
+ vis_list.append(pc)
113
+
114
+ vp.show(*vis_list,
115
+ bg="white",
116
+ axes=1,
117
+ interactive=interactive,
118
+ azimuth=30,
119
+ elevation=30)
120
+
121
+
122
+ def create_grid3D(min, max, steps):
123
+ if type(min) is int:
124
+ min = (min, min, min) # (x, y, z)
125
+ if type(max) is int:
126
+ max = (max, max, max) # (x, y)
127
+ if type(steps) is int:
128
+ steps = (steps, steps, steps) # (x, y, z)
129
+ arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
130
+ arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
131
+ arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
132
+ gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX])
133
+ coords = torch.stack([gridW, girdH,
134
+ gridD]) # [2, steps[0], steps[1], steps[2]]
135
+ coords = coords.view(3, -1).t() # [N, 3]
136
+ return coords
137
+
138
+
139
+ def create_grid2D(min, max, steps):
140
+ if type(min) is int:
141
+ min = (min, min) # (x, y)
142
+ if type(max) is int:
143
+ max = (max, max) # (x, y)
144
+ if type(steps) is int:
145
+ steps = (steps, steps) # (x, y)
146
+ arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
147
+ arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
148
+ girdH, gridW = torch.meshgrid([arrangeY, arrangeX])
149
+ coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
150
+ coords = coords.view(2, -1).t() # [N, 2]
151
+ return coords
152
+
153
+
154
+ class SmoothConv2D(nn.Module):
155
+ def __init__(self, in_channels, out_channels, kernel_size=3):
156
+ super().__init__()
157
+ assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
158
+ self.padding = (kernel_size - 1) // 2
159
+
160
+ weight = torch.ones(
161
+ (in_channels, out_channels, kernel_size, kernel_size),
162
+ dtype=torch.float32) / (kernel_size**2)
163
+ self.register_buffer('weight', weight)
164
+
165
+ def forward(self, input):
166
+ return F.conv2d(input, self.weight, padding=self.padding)
167
+
168
+
169
+ class SmoothConv3D(nn.Module):
170
+ def __init__(self, in_channels, out_channels, kernel_size=3):
171
+ super().__init__()
172
+ assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
173
+ self.padding = (kernel_size - 1) // 2
174
+
175
+ weight = torch.ones(
176
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
177
+ dtype=torch.float32) / (kernel_size**3)
178
+ self.register_buffer('weight', weight)
179
+
180
+ def forward(self, input):
181
+ return F.conv3d(input, self.weight, padding=self.padding)
182
+
183
+
184
+ def build_smooth_conv3D(in_channels=1,
185
+ out_channels=1,
186
+ kernel_size=3,
187
+ padding=1):
188
+ smooth_conv = torch.nn.Conv3d(in_channels=in_channels,
189
+ out_channels=out_channels,
190
+ kernel_size=kernel_size,
191
+ padding=padding)
192
+ smooth_conv.weight.data = torch.ones(
193
+ (in_channels, out_channels, kernel_size, kernel_size, kernel_size),
194
+ dtype=torch.float32) / (kernel_size**3)
195
+ smooth_conv.bias.data = torch.zeros(out_channels)
196
+ return smooth_conv
197
+
198
+
199
+ def build_smooth_conv2D(in_channels=1,
200
+ out_channels=1,
201
+ kernel_size=3,
202
+ padding=1):
203
+ smooth_conv = torch.nn.Conv2d(in_channels=in_channels,
204
+ out_channels=out_channels,
205
+ kernel_size=kernel_size,
206
+ padding=padding)
207
+ smooth_conv.weight.data = torch.ones(
208
+ (in_channels, out_channels, kernel_size, kernel_size),
209
+ dtype=torch.float32) / (kernel_size**2)
210
+ smooth_conv.bias.data = torch.zeros(out_channels)
211
+ return smooth_conv
212
+
213
+
214
+ def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
215
+ **kwargs):
216
+ """
217
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
218
+ Args:
219
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
220
+ values for a set of points on a regular H x W x D grid.
221
+ num_points (int): The number of points P to select.
222
+ Returns:
223
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
224
+ [0, H x W x D) of the most uncertain points.
225
+ point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
226
+ coordinates of the most uncertain points from the H x W x D grid.
227
+ """
228
+ R, _, D, H, W = uncertainty_map.shape
229
+ # h_step = 1.0 / float(H)
230
+ # w_step = 1.0 / float(W)
231
+ # d_step = 1.0 / float(D)
232
+
233
+ num_points = min(D * H * W, num_points)
234
+ point_scores, point_indices = torch.topk(uncertainty_map.view(
235
+ R, D * H * W),
236
+ k=num_points,
237
+ dim=1)
238
+ point_coords = torch.zeros(R,
239
+ num_points,
240
+ 3,
241
+ dtype=torch.float,
242
+ device=uncertainty_map.device)
243
+ # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
244
+ # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
245
+ # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
246
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
247
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
248
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
249
+ print(f"resolution {D} x {H} x {W}", point_scores.min(),
250
+ point_scores.max())
251
+ return point_indices, point_coords
252
+
253
+
254
+ def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
255
+ clip_min):
256
+ """
257
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
258
+ Args:
259
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
260
+ values for a set of points on a regular H x W x D grid.
261
+ num_points (int): The number of points P to select.
262
+ Returns:
263
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
264
+ [0, H x W x D) of the most uncertain points.
265
+ point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
266
+ coordinates of the most uncertain points from the H x W x D grid.
267
+ """
268
+ R, _, D, H, W = uncertainty_map.shape
269
+ # h_step = 1.0 / float(H)
270
+ # w_step = 1.0 / float(W)
271
+ # d_step = 1.0 / float(D)
272
+
273
+ assert R == 1, "batchsize > 1 is not implemented!"
274
+ uncertainty_map = uncertainty_map.view(D * H * W)
275
+ indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
276
+ num_points = min(num_points, indices.size(0))
277
+ point_scores, point_indices = torch.topk(uncertainty_map[indices],
278
+ k=num_points,
279
+ dim=0)
280
+ point_indices = indices[point_indices].unsqueeze(0)
281
+
282
+ point_coords = torch.zeros(R,
283
+ num_points,
284
+ 3,
285
+ dtype=torch.float,
286
+ device=uncertainty_map.device)
287
+ # point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
288
+ # point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
289
+ # point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
290
+ point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
291
+ point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
292
+ point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
293
+ # print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
294
+ return point_indices, point_coords
295
+
296
+
297
+ def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
298
+ **kwargs):
299
+ """
300
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
301
+ Args:
302
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
303
+ values for a set of points on a regular H x W grid.
304
+ num_points (int): The number of points P to select.
305
+ Returns:
306
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
307
+ [0, H x W) of the most uncertain points.
308
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
309
+ coordinates of the most uncertain points from the H x W grid.
310
+ """
311
+ R, _, H, W = uncertainty_map.shape
312
+ # h_step = 1.0 / float(H)
313
+ # w_step = 1.0 / float(W)
314
+
315
+ num_points = min(H * W, num_points)
316
+ point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W),
317
+ k=num_points,
318
+ dim=1)
319
+ point_coords = torch.zeros(R,
320
+ num_points,
321
+ 2,
322
+ dtype=torch.long,
323
+ device=uncertainty_map.device)
324
+ # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
325
+ # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
326
+ point_coords[:, :, 0] = (point_indices % W).to(torch.long)
327
+ point_coords[:, :, 1] = (point_indices // W).to(torch.long)
328
+ # print (point_scores.min(), point_scores.max())
329
+ return point_indices, point_coords
330
+
331
+
332
+ def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
333
+ clip_min):
334
+ """
335
+ Find `num_points` most uncertain points from `uncertainty_map` grid.
336
+ Args:
337
+ uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
338
+ values for a set of points on a regular H x W grid.
339
+ num_points (int): The number of points P to select.
340
+ Returns:
341
+ point_indices (Tensor): A tensor of shape (N, P) that contains indices from
342
+ [0, H x W) of the most uncertain points.
343
+ point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
344
+ coordinates of the most uncertain points from the H x W grid.
345
+ """
346
+ R, _, H, W = uncertainty_map.shape
347
+ # h_step = 1.0 / float(H)
348
+ # w_step = 1.0 / float(W)
349
+
350
+ assert R == 1, "batchsize > 1 is not implemented!"
351
+ uncertainty_map = uncertainty_map.view(H * W)
352
+ indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
353
+ num_points = min(num_points, indices.size(0))
354
+ point_scores, point_indices = torch.topk(uncertainty_map[indices],
355
+ k=num_points,
356
+ dim=0)
357
+ point_indices = indices[point_indices].unsqueeze(0)
358
+
359
+ point_coords = torch.zeros(R,
360
+ num_points,
361
+ 2,
362
+ dtype=torch.long,
363
+ device=uncertainty_map.device)
364
+ # point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
365
+ # point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
366
+ point_coords[:, :, 0] = (point_indices % W).to(torch.long)
367
+ point_coords[:, :, 1] = (point_indices // W).to(torch.long)
368
+ # print (point_scores.min(), point_scores.max())
369
+ return point_indices, point_coords
370
+
371
+
372
+ def calculate_uncertainty(logits, classes=None, balance_value=0.5):
373
+ """
374
+ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
375
+ foreground class in `classes`.
376
+ Args:
377
+ logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or
378
+ class-agnostic, where R is the total number of predicted masks in all images and C is
379
+ the number of foreground classes. The values are logits.
380
+ classes (list): A list of length R that contains either predicted of ground truth class
381
+ for eash predicted mask.
382
+ Returns:
383
+ scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
384
+ the most uncertain locations having the highest uncertainty score.
385
+ """
386
+ if logits.shape[1] == 1:
387
+ gt_class_logits = logits
388
+ else:
389
+ gt_class_logits = logits[
390
+ torch.arange(logits.shape[0], device=logits.device),
391
+ classes].unsqueeze(1)
392
+ return -torch.abs(gt_class_logits - balance_value)
lib/common/smpl_vert_segmentation.json ADDED
The diff for this file is too large to render. See raw diff
 
lib/common/train_util.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import yaml
19
+ import os.path as osp
20
+ import torch
21
+ import numpy as np
22
+ import torch.nn.functional as F
23
+ from ..dataset.mesh_util import *
24
+ from ..net.geometry import orthogonal
25
+ from pytorch3d.renderer.mesh import rasterize_meshes
26
+ from .render_utils import Pytorch3dRasterizer
27
+ from pytorch3d.structures import Meshes
28
+ import cv2
29
+ from PIL import Image
30
+ from tqdm import tqdm
31
+ import os
32
+ from termcolor import colored
33
+
34
+
35
+ def reshape_sample_tensor(sample_tensor, num_views):
36
+ if num_views == 1:
37
+ return sample_tensor
38
+ # Need to repeat sample_tensor along the batch dim num_views times
39
+ sample_tensor = sample_tensor.unsqueeze(dim=1)
40
+ sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
41
+ sample_tensor = sample_tensor.view(
42
+ sample_tensor.shape[0] * sample_tensor.shape[1],
43
+ sample_tensor.shape[2], sample_tensor.shape[3])
44
+ return sample_tensor
45
+
46
+
47
+ def gen_mesh_eval(opt, net, cuda, data, resolution=None):
48
+ resolution = opt.resolution if resolution is None else resolution
49
+ image_tensor = data['img'].to(device=cuda)
50
+ calib_tensor = data['calib'].to(device=cuda)
51
+
52
+ net.filter(image_tensor)
53
+
54
+ b_min = data['b_min']
55
+ b_max = data['b_max']
56
+ try:
57
+ verts, faces, _, _ = reconstruction_faster(net,
58
+ cuda,
59
+ calib_tensor,
60
+ resolution,
61
+ b_min,
62
+ b_max,
63
+ use_octree=False)
64
+
65
+ except Exception as e:
66
+ print(e)
67
+ print('Can not create marching cubes at this time.')
68
+ verts, faces = None, None
69
+ return verts, faces
70
+
71
+
72
+ def gen_mesh(opt, net, cuda, data, save_path, resolution=None):
73
+ resolution = opt.resolution if resolution is None else resolution
74
+ image_tensor = data['img'].to(device=cuda)
75
+ calib_tensor = data['calib'].to(device=cuda)
76
+
77
+ net.filter(image_tensor)
78
+
79
+ b_min = data['b_min']
80
+ b_max = data['b_max']
81
+ try:
82
+ save_img_path = save_path[:-4] + '.png'
83
+ save_img_list = []
84
+ for v in range(image_tensor.shape[0]):
85
+ save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
86
+ (1, 2, 0)) * 0.5 +
87
+ 0.5)[:, :, ::-1] * 255.0
88
+ save_img_list.append(save_img)
89
+ save_img = np.concatenate(save_img_list, axis=1)
90
+ Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
91
+
92
+ verts, faces, _, _ = reconstruction_faster(net, cuda, calib_tensor,
93
+ resolution, b_min, b_max)
94
+ verts_tensor = torch.from_numpy(
95
+ verts.T).unsqueeze(0).to(device=cuda).float()
96
+ xyz_tensor = net.projection(verts_tensor, calib_tensor[:1])
97
+ uv = xyz_tensor[:, :2, :]
98
+ color = netG.index(image_tensor[:1], uv).detach().cpu().numpy()[0].T
99
+ color = color * 0.5 + 0.5
100
+ save_obj_mesh_with_color(save_path, verts, faces, color)
101
+ except Exception as e:
102
+ print(e)
103
+ print('Can not create marching cubes at this time.')
104
+ verts, faces, color = None, None, None
105
+ return verts, faces, color
106
+
107
+
108
+ def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True):
109
+ image_tensor = data['img'].to(device=cuda)
110
+ calib_tensor = data['calib'].to(device=cuda)
111
+
112
+ netG.filter(image_tensor)
113
+ netC.filter(image_tensor)
114
+ netC.attach(netG.get_im_feat())
115
+
116
+ b_min = data['b_min']
117
+ b_max = data['b_max']
118
+ try:
119
+ save_img_path = save_path[:-4] + '.png'
120
+ save_img_list = []
121
+ for v in range(image_tensor.shape[0]):
122
+ save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
123
+ (1, 2, 0)) * 0.5 +
124
+ 0.5)[:, :, ::-1] * 255.0
125
+ save_img_list.append(save_img)
126
+ save_img = np.concatenate(save_img_list, axis=1)
127
+ Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
128
+
129
+ verts, faces, _, _ = reconstruction_faster(netG,
130
+ cuda,
131
+ calib_tensor,
132
+ opt.resolution,
133
+ b_min,
134
+ b_max,
135
+ use_octree=use_octree)
136
+
137
+ # Now Getting colors
138
+ verts_tensor = torch.from_numpy(
139
+ verts.T).unsqueeze(0).to(device=cuda).float()
140
+ verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views)
141
+ color = np.zeros(verts.shape)
142
+ interval = 10000
143
+ for i in range(len(color) // interval):
144
+ left = i * interval
145
+ right = i * interval + interval
146
+ if i == len(color) // interval - 1:
147
+ right = -1
148
+ netC.query(verts_tensor[:, :, left:right], calib_tensor)
149
+ rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5
150
+ color[left:right] = rgb.T
151
+
152
+ save_obj_mesh_with_color(save_path, verts, faces, color)
153
+ except Exception as e:
154
+ print(e)
155
+ print('Can not create marching cubes at this time.')
156
+ verts, faces, color = None, None, None
157
+ return verts, faces, color
158
+
159
+
160
+ def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
161
+ """Sets the learning rate to the initial LR decayed by schedule"""
162
+ if epoch in schedule:
163
+ lr *= gamma
164
+ for param_group in optimizer.param_groups:
165
+ param_group['lr'] = lr
166
+ return lr
167
+
168
+
169
+ def compute_acc(pred, gt, thresh=0.5):
170
+ '''
171
+ return:
172
+ IOU, precision, and recall
173
+ '''
174
+ with torch.no_grad():
175
+ vol_pred = pred > thresh
176
+ vol_gt = gt > thresh
177
+
178
+ union = vol_pred | vol_gt
179
+ inter = vol_pred & vol_gt
180
+
181
+ true_pos = inter.sum().float()
182
+
183
+ union = union.sum().float()
184
+ if union == 0:
185
+ union = 1
186
+ vol_pred = vol_pred.sum().float()
187
+ if vol_pred == 0:
188
+ vol_pred = 1
189
+ vol_gt = vol_gt.sum().float()
190
+ if vol_gt == 0:
191
+ vol_gt = 1
192
+ return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
193
+
194
+
195
+ # def calc_metrics(opt, net, cuda, dataset, num_tests,
196
+ # resolution=128, sampled_points=1000, use_kaolin=True):
197
+ # if num_tests > len(dataset):
198
+ # num_tests = len(dataset)
199
+ # with torch.no_grad():
200
+ # chamfer_arr, p2s_arr = [], []
201
+ # for idx in tqdm(range(num_tests)):
202
+ # data = dataset[idx * len(dataset) // num_tests]
203
+
204
+ # verts, faces = gen_mesh_eval(opt, net, cuda, data, resolution)
205
+ # if verts is None:
206
+ # continue
207
+
208
+ # mesh_gt = trimesh.load(data['mesh_path'])
209
+ # mesh_gt = mesh_gt.split(only_watertight=False)
210
+ # comp_num = [mesh.vertices.shape[0] for mesh in mesh_gt]
211
+ # mesh_gt = mesh_gt[comp_num.index(max(comp_num))]
212
+
213
+ # mesh_pred = trimesh.Trimesh(verts, faces)
214
+
215
+ # gt_surface_pts, _ = trimesh.sample.sample_surface_even(
216
+ # mesh_gt, sampled_points)
217
+ # pred_surface_pts, _ = trimesh.sample.sample_surface_even(
218
+ # mesh_pred, sampled_points)
219
+
220
+ # if use_kaolin and has_kaolin:
221
+ # kal_mesh_gt = kal.rep.TriangleMesh.from_tensors(
222
+ # torch.tensor(mesh_gt.vertices).float().to(device=cuda),
223
+ # torch.tensor(mesh_gt.faces).long().to(device=cuda))
224
+ # kal_mesh_pred = kal.rep.TriangleMesh.from_tensors(
225
+ # torch.tensor(mesh_pred.vertices).float().to(device=cuda),
226
+ # torch.tensor(mesh_pred.faces).long().to(device=cuda))
227
+
228
+ # kal_distance_0 = kal.metrics.mesh.point_to_surface(
229
+ # torch.tensor(pred_surface_pts).float().to(device=cuda), kal_mesh_gt)
230
+ # kal_distance_1 = kal.metrics.mesh.point_to_surface(
231
+ # torch.tensor(gt_surface_pts).float().to(device=cuda), kal_mesh_pred)
232
+
233
+ # dist_gt_pred = torch.sqrt(kal_distance_0).cpu().numpy()
234
+ # dist_pred_gt = torch.sqrt(kal_distance_1).cpu().numpy()
235
+ # else:
236
+ # try:
237
+ # _, dist_pred_gt, _ = trimesh.proximity.closest_point(mesh_pred, gt_surface_pts)
238
+ # _, dist_gt_pred, _ = trimesh.proximity.closest_point(mesh_gt, pred_surface_pts)
239
+ # except Exception as e:
240
+ # print (e)
241
+ # continue
242
+
243
+ # chamfer_dist = 0.5 * (dist_pred_gt.mean() + dist_gt_pred.mean())
244
+ # p2s_dist = dist_pred_gt.mean()
245
+
246
+ # chamfer_arr.append(chamfer_dist)
247
+ # p2s_arr.append(p2s_dist)
248
+
249
+ # return np.average(chamfer_arr), np.average(p2s_arr)
250
+
251
+
252
+ def calc_error(opt, net, cuda, dataset, num_tests):
253
+ if num_tests > len(dataset):
254
+ num_tests = len(dataset)
255
+ with torch.no_grad():
256
+ erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
257
+ for idx in tqdm(range(num_tests)):
258
+ data = dataset[idx * len(dataset) // num_tests]
259
+ # retrieve the data
260
+ image_tensor = data['img'].to(device=cuda)
261
+ calib_tensor = data['calib'].to(device=cuda)
262
+ sample_tensor = data['samples'].to(device=cuda).unsqueeze(0)
263
+ if opt.num_views > 1:
264
+ sample_tensor = reshape_sample_tensor(sample_tensor,
265
+ opt.num_views)
266
+ label_tensor = data['labels'].to(device=cuda).unsqueeze(0)
267
+
268
+ res, error = net.forward(image_tensor,
269
+ sample_tensor,
270
+ calib_tensor,
271
+ labels=label_tensor)
272
+
273
+ IOU, prec, recall = compute_acc(res, label_tensor)
274
+
275
+ # print(
276
+ # '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
277
+ # .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
278
+ erorr_arr.append(error.item())
279
+ IOU_arr.append(IOU.item())
280
+ prec_arr.append(prec.item())
281
+ recall_arr.append(recall.item())
282
+
283
+ return np.average(erorr_arr), np.average(IOU_arr), np.average(
284
+ prec_arr), np.average(recall_arr)
285
+
286
+
287
+ def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
288
+ if num_tests > len(dataset):
289
+ num_tests = len(dataset)
290
+ with torch.no_grad():
291
+ error_color_arr = []
292
+
293
+ for idx in tqdm(range(num_tests)):
294
+ data = dataset[idx * len(dataset) // num_tests]
295
+ # retrieve the data
296
+ image_tensor = data['img'].to(device=cuda)
297
+ calib_tensor = data['calib'].to(device=cuda)
298
+ color_sample_tensor = data['color_samples'].to(
299
+ device=cuda).unsqueeze(0)
300
+
301
+ if opt.num_views > 1:
302
+ color_sample_tensor = reshape_sample_tensor(
303
+ color_sample_tensor, opt.num_views)
304
+
305
+ rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0)
306
+
307
+ netG.filter(image_tensor)
308
+ _, errorC = netC.forward(image_tensor,
309
+ netG.get_im_feat(),
310
+ color_sample_tensor,
311
+ calib_tensor,
312
+ labels=rgb_tensor)
313
+
314
+ # print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
315
+ # .format(idx, num_tests, errorG.item(), errorC.item()))
316
+ error_color_arr.append(errorC.item())
317
+
318
+ return np.average(error_color_arr)
319
+
320
+
321
+ # pytorch lightning training related fucntions
322
+
323
+
324
+ def query_func(opt, netG, features, points, proj_matrix=None):
325
+ '''
326
+ - points: size of (bz, N, 3)
327
+ - proj_matrix: size of (bz, 4, 4)
328
+ return: size of (bz, 1, N)
329
+ '''
330
+ assert len(points) == 1
331
+ samples = points.repeat(opt.num_views, 1, 1)
332
+ samples = samples.permute(0, 2, 1) # [bz, 3, N]
333
+
334
+ # view specific query
335
+ if proj_matrix is not None:
336
+ samples = orthogonal(samples, proj_matrix)
337
+
338
+ calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples)
339
+
340
+ preds = netG.query(features=features,
341
+ points=samples,
342
+ calibs=calib_tensor,
343
+ regressor=netG.if_regressor)
344
+
345
+ if type(preds) is list:
346
+ preds = preds[0]
347
+
348
+ return preds
349
+
350
+
351
+ def isin(ar1, ar2):
352
+ return (ar1[..., None] == ar2).any(-1)
353
+
354
+
355
+ def in1d(ar1, ar2):
356
+ mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool)
357
+ mask[ar2.unique()] = True
358
+ return mask[ar1]
359
+
360
+
361
+ def get_visibility(xy, z, faces):
362
+ """get the visibility of vertices
363
+
364
+ Args:
365
+ xy (torch.tensor): [N,2]
366
+ z (torch.tensor): [N,1]
367
+ faces (torch.tensor): [N,3]
368
+ size (int): resolution of rendered image
369
+ """
370
+
371
+ xyz = torch.cat((xy, -z), dim=1)
372
+ xyz = (xyz + 1.0) / 2.0
373
+ faces = faces.long()
374
+
375
+ rasterizer = Pytorch3dRasterizer(image_size=2**12)
376
+ meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...])
377
+ raster_settings = rasterizer.raster_settings
378
+
379
+ pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
380
+ meshes_screen,
381
+ image_size=raster_settings.image_size,
382
+ blur_radius=raster_settings.blur_radius,
383
+ faces_per_pixel=raster_settings.faces_per_pixel,
384
+ bin_size=raster_settings.bin_size,
385
+ max_faces_per_bin=raster_settings.max_faces_per_bin,
386
+ perspective_correct=raster_settings.perspective_correct,
387
+ cull_backfaces=raster_settings.cull_backfaces,
388
+ )
389
+
390
+ vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :])
391
+ vis_mask = torch.zeros(size=(z.shape[0], 1))
392
+ vis_mask[vis_vertices_id] = 1.0
393
+
394
+ # print("------------------------\n")
395
+ # print(f"keep points : {vis_mask.sum()/len(vis_mask)}")
396
+
397
+ return vis_mask
398
+
399
+
400
+ def batch_mean(res, key):
401
+ # recursive mean for multilevel dicts
402
+ return torch.stack([
403
+ x[key] if isinstance(x, dict) else batch_mean(x, key) for x in res
404
+ ]).mean()
405
+
406
+
407
+ def tf_log_convert(log_dict):
408
+ new_log_dict = log_dict.copy()
409
+ for k, v in log_dict.items():
410
+ new_log_dict[k.replace("_", "/")] = v
411
+ del new_log_dict[k]
412
+
413
+ return new_log_dict
414
+
415
+
416
+ def bar_log_convert(log_dict, name=None, rot=None):
417
+ from decimal import Decimal
418
+
419
+ new_log_dict = {}
420
+
421
+ if name is not None:
422
+ new_log_dict['name'] = name[0]
423
+ if rot is not None:
424
+ new_log_dict['rot'] = rot[0]
425
+
426
+ for k, v in log_dict.items():
427
+ color = "yellow"
428
+ if 'loss' in k:
429
+ color = "red"
430
+ k = k.replace("loss", "L")
431
+ elif 'acc' in k:
432
+ color = "green"
433
+ k = k.replace("acc", "A")
434
+ elif 'iou' in k:
435
+ color = "green"
436
+ k = k.replace("iou", "I")
437
+ elif 'prec' in k:
438
+ color = "green"
439
+ k = k.replace("prec", "P")
440
+ elif 'recall' in k:
441
+ color = "green"
442
+ k = k.replace("recall", "R")
443
+
444
+ if 'lr' not in k:
445
+ new_log_dict[colored(k.split("_")[1],
446
+ color)] = colored(f"{v:.3f}", color)
447
+ else:
448
+ new_log_dict[colored(k.split("_")[1],
449
+ color)] = colored(f"{Decimal(str(v)):.1E}",
450
+ color)
451
+
452
+ if 'loss' in new_log_dict.keys():
453
+ del new_log_dict['loss']
454
+
455
+ return new_log_dict
456
+
457
+
458
+ def accumulate(outputs, rot_num, split):
459
+
460
+ hparam_log_dict = {}
461
+
462
+ metrics = outputs[0].keys()
463
+ datasets = split.keys()
464
+
465
+ for dataset in datasets:
466
+ for metric in metrics:
467
+ keyword = f"hparam/{dataset}-{metric}"
468
+ if keyword not in hparam_log_dict.keys():
469
+ hparam_log_dict[keyword] = 0
470
+ for idx in range(split[dataset][0] * rot_num,
471
+ split[dataset][1] * rot_num):
472
+ hparam_log_dict[keyword] += outputs[idx][metric]
473
+ hparam_log_dict[keyword] /= (split[dataset][1] -
474
+ split[dataset][0]) * rot_num
475
+
476
+ print(colored(hparam_log_dict, "green"))
477
+
478
+ return hparam_log_dict
479
+
480
+
481
+ def calc_error_N(outputs, targets):
482
+ """calculate the error of normal (IGR)
483
+
484
+ Args:
485
+ outputs (torch.tensor): [B, 3, N]
486
+ target (torch.tensor): [B, N, 3]
487
+
488
+ # manifold loss and grad_loss in IGR paper
489
+ grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
490
+ normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
491
+
492
+ Returns:
493
+ torch.tensor: error of valid normals on the surface
494
+ """
495
+ # outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3))
496
+ outputs = -outputs.permute(0, 2, 1).reshape(-1, 1)
497
+ targets = targets.reshape(-1, 3)[:, 2:3]
498
+ with_normals = targets.sum(dim=1).abs() > 0.0
499
+
500
+ # eikonal loss
501
+ grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean()
502
+ # normals loss
503
+ normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean()
504
+
505
+ return grad_loss * 0.0 + normal_loss
506
+
507
+
508
+ def calc_knn_acc(preds, carn_verts, labels, pick_num):
509
+ """calculate knn accuracy
510
+
511
+ Args:
512
+ preds (torch.tensor): [B, 3, N]
513
+ carn_verts (torch.tensor): [SMPLX_V_num, 3]
514
+ labels (torch.tensor): [B, N_knn, N]
515
+ """
516
+ N_knn_full = labels.shape[1]
517
+ preds = preds.permute(0, 2, 1).reshape(-1, 3)
518
+ labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn]
519
+ labels = labels[:, :pick_num]
520
+
521
+ dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num]
522
+ knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn]
523
+ cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0]
524
+ bool_col = torch.zeros_like(cat_mat)[:, 0]
525
+ for i in range(pick_num * 2 - 1):
526
+ bool_col += cat_mat[:, i] == cat_mat[:, i + 1]
527
+ acc = (bool_col > 0).sum() / len(bool_col)
528
+
529
+ return acc
530
+
531
+
532
+ def calc_acc_seg(output, target, num_multiseg):
533
+ from pytorch_lightning.metrics import Accuracy
534
+ return Accuracy()(output.reshape(-1, num_multiseg).cpu(),
535
+ target.flatten().cpu())
536
+
537
+
538
+ def add_watermark(imgs, titles):
539
+
540
+ # Write some Text
541
+
542
+ font = cv2.FONT_HERSHEY_SIMPLEX
543
+ bottomLeftCornerOfText = (350, 50)
544
+ bottomRightCornerOfText = (800, 50)
545
+ fontScale = 1
546
+ fontColor = (1.0, 1.0, 1.0)
547
+ lineType = 2
548
+
549
+ for i in range(len(imgs)):
550
+
551
+ title = titles[i + 1]
552
+ cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale,
553
+ fontColor, lineType)
554
+
555
+ if i == 0:
556
+ cv2.putText(imgs[i], str(titles[i][0]), bottomRightCornerOfText,
557
+ font, fontScale, fontColor, lineType)
558
+
559
+ result = np.concatenate(imgs, axis=0).transpose(2, 0, 1)
560
+
561
+ return result
562
+
563
+
564
+ def make_test_gif(img_dir):
565
+
566
+ if img_dir is not None and len(os.listdir(img_dir)) > 0:
567
+ for dataset in os.listdir(img_dir):
568
+ for subject in sorted(os.listdir(osp.join(img_dir, dataset))):
569
+ img_lst = []
570
+ im1 = None
571
+ for file in sorted(
572
+ os.listdir(osp.join(img_dir, dataset, subject))):
573
+ if file[-3:] not in ['obj', 'gif']:
574
+ img_path = os.path.join(img_dir, dataset, subject,
575
+ file)
576
+ if im1 == None:
577
+ im1 = Image.open(img_path)
578
+ else:
579
+ img_lst.append(Image.open(img_path))
580
+
581
+ print(os.path.join(img_dir, dataset, subject, "out.gif"))
582
+ im1.save(os.path.join(img_dir, dataset, subject, "out.gif"),
583
+ save_all=True,
584
+ append_images=img_lst,
585
+ duration=500,
586
+ loop=0)
587
+
588
+
589
+ def export_cfg(logger, cfg):
590
+
591
+ cfg_export_file = osp.join(logger.save_dir, logger.name,
592
+ f"version_{logger.version}", "cfg.yaml")
593
+
594
+ if not osp.exists(cfg_export_file):
595
+ os.makedirs(osp.dirname(cfg_export_file), exist_ok=True)
596
+ with open(cfg_export_file, "w+") as file:
597
+ _ = yaml.dump(cfg, file)
lib/dataloader_demo.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from lib.common.config import get_cfg_defaults
3
+ from lib.dataset.PIFuDataset import PIFuDataset
4
+
5
+ if __name__ == '__main__':
6
+
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('-v',
9
+ '--show',
10
+ action='store_true',
11
+ help='vis sampler 3D')
12
+ parser.add_argument('-s',
13
+ '--speed',
14
+ action='store_true',
15
+ help='vis sampler 3D')
16
+ parser.add_argument('-l',
17
+ '--list',
18
+ action='store_true',
19
+ help='vis sampler 3D')
20
+ parser.add_argument('-c',
21
+ '--config',
22
+ default='./configs/train/icon-filter.yaml',
23
+ help='vis sampler 3D')
24
+ args_c = parser.parse_args()
25
+
26
+ args = get_cfg_defaults()
27
+ args.merge_from_file(args_c.config)
28
+
29
+ dataset = PIFuDataset(args, split='train', vis=args_c.show)
30
+ print(f"Number of subjects :{len(dataset.subject_list)}")
31
+ data_dict = dataset[0]
32
+
33
+ if args_c.list:
34
+ for k in data_dict.keys():
35
+ if not hasattr(data_dict[k], "shape"):
36
+ print(f"{k}: {data_dict[k]}")
37
+ else:
38
+ print(f"{k}: {data_dict[k].shape}")
39
+
40
+ if args_c.show:
41
+ # for item in dataset:
42
+ item = dataset[0]
43
+ dataset.visualize_sampling3D(item, mode='occ')
44
+
45
+ if args_c.speed:
46
+ # original: 2 it/s
47
+ # smpl online compute: 2 it/s
48
+ # normal online compute: 1.5 it/s
49
+ from tqdm import tqdm
50
+ for item in tqdm(dataset):
51
+ # pass
52
+ for k in item.keys():
53
+ if 'voxel' in k:
54
+ if not hasattr(item[k], "shape"):
55
+ print(f"{k}: {item[k]}")
56
+ else:
57
+ print(f"{k}: {item[k].shape}")
58
+ print("--------------------")
lib/dataset/Evaluator.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ from lib.renderer.gl.normal_render import NormalRender
19
+ from lib.dataset.mesh_util import projection
20
+ from lib.common.render import Render
21
+ from PIL import Image
22
+ import os
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+ import trimesh
27
+ import os.path as osp
28
+ from PIL import Image
29
+
30
+
31
+ class Evaluator:
32
+
33
+ _normal_render = None
34
+
35
+ @staticmethod
36
+ def init_gl():
37
+ Evaluator._normal_render = NormalRender(width=512, height=512)
38
+
39
+ def __init__(self, device):
40
+ self.device = device
41
+ self.render = Render(size=512, device=self.device)
42
+ self.error_term = nn.MSELoss()
43
+
44
+ self.offset = 0.0
45
+ self.scale_factor = None
46
+
47
+ def set_mesh(self, result_dict, scale_factor=1.0, offset=0.0):
48
+
49
+ for key in result_dict.keys():
50
+ if torch.is_tensor(result_dict[key]):
51
+ result_dict[key] = result_dict[key].detach().cpu().numpy()
52
+
53
+ for k, v in result_dict.items():
54
+ setattr(self, k, v)
55
+
56
+ self.scale_factor = scale_factor
57
+ self.offset = offset
58
+
59
+ def _render_normal(self, mesh, deg, norms=None):
60
+ view_mat = np.identity(4)
61
+ rz = deg / 180.0 * np.pi
62
+ model_mat = np.identity(4)
63
+ model_mat[:3, :3] = self._normal_render.euler_to_rot_mat(0, rz, 0)
64
+ model_mat[1, 3] = self.offset
65
+ view_mat[2, 2] *= -1
66
+
67
+ self._normal_render.set_matrices(view_mat, model_mat)
68
+ if norms is None:
69
+ norms = mesh.vertex_normals
70
+ self._normal_render.set_normal_mesh(self.scale_factor * mesh.vertices,
71
+ mesh.faces, norms, mesh.faces)
72
+ self._normal_render.draw()
73
+ normal_img = self._normal_render.get_color()
74
+ return normal_img
75
+
76
+ def render_mesh_list(self, mesh_lst):
77
+
78
+ self.offset = 0.0
79
+ self.scale_factor = 1.0
80
+
81
+ full_list = []
82
+ for mesh in mesh_lst:
83
+ row_lst = []
84
+ for deg in np.arange(0, 360, 90):
85
+ normal = self._render_normal(mesh, deg)
86
+ row_lst.append(normal)
87
+ full_list.append(np.concatenate(row_lst, axis=1))
88
+
89
+ res_array = np.concatenate(full_list, axis=0)
90
+
91
+ return res_array
92
+
93
+ def _get_reproj_normal_error(self, deg):
94
+
95
+ tgt_normal = self._render_normal(self.tgt_mesh, deg)
96
+ src_normal = self._render_normal(self.src_mesh, deg)
97
+ error = (((src_normal[:, :, :3] -
98
+ tgt_normal[:, :, :3])**2).sum(axis=2).mean(axis=(0, 1)))
99
+
100
+ return error, [src_normal, tgt_normal]
101
+
102
+ def render_normal(self, verts, faces):
103
+
104
+ verts = verts[0].detach().cpu().numpy()
105
+ faces = faces[0].detach().cpu().numpy()
106
+
107
+ mesh_F = trimesh.Trimesh(verts * np.array([1.0, -1.0, 1.0]), faces)
108
+ mesh_B = trimesh.Trimesh(verts * np.array([1.0, -1.0, -1.0]), faces)
109
+
110
+ self.scale_factor = 1.0
111
+
112
+ normal_F = self._render_normal(mesh_F, 0)
113
+ normal_B = self._render_normal(mesh_B,
114
+ 0,
115
+ norms=mesh_B.vertex_normals *
116
+ np.array([-1.0, -1.0, 1.0]))
117
+
118
+ mask = normal_F[:, :, 3:4]
119
+ normal_F = (torch.as_tensor(2.0 * (normal_F - 0.5) * mask).permute(
120
+ 2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
121
+ normal_B = (torch.as_tensor(2.0 * (normal_B - 0.5) * mask).permute(
122
+ 2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
123
+
124
+ return {"T_normal_F": normal_F, "T_normal_B": normal_B}
125
+
126
+ def calculate_normal_consist(
127
+ self,
128
+ frontal=True,
129
+ back=True,
130
+ left=True,
131
+ right=True,
132
+ save_demo_img=None,
133
+ return_demo=False,
134
+ ):
135
+
136
+ # reproj error
137
+ # if save_demo_img is not None, save a visualization at the given path (etc, "./test.png")
138
+ if self._normal_render is None:
139
+ print(
140
+ "In order to use normal render, "
141
+ "you have to call init_gl() before initialing any evaluator objects."
142
+ )
143
+ return -1
144
+
145
+ side_cnt = 0
146
+ total_error = 0
147
+ demo_list = []
148
+
149
+ if frontal:
150
+ side_cnt += 1
151
+ error, normal_lst = self._get_reproj_normal_error(0)
152
+ total_error += error
153
+ demo_list.append(np.concatenate(normal_lst, axis=0))
154
+ if back:
155
+ side_cnt += 1
156
+ error, normal_lst = self._get_reproj_normal_error(180)
157
+ total_error += error
158
+ demo_list.append(np.concatenate(normal_lst, axis=0))
159
+ if left:
160
+ side_cnt += 1
161
+ error, normal_lst = self._get_reproj_normal_error(90)
162
+ total_error += error
163
+ demo_list.append(np.concatenate(normal_lst, axis=0))
164
+ if right:
165
+ side_cnt += 1
166
+ error, normal_lst = self._get_reproj_normal_error(270)
167
+ total_error += error
168
+ demo_list.append(np.concatenate(normal_lst, axis=0))
169
+ if save_demo_img is not None:
170
+ res_array = np.concatenate(demo_list, axis=1)
171
+ res_img = Image.fromarray((res_array * 255).astype(np.uint8))
172
+ res_img.save(save_demo_img)
173
+
174
+ if return_demo:
175
+ res_array = np.concatenate(demo_list, axis=1)
176
+ return res_array
177
+ else:
178
+ return total_error
179
+
180
+ def space_transfer(self):
181
+
182
+ # convert from GT to SDF
183
+ self.verts_pr -= self.recon_size / 2.0
184
+ self.verts_pr /= self.recon_size / 2.0
185
+
186
+ self.verts_gt = projection(self.verts_gt, self.calib)
187
+ self.verts_gt[:, 1] *= -1
188
+
189
+ self.tgt_mesh = trimesh.Trimesh(self.verts_gt, self.faces_gt)
190
+ self.src_mesh = trimesh.Trimesh(self.verts_pr, self.faces_pr)
191
+
192
+ # (self.tgt_mesh+self.src_mesh).show()
193
+
194
+ def export_mesh(self, dir, name):
195
+ self.tgt_mesh.visual.vertex_colors = np.array([255, 0, 0])
196
+ self.src_mesh.visual.vertex_colors = np.array([0, 255, 0])
197
+
198
+ (self.tgt_mesh + self.src_mesh).export(
199
+ osp.join(dir, f"{name}_gt_pr.obj"))
200
+
201
+ def calculate_chamfer_p2s(self, sampled_points=1000):
202
+ """calculate the geometry metrics [chamfer, p2s, chamfer_H, p2s_H]
203
+
204
+ Args:
205
+ verts_gt (torch.cuda.tensor): [N, 3]
206
+ faces_gt (torch.cuda.tensor): [M, 3]
207
+ verts_pr (torch.cuda.tensor): [N', 3]
208
+ faces_pr (torch.cuda.tensor): [M', 3]
209
+ sampled_points (int, optional): use smaller number for faster testing. Defaults to 1000.
210
+
211
+ Returns:
212
+ tuple: chamfer, p2s, chamfer_H, p2s_H
213
+ """
214
+
215
+ gt_surface_pts, _ = trimesh.sample.sample_surface_even(
216
+ self.tgt_mesh, sampled_points)
217
+ pred_surface_pts, _ = trimesh.sample.sample_surface_even(
218
+ self.src_mesh, sampled_points)
219
+
220
+ _, dist_pred_gt, _ = trimesh.proximity.closest_point(
221
+ self.src_mesh, gt_surface_pts)
222
+ _, dist_gt_pred, _ = trimesh.proximity.closest_point(
223
+ self.tgt_mesh, pred_surface_pts)
224
+
225
+ dist_pred_gt[np.isnan(dist_pred_gt)] = 0
226
+ dist_gt_pred[np.isnan(dist_gt_pred)] = 0
227
+ chamfer_dist = 0.5 * (dist_pred_gt.mean() +
228
+ dist_gt_pred.mean()).item() * 100
229
+ p2s_dist = dist_pred_gt.mean().item() * 100
230
+
231
+ return chamfer_dist, p2s_dist
232
+
233
+ def calc_acc(self, output, target, thres=0.5, use_sdf=False):
234
+
235
+ # # remove the surface points with thres
236
+ # non_surf_ids = (target != thres)
237
+ # output = output[non_surf_ids]
238
+ # target = target[non_surf_ids]
239
+
240
+ with torch.no_grad():
241
+ output = output.masked_fill(output < thres, 0.0)
242
+ output = output.masked_fill(output > thres, 1.0)
243
+
244
+ if use_sdf:
245
+ target = target.masked_fill(target < thres, 0.0)
246
+ target = target.masked_fill(target > thres, 1.0)
247
+
248
+ acc = output.eq(target).float().mean()
249
+
250
+ # iou, precison, recall
251
+ output = output > thres
252
+ target = target > thres
253
+
254
+ union = output | target
255
+ inter = output & target
256
+
257
+ _max = torch.tensor(1.0).to(output.device)
258
+
259
+ union = max(union.sum().float(), _max)
260
+ true_pos = max(inter.sum().float(), _max)
261
+ vol_pred = max(output.sum().float(), _max)
262
+ vol_gt = max(target.sum().float(), _max)
263
+
264
+ return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt
lib/dataset/NormalDataset.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import os.path as osp
19
+ import numpy as np
20
+ from PIL import Image
21
+ import torchvision.transforms as transforms
22
+
23
+
24
+ class NormalDataset():
25
+ def __init__(self, cfg, split='train'):
26
+
27
+ self.split = split
28
+ self.root = cfg.root
29
+ self.overfit = cfg.overfit
30
+
31
+ self.opt = cfg.dataset
32
+ self.datasets = self.opt.types
33
+ self.input_size = self.opt.input_size
34
+ self.set_splits = self.opt.set_splits
35
+ self.scales = self.opt.scales
36
+ self.pifu = self.opt.pifu
37
+
38
+ # input data types and dimensions
39
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
40
+ self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
41
+ self.in_total = self.in_nml + ['normal_F', 'normal_B']
42
+ self.in_total_dim = self.in_nml_dim + [3, 3]
43
+
44
+ if self.split != 'train':
45
+ self.rotations = range(0, 360, 120)
46
+ else:
47
+ self.rotations = np.arange(0, 360, 360 /
48
+ self.opt.rotation_num).astype(np.int)
49
+
50
+ self.datasets_dict = {}
51
+ for dataset_id, dataset in enumerate(self.datasets):
52
+ dataset_dir = osp.join(self.root, dataset, "smplx")
53
+ self.datasets_dict[dataset] = {
54
+ "subjects":
55
+ np.loadtxt(osp.join(self.root, dataset, "all.txt"), dtype=str),
56
+ "path":
57
+ dataset_dir,
58
+ "scale":
59
+ self.scales[dataset_id]
60
+ }
61
+
62
+ self.subject_list = self.get_subject_list(split)
63
+
64
+ # PIL to tensor
65
+ self.image_to_tensor = transforms.Compose([
66
+ transforms.Resize(self.input_size),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69
+ ])
70
+
71
+ # PIL to tensor
72
+ self.mask_to_tensor = transforms.Compose([
73
+ transforms.Resize(self.input_size),
74
+ transforms.ToTensor(),
75
+ transforms.Normalize((0.0, ), (1.0, ))
76
+ ])
77
+
78
+ def get_subject_list(self, split):
79
+
80
+ subject_list = []
81
+
82
+ for dataset in self.datasets:
83
+
84
+ if self.pifu:
85
+ txt = osp.join(self.root, dataset, f'{split}_pifu.txt')
86
+ else:
87
+ txt = osp.join(self.root, dataset, f'{split}.txt')
88
+
89
+ if osp.exists(txt):
90
+ print(f"load from {txt}")
91
+ subject_list += sorted(np.loadtxt(txt, dtype=str).tolist())
92
+
93
+ if self.pifu:
94
+ miss_pifu = sorted(
95
+ np.loadtxt(osp.join(self.root, dataset,
96
+ "miss_pifu.txt"),
97
+ dtype=str).tolist())
98
+ subject_list = [
99
+ subject for subject in subject_list
100
+ if subject not in miss_pifu
101
+ ]
102
+ subject_list = [
103
+ "renderpeople/" + subject for subject in subject_list
104
+ ]
105
+
106
+ else:
107
+ train_txt = osp.join(self.root, dataset, 'train.txt')
108
+ val_txt = osp.join(self.root, dataset, 'val.txt')
109
+ test_txt = osp.join(self.root, dataset, 'test.txt')
110
+
111
+ print(
112
+ f"generate lists of [train, val, test] \n {train_txt} \n {val_txt} \n {test_txt} \n"
113
+ )
114
+
115
+ split_txt = osp.join(self.root, dataset, f'{split}.txt')
116
+
117
+ subjects = self.datasets_dict[dataset]['subjects']
118
+ train_split = int(len(subjects) * self.set_splits[0])
119
+ val_split = int(
120
+ len(subjects) * self.set_splits[1]) + train_split
121
+
122
+ with open(train_txt, "w") as f:
123
+ f.write("\n".join(dataset + "/" + item
124
+ for item in subjects[:train_split]))
125
+ with open(val_txt, "w") as f:
126
+ f.write("\n".join(
127
+ dataset + "/" + item
128
+ for item in subjects[train_split:val_split]))
129
+ with open(test_txt, "w") as f:
130
+ f.write("\n".join(dataset + "/" + item
131
+ for item in subjects[val_split:]))
132
+
133
+ subject_list += sorted(
134
+ np.loadtxt(split_txt, dtype=str).tolist())
135
+
136
+ bug_list = sorted(
137
+ np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
138
+
139
+ subject_list = [
140
+ subject for subject in subject_list if (subject not in bug_list)
141
+ ]
142
+
143
+ return subject_list
144
+
145
+ def __len__(self):
146
+ return len(self.subject_list) * len(self.rotations)
147
+
148
+ def __getitem__(self, index):
149
+
150
+ # only pick the first data if overfitting
151
+ if self.overfit:
152
+ index = 0
153
+
154
+ rid = index % len(self.rotations)
155
+ mid = index // len(self.rotations)
156
+
157
+ rotation = self.rotations[rid]
158
+
159
+ # choose specific test sets
160
+ subject = self.subject_list[mid]
161
+
162
+ subject_render = "/".join(
163
+ [subject.split("/")[0] + "_12views",
164
+ subject.split("/")[1]])
165
+
166
+ # setup paths
167
+ data_dict = {
168
+ 'dataset':
169
+ subject.split("/")[0],
170
+ 'subject':
171
+ subject,
172
+ 'rotation':
173
+ rotation,
174
+ 'image_path':
175
+ osp.join(self.root, subject_render, 'render',
176
+ f'{rotation:03d}.png')
177
+ }
178
+
179
+ # image/normal/depth loader
180
+ for name, channel in zip(self.in_total, self.in_total_dim):
181
+
182
+ if name != 'image':
183
+ data_dict.update({
184
+ f'{name}_path':
185
+ osp.join(self.root, subject_render, name,
186
+ f'{rotation:03d}.png')
187
+ })
188
+ data_dict.update({
189
+ name:
190
+ self.imagepath2tensor(data_dict[f'{name}_path'],
191
+ channel,
192
+ inv='depth_B' in name)
193
+ })
194
+
195
+ path_keys = [
196
+ key for key in data_dict.keys() if '_path' in key or '_dir' in key
197
+ ]
198
+ for key in path_keys:
199
+ del data_dict[key]
200
+
201
+ return data_dict
202
+
203
+ def imagepath2tensor(self, path, channel=3, inv=False):
204
+
205
+ rgba = Image.open(path).convert('RGBA')
206
+ mask = rgba.split()[-1]
207
+ image = rgba.convert('RGB')
208
+ image = self.image_to_tensor(image)
209
+ mask = self.mask_to_tensor(mask)
210
+ image = (image * mask)[:channel]
211
+
212
+ return (image * (0.5 - inv) * 2.0).float()
lib/dataset/NormalModule.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import numpy as np
19
+ from torch.utils.data import DataLoader
20
+ from .NormalDataset import NormalDataset
21
+
22
+ # pytorch lightning related libs
23
+ import pytorch_lightning as pl
24
+
25
+
26
+ class NormalModule(pl.LightningDataModule):
27
+ def __init__(self, cfg):
28
+ super(NormalModule, self).__init__()
29
+ self.cfg = cfg
30
+ self.overfit = self.cfg.overfit
31
+
32
+ if self.overfit:
33
+ self.batch_size = 1
34
+ else:
35
+ self.batch_size = self.cfg.batch_size
36
+
37
+ self.data_size = {}
38
+
39
+ def prepare_data(self):
40
+
41
+ pass
42
+
43
+ @staticmethod
44
+ def worker_init_fn(worker_id):
45
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
46
+
47
+ def setup(self, stage):
48
+
49
+ if stage == 'fit' or stage is None:
50
+ self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
51
+ self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
52
+ self.data_size = {
53
+ 'train': len(self.train_dataset),
54
+ 'val': len(self.val_dataset)
55
+ }
56
+
57
+ if stage == 'test' or stage is None:
58
+ self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
59
+
60
+ def train_dataloader(self):
61
+
62
+ train_data_loader = DataLoader(self.train_dataset,
63
+ batch_size=self.batch_size,
64
+ shuffle=not self.overfit,
65
+ num_workers=self.cfg.num_threads,
66
+ pin_memory=True,
67
+ worker_init_fn=self.worker_init_fn)
68
+
69
+ return train_data_loader
70
+
71
+ def val_dataloader(self):
72
+
73
+ if self.overfit:
74
+ current_dataset = self.train_dataset
75
+ else:
76
+ current_dataset = self.val_dataset
77
+
78
+ val_data_loader = DataLoader(current_dataset,
79
+ batch_size=self.batch_size,
80
+ shuffle=False,
81
+ num_workers=self.cfg.num_threads,
82
+ pin_memory=True)
83
+
84
+ return val_data_loader
85
+
86
+ def test_dataloader(self):
87
+
88
+ test_data_loader = DataLoader(self.test_dataset,
89
+ batch_size=1,
90
+ shuffle=False,
91
+ num_workers=self.cfg.num_threads,
92
+ pin_memory=True)
93
+
94
+ return test_data_loader
lib/dataset/PIFuDataModule.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils.data import DataLoader
3
+ from .PIFuDataset import PIFuDataset
4
+ import pytorch_lightning as pl
5
+
6
+
7
+ class PIFuDataModule(pl.LightningDataModule):
8
+ def __init__(self, cfg):
9
+ super(PIFuDataModule, self).__init__()
10
+ self.cfg = cfg
11
+ self.overfit = self.cfg.overfit
12
+
13
+ if self.overfit:
14
+ self.batch_size = 1
15
+ else:
16
+ self.batch_size = self.cfg.batch_size
17
+
18
+ self.data_size = {}
19
+
20
+ def prepare_data(self):
21
+
22
+ pass
23
+
24
+ @staticmethod
25
+ def worker_init_fn(worker_id):
26
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
27
+
28
+ def setup(self, stage):
29
+
30
+ if stage == 'fit':
31
+ self.train_dataset = PIFuDataset(cfg=self.cfg, split="train")
32
+ self.val_dataset = PIFuDataset(cfg=self.cfg, split="val")
33
+ self.data_size = {'train': len(self.train_dataset),
34
+ 'val': len(self.val_dataset)}
35
+
36
+ if stage == 'test':
37
+ self.test_dataset = PIFuDataset(cfg=self.cfg, split="test")
38
+
39
+ def train_dataloader(self):
40
+
41
+ train_data_loader = DataLoader(
42
+ self.train_dataset,
43
+ batch_size=self.batch_size, shuffle=True,
44
+ num_workers=self.cfg.num_threads, pin_memory=True,
45
+ worker_init_fn=self.worker_init_fn)
46
+
47
+ return train_data_loader
48
+
49
+ def val_dataloader(self):
50
+
51
+ if self.overfit:
52
+ current_dataset = self.train_dataset
53
+ else:
54
+ current_dataset = self.val_dataset
55
+
56
+ val_data_loader = DataLoader(
57
+ current_dataset,
58
+ batch_size=1, shuffle=False,
59
+ num_workers=self.cfg.num_threads, pin_memory=True,
60
+ worker_init_fn=self.worker_init_fn)
61
+
62
+ return val_data_loader
63
+
64
+ def test_dataloader(self):
65
+
66
+ test_data_loader = DataLoader(
67
+ self.test_dataset,
68
+ batch_size=1, shuffle=False,
69
+ num_workers=self.cfg.num_threads, pin_memory=True)
70
+
71
+ return test_data_loader
lib/dataset/PIFuDataset.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib.renderer.mesh import load_fit_body
2
+ from lib.dataset.hoppeMesh import HoppeMesh
3
+ from lib.dataset.body_model import TetraSMPLModel
4
+ from lib.common.render import Render
5
+ from lib.dataset.mesh_util import SMPLX, projection, cal_sdf_batch, get_visibility
6
+ from lib.pare.pare.utils.geometry import rotation_matrix_to_angle_axis
7
+ from termcolor import colored
8
+ import os.path as osp
9
+ import numpy as np
10
+ from PIL import Image
11
+ import random
12
+ import trimesh
13
+ import torch
14
+ import vedo
15
+ from kaolin.ops.mesh import check_sign
16
+ import torchvision.transforms as transforms
17
+ from ipdb import set_trace
18
+
19
+
20
+ class PIFuDataset():
21
+ def __init__(self, cfg, split='train', vis=False):
22
+
23
+ self.split = split
24
+ self.root = cfg.root
25
+ self.bsize = cfg.batch_size
26
+ self.overfit = cfg.overfit
27
+
28
+ # for debug, only used in visualize_sampling3D
29
+ self.vis = vis
30
+
31
+ self.opt = cfg.dataset
32
+ self.datasets = self.opt.types
33
+ self.input_size = self.opt.input_size
34
+ self.scales = self.opt.scales
35
+ self.workers = cfg.num_threads
36
+ self.prior_type = cfg.net.prior_type
37
+
38
+ self.noise_type = self.opt.noise_type
39
+ self.noise_scale = self.opt.noise_scale
40
+
41
+ noise_joints = [4, 5, 7, 8, 13, 14, 16, 17, 18, 19, 20, 21]
42
+
43
+ self.noise_smpl_idx = []
44
+ self.noise_smplx_idx = []
45
+
46
+ for idx in noise_joints:
47
+ self.noise_smpl_idx.append(idx * 3)
48
+ self.noise_smpl_idx.append(idx * 3 + 1)
49
+ self.noise_smpl_idx.append(idx * 3 + 2)
50
+
51
+ self.noise_smplx_idx.append((idx-1) * 3)
52
+ self.noise_smplx_idx.append((idx-1) * 3 + 1)
53
+ self.noise_smplx_idx.append((idx-1) * 3 + 2)
54
+
55
+ self.use_sdf = cfg.sdf
56
+ self.sdf_clip = cfg.sdf_clip
57
+
58
+ # [(feat_name, channel_num),...]
59
+ self.in_geo = [item[0] for item in cfg.net.in_geo]
60
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
61
+
62
+ self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
63
+ self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
64
+
65
+ self.in_total = self.in_geo + self.in_nml
66
+ self.in_total_dim = self.in_geo_dim + self.in_nml_dim
67
+
68
+ if self.split == 'train':
69
+ self.rotations = np.arange(
70
+ 0, 360, 360 / self.opt.rotation_num).astype(np.int32)
71
+ else:
72
+ self.rotations = range(0, 360, 120)
73
+
74
+ self.datasets_dict = {}
75
+
76
+ for dataset_id, dataset in enumerate(self.datasets):
77
+
78
+ mesh_dir = None
79
+ smplx_dir = None
80
+
81
+ dataset_dir = osp.join(self.root, dataset)
82
+
83
+ if dataset in ['thuman2']:
84
+ mesh_dir = osp.join(dataset_dir, "scans")
85
+ smplx_dir = osp.join(dataset_dir, "fits")
86
+ smpl_dir = osp.join(dataset_dir, "smpl")
87
+
88
+ self.datasets_dict[dataset] = {
89
+ "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str),
90
+ "smplx_dir": smplx_dir,
91
+ "smpl_dir": smpl_dir,
92
+ "mesh_dir": mesh_dir,
93
+ "scale": self.scales[dataset_id]
94
+ }
95
+
96
+ self.subject_list = self.get_subject_list(split)
97
+ self.smplx = SMPLX()
98
+
99
+ # PIL to tensor
100
+ self.image_to_tensor = transforms.Compose([
101
+ transforms.Resize(self.input_size),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
104
+ ])
105
+
106
+ # PIL to tensor
107
+ self.mask_to_tensor = transforms.Compose([
108
+ transforms.Resize(self.input_size),
109
+ transforms.ToTensor(),
110
+ transforms.Normalize((0.0, ), (1.0, ))
111
+ ])
112
+
113
+ self.device = torch.device(f"cuda:{cfg.gpus[0]}")
114
+ self.render = Render(size=512, device=self.device)
115
+
116
+ def render_normal(self, verts, faces):
117
+
118
+ # render optimized mesh (normal, T_normal, image [-1,1])
119
+ self.render.load_meshes(verts, faces)
120
+ return self.render.get_rgb_image()
121
+
122
+ def get_subject_list(self, split):
123
+
124
+ subject_list = []
125
+
126
+ for dataset in self.datasets:
127
+
128
+ split_txt = osp.join(self.root, dataset, f'{split}.txt')
129
+
130
+ if osp.exists(split_txt):
131
+ print(f"load from {split_txt}")
132
+ subject_list += np.loadtxt(split_txt, dtype=str).tolist()
133
+ else:
134
+ full_txt = osp.join(self.root, dataset, 'all.txt')
135
+ print(f"split {full_txt} into train/val/test")
136
+
137
+ full_lst = np.loadtxt(full_txt, dtype=str)
138
+ full_lst = [dataset+"/"+item for item in full_lst]
139
+ [train_lst, test_lst, val_lst] = np.split(
140
+ full_lst, [500, 500+5, ])
141
+
142
+ np.savetxt(full_txt.replace(
143
+ "all", "train"), train_lst, fmt="%s")
144
+ np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s")
145
+ np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s")
146
+
147
+ print(f"load from {split_txt}")
148
+ subject_list += np.loadtxt(split_txt, dtype=str).tolist()
149
+
150
+ if self.split != 'test':
151
+ subject_list += subject_list[:self.bsize -
152
+ len(subject_list) % self.bsize]
153
+ print(colored(f"total: {len(subject_list)}", "yellow"))
154
+ random.shuffle(subject_list)
155
+
156
+ # subject_list = ["thuman2/0008"]
157
+ return subject_list
158
+
159
+ def __len__(self):
160
+ return len(self.subject_list) * len(self.rotations)
161
+
162
+ def __getitem__(self, index):
163
+
164
+ # only pick the first data if overfitting
165
+ if self.overfit:
166
+ index = 0
167
+
168
+ rid = index % len(self.rotations)
169
+ mid = index // len(self.rotations)
170
+
171
+ rotation = self.rotations[rid]
172
+ subject = self.subject_list[mid].split("/")[1]
173
+ dataset = self.subject_list[mid].split("/")[0]
174
+ render_folder = "/".join([dataset +
175
+ f"_{self.opt.rotation_num}views", subject])
176
+
177
+ # setup paths
178
+ data_dict = {
179
+ 'dataset': dataset,
180
+ 'subject': subject,
181
+ 'rotation': rotation,
182
+ 'scale': self.datasets_dict[dataset]["scale"],
183
+ 'mesh_path': osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}/{subject}.obj"),
184
+ 'smplx_path': osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}/smplx_param.pkl"),
185
+ 'smpl_path': osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.pkl"),
186
+ 'calib_path': osp.join(self.root, render_folder, 'calib', f'{rotation:03d}.txt'),
187
+ 'vis_path': osp.join(self.root, render_folder, 'vis', f'{rotation:03d}.pt'),
188
+ 'image_path': osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png')
189
+ }
190
+
191
+ # load training data
192
+ data_dict.update(self.load_calib(data_dict))
193
+
194
+ # image/normal/depth loader
195
+ for name, channel in zip(self.in_total, self.in_total_dim):
196
+
197
+ if f'{name}_path' not in data_dict.keys():
198
+ data_dict.update({
199
+ f'{name}_path': osp.join(self.root, render_folder, name, f'{rotation:03d}.png')
200
+ })
201
+
202
+ # tensor update
203
+ data_dict.update({
204
+ name: self.imagepath2tensor(
205
+ data_dict[f'{name}_path'], channel, inv=False)
206
+ })
207
+
208
+ data_dict.update(self.load_mesh(data_dict))
209
+ data_dict.update(self.get_sampling_geo(
210
+ data_dict, is_valid=self.split == "val", is_sdf=self.use_sdf))
211
+ data_dict.update(self.load_smpl(data_dict, self.vis))
212
+
213
+ if self.prior_type == 'pamir':
214
+ data_dict.update(self.load_smpl_voxel(data_dict))
215
+
216
+ if (self.split != 'test') and (not self.vis):
217
+
218
+ del data_dict['verts']
219
+ del data_dict['faces']
220
+
221
+ if not self.vis:
222
+ del data_dict['mesh']
223
+
224
+ path_keys = [
225
+ key for key in data_dict.keys() if '_path' in key or '_dir' in key
226
+ ]
227
+ for key in path_keys:
228
+ del data_dict[key]
229
+
230
+ return data_dict
231
+
232
+ def imagepath2tensor(self, path, channel=3, inv=False):
233
+
234
+ rgba = Image.open(path).convert('RGBA')
235
+ mask = rgba.split()[-1]
236
+ image = rgba.convert('RGB')
237
+ image = self.image_to_tensor(image)
238
+ mask = self.mask_to_tensor(mask)
239
+ image = (image * mask)[:channel]
240
+
241
+ return (image * (0.5 - inv) * 2.0).float()
242
+
243
+ def load_calib(self, data_dict):
244
+ calib_data = np.loadtxt(data_dict['calib_path'], dtype=float)
245
+ extrinsic = calib_data[:4, :4]
246
+ intrinsic = calib_data[4:8, :4]
247
+ calib_mat = np.matmul(intrinsic, extrinsic)
248
+ calib_mat = torch.from_numpy(calib_mat).float()
249
+ return {'calib': calib_mat}
250
+
251
+ def load_mesh(self, data_dict):
252
+ mesh_path = data_dict['mesh_path']
253
+ scale = data_dict['scale']
254
+
255
+ mesh_ori = trimesh.load(mesh_path,
256
+ skip_materials=True,
257
+ process=False,
258
+ maintain_order=True)
259
+ verts = mesh_ori.vertices * scale
260
+ faces = mesh_ori.faces
261
+
262
+ vert_normals = np.array(mesh_ori.vertex_normals)
263
+ face_normals = np.array(mesh_ori.face_normals)
264
+
265
+ mesh = HoppeMesh(verts, faces, vert_normals, face_normals)
266
+
267
+ return {
268
+ 'mesh': mesh,
269
+ 'verts': torch.as_tensor(mesh.verts).float(),
270
+ 'faces': torch.as_tensor(mesh.faces).long()
271
+ }
272
+
273
+ def add_noise(self,
274
+ beta_num,
275
+ smpl_pose,
276
+ smpl_betas,
277
+ noise_type,
278
+ noise_scale,
279
+ type,
280
+ hashcode):
281
+
282
+ np.random.seed(hashcode)
283
+
284
+ if type == 'smplx':
285
+ noise_idx = self.noise_smplx_idx
286
+ else:
287
+ noise_idx = self.noise_smpl_idx
288
+
289
+ if 'beta' in noise_type and noise_scale[noise_type.index("beta")] > 0.0:
290
+ smpl_betas += (np.random.rand(beta_num) -
291
+ 0.5) * 2.0 * noise_scale[noise_type.index("beta")]
292
+ smpl_betas = smpl_betas.astype(np.float32)
293
+
294
+ if 'pose' in noise_type and noise_scale[noise_type.index("pose")] > 0.0:
295
+ smpl_pose[noise_idx] += (
296
+ np.random.rand(len(noise_idx)) -
297
+ 0.5) * 2.0 * np.pi * noise_scale[noise_type.index("pose")]
298
+ smpl_pose = smpl_pose.astype(np.float32)
299
+ if type == 'smplx':
300
+ return torch.as_tensor(smpl_pose[None, ...]), torch.as_tensor(smpl_betas[None, ...])
301
+ else:
302
+ return smpl_pose, smpl_betas
303
+
304
+ def compute_smpl_verts(self, data_dict, noise_type=None, noise_scale=None):
305
+
306
+ dataset = data_dict['dataset']
307
+ smplx_dict = {}
308
+
309
+ smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
310
+ smplx_pose = smplx_param["body_pose"] # [1,63]
311
+ smplx_betas = smplx_param["betas"] # [1,10]
312
+ smplx_pose, smplx_betas = self.add_noise(
313
+ smplx_betas.shape[1],
314
+ smplx_pose[0],
315
+ smplx_betas[0],
316
+ noise_type,
317
+ noise_scale,
318
+ type='smplx',
319
+ hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
320
+
321
+ smplx_out, _ = load_fit_body(fitted_path=data_dict['smplx_path'],
322
+ scale=self.datasets_dict[dataset]['scale'],
323
+ smpl_type='smplx',
324
+ smpl_gender='male',
325
+ noise_dict=dict(betas=smplx_betas, body_pose=smplx_pose))
326
+
327
+ smplx_dict.update({"type": "smplx",
328
+ "gender": 'male',
329
+ "body_pose": torch.as_tensor(smplx_pose),
330
+ "betas": torch.as_tensor(smplx_betas)})
331
+
332
+ return smplx_out.vertices, smplx_dict
333
+
334
+ def compute_voxel_verts(self,
335
+ data_dict,
336
+ noise_type=None,
337
+ noise_scale=None):
338
+
339
+ smpl_param = np.load(data_dict['smpl_path'], allow_pickle=True)
340
+ smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
341
+
342
+ smpl_pose = rotation_matrix_to_angle_axis(
343
+ torch.as_tensor(smpl_param['full_pose'][0])).numpy()
344
+ smpl_betas = smpl_param["betas"]
345
+
346
+ smpl_path = osp.join(self.smplx.model_dir, "smpl/SMPL_MALE.pkl")
347
+ tetra_path = osp.join(self.smplx.tedra_dir,
348
+ "tetra_male_adult_smpl.npz")
349
+
350
+ smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
351
+
352
+ smpl_pose, smpl_betas = self.add_noise(
353
+ smpl_model.beta_shape[0],
354
+ smpl_pose.flatten(),
355
+ smpl_betas[0],
356
+ noise_type,
357
+ noise_scale,
358
+ type='smpl',
359
+ hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
360
+
361
+ smpl_model.set_params(pose=smpl_pose.reshape(-1, 3),
362
+ beta=smpl_betas,
363
+ trans=smpl_param["transl"])
364
+
365
+ verts = (np.concatenate([smpl_model.verts, smpl_model.verts_added],
366
+ axis=0) * smplx_param["scale"] + smplx_param["translation"]
367
+ ) * self.datasets_dict[data_dict['dataset']]['scale']
368
+ faces = np.loadtxt(osp.join(self.smplx.tedra_dir, "tetrahedrons_male_adult.txt"),
369
+ dtype=np.int32) - 1
370
+
371
+ pad_v_num = int(8000 - verts.shape[0])
372
+ pad_f_num = int(25100 - faces.shape[0])
373
+
374
+ verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
375
+ mode='constant',
376
+ constant_values=0.0).astype(np.float32)
377
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
378
+ mode='constant',
379
+ constant_values=0.0).astype(np.int32)
380
+
381
+
382
+ return verts, faces, pad_v_num, pad_f_num
383
+
384
+ def load_smpl(self, data_dict, vis=False):
385
+
386
+ smplx_verts, smplx_dict = self.compute_smpl_verts(
387
+ data_dict, self.noise_type,
388
+ self.noise_scale) # compute using smpl model
389
+
390
+ smplx_verts = projection(smplx_verts, data_dict['calib']).float()
391
+ smplx_faces = torch.as_tensor(self.smplx.faces).long()
392
+ smplx_vis = torch.load(data_dict['vis_path']).float()
393
+ smplx_cmap = torch.as_tensor(
394
+ np.load(self.smplx.cmap_vert_path)).float()
395
+
396
+ # get smpl_signs
397
+ query_points = projection(data_dict['samples_geo'],
398
+ data_dict['calib']).float()
399
+
400
+ pts_signs = 2.0 * (check_sign(smplx_verts.unsqueeze(0),
401
+ smplx_faces,
402
+ query_points.unsqueeze(0)).float() - 0.5).squeeze(0)
403
+
404
+ return_dict = {
405
+ 'smpl_verts': smplx_verts,
406
+ 'smpl_faces': smplx_faces,
407
+ 'smpl_vis': smplx_vis,
408
+ 'smpl_cmap': smplx_cmap,
409
+ 'pts_signs': pts_signs
410
+ }
411
+ if smplx_dict is not None:
412
+ return_dict.update(smplx_dict)
413
+
414
+ if vis:
415
+
416
+ (xy, z) = torch.as_tensor(smplx_verts).to(
417
+ self.device).split([2, 1], dim=1)
418
+ smplx_vis = get_visibility(xy, z, torch.as_tensor(
419
+ smplx_faces).to(self.device).long())
420
+
421
+ T_normal_F, T_normal_B = self.render_normal(
422
+ (smplx_verts*torch.tensor([1.0, -1.0, 1.0])).to(self.device),
423
+ smplx_faces.to(self.device))
424
+
425
+ return_dict.update({"T_normal_F": T_normal_F.squeeze(0),
426
+ "T_normal_B": T_normal_B.squeeze(0)})
427
+ query_points = projection(data_dict['samples_geo'],
428
+ data_dict['calib']).float()
429
+
430
+ smplx_sdf, smplx_norm, smplx_cmap, smplx_vis = cal_sdf_batch(
431
+ smplx_verts.unsqueeze(0).to(self.device),
432
+ smplx_faces.unsqueeze(0).to(self.device),
433
+ smplx_cmap.unsqueeze(0).to(self.device),
434
+ smplx_vis.unsqueeze(0).to(self.device),
435
+ query_points.unsqueeze(0).contiguous().to(self.device))
436
+
437
+ return_dict.update({
438
+ 'smpl_feat':
439
+ torch.cat(
440
+ (smplx_sdf[0].detach().cpu(),
441
+ smplx_cmap[0].detach().cpu(),
442
+ smplx_norm[0].detach().cpu(),
443
+ smplx_vis[0].detach().cpu()),
444
+ dim=1)
445
+ })
446
+
447
+ return return_dict
448
+
449
+ def load_smpl_voxel(self, data_dict):
450
+
451
+ smpl_verts, smpl_faces, pad_v_num, pad_f_num = self.compute_voxel_verts(
452
+ data_dict, self.noise_type,
453
+ self.noise_scale) # compute using smpl model
454
+ smpl_verts = projection(smpl_verts, data_dict['calib'])
455
+
456
+ smpl_verts *= 0.5
457
+
458
+ return {
459
+ 'voxel_verts': smpl_verts,
460
+ 'voxel_faces': smpl_faces,
461
+ 'pad_v_num': pad_v_num,
462
+ 'pad_f_num': pad_f_num
463
+ }
464
+
465
+ def get_sampling_geo(self, data_dict, is_valid=False, is_sdf=False):
466
+
467
+ mesh = data_dict['mesh']
468
+ calib = data_dict['calib']
469
+
470
+ # Samples are around the true surface with an offset
471
+ n_samples_surface = 4 * self.opt.num_sample_geo
472
+ vert_ids = np.arange(mesh.verts.shape[0])
473
+ thickness_sample_ratio = np.ones_like(vert_ids).astype(np.float32)
474
+
475
+ thickness_sample_ratio /= thickness_sample_ratio.sum()
476
+
477
+ samples_surface_ids = np.random.choice(vert_ids,
478
+ n_samples_surface,
479
+ replace=True,
480
+ p=thickness_sample_ratio)
481
+
482
+ samples_normal_ids = np.random.choice(vert_ids,
483
+ self.opt.num_sample_geo // 2,
484
+ replace=False,
485
+ p=thickness_sample_ratio)
486
+
487
+ surf_samples = mesh.verts[samples_normal_ids, :]
488
+ surf_normals = mesh.vert_normals[samples_normal_ids, :]
489
+
490
+ samples_surface = mesh.verts[samples_surface_ids, :]
491
+
492
+ # Sampling offsets are random noise with constant scale (15cm - 20cm)
493
+ offset = np.random.normal(scale=self.opt.sigma_geo,
494
+ size=(n_samples_surface, 1))
495
+ samples_surface += mesh.vert_normals[samples_surface_ids, :] * offset
496
+
497
+ # Uniform samples in [-1, 1]
498
+ calib_inv = np.linalg.inv(calib)
499
+ n_samples_space = self.opt.num_sample_geo // 4
500
+ samples_space_img = 2.0 * np.random.rand(n_samples_space, 3) - 1.0
501
+ samples_space = projection(samples_space_img, calib_inv)
502
+
503
+ # z-ray direction samples
504
+ if self.opt.zray_type and not is_valid:
505
+ n_samples_rayz = self.opt.ray_sample_num
506
+ samples_surface_cube = projection(samples_surface, calib)
507
+ samples_surface_cube_repeat = np.repeat(samples_surface_cube,
508
+ n_samples_rayz,
509
+ axis=0)
510
+
511
+ thickness_repeat = np.repeat(0.5 *
512
+ np.ones_like(samples_surface_ids),
513
+ n_samples_rayz,
514
+ axis=0)
515
+
516
+ noise_repeat = np.random.normal(scale=0.40,
517
+ size=(n_samples_surface *
518
+ n_samples_rayz, ))
519
+ samples_surface_cube_repeat[:,
520
+ -1] += thickness_repeat * noise_repeat
521
+ samples_surface_rayz = projection(samples_surface_cube_repeat,
522
+ calib_inv)
523
+
524
+ samples = np.concatenate(
525
+ [samples_surface, samples_space, samples_surface_rayz], 0)
526
+ else:
527
+ samples = np.concatenate([samples_surface, samples_space], 0)
528
+
529
+ np.random.shuffle(samples)
530
+
531
+ # labels: in->1.0; out->0.0.
532
+ if is_sdf:
533
+ sdfs = mesh.get_sdf(samples)
534
+ inside_samples = samples[sdfs < 0]
535
+ outside_samples = samples[sdfs >= 0]
536
+
537
+ inside_sdfs = sdfs[sdfs < 0]
538
+ outside_sdfs = sdfs[sdfs >= 0]
539
+ else:
540
+ inside = mesh.contains(samples)
541
+ inside_samples = samples[inside >= 0.5]
542
+ outside_samples = samples[inside < 0.5]
543
+
544
+ nin = inside_samples.shape[0]
545
+
546
+ if nin > self.opt.num_sample_geo // 2:
547
+ inside_samples = inside_samples[:self.opt.num_sample_geo // 2]
548
+ outside_samples = outside_samples[:self.opt.num_sample_geo // 2]
549
+ if is_sdf:
550
+ inside_sdfs = inside_sdfs[:self.opt.num_sample_geo // 2]
551
+ outside_sdfs = outside_sdfs[:self.opt.num_sample_geo // 2]
552
+ else:
553
+ outside_samples = outside_samples[:(self.opt.num_sample_geo - nin)]
554
+ if is_sdf:
555
+ outside_sdfs = outside_sdfs[:(self.opt.num_sample_geo - nin)]
556
+
557
+ if is_sdf:
558
+ samples = np.concatenate(
559
+ [inside_samples, outside_samples, surf_samples], 0)
560
+
561
+ labels = np.concatenate([
562
+ inside_sdfs, outside_sdfs, 0.0 * np.ones(surf_samples.shape[0])
563
+ ])
564
+
565
+ normals = np.zeros_like(samples)
566
+ normals[-self.opt.num_sample_geo // 2:, :] = surf_normals
567
+
568
+ # convert sdf from [-14, 130] to [0, 1]
569
+ # outside: 0, inside: 1
570
+ # Note: Marching cubes is defined on occupancy space (inside=1.0, outside=0.0)
571
+
572
+ labels = -labels.clip(min=-self.sdf_clip, max=self.sdf_clip)
573
+ labels += self.sdf_clip
574
+ labels /= (self.sdf_clip * 2)
575
+
576
+ else:
577
+ samples = np.concatenate([inside_samples, outside_samples])
578
+ labels = np.concatenate([
579
+ np.ones(inside_samples.shape[0]),
580
+ np.zeros(outside_samples.shape[0])
581
+ ])
582
+
583
+ normals = np.zeros_like(samples)
584
+
585
+ samples = torch.from_numpy(samples).float()
586
+ labels = torch.from_numpy(labels).float()
587
+ normals = torch.from_numpy(normals).float()
588
+
589
+ return {'samples_geo': samples, 'labels_geo': labels}
590
+
591
+ def visualize_sampling3D(self, data_dict, mode='vis'):
592
+
593
+ # create plot
594
+ vp = vedo.Plotter(title="", size=(1500, 1500), axes=0, bg='white')
595
+ vis_list = []
596
+
597
+ assert mode in ['vis', 'sdf', 'normal', 'cmap', 'occ']
598
+
599
+ # sdf-1 cmap-3 norm-3 vis-1
600
+ if mode == 'vis':
601
+ labels = data_dict[f'smpl_feat'][:, [-1]] # visibility
602
+ colors = np.concatenate([labels, labels, labels], axis=1)
603
+ elif mode == 'occ':
604
+ labels = data_dict[f'labels_geo'][..., None] # occupancy
605
+ colors = np.concatenate([labels, labels, labels], axis=1)
606
+ elif mode == 'sdf':
607
+ labels = data_dict[f'smpl_feat'][:, [0]] # sdf
608
+ labels -= labels.min()
609
+ labels /= labels.max()
610
+ colors = np.concatenate([labels, labels, labels], axis=1)
611
+ elif mode == 'normal':
612
+ labels = data_dict[f'smpl_feat'][:, -4:-1] # normal
613
+ colors = (labels + 1.0) * 0.5
614
+ elif mode == 'cmap':
615
+ labels = data_dict[f'smpl_feat'][:, -7:-4] # colormap
616
+ colors = np.array(labels)
617
+
618
+ points = projection(data_dict['samples_geo'], data_dict['calib'])
619
+ verts = projection(data_dict['verts'], data_dict['calib'])
620
+ points[:, 1] *= -1
621
+ verts[:, 1] *= -1
622
+
623
+ # create a mesh
624
+ mesh = trimesh.Trimesh(verts, data_dict['faces'], process=True)
625
+ mesh.visual.vertex_colors = [128.0, 128.0, 128.0, 255.0]
626
+ vis_list.append(mesh)
627
+
628
+ if 'voxel_verts' in data_dict.keys():
629
+ print(colored("voxel verts", "green"))
630
+ voxel_verts = data_dict['voxel_verts'] * 2.0
631
+ voxel_faces = data_dict['voxel_faces']
632
+ voxel_verts[:, 1] *= -1
633
+ voxel = trimesh.Trimesh(
634
+ voxel_verts, voxel_faces[:, [0, 2, 1]], process=False, maintain_order=True)
635
+ voxel.visual.vertex_colors = [0.0, 128.0, 0.0, 255.0]
636
+ vis_list.append(voxel)
637
+
638
+ if 'smpl_verts' in data_dict.keys():
639
+ print(colored("smpl verts", "green"))
640
+ smplx_verts = data_dict['smpl_verts']
641
+ smplx_faces = data_dict['smpl_faces']
642
+ smplx_verts[:, 1] *= -1
643
+ smplx = trimesh.Trimesh(
644
+ smplx_verts, smplx_faces[:, [0, 2, 1]], process=False, maintain_order=True)
645
+ smplx.visual.vertex_colors = [128.0, 128.0, 0.0, 255.0]
646
+ vis_list.append(smplx)
647
+
648
+ # create a picure
649
+ img_pos = [1.0, 0.0, -1.0]
650
+ for img_id, img_key in enumerate(['normal_F', 'image', 'T_normal_B']):
651
+ image_arr = (data_dict[img_key].detach().cpu().permute(
652
+ 1, 2, 0).numpy() + 1.0) * 0.5 * 255.0
653
+ image_dim = image_arr.shape[0]
654
+ image = vedo.Picture(image_arr).scale(
655
+ 2.0 / image_dim).pos(-1.0, -1.0, img_pos[img_id])
656
+ vis_list.append(image)
657
+
658
+ # create a pointcloud
659
+ pc = vedo.Points(points, r=15, c=np.float32(colors))
660
+ vis_list.append(pc)
661
+
662
+ vp.show(*vis_list, bg="white", axes=1.0, interactive=True)
lib/dataset/TestDataset.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: [email protected]
17
+
18
+ import smplx
19
+ from lib.pymaf.utils.geometry import rotation_matrix_to_angle_axis, batch_rodrigues
20
+ from lib.pymaf.utils.imutils import process_image
21
+ from lib.pymaf.core import path_config
22
+ from lib.pymaf.models import pymaf_net
23
+ from lib.common.config import cfg
24
+ from lib.common.render import Render
25
+ from lib.dataset.body_model import TetraSMPLModel
26
+ from lib.dataset.mesh_util import get_visibility, SMPLX
27
+ import os.path as osp
28
+ import os
29
+ import torch
30
+ import glob
31
+ import numpy as np
32
+ import random
33
+ import human_det
34
+ from termcolor import colored
35
+ from PIL import ImageFile
36
+
37
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
38
+
39
+
40
+ class TestDataset():
41
+ def __init__(self, cfg, device):
42
+
43
+ random.seed(1993)
44
+
45
+ self.image_dir = cfg['image_dir']
46
+ self.seg_dir = cfg['seg_dir']
47
+ self.has_det = cfg['has_det']
48
+ self.hps_type = cfg['hps_type']
49
+ self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx'
50
+ self.smpl_gender = 'neutral'
51
+
52
+ self.device = device
53
+
54
+ if self.has_det:
55
+ self.det = human_det.Detection()
56
+ else:
57
+ self.det = None
58
+
59
+ keep_lst = sorted(glob.glob(f"{self.image_dir}/*"))
60
+ img_fmts = ['jpg', 'png', 'jpeg', "JPG", 'bmp']
61
+ keep_lst = [
62
+ item for item in keep_lst if item.split(".")[-1] in img_fmts
63
+ ]
64
+
65
+ self.subject_list = sorted(
66
+ [item for item in keep_lst if item.split(".")[-1] in img_fmts])
67
+
68
+ # smpl related
69
+ self.smpl_data = SMPLX()
70
+
71
+ self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(
72
+ model_path=self.smpl_data.model_dir,
73
+ gender=smpl_gender,
74
+ model_type=smpl_type,
75
+ ext='npz')
76
+
77
+ # Load SMPL model
78
+ self.smpl_model = self.get_smpl_model(
79
+ self.smpl_type, self.smpl_gender).to(self.device)
80
+ self.faces = self.smpl_model.faces
81
+
82
+ self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS,
83
+ pretrained=True).to(self.device)
84
+ self.hps.load_state_dict(torch.load(
85
+ path_config.CHECKPOINT_FILE)['model'],
86
+ strict=True)
87
+ self.hps.eval()
88
+
89
+ print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green"))
90
+
91
+ self.render = Render(size=512, device=device)
92
+
93
+ def __len__(self):
94
+ return len(self.subject_list)
95
+
96
+ def compute_vis_cmap(self, smpl_verts, smpl_faces):
97
+
98
+ (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
99
+ smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
100
+ if self.smpl_type == 'smpl':
101
+ smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
102
+ else:
103
+ smplx_ind = np.arange(smpl_vis.shape[0])
104
+ smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
105
+
106
+ return {
107
+ 'smpl_vis': smpl_vis.unsqueeze(0).to(self.device),
108
+ 'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device),
109
+ 'smpl_verts': smpl_verts.unsqueeze(0)
110
+ }
111
+
112
+ def compute_voxel_verts(self, body_pose, global_orient, betas, trans,
113
+ scale):
114
+
115
+ smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl")
116
+ tetra_path = osp.join(self.smpl_data.tedra_dir,
117
+ 'tetra_neutral_adult_smpl.npz')
118
+ smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
119
+
120
+ pose = torch.cat([global_orient[0], body_pose[0]], dim=0)
121
+ smpl_model.set_params(rotation_matrix_to_angle_axis(pose),
122
+ beta=betas[0])
123
+
124
+ verts = np.concatenate(
125
+ [smpl_model.verts, smpl_model.verts_added],
126
+ axis=0) * scale.item() + trans.detach().cpu().numpy()
127
+ faces = np.loadtxt(osp.join(self.smpl_data.tedra_dir,
128
+ 'tetrahedrons_neutral_adult.txt'),
129
+ dtype=np.int32) - 1
130
+
131
+ pad_v_num = int(8000 - verts.shape[0])
132
+ pad_f_num = int(25100 - faces.shape[0])
133
+
134
+ verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
135
+ mode='constant',
136
+ constant_values=0.0).astype(np.float32) * 0.5
137
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
138
+ mode='constant',
139
+ constant_values=0.0).astype(np.int32)
140
+
141
+ verts[:, 2] *= -1.0
142
+
143
+ voxel_dict = {
144
+ 'voxel_verts':
145
+ torch.from_numpy(verts).to(self.device).unsqueeze(0).float(),
146
+ 'voxel_faces':
147
+ torch.from_numpy(faces).to(self.device).unsqueeze(0).long(),
148
+ 'pad_v_num':
149
+ torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(),
150
+ 'pad_f_num':
151
+ torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long()
152
+ }
153
+
154
+ return voxel_dict
155
+
156
+ def __getitem__(self, index):
157
+
158
+ img_path = self.subject_list[index]
159
+ img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
160
+
161
+ if self.seg_dir is None:
162
+ img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image(
163
+ img_path, self.det, self.hps_type, 512, self.device)
164
+
165
+ data_dict = {
166
+ 'name': img_name,
167
+ 'image': img_icon.to(self.device).unsqueeze(0),
168
+ 'ori_image': img_ori,
169
+ 'mask': img_mask,
170
+ 'uncrop_param': uncrop_param
171
+ }
172
+
173
+ else:
174
+ img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image(
175
+ img_path, self.det, self.hps_type, 512, self.device,
176
+ seg_path=os.path.join(self.seg_dir, f'{img_name}.json'))
177
+ data_dict = {
178
+ 'name': img_name,
179
+ 'image': img_icon.to(self.device).unsqueeze(0),
180
+ 'ori_image': img_ori,
181
+ 'mask': img_mask,
182
+ 'uncrop_param': uncrop_param,
183
+ 'segmentations': segmentations
184
+ }
185
+
186
+ with torch.no_grad():
187
+ # import ipdb; ipdb.set_trace()
188
+ preds_dict = self.hps.forward(img_hps)
189
+
190
+ data_dict['smpl_faces'] = torch.Tensor(
191
+ self.faces.astype(np.int16)).long().unsqueeze(0).to(
192
+ self.device)
193
+
194
+ if self.hps_type == 'pymaf':
195
+ output = preds_dict['smpl_out'][-1]
196
+ scale, tranX, tranY = output['theta'][0, :3]
197
+ data_dict['betas'] = output['pred_shape']
198
+ data_dict['body_pose'] = output['rotmat'][:, 1:]
199
+ data_dict['global_orient'] = output['rotmat'][:, 0:1]
200
+ data_dict['smpl_verts'] = output['verts']
201
+
202
+ elif self.hps_type == 'pare':
203
+ data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:]
204
+ data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1]
205
+ data_dict['betas'] = preds_dict['pred_shape']
206
+ data_dict['smpl_verts'] = preds_dict['smpl_vertices']
207
+ scale, tranX, tranY = preds_dict['pred_cam'][0, :3]
208
+
209
+ elif self.hps_type == 'pixie':
210
+ data_dict.update(preds_dict)
211
+ data_dict['body_pose'] = preds_dict['body_pose']
212
+ data_dict['global_orient'] = preds_dict['global_pose']
213
+ data_dict['betas'] = preds_dict['shape']
214
+ data_dict['smpl_verts'] = preds_dict['vertices']
215
+ scale, tranX, tranY = preds_dict['cam'][0, :3]
216
+
217
+ elif self.hps_type == 'hybrik':
218
+ data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:]
219
+ data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]]
220
+ data_dict['betas'] = preds_dict['pred_shape']
221
+ data_dict['smpl_verts'] = preds_dict['pred_vertices']
222
+ scale, tranX, tranY = preds_dict['pred_camera'][0, :3]
223
+ scale = scale * 2
224
+
225
+ elif self.hps_type == 'bev':
226
+ data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[
227
+ [0], :10].to(self.device).float()
228
+ pred_thetas = batch_rodrigues(torch.from_numpy(
229
+ preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float()
230
+ data_dict['body_pose'] = pred_thetas[1:][None].to(self.device)
231
+ data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device)
232
+ data_dict['smpl_verts'] = torch.from_numpy(
233
+ preds_dict['verts'][[0]]).to(self.device).float()
234
+ tranX = preds_dict['cam_trans'][0, 0]
235
+ tranY = preds_dict['cam'][0, 1] + 0.28
236
+ scale = preds_dict['cam'][0, 0] * 1.1
237
+
238
+ data_dict['scale'] = scale
239
+ data_dict['trans'] = torch.tensor(
240
+ [tranX, tranY, 0.0]).to(self.device).float()
241
+
242
+ # data_dict info (key-shape):
243
+ # scale, tranX, tranY - tensor.float
244
+ # betas - [1,10] / [1, 200]
245
+ # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
246
+ # global_orient - [1, 1, 3, 3]
247
+ # smpl_verts - [1, 6890, 3] / [1, 10475, 3]
248
+
249
+ return data_dict
250
+
251
+ def render_normal(self, verts, faces):
252
+
253
+ # render optimized mesh (normal, T_normal, image [-1,1])
254
+ self.render.load_meshes(verts, faces)
255
+ return self.render.get_rgb_image()
256
+
257
+ def render_depth(self, verts, faces):
258
+
259
+ # render optimized mesh (normal, T_normal, image [-1,1])
260
+ self.render.load_meshes(verts, faces)
261
+ return self.render.get_depth_map(cam_ids=[0, 2])
262
+
263
+ def visualize_alignment(self, data):
264
+
265
+ import vedo
266
+ import trimesh
267
+
268
+ if self.hps_type != 'pixie':
269
+ smpl_out = self.smpl_model(betas=data['betas'],
270
+ body_pose=data['body_pose'],
271
+ global_orient=data['global_orient'],
272
+ pose2rot=False)
273
+ smpl_verts = (
274
+ (smpl_out.vertices + data['trans']) * data['scale']).detach().cpu().numpy()[0]
275
+ else:
276
+ smpl_verts, _, _ = self.smpl_model(shape_params=data['betas'],
277
+ expression_params=data['exp'],
278
+ body_pose=data['body_pose'],
279
+ global_pose=data['global_orient'],
280
+ jaw_pose=data['jaw_pose'],
281
+ left_hand_pose=data['left_hand_pose'],
282
+ right_hand_pose=data['right_hand_pose'])
283
+
284
+ smpl_verts = (
285
+ (smpl_verts + data['trans']) * data['scale']).detach().cpu().numpy()[0]
286
+
287
+ smpl_verts *= np.array([1.0, -1.0, -1.0])
288
+ faces = data['smpl_faces'][0].detach().cpu().numpy()
289
+
290
+ image_P = data['image']
291
+ image_F, image_B = self.render_normal(smpl_verts, faces)
292
+
293
+ # create plot
294
+ vp = vedo.Plotter(title="", size=(1500, 1500))
295
+ vis_list = []
296
+
297
+ image_F = (
298
+ 0.5 * (1.0 + image_F[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
299
+ image_B = (
300
+ 0.5 * (1.0 + image_B[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
301
+ image_P = (
302
+ 0.5 * (1.0 + image_P[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
303
+
304
+ vis_list.append(vedo.Picture(image_P*0.5+image_F *
305
+ 0.5).scale(2.0/image_P.shape[0]).pos(-1.0, -1.0, 1.0))
306
+ vis_list.append(vedo.Picture(image_F).scale(
307
+ 2.0/image_F.shape[0]).pos(-1.0, -1.0, -0.5))
308
+ vis_list.append(vedo.Picture(image_B).scale(
309
+ 2.0/image_B.shape[0]).pos(-1.0, -1.0, -1.0))
310
+
311
+ # create a mesh
312
+ mesh = trimesh.Trimesh(smpl_verts, faces, process=False)
313
+ mesh.visual.vertex_colors = [200, 200, 0]
314
+ vis_list.append(mesh)
315
+
316
+ vp.show(*vis_list, bg="white", axes=1, interactive=True)
317
+
318
+
319
+ if __name__ == '__main__':
320
+
321
+ cfg.merge_from_file("./configs/icon-filter.yaml")
322
+ cfg.merge_from_file('./lib/pymaf/configs/pymaf_config.yaml')
323
+
324
+ cfg_show_list = [
325
+ 'test_gpus', ['0'], 'mcube_res', 512, 'clean_mesh', False
326
+ ]
327
+
328
+ cfg.merge_from_list(cfg_show_list)
329
+ cfg.freeze()
330
+
331
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
332
+ device = torch.device('cuda:0')
333
+
334
+ dataset = TestDataset(
335
+ {
336
+ 'image_dir': "./examples",
337
+ 'has_det': True, # w/ or w/o detection
338
+ 'hps_type': 'bev' # pymaf/pare/pixie/hybrik/bev
339
+ }, device)
340
+
341
+ for i in range(len(dataset)):
342
+ dataset.visualize_alignment(dataset[i])
lib/dataset/__init__.py ADDED
File without changes