Spaces:
Runtime error
Runtime error
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +17 -0
- LICENSE +53 -0
- README.md +7 -8
- apps/ICON.py +762 -0
- apps/Normal.py +213 -0
- apps/__pycache__/app.cpython-38.pyc +0 -0
- apps/app.py +21 -0
- apps/infer.py +616 -0
- assets/garment_teaser.png +0 -0
- assets/intermediate_results.png +0 -0
- assets/teaser.gif +0 -0
- configs/icon-filter.yaml +25 -0
- configs/icon-nofilter.yaml +25 -0
- configs/pamir.yaml +24 -0
- configs/pifu.yaml +24 -0
- environment.yaml +16 -0
- examples/22097467bffc92d4a5c4246f7d4edb75.png +0 -0
- examples/44c0f84c957b6b9bdf77662af5bb7078.png +0 -0
- examples/5a6a25963db2f667441d5076972c207c.png +0 -0
- examples/8da7ceb94669c2f65cbd28022e1f9876.png +0 -0
- examples/923d65f767c85a42212cae13fba3750b.png +0 -0
- examples/959c4c726a69901ce71b93a9242ed900.png +0 -0
- examples/c9856a2bc31846d684cbb965457fad59.png +0 -0
- examples/e1e7622af7074a022f5d96dc16672517.png +0 -0
- examples/fb9d20fdb93750584390599478ecf86e.png +0 -0
- examples/segmentation/003883.jpg +0 -0
- examples/segmentation/003883.json +136 -0
- examples/segmentation/028009.jpg +0 -0
- examples/segmentation/028009.json +191 -0
- examples/slack_trial2-000150.png +0 -0
- fetch_data.sh +60 -0
- install.sh +16 -0
- lib/__init__.py +0 -0
- lib/common/__init__.py +0 -0
- lib/common/cloth_extraction.py +170 -0
- lib/common/config.py +218 -0
- lib/common/render.py +387 -0
- lib/common/render_utils.py +221 -0
- lib/common/seg3d_lossless.py +604 -0
- lib/common/seg3d_utils.py +392 -0
- lib/common/smpl_vert_segmentation.json +0 -0
- lib/common/train_util.py +597 -0
- lib/dataloader_demo.py +58 -0
- lib/dataset/Evaluator.py +264 -0
- lib/dataset/NormalDataset.py +212 -0
- lib/dataset/NormalModule.py +94 -0
- lib/dataset/PIFuDataModule.py +71 -0
- lib/dataset/PIFuDataset.py +662 -0
- lib/dataset/TestDataset.py +342 -0
- lib/dataset/__init__.py +0 -0
.gitignore
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/*/*
|
2 |
+
data/thuman*
|
3 |
+
!data/tbfo.ttf
|
4 |
+
__pycache__
|
5 |
+
debug/
|
6 |
+
log/
|
7 |
+
results/*
|
8 |
+
.vscode
|
9 |
+
!.gitignore
|
10 |
+
force_push.sh
|
11 |
+
.idea
|
12 |
+
smplx/
|
13 |
+
human_det/
|
14 |
+
kaolin/
|
15 |
+
neural_voxelization_layer/
|
16 |
+
pytorch3d/
|
17 |
+
force_push.sh
|
LICENSE
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
License
|
2 |
+
|
3 |
+
Software Copyright License for non-commercial scientific research purposes
|
4 |
+
Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the ICON model, data and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
|
5 |
+
|
6 |
+
Ownership / Licensees
|
7 |
+
The Software and the associated materials has been developed at the Max Planck Institute for Intelligent Systems (hereinafter "MPI"). Any copyright or patent right is owned by and proprietary material of the Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) hereinafter the “Licensor”.
|
8 |
+
|
9 |
+
License Grant
|
10 |
+
Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
|
11 |
+
|
12 |
+
• To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization;
|
13 |
+
• To use the Model & Software for the sole purpose of performing peaceful non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
|
14 |
+
• To modify, adapt, translate or create derivative works based upon the Model & Software.
|
15 |
+
|
16 |
+
Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
|
17 |
+
|
18 |
+
The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
|
19 |
+
|
20 |
+
No Distribution
|
21 |
+
The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
|
22 |
+
|
23 |
+
Disclaimer of Representations and Warranties
|
24 |
+
You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.
|
25 |
+
|
26 |
+
Limitation of Liability
|
27 |
+
Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
|
28 |
+
Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
|
29 |
+
Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders.
|
30 |
+
The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.
|
31 |
+
|
32 |
+
No Maintenance Services
|
33 |
+
You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.
|
34 |
+
|
35 |
+
Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
|
36 |
+
|
37 |
+
Publications using the Model & Software
|
38 |
+
You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.
|
39 |
+
|
40 |
+
Citation:
|
41 |
+
|
42 |
+
@inproceedings{xiu2022icon,
|
43 |
+
title={{ICON}: {I}mplicit {C}lothed humans {O}btained from {N}ormals},
|
44 |
+
author={Xiu, Yuliang and Yang, Jinlong and Tzionas, Dimitrios and Black, Michael J.},
|
45 |
+
booktitle={IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR)},
|
46 |
+
month = jun,
|
47 |
+
year={2022}
|
48 |
+
}
|
49 |
+
|
50 |
+
Commercial licensing opportunities
|
51 |
+
For commercial uses of the Model & Software, please send email to [email protected]
|
52 |
+
|
53 |
+
This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
|
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
---
|
2 |
title: ICON
|
3 |
-
|
|
|
4 |
colorFrom: indigo
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.1.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: ICON
|
3 |
+
metaTitle: "Image2Human by Yuliang Xiu"
|
4 |
+
emoji: 🤼
|
5 |
colorFrom: indigo
|
6 |
+
colorTo: yellow
|
7 |
sdk: gradio
|
8 |
sdk_version: 3.1.1
|
9 |
+
app_file: ./apps/app.py
|
10 |
+
pinned: true
|
11 |
+
python_version: 3.8
|
12 |
+
---
|
|
|
|
apps/ICON.py
ADDED
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
from lib.common.seg3d_lossless import Seg3dLossless
|
18 |
+
from lib.dataset.Evaluator import Evaluator
|
19 |
+
from lib.net import HGPIFuNet
|
20 |
+
from lib.common.train_util import *
|
21 |
+
from lib.renderer.gl.init_gl import initialize_GL_context
|
22 |
+
from lib.common.render import Render
|
23 |
+
from lib.dataset.mesh_util import SMPLX, update_mesh_shape_prior_losses, get_visibility
|
24 |
+
import warnings
|
25 |
+
import logging
|
26 |
+
import torch
|
27 |
+
import smplx
|
28 |
+
import numpy as np
|
29 |
+
from torch import nn
|
30 |
+
from skimage.transform import resize
|
31 |
+
import pytorch_lightning as pl
|
32 |
+
|
33 |
+
torch.backends.cudnn.benchmark = True
|
34 |
+
|
35 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
36 |
+
|
37 |
+
warnings.filterwarnings("ignore")
|
38 |
+
|
39 |
+
|
40 |
+
class ICON(pl.LightningModule):
|
41 |
+
def __init__(self, cfg):
|
42 |
+
super(ICON, self).__init__()
|
43 |
+
|
44 |
+
self.cfg = cfg
|
45 |
+
self.batch_size = self.cfg.batch_size
|
46 |
+
self.lr_G = self.cfg.lr_G
|
47 |
+
|
48 |
+
self.use_sdf = cfg.sdf
|
49 |
+
self.prior_type = cfg.net.prior_type
|
50 |
+
self.mcube_res = cfg.mcube_res
|
51 |
+
self.clean_mesh_flag = cfg.clean_mesh
|
52 |
+
|
53 |
+
self.netG = HGPIFuNet(
|
54 |
+
self.cfg,
|
55 |
+
self.cfg.projection_mode,
|
56 |
+
error_term=nn.SmoothL1Loss() if self.use_sdf else nn.MSELoss(),
|
57 |
+
)
|
58 |
+
|
59 |
+
# TODO: replace the renderer from opengl to pytorch3d
|
60 |
+
self.evaluator = Evaluator(
|
61 |
+
device=torch.device(f"cuda:{self.cfg.gpus[0]}"))
|
62 |
+
|
63 |
+
self.resolutions = (
|
64 |
+
np.logspace(
|
65 |
+
start=5,
|
66 |
+
stop=np.log2(self.mcube_res),
|
67 |
+
base=2,
|
68 |
+
num=int(np.log2(self.mcube_res) - 4),
|
69 |
+
endpoint=True,
|
70 |
+
)
|
71 |
+
+ 1.0
|
72 |
+
)
|
73 |
+
self.resolutions = self.resolutions.astype(np.int16).tolist()
|
74 |
+
|
75 |
+
self.icon_keys = ["smpl_verts", "smpl_faces", "smpl_vis", "smpl_cmap"]
|
76 |
+
self.pamir_keys = ["voxel_verts",
|
77 |
+
"voxel_faces", "pad_v_num", "pad_f_num"]
|
78 |
+
|
79 |
+
self.reconEngine = Seg3dLossless(
|
80 |
+
query_func=query_func,
|
81 |
+
b_min=[[-1.0, 1.0, -1.0]],
|
82 |
+
b_max=[[1.0, -1.0, 1.0]],
|
83 |
+
resolutions=self.resolutions,
|
84 |
+
align_corners=True,
|
85 |
+
balance_value=0.50,
|
86 |
+
device=torch.device(f"cuda:{self.cfg.test_gpus[0]}"),
|
87 |
+
visualize=False,
|
88 |
+
debug=False,
|
89 |
+
use_cuda_impl=False,
|
90 |
+
faster=True,
|
91 |
+
)
|
92 |
+
|
93 |
+
self.render = Render(
|
94 |
+
size=512, device=torch.device(f"cuda:{self.cfg.test_gpus[0]}")
|
95 |
+
)
|
96 |
+
self.smpl_data = SMPLX()
|
97 |
+
|
98 |
+
self.get_smpl_model = lambda smpl_type, gender, age, v_template: smplx.create(
|
99 |
+
self.smpl_data.model_dir,
|
100 |
+
kid_template_path=osp.join(
|
101 |
+
osp.realpath(self.smpl_data.model_dir),
|
102 |
+
f"{smpl_type}/{smpl_type}_kid_template.npy",
|
103 |
+
),
|
104 |
+
model_type=smpl_type,
|
105 |
+
gender=gender,
|
106 |
+
age=age,
|
107 |
+
v_template=v_template,
|
108 |
+
use_face_contour=False,
|
109 |
+
ext="pkl",
|
110 |
+
)
|
111 |
+
|
112 |
+
self.in_geo = [item[0] for item in cfg.net.in_geo]
|
113 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
114 |
+
self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
|
115 |
+
self.in_total = self.in_geo + self.in_nml
|
116 |
+
self.smpl_dim = cfg.net.smpl_dim
|
117 |
+
|
118 |
+
self.export_dir = None
|
119 |
+
self.result_eval = {}
|
120 |
+
|
121 |
+
def get_progress_bar_dict(self):
|
122 |
+
tqdm_dict = super().get_progress_bar_dict()
|
123 |
+
if "v_num" in tqdm_dict:
|
124 |
+
del tqdm_dict["v_num"]
|
125 |
+
return tqdm_dict
|
126 |
+
|
127 |
+
# Training related
|
128 |
+
def configure_optimizers(self):
|
129 |
+
|
130 |
+
# set optimizer
|
131 |
+
weight_decay = self.cfg.weight_decay
|
132 |
+
momentum = self.cfg.momentum
|
133 |
+
|
134 |
+
optim_params_G = [
|
135 |
+
{"params": self.netG.if_regressor.parameters(), "lr": self.lr_G}
|
136 |
+
]
|
137 |
+
|
138 |
+
if self.cfg.net.use_filter:
|
139 |
+
optim_params_G.append(
|
140 |
+
{"params": self.netG.F_filter.parameters(), "lr": self.lr_G}
|
141 |
+
)
|
142 |
+
|
143 |
+
if self.cfg.net.prior_type == "pamir":
|
144 |
+
optim_params_G.append(
|
145 |
+
{"params": self.netG.ve.parameters(), "lr": self.lr_G}
|
146 |
+
)
|
147 |
+
|
148 |
+
if self.cfg.optim == "Adadelta":
|
149 |
+
|
150 |
+
optimizer_G = torch.optim.Adadelta(
|
151 |
+
optim_params_G, lr=self.lr_G, weight_decay=weight_decay
|
152 |
+
)
|
153 |
+
|
154 |
+
elif self.cfg.optim == "Adam":
|
155 |
+
|
156 |
+
optimizer_G = torch.optim.Adam(
|
157 |
+
optim_params_G, lr=self.lr_G, weight_decay=weight_decay
|
158 |
+
)
|
159 |
+
|
160 |
+
elif self.cfg.optim == "RMSprop":
|
161 |
+
|
162 |
+
optimizer_G = torch.optim.RMSprop(
|
163 |
+
optim_params_G,
|
164 |
+
lr=self.lr_G,
|
165 |
+
weight_decay=weight_decay,
|
166 |
+
momentum=momentum,
|
167 |
+
)
|
168 |
+
|
169 |
+
else:
|
170 |
+
raise NotImplementedError
|
171 |
+
|
172 |
+
# set scheduler
|
173 |
+
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
|
174 |
+
optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
175 |
+
)
|
176 |
+
|
177 |
+
return [optimizer_G], [scheduler_G]
|
178 |
+
|
179 |
+
def training_step(self, batch, batch_idx):
|
180 |
+
|
181 |
+
if not self.cfg.fast_dev:
|
182 |
+
export_cfg(self.logger, self.cfg)
|
183 |
+
|
184 |
+
self.netG.train()
|
185 |
+
|
186 |
+
in_tensor_dict = {
|
187 |
+
"sample": batch["samples_geo"].permute(0, 2, 1),
|
188 |
+
"calib": batch["calib"],
|
189 |
+
"label": batch["labels_geo"].unsqueeze(1),
|
190 |
+
}
|
191 |
+
|
192 |
+
for name in self.in_total:
|
193 |
+
in_tensor_dict.update({name: batch[name]})
|
194 |
+
|
195 |
+
if self.prior_type == "icon":
|
196 |
+
for key in self.icon_keys:
|
197 |
+
in_tensor_dict.update({key: batch[key]})
|
198 |
+
elif self.prior_type == "pamir":
|
199 |
+
for key in self.pamir_keys:
|
200 |
+
in_tensor_dict.update({key: batch[key]})
|
201 |
+
else:
|
202 |
+
pass
|
203 |
+
|
204 |
+
preds_G, error_G = self.netG(in_tensor_dict)
|
205 |
+
|
206 |
+
acc, iou, prec, recall = self.evaluator.calc_acc(
|
207 |
+
preds_G.flatten(),
|
208 |
+
in_tensor_dict["label"].flatten(),
|
209 |
+
0.5,
|
210 |
+
use_sdf=self.cfg.sdf,
|
211 |
+
)
|
212 |
+
|
213 |
+
# metrics processing
|
214 |
+
metrics_log = {
|
215 |
+
"train_loss": error_G.item(),
|
216 |
+
"train_acc": acc.item(),
|
217 |
+
"train_iou": iou.item(),
|
218 |
+
"train_prec": prec.item(),
|
219 |
+
"train_recall": recall.item(),
|
220 |
+
}
|
221 |
+
|
222 |
+
tf_log = tf_log_convert(metrics_log)
|
223 |
+
bar_log = bar_log_convert(metrics_log)
|
224 |
+
|
225 |
+
if batch_idx % int(self.cfg.freq_show_train) == 0:
|
226 |
+
|
227 |
+
with torch.no_grad():
|
228 |
+
self.render_func(in_tensor_dict, dataset="train")
|
229 |
+
|
230 |
+
metrics_return = {
|
231 |
+
k.replace("train_", ""): torch.tensor(v) for k, v in metrics_log.items()
|
232 |
+
}
|
233 |
+
|
234 |
+
metrics_return.update(
|
235 |
+
{"loss": error_G, "log": tf_log, "progress_bar": bar_log})
|
236 |
+
|
237 |
+
return metrics_return
|
238 |
+
|
239 |
+
def training_epoch_end(self, outputs):
|
240 |
+
|
241 |
+
if [] in outputs:
|
242 |
+
outputs = outputs[0]
|
243 |
+
|
244 |
+
# metrics processing
|
245 |
+
metrics_log = {
|
246 |
+
"train_avgloss": batch_mean(outputs, "loss"),
|
247 |
+
"train_avgiou": batch_mean(outputs, "iou"),
|
248 |
+
"train_avgprec": batch_mean(outputs, "prec"),
|
249 |
+
"train_avgrecall": batch_mean(outputs, "recall"),
|
250 |
+
"train_avgacc": batch_mean(outputs, "acc"),
|
251 |
+
}
|
252 |
+
|
253 |
+
tf_log = tf_log_convert(metrics_log)
|
254 |
+
|
255 |
+
return {"log": tf_log}
|
256 |
+
|
257 |
+
def validation_step(self, batch, batch_idx):
|
258 |
+
|
259 |
+
self.netG.eval()
|
260 |
+
self.netG.training = False
|
261 |
+
|
262 |
+
in_tensor_dict = {
|
263 |
+
"sample": batch["samples_geo"].permute(0, 2, 1),
|
264 |
+
"calib": batch["calib"],
|
265 |
+
"label": batch["labels_geo"].unsqueeze(1),
|
266 |
+
}
|
267 |
+
|
268 |
+
for name in self.in_total:
|
269 |
+
in_tensor_dict.update({name: batch[name]})
|
270 |
+
|
271 |
+
if self.prior_type == "icon":
|
272 |
+
for key in self.icon_keys:
|
273 |
+
in_tensor_dict.update({key: batch[key]})
|
274 |
+
elif self.prior_type == "pamir":
|
275 |
+
for key in self.pamir_keys:
|
276 |
+
in_tensor_dict.update({key: batch[key]})
|
277 |
+
else:
|
278 |
+
pass
|
279 |
+
|
280 |
+
preds_G, error_G = self.netG(in_tensor_dict)
|
281 |
+
|
282 |
+
acc, iou, prec, recall = self.evaluator.calc_acc(
|
283 |
+
preds_G.flatten(),
|
284 |
+
in_tensor_dict["label"].flatten(),
|
285 |
+
0.5,
|
286 |
+
use_sdf=self.cfg.sdf,
|
287 |
+
)
|
288 |
+
|
289 |
+
if batch_idx % int(self.cfg.freq_show_val) == 0:
|
290 |
+
with torch.no_grad():
|
291 |
+
self.render_func(in_tensor_dict, dataset="val", idx=batch_idx)
|
292 |
+
|
293 |
+
metrics_return = {
|
294 |
+
"val_loss": error_G,
|
295 |
+
"val_acc": acc,
|
296 |
+
"val_iou": iou,
|
297 |
+
"val_prec": prec,
|
298 |
+
"val_recall": recall,
|
299 |
+
}
|
300 |
+
|
301 |
+
return metrics_return
|
302 |
+
|
303 |
+
def validation_epoch_end(self, outputs):
|
304 |
+
|
305 |
+
# metrics processing
|
306 |
+
metrics_log = {
|
307 |
+
"val_avgloss": batch_mean(outputs, "val_loss"),
|
308 |
+
"val_avgacc": batch_mean(outputs, "val_acc"),
|
309 |
+
"val_avgiou": batch_mean(outputs, "val_iou"),
|
310 |
+
"val_avgprec": batch_mean(outputs, "val_prec"),
|
311 |
+
"val_avgrecall": batch_mean(outputs, "val_recall"),
|
312 |
+
}
|
313 |
+
|
314 |
+
tf_log = tf_log_convert(metrics_log)
|
315 |
+
|
316 |
+
return {"log": tf_log}
|
317 |
+
|
318 |
+
def compute_vis_cmap(self, smpl_type, smpl_verts, smpl_faces):
|
319 |
+
|
320 |
+
(xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
|
321 |
+
smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
|
322 |
+
if smpl_type == "smpl":
|
323 |
+
smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
|
324 |
+
else:
|
325 |
+
smplx_ind = np.arange(smpl_vis.shape[0])
|
326 |
+
smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
|
327 |
+
|
328 |
+
return {
|
329 |
+
"smpl_vis": smpl_vis.unsqueeze(0).to(self.device),
|
330 |
+
"smpl_cmap": smpl_cmap.unsqueeze(0).to(self.device),
|
331 |
+
"smpl_verts": smpl_verts.unsqueeze(0),
|
332 |
+
}
|
333 |
+
|
334 |
+
@torch.enable_grad()
|
335 |
+
def optim_body(self, in_tensor_dict, batch):
|
336 |
+
|
337 |
+
smpl_model = self.get_smpl_model(
|
338 |
+
batch["type"][0], batch["gender"][0], batch["age"][0], None
|
339 |
+
).to(self.device)
|
340 |
+
in_tensor_dict["smpl_faces"] = (
|
341 |
+
torch.tensor(smpl_model.faces.astype(np.int))
|
342 |
+
.long()
|
343 |
+
.unsqueeze(0)
|
344 |
+
.to(self.device)
|
345 |
+
)
|
346 |
+
|
347 |
+
# The optimizer and variables
|
348 |
+
optimed_pose = torch.tensor(
|
349 |
+
batch["body_pose"][0], device=self.device, requires_grad=True
|
350 |
+
) # [1,23,3,3]
|
351 |
+
optimed_trans = torch.tensor(
|
352 |
+
batch["transl"][0], device=self.device, requires_grad=True
|
353 |
+
) # [3]
|
354 |
+
optimed_betas = torch.tensor(
|
355 |
+
batch["betas"][0], device=self.device, requires_grad=True
|
356 |
+
) # [1,10]
|
357 |
+
optimed_orient = torch.tensor(
|
358 |
+
batch["global_orient"][0], device=self.device, requires_grad=True
|
359 |
+
) # [1,1,3,3]
|
360 |
+
|
361 |
+
optimizer_smpl = torch.optim.SGD(
|
362 |
+
[optimed_pose, optimed_trans, optimed_betas, optimed_orient],
|
363 |
+
lr=1e-3,
|
364 |
+
momentum=0.9,
|
365 |
+
)
|
366 |
+
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
367 |
+
optimizer_smpl, mode="min", factor=0.5, verbose=0, min_lr=1e-5, patience=5
|
368 |
+
)
|
369 |
+
loop_smpl = range(50)
|
370 |
+
for i in loop_smpl:
|
371 |
+
|
372 |
+
optimizer_smpl.zero_grad()
|
373 |
+
|
374 |
+
# prior_loss, optimed_pose = dataset.vposer_prior(optimed_pose)
|
375 |
+
smpl_out = smpl_model(
|
376 |
+
betas=optimed_betas,
|
377 |
+
body_pose=optimed_pose,
|
378 |
+
global_orient=optimed_orient,
|
379 |
+
transl=optimed_trans,
|
380 |
+
return_verts=True,
|
381 |
+
)
|
382 |
+
|
383 |
+
smpl_verts = smpl_out.vertices[0] * 100.0
|
384 |
+
smpl_verts = projection(
|
385 |
+
smpl_verts, batch["calib"][0], format="tensor")
|
386 |
+
smpl_verts[:, 1] *= -1
|
387 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
388 |
+
self.render.load_meshes(
|
389 |
+
smpl_verts, in_tensor_dict["smpl_faces"])
|
390 |
+
(
|
391 |
+
in_tensor_dict["T_normal_F"],
|
392 |
+
in_tensor_dict["T_normal_B"],
|
393 |
+
) = self.render.get_rgb_image()
|
394 |
+
|
395 |
+
T_mask_F, T_mask_B = self.render.get_silhouette_image()
|
396 |
+
|
397 |
+
with torch.no_grad():
|
398 |
+
(
|
399 |
+
in_tensor_dict["normal_F"],
|
400 |
+
in_tensor_dict["normal_B"],
|
401 |
+
) = self.netG.normal_filter(in_tensor_dict)
|
402 |
+
|
403 |
+
# mask = torch.abs(in_tensor['T_normal_F']).sum(dim=0, keepdims=True) > 0.0
|
404 |
+
diff_F_smpl = torch.abs(
|
405 |
+
in_tensor_dict["T_normal_F"] - in_tensor_dict["normal_F"]
|
406 |
+
)
|
407 |
+
diff_B_smpl = torch.abs(
|
408 |
+
in_tensor_dict["T_normal_B"] - in_tensor_dict["normal_B"]
|
409 |
+
)
|
410 |
+
loss = (diff_F_smpl + diff_B_smpl).mean()
|
411 |
+
|
412 |
+
# silhouette loss
|
413 |
+
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
|
414 |
+
gt_arr = torch.cat(
|
415 |
+
[in_tensor_dict["normal_F"][0], in_tensor_dict["normal_B"][0]], dim=2
|
416 |
+
).permute(1, 2, 0)
|
417 |
+
gt_arr = ((gt_arr + 1.0) * 0.5).to(self.device)
|
418 |
+
bg_color = (
|
419 |
+
torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
|
420 |
+
0).unsqueeze(0).to(self.device)
|
421 |
+
)
|
422 |
+
gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
|
423 |
+
loss += torch.abs(smpl_arr - gt_arr).mean()
|
424 |
+
|
425 |
+
# Image.fromarray(((in_tensor_dict['T_normal_F'][0].permute(1,2,0)+1.0)*0.5*255.0).detach().cpu().numpy().astype(np.uint8)).show()
|
426 |
+
|
427 |
+
# loop_smpl.set_description(f"smpl = {loss:.3f}")
|
428 |
+
|
429 |
+
loss.backward(retain_graph=True)
|
430 |
+
optimizer_smpl.step()
|
431 |
+
scheduler_smpl.step(loss)
|
432 |
+
in_tensor_dict["smpl_verts"] = smpl_verts.unsqueeze(0)
|
433 |
+
|
434 |
+
in_tensor_dict.update(
|
435 |
+
self.compute_vis_cmap(
|
436 |
+
batch["type"][0],
|
437 |
+
in_tensor_dict["smpl_verts"][0],
|
438 |
+
in_tensor_dict["smpl_faces"][0],
|
439 |
+
)
|
440 |
+
)
|
441 |
+
|
442 |
+
features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
|
443 |
+
|
444 |
+
return features, inter, in_tensor_dict
|
445 |
+
|
446 |
+
@torch.enable_grad()
|
447 |
+
def optim_cloth(self, verts_pr, faces_pr, inter):
|
448 |
+
|
449 |
+
# convert from GT to SDF
|
450 |
+
verts_pr -= (self.resolutions[-1] - 1) / 2.0
|
451 |
+
verts_pr /= (self.resolutions[-1] - 1) / 2.0
|
452 |
+
|
453 |
+
losses = {
|
454 |
+
"cloth": {"weight": 5.0, "value": 0.0},
|
455 |
+
"edge": {"weight": 100.0, "value": 0.0},
|
456 |
+
"normal": {"weight": 0.2, "value": 0.0},
|
457 |
+
"laplacian": {"weight": 100.0, "value": 0.0},
|
458 |
+
"smpl": {"weight": 1.0, "value": 0.0},
|
459 |
+
"deform": {"weight": 20.0, "value": 0.0},
|
460 |
+
}
|
461 |
+
|
462 |
+
deform_verts = torch.full(
|
463 |
+
verts_pr.shape, 0.0, device=self.device, requires_grad=True
|
464 |
+
)
|
465 |
+
optimizer_cloth = torch.optim.SGD(
|
466 |
+
[deform_verts], lr=1e-1, momentum=0.9)
|
467 |
+
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
468 |
+
optimizer_cloth, mode="min", factor=0.1, verbose=0, min_lr=1e-3, patience=5
|
469 |
+
)
|
470 |
+
# cloth optimization
|
471 |
+
loop_cloth = range(100)
|
472 |
+
|
473 |
+
for i in loop_cloth:
|
474 |
+
|
475 |
+
optimizer_cloth.zero_grad()
|
476 |
+
|
477 |
+
self.render.load_meshes(
|
478 |
+
verts_pr.unsqueeze(0).to(self.device),
|
479 |
+
faces_pr.unsqueeze(0).to(self.device).long(),
|
480 |
+
deform_verts,
|
481 |
+
)
|
482 |
+
P_normal_F, P_normal_B = self.render.get_rgb_image()
|
483 |
+
|
484 |
+
update_mesh_shape_prior_losses(self.render.mesh, losses)
|
485 |
+
diff_F_cloth = torch.abs(P_normal_F[0] - inter[:3])
|
486 |
+
diff_B_cloth = torch.abs(P_normal_B[0] - inter[3:])
|
487 |
+
losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
|
488 |
+
losses["deform"]["value"] = torch.topk(
|
489 |
+
torch.abs(deform_verts.flatten()), 30
|
490 |
+
)[0].mean()
|
491 |
+
|
492 |
+
# Weighted sum of the losses
|
493 |
+
cloth_loss = torch.tensor(0.0, device=self.device)
|
494 |
+
pbar_desc = ""
|
495 |
+
|
496 |
+
for k in losses.keys():
|
497 |
+
if k != "smpl":
|
498 |
+
cloth_loss_per_cls = losses[k]["value"] * \
|
499 |
+
losses[k]["weight"]
|
500 |
+
pbar_desc += f"{k}: {cloth_loss_per_cls:.3f} | "
|
501 |
+
cloth_loss += cloth_loss_per_cls
|
502 |
+
|
503 |
+
# loop_cloth.set_description(pbar_desc)
|
504 |
+
cloth_loss.backward(retain_graph=True)
|
505 |
+
optimizer_cloth.step()
|
506 |
+
scheduler_cloth.step(cloth_loss)
|
507 |
+
|
508 |
+
# convert from GT to SDF
|
509 |
+
deform_verts = deform_verts.flatten().detach()
|
510 |
+
deform_verts[torch.topk(torch.abs(deform_verts), 30)[
|
511 |
+
1]] = deform_verts.mean()
|
512 |
+
deform_verts = deform_verts.view(-1, 3).cpu()
|
513 |
+
|
514 |
+
verts_pr += deform_verts
|
515 |
+
verts_pr *= (self.resolutions[-1] - 1) / 2.0
|
516 |
+
verts_pr += (self.resolutions[-1] - 1) / 2.0
|
517 |
+
|
518 |
+
return verts_pr
|
519 |
+
|
520 |
+
def test_step(self, batch, batch_idx):
|
521 |
+
|
522 |
+
# dict_keys(['dataset', 'subject', 'rotation', 'scale', 'calib',
|
523 |
+
# 'normal_F', 'normal_B', 'image', 'T_normal_F', 'T_normal_B',
|
524 |
+
# 'z-trans', 'verts', 'faces', 'samples_geo', 'labels_geo',
|
525 |
+
# 'smpl_verts', 'smpl_faces', 'smpl_vis', 'smpl_cmap', 'pts_signs',
|
526 |
+
# 'type', 'gender', 'age', 'body_pose', 'global_orient', 'betas', 'transl'])
|
527 |
+
|
528 |
+
if self.evaluator._normal_render is None:
|
529 |
+
self.evaluator.init_gl()
|
530 |
+
|
531 |
+
self.netG.eval()
|
532 |
+
self.netG.training = False
|
533 |
+
in_tensor_dict = {}
|
534 |
+
|
535 |
+
# export paths
|
536 |
+
mesh_name = batch["subject"][0]
|
537 |
+
mesh_rot = batch["rotation"][0].item()
|
538 |
+
ckpt_dir = self.cfg.name
|
539 |
+
|
540 |
+
for kid, key in enumerate(self.cfg.dataset.noise_type):
|
541 |
+
ckpt_dir += f"_{key}_{self.cfg.dataset.noise_scale[kid]}"
|
542 |
+
|
543 |
+
if self.cfg.optim_cloth:
|
544 |
+
ckpt_dir += "_optim_cloth"
|
545 |
+
if self.cfg.optim_body:
|
546 |
+
ckpt_dir += "_optim_body"
|
547 |
+
|
548 |
+
self.export_dir = osp.join(self.cfg.results_path, ckpt_dir, mesh_name)
|
549 |
+
os.makedirs(self.export_dir, exist_ok=True)
|
550 |
+
|
551 |
+
for name in self.in_total:
|
552 |
+
if name in batch.keys():
|
553 |
+
in_tensor_dict.update({name: batch[name]})
|
554 |
+
|
555 |
+
# update the new T_normal_F/B
|
556 |
+
in_tensor_dict.update(
|
557 |
+
self.evaluator.render_normal(
|
558 |
+
batch["smpl_verts"], batch["smpl_faces"])
|
559 |
+
)
|
560 |
+
|
561 |
+
# update the new smpl_vis
|
562 |
+
(xy, z) = batch["smpl_verts"][0].split([2, 1], dim=1)
|
563 |
+
smpl_vis = get_visibility(
|
564 |
+
xy,
|
565 |
+
z,
|
566 |
+
torch.as_tensor(self.smpl_data.faces).type_as(
|
567 |
+
batch["smpl_verts"]).long(),
|
568 |
+
)
|
569 |
+
in_tensor_dict.update({"smpl_vis": smpl_vis.unsqueeze(0)})
|
570 |
+
|
571 |
+
if self.prior_type == "icon":
|
572 |
+
for key in self.icon_keys:
|
573 |
+
in_tensor_dict.update({key: batch[key]})
|
574 |
+
elif self.prior_type == "pamir":
|
575 |
+
for key in self.pamir_keys:
|
576 |
+
in_tensor_dict.update({key: batch[key]})
|
577 |
+
else:
|
578 |
+
pass
|
579 |
+
|
580 |
+
with torch.no_grad():
|
581 |
+
if self.cfg.optim_body:
|
582 |
+
features, inter, in_tensor_dict = self.optim_body(
|
583 |
+
in_tensor_dict, batch)
|
584 |
+
else:
|
585 |
+
features, inter = self.netG.filter(
|
586 |
+
in_tensor_dict, return_inter=True)
|
587 |
+
sdf = self.reconEngine(
|
588 |
+
opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
|
589 |
+
)
|
590 |
+
|
591 |
+
# save inter results
|
592 |
+
image = (
|
593 |
+
in_tensor_dict["image"][0].permute(
|
594 |
+
1, 2, 0).detach().cpu().numpy() + 1.0
|
595 |
+
) * 0.5
|
596 |
+
smpl_F = (
|
597 |
+
in_tensor_dict["T_normal_F"][0].permute(
|
598 |
+
1, 2, 0).detach().cpu().numpy()
|
599 |
+
+ 1.0
|
600 |
+
) * 0.5
|
601 |
+
smpl_B = (
|
602 |
+
in_tensor_dict["T_normal_B"][0].permute(
|
603 |
+
1, 2, 0).detach().cpu().numpy()
|
604 |
+
+ 1.0
|
605 |
+
) * 0.5
|
606 |
+
image_inter = np.concatenate(
|
607 |
+
self.tensor2image(512, inter[0]) + [smpl_F, smpl_B, image], axis=1
|
608 |
+
)
|
609 |
+
Image.fromarray((image_inter * 255.0).astype(np.uint8)).save(
|
610 |
+
osp.join(self.export_dir, f"{mesh_rot}_inter.png")
|
611 |
+
)
|
612 |
+
|
613 |
+
verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
|
614 |
+
|
615 |
+
if self.clean_mesh_flag:
|
616 |
+
verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
|
617 |
+
|
618 |
+
if self.cfg.optim_cloth:
|
619 |
+
verts_pr = self.optim_cloth(verts_pr, faces_pr, inter[0].detach())
|
620 |
+
|
621 |
+
verts_gt = batch["verts"][0]
|
622 |
+
faces_gt = batch["faces"][0]
|
623 |
+
|
624 |
+
self.result_eval.update(
|
625 |
+
{
|
626 |
+
"verts_gt": verts_gt,
|
627 |
+
"faces_gt": faces_gt,
|
628 |
+
"verts_pr": verts_pr,
|
629 |
+
"faces_pr": faces_pr,
|
630 |
+
"recon_size": (self.resolutions[-1] - 1.0),
|
631 |
+
"calib": batch["calib"][0],
|
632 |
+
}
|
633 |
+
)
|
634 |
+
|
635 |
+
self.evaluator.set_mesh(self.result_eval, scale_factor=1.0)
|
636 |
+
self.evaluator.space_transfer()
|
637 |
+
|
638 |
+
chamfer, p2s = self.evaluator.calculate_chamfer_p2s(
|
639 |
+
sampled_points=1000)
|
640 |
+
normal_consist = self.evaluator.calculate_normal_consist(
|
641 |
+
save_demo_img=osp.join(self.export_dir, f"{mesh_rot}_nc.png")
|
642 |
+
)
|
643 |
+
|
644 |
+
test_log = {"chamfer": chamfer, "p2s": p2s, "NC": normal_consist}
|
645 |
+
|
646 |
+
return test_log
|
647 |
+
|
648 |
+
def test_epoch_end(self, outputs):
|
649 |
+
|
650 |
+
# make_test_gif("/".join(self.export_dir.split("/")[:-2]))
|
651 |
+
|
652 |
+
accu_outputs = accumulate(
|
653 |
+
outputs,
|
654 |
+
rot_num=3,
|
655 |
+
split={
|
656 |
+
"thuman2": (0, 5),
|
657 |
+
},
|
658 |
+
)
|
659 |
+
|
660 |
+
print(colored(self.cfg.name, "green"))
|
661 |
+
print(colored(self.cfg.dataset.noise_scale, "green"))
|
662 |
+
|
663 |
+
self.logger.experiment.add_hparams(
|
664 |
+
hparam_dict={"lr_G": self.lr_G, "bsize": self.batch_size},
|
665 |
+
metric_dict=accu_outputs,
|
666 |
+
)
|
667 |
+
|
668 |
+
np.save(
|
669 |
+
osp.join(self.export_dir, "../test_results.npy"),
|
670 |
+
accu_outputs,
|
671 |
+
allow_pickle=True,
|
672 |
+
)
|
673 |
+
|
674 |
+
return accu_outputs
|
675 |
+
|
676 |
+
def tensor2image(self, height, inter):
|
677 |
+
|
678 |
+
all = []
|
679 |
+
for dim in self.in_geo_dim:
|
680 |
+
img = resize(
|
681 |
+
np.tile(
|
682 |
+
((inter[:dim].cpu().numpy() + 1.0) /
|
683 |
+
2.0).transpose(1, 2, 0),
|
684 |
+
(1, 1, int(3 / dim)),
|
685 |
+
),
|
686 |
+
(height, height),
|
687 |
+
anti_aliasing=True,
|
688 |
+
)
|
689 |
+
|
690 |
+
all.append(img)
|
691 |
+
inter = inter[dim:]
|
692 |
+
|
693 |
+
return all
|
694 |
+
|
695 |
+
def render_func(self, in_tensor_dict, dataset="title", idx=0):
|
696 |
+
|
697 |
+
for name in in_tensor_dict.keys():
|
698 |
+
in_tensor_dict[name] = in_tensor_dict[name][0:1]
|
699 |
+
|
700 |
+
self.netG.eval()
|
701 |
+
features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
|
702 |
+
sdf = self.reconEngine(
|
703 |
+
opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
|
704 |
+
)
|
705 |
+
|
706 |
+
if sdf is not None:
|
707 |
+
render = self.reconEngine.display(sdf)
|
708 |
+
|
709 |
+
image_pred = np.flip(render[:, :, ::-1], axis=0)
|
710 |
+
height = image_pred.shape[0]
|
711 |
+
|
712 |
+
image_gt = resize(
|
713 |
+
((in_tensor_dict["image"].cpu().numpy()[0] + 1.0) / 2.0).transpose(
|
714 |
+
1, 2, 0
|
715 |
+
),
|
716 |
+
(height, height),
|
717 |
+
anti_aliasing=True,
|
718 |
+
)
|
719 |
+
image_inter = self.tensor2image(height, inter[0])
|
720 |
+
image = np.concatenate(
|
721 |
+
[image_pred, image_gt] + image_inter, axis=1)
|
722 |
+
|
723 |
+
step_id = self.global_step if dataset == "train" else self.global_step + idx
|
724 |
+
self.logger.experiment.add_image(
|
725 |
+
tag=f"Occupancy-{dataset}/{step_id}",
|
726 |
+
img_tensor=image.transpose(2, 0, 1),
|
727 |
+
global_step=step_id,
|
728 |
+
)
|
729 |
+
|
730 |
+
def test_single(self, batch):
|
731 |
+
|
732 |
+
self.netG.eval()
|
733 |
+
self.netG.training = False
|
734 |
+
in_tensor_dict = {}
|
735 |
+
|
736 |
+
for name in self.in_total:
|
737 |
+
if name in batch.keys():
|
738 |
+
in_tensor_dict.update({name: batch[name]})
|
739 |
+
|
740 |
+
if self.prior_type == "icon":
|
741 |
+
for key in self.icon_keys:
|
742 |
+
in_tensor_dict.update({key: batch[key]})
|
743 |
+
elif self.prior_type == "pamir":
|
744 |
+
for key in self.pamir_keys:
|
745 |
+
in_tensor_dict.update({key: batch[key]})
|
746 |
+
else:
|
747 |
+
pass
|
748 |
+
|
749 |
+
features, inter = self.netG.filter(in_tensor_dict, return_inter=True)
|
750 |
+
sdf = self.reconEngine(
|
751 |
+
opt=self.cfg, netG=self.netG, features=features, proj_matrix=None
|
752 |
+
)
|
753 |
+
|
754 |
+
verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
|
755 |
+
|
756 |
+
if self.clean_mesh_flag:
|
757 |
+
verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
|
758 |
+
|
759 |
+
verts_pr -= (self.resolutions[-1] - 1) / 2.0
|
760 |
+
verts_pr /= (self.resolutions[-1] - 1) / 2.0
|
761 |
+
|
762 |
+
return verts_pr, faces_pr, inter
|
apps/Normal.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lib.net import NormalNet
|
2 |
+
from lib.common.train_util import *
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from torch import nn
|
7 |
+
from skimage.transform import resize
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
|
10 |
+
torch.backends.cudnn.benchmark = True
|
11 |
+
|
12 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
13 |
+
|
14 |
+
|
15 |
+
class Normal(pl.LightningModule):
|
16 |
+
def __init__(self, cfg):
|
17 |
+
super(Normal, self).__init__()
|
18 |
+
self.cfg = cfg
|
19 |
+
self.batch_size = self.cfg.batch_size
|
20 |
+
self.lr_N = self.cfg.lr_N
|
21 |
+
|
22 |
+
self.schedulers = []
|
23 |
+
|
24 |
+
self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss())
|
25 |
+
|
26 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
27 |
+
|
28 |
+
def get_progress_bar_dict(self):
|
29 |
+
tqdm_dict = super().get_progress_bar_dict()
|
30 |
+
if "v_num" in tqdm_dict:
|
31 |
+
del tqdm_dict["v_num"]
|
32 |
+
return tqdm_dict
|
33 |
+
|
34 |
+
# Training related
|
35 |
+
def configure_optimizers(self):
|
36 |
+
|
37 |
+
# set optimizer
|
38 |
+
weight_decay = self.cfg.weight_decay
|
39 |
+
momentum = self.cfg.momentum
|
40 |
+
|
41 |
+
optim_params_N_F = [
|
42 |
+
{"params": self.netG.netF.parameters(), "lr": self.lr_N}]
|
43 |
+
optim_params_N_B = [
|
44 |
+
{"params": self.netG.netB.parameters(), "lr": self.lr_N}]
|
45 |
+
|
46 |
+
optimizer_N_F = torch.optim.Adam(
|
47 |
+
optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay
|
48 |
+
)
|
49 |
+
|
50 |
+
optimizer_N_B = torch.optim.Adam(
|
51 |
+
optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay
|
52 |
+
)
|
53 |
+
|
54 |
+
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
|
55 |
+
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
56 |
+
)
|
57 |
+
|
58 |
+
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
|
59 |
+
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
60 |
+
)
|
61 |
+
|
62 |
+
self.schedulers = [scheduler_N_F, scheduler_N_B]
|
63 |
+
optims = [optimizer_N_F, optimizer_N_B]
|
64 |
+
|
65 |
+
return optims, self.schedulers
|
66 |
+
|
67 |
+
def render_func(self, render_tensor):
|
68 |
+
|
69 |
+
height = render_tensor["image"].shape[2]
|
70 |
+
result_list = []
|
71 |
+
|
72 |
+
for name in render_tensor.keys():
|
73 |
+
result_list.append(
|
74 |
+
resize(
|
75 |
+
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(
|
76 |
+
1, 2, 0
|
77 |
+
),
|
78 |
+
(height, height),
|
79 |
+
anti_aliasing=True,
|
80 |
+
)
|
81 |
+
)
|
82 |
+
result_array = np.concatenate(result_list, axis=1)
|
83 |
+
|
84 |
+
return result_array
|
85 |
+
|
86 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
87 |
+
|
88 |
+
export_cfg(self.logger, self.cfg)
|
89 |
+
|
90 |
+
# retrieve the data
|
91 |
+
in_tensor = {}
|
92 |
+
for name in self.in_nml:
|
93 |
+
in_tensor[name] = batch[name]
|
94 |
+
|
95 |
+
FB_tensor = {"normal_F": batch["normal_F"],
|
96 |
+
"normal_B": batch["normal_B"]}
|
97 |
+
|
98 |
+
self.netG.train()
|
99 |
+
|
100 |
+
preds_F, preds_B = self.netG(in_tensor)
|
101 |
+
error_NF, error_NB = self.netG.get_norm_error(
|
102 |
+
preds_F, preds_B, FB_tensor)
|
103 |
+
|
104 |
+
(opt_nf, opt_nb) = self.optimizers()
|
105 |
+
|
106 |
+
opt_nf.zero_grad()
|
107 |
+
opt_nb.zero_grad()
|
108 |
+
|
109 |
+
self.manual_backward(error_NF, opt_nf)
|
110 |
+
self.manual_backward(error_NB, opt_nb)
|
111 |
+
|
112 |
+
opt_nf.step()
|
113 |
+
opt_nb.step()
|
114 |
+
|
115 |
+
if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0:
|
116 |
+
|
117 |
+
self.netG.eval()
|
118 |
+
with torch.no_grad():
|
119 |
+
nmlF, nmlB = self.netG(in_tensor)
|
120 |
+
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
|
121 |
+
result_array = self.render_func(in_tensor)
|
122 |
+
|
123 |
+
self.logger.experiment.add_image(
|
124 |
+
tag=f"Normal-train/{self.global_step}",
|
125 |
+
img_tensor=result_array.transpose(2, 0, 1),
|
126 |
+
global_step=self.global_step,
|
127 |
+
)
|
128 |
+
|
129 |
+
# metrics processing
|
130 |
+
metrics_log = {
|
131 |
+
"train_loss-NF": error_NF.item(),
|
132 |
+
"train_loss-NB": error_NB.item(),
|
133 |
+
}
|
134 |
+
|
135 |
+
tf_log = tf_log_convert(metrics_log)
|
136 |
+
bar_log = bar_log_convert(metrics_log)
|
137 |
+
|
138 |
+
return {
|
139 |
+
"loss": error_NF + error_NB,
|
140 |
+
"loss-NF": error_NF,
|
141 |
+
"loss-NB": error_NB,
|
142 |
+
"log": tf_log,
|
143 |
+
"progress_bar": bar_log,
|
144 |
+
}
|
145 |
+
|
146 |
+
def training_epoch_end(self, outputs):
|
147 |
+
|
148 |
+
if [] in outputs:
|
149 |
+
outputs = outputs[0]
|
150 |
+
|
151 |
+
# metrics processing
|
152 |
+
metrics_log = {
|
153 |
+
"train_avgloss": batch_mean(outputs, "loss"),
|
154 |
+
"train_avgloss-NF": batch_mean(outputs, "loss-NF"),
|
155 |
+
"train_avgloss-NB": batch_mean(outputs, "loss-NB"),
|
156 |
+
}
|
157 |
+
|
158 |
+
tf_log = tf_log_convert(metrics_log)
|
159 |
+
|
160 |
+
tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0]
|
161 |
+
tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0]
|
162 |
+
|
163 |
+
return {"log": tf_log}
|
164 |
+
|
165 |
+
def validation_step(self, batch, batch_idx):
|
166 |
+
|
167 |
+
# retrieve the data
|
168 |
+
in_tensor = {}
|
169 |
+
for name in self.in_nml:
|
170 |
+
in_tensor[name] = batch[name]
|
171 |
+
|
172 |
+
FB_tensor = {"normal_F": batch["normal_F"],
|
173 |
+
"normal_B": batch["normal_B"]}
|
174 |
+
|
175 |
+
self.netG.train()
|
176 |
+
|
177 |
+
preds_F, preds_B = self.netG(in_tensor)
|
178 |
+
error_NF, error_NB = self.netG.get_norm_error(
|
179 |
+
preds_F, preds_B, FB_tensor)
|
180 |
+
|
181 |
+
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or (
|
182 |
+
batch_idx == 0
|
183 |
+
):
|
184 |
+
|
185 |
+
with torch.no_grad():
|
186 |
+
nmlF, nmlB = self.netG(in_tensor)
|
187 |
+
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB})
|
188 |
+
result_array = self.render_func(in_tensor)
|
189 |
+
|
190 |
+
self.logger.experiment.add_image(
|
191 |
+
tag=f"Normal-val/{self.global_step}",
|
192 |
+
img_tensor=result_array.transpose(2, 0, 1),
|
193 |
+
global_step=self.global_step,
|
194 |
+
)
|
195 |
+
|
196 |
+
return {
|
197 |
+
"val_loss": error_NF + error_NB,
|
198 |
+
"val_loss-NF": error_NF,
|
199 |
+
"val_loss-NB": error_NB,
|
200 |
+
}
|
201 |
+
|
202 |
+
def validation_epoch_end(self, outputs):
|
203 |
+
|
204 |
+
# metrics processing
|
205 |
+
metrics_log = {
|
206 |
+
"val_avgloss": batch_mean(outputs, "val_loss"),
|
207 |
+
"val_avgloss-NF": batch_mean(outputs, "val_loss-NF"),
|
208 |
+
"val_avgloss-NB": batch_mean(outputs, "val_loss-NB"),
|
209 |
+
}
|
210 |
+
|
211 |
+
tf_log = tf_log_convert(metrics_log)
|
212 |
+
|
213 |
+
return {"log": tf_log}
|
apps/__pycache__/app.cpython-38.pyc
ADDED
Binary file (555 Bytes). View file
|
|
apps/app.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# install
|
2 |
+
|
3 |
+
import os
|
4 |
+
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
|
5 |
+
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
6 |
+
try:
|
7 |
+
os.system("bash install.sh")
|
8 |
+
except Exception as e:
|
9 |
+
print(e)
|
10 |
+
|
11 |
+
|
12 |
+
# running
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
|
16 |
+
def image_classifier(inp):
|
17 |
+
return {'cat': 0.3, 'dog': 0.7}
|
18 |
+
|
19 |
+
demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label")
|
20 |
+
demo.launch(auth=("[email protected]", "icon_2022"),
|
21 |
+
auth_message="Register at icon.is.tue.mpg.de/download to get the username and password.")
|
apps/infer.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
import logging
|
18 |
+
from lib.common.render import query_color, image2vid
|
19 |
+
from lib.common.config import cfg
|
20 |
+
from lib.common.cloth_extraction import extract_cloth
|
21 |
+
from lib.dataset.mesh_util import (
|
22 |
+
load_checkpoint,
|
23 |
+
update_mesh_shape_prior_losses,
|
24 |
+
get_optim_grid_image,
|
25 |
+
blend_rgb_norm,
|
26 |
+
unwrap,
|
27 |
+
remesh,
|
28 |
+
tensor2variable,
|
29 |
+
normal_loss
|
30 |
+
)
|
31 |
+
|
32 |
+
from lib.dataset.TestDataset import TestDataset
|
33 |
+
from lib.net.local_affine import LocalAffine
|
34 |
+
from pytorch3d.structures import Meshes
|
35 |
+
from apps.ICON import ICON
|
36 |
+
|
37 |
+
import os
|
38 |
+
from termcolor import colored
|
39 |
+
import argparse
|
40 |
+
import numpy as np
|
41 |
+
from PIL import Image
|
42 |
+
import trimesh
|
43 |
+
import pickle
|
44 |
+
import numpy as np
|
45 |
+
|
46 |
+
import torch
|
47 |
+
torch.backends.cudnn.benchmark = True
|
48 |
+
|
49 |
+
logging.getLogger("trimesh").setLevel(logging.ERROR)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
|
54 |
+
# loading cfg file
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
|
57 |
+
parser.add_argument("-gpu", "--gpu_device", type=int, default=0)
|
58 |
+
parser.add_argument("-colab", action="store_true")
|
59 |
+
parser.add_argument("-loop_smpl", "--loop_smpl", type=int, default=100)
|
60 |
+
parser.add_argument("-patience", "--patience", type=int, default=5)
|
61 |
+
parser.add_argument("-vis_freq", "--vis_freq", type=int, default=10)
|
62 |
+
parser.add_argument("-loop_cloth", "--loop_cloth", type=int, default=200)
|
63 |
+
parser.add_argument("-hps_type", "--hps_type", type=str, default="pymaf")
|
64 |
+
parser.add_argument("-export_video", action="store_true")
|
65 |
+
parser.add_argument("-in_dir", "--in_dir", type=str, default="./examples")
|
66 |
+
parser.add_argument("-out_dir", "--out_dir",
|
67 |
+
type=str, default="./results")
|
68 |
+
parser.add_argument('-seg_dir', '--seg_dir', type=str, default=None)
|
69 |
+
parser.add_argument(
|
70 |
+
"-cfg", "--config", type=str, default="./configs/icon-filter.yaml"
|
71 |
+
)
|
72 |
+
|
73 |
+
args = parser.parse_args()
|
74 |
+
|
75 |
+
# cfg read and merge
|
76 |
+
cfg.merge_from_file(args.config)
|
77 |
+
cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml")
|
78 |
+
|
79 |
+
cfg_show_list = [
|
80 |
+
"test_gpus",
|
81 |
+
[args.gpu_device],
|
82 |
+
"mcube_res",
|
83 |
+
256,
|
84 |
+
"clean_mesh",
|
85 |
+
True,
|
86 |
+
]
|
87 |
+
|
88 |
+
cfg.merge_from_list(cfg_show_list)
|
89 |
+
cfg.freeze()
|
90 |
+
|
91 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
92 |
+
device = torch.device(f"cuda:{args.gpu_device}")
|
93 |
+
|
94 |
+
if args.colab:
|
95 |
+
print(colored("colab environment...", "red"))
|
96 |
+
from tqdm.notebook import tqdm
|
97 |
+
else:
|
98 |
+
print(colored("normal environment...", "red"))
|
99 |
+
from tqdm import tqdm
|
100 |
+
|
101 |
+
# load model and dataloader
|
102 |
+
model = ICON(cfg)
|
103 |
+
model = load_checkpoint(model, cfg)
|
104 |
+
|
105 |
+
dataset_param = {
|
106 |
+
'image_dir': args.in_dir,
|
107 |
+
'seg_dir': args.seg_dir,
|
108 |
+
'has_det': True, # w/ or w/o detection
|
109 |
+
'hps_type': args.hps_type # pymaf/pare/pixie
|
110 |
+
}
|
111 |
+
|
112 |
+
if args.hps_type == "pixie" and "pamir" in args.config:
|
113 |
+
print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red"))
|
114 |
+
dataset_param["hps_type"] = "pymaf"
|
115 |
+
|
116 |
+
dataset = TestDataset(dataset_param, device)
|
117 |
+
|
118 |
+
print(colored(f"Dataset Size: {len(dataset)}", "green"))
|
119 |
+
|
120 |
+
pbar = tqdm(dataset)
|
121 |
+
|
122 |
+
for data in pbar:
|
123 |
+
|
124 |
+
pbar.set_description(f"{data['name']}")
|
125 |
+
|
126 |
+
in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]}
|
127 |
+
|
128 |
+
# The optimizer and variables
|
129 |
+
optimed_pose = torch.tensor(
|
130 |
+
data["body_pose"], device=device, requires_grad=True
|
131 |
+
) # [1,23,3,3]
|
132 |
+
optimed_trans = torch.tensor(
|
133 |
+
data["trans"], device=device, requires_grad=True
|
134 |
+
) # [3]
|
135 |
+
optimed_betas = torch.tensor(
|
136 |
+
data["betas"], device=device, requires_grad=True
|
137 |
+
) # [1,10]
|
138 |
+
optimed_orient = torch.tensor(
|
139 |
+
data["global_orient"], device=device, requires_grad=True
|
140 |
+
) # [1,1,3,3]
|
141 |
+
|
142 |
+
optimizer_smpl = torch.optim.SGD(
|
143 |
+
[optimed_pose, optimed_trans, optimed_betas, optimed_orient],
|
144 |
+
lr=1e-3,
|
145 |
+
momentum=0.9,
|
146 |
+
)
|
147 |
+
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
148 |
+
optimizer_smpl,
|
149 |
+
mode="min",
|
150 |
+
factor=0.5,
|
151 |
+
verbose=0,
|
152 |
+
min_lr=1e-5,
|
153 |
+
patience=args.patience,
|
154 |
+
)
|
155 |
+
|
156 |
+
losses = {
|
157 |
+
"cloth": {"weight": 1e1, "value": 0.0}, # Cloth: Normal_recon - Normal_pred
|
158 |
+
"stiffness": {"weight": 1e5, "value": 0.0}, # Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
|
159 |
+
"rigid": {"weight": 1e5, "value": 0.0}, # Cloth: det(R) = 1
|
160 |
+
"edge": {"weight": 0, "value": 0.0}, # Cloth: edge length
|
161 |
+
"nc": {"weight": 0, "value": 0.0}, # Cloth: normal consistency
|
162 |
+
"laplacian": {"weight": 1e2, "value": 0.0}, # Cloth: laplacian smoonth
|
163 |
+
"normal": {"weight": 1e0, "value": 0.0}, # Body: Normal_pred - Normal_smpl
|
164 |
+
"silhouette": {"weight": 1e1, "value": 0.0}, # Body: Silhouette_pred - Silhouette_smpl
|
165 |
+
}
|
166 |
+
|
167 |
+
# smpl optimization
|
168 |
+
|
169 |
+
loop_smpl = tqdm(
|
170 |
+
range(args.loop_smpl if cfg.net.prior_type != "pifu" else 1))
|
171 |
+
|
172 |
+
per_data_lst = []
|
173 |
+
|
174 |
+
for i in loop_smpl:
|
175 |
+
|
176 |
+
per_loop_lst = []
|
177 |
+
|
178 |
+
optimizer_smpl.zero_grad()
|
179 |
+
|
180 |
+
if dataset_param["hps_type"] != "pixie":
|
181 |
+
smpl_out = dataset.smpl_model(
|
182 |
+
betas=optimed_betas,
|
183 |
+
body_pose=optimed_pose,
|
184 |
+
global_orient=optimed_orient,
|
185 |
+
pose2rot=False,
|
186 |
+
)
|
187 |
+
|
188 |
+
smpl_verts = ((smpl_out.vertices) +
|
189 |
+
optimed_trans) * data["scale"]
|
190 |
+
else:
|
191 |
+
smpl_verts, _, _ = dataset.smpl_model(
|
192 |
+
shape_params=optimed_betas,
|
193 |
+
expression_params=tensor2variable(data["exp"], device),
|
194 |
+
body_pose=optimed_pose,
|
195 |
+
global_pose=optimed_orient,
|
196 |
+
jaw_pose=tensor2variable(data["jaw_pose"], device),
|
197 |
+
left_hand_pose=tensor2variable(
|
198 |
+
data["left_hand_pose"], device),
|
199 |
+
right_hand_pose=tensor2variable(
|
200 |
+
data["right_hand_pose"], device),
|
201 |
+
)
|
202 |
+
|
203 |
+
smpl_verts = (smpl_verts + optimed_trans) * data["scale"]
|
204 |
+
|
205 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
206 |
+
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal(
|
207 |
+
smpl_verts *
|
208 |
+
torch.tensor([1.0, -1.0, -1.0]
|
209 |
+
).to(device), in_tensor["smpl_faces"]
|
210 |
+
)
|
211 |
+
T_mask_F, T_mask_B = dataset.render.get_silhouette_image()
|
212 |
+
|
213 |
+
with torch.no_grad():
|
214 |
+
in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter(
|
215 |
+
in_tensor
|
216 |
+
)
|
217 |
+
|
218 |
+
diff_F_smpl = torch.abs(
|
219 |
+
in_tensor["T_normal_F"] - in_tensor["normal_F"])
|
220 |
+
diff_B_smpl = torch.abs(
|
221 |
+
in_tensor["T_normal_B"] - in_tensor["normal_B"])
|
222 |
+
|
223 |
+
loss_F_smpl = normal_loss(
|
224 |
+
in_tensor["T_normal_F"], in_tensor["normal_F"])
|
225 |
+
loss_B_smpl = normal_loss(
|
226 |
+
in_tensor["T_normal_B"], in_tensor["normal_B"])
|
227 |
+
|
228 |
+
losses["normal"]["value"] = (loss_F_smpl + loss_B_smpl).mean()
|
229 |
+
|
230 |
+
# silhouette loss
|
231 |
+
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0]
|
232 |
+
gt_arr = torch.cat(
|
233 |
+
[in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2
|
234 |
+
).permute(1, 2, 0)
|
235 |
+
gt_arr = ((gt_arr + 1.0) * 0.5).to(device)
|
236 |
+
bg_color = (
|
237 |
+
torch.Tensor([0.5, 0.5, 0.5]).unsqueeze(
|
238 |
+
0).unsqueeze(0).to(device)
|
239 |
+
)
|
240 |
+
gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float()
|
241 |
+
diff_S = torch.abs(smpl_arr - gt_arr)
|
242 |
+
losses["silhouette"]["value"] = diff_S.mean()
|
243 |
+
|
244 |
+
# Weighted sum of the losses
|
245 |
+
smpl_loss = 0.0
|
246 |
+
pbar_desc = "Body Fitting --- "
|
247 |
+
for k in ["normal", "silhouette"]:
|
248 |
+
pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | "
|
249 |
+
smpl_loss += losses[k]["value"] * losses[k]["weight"]
|
250 |
+
pbar_desc += f"Total: {smpl_loss:.3f}"
|
251 |
+
loop_smpl.set_description(pbar_desc)
|
252 |
+
|
253 |
+
if i % args.vis_freq == 0:
|
254 |
+
|
255 |
+
per_loop_lst.extend(
|
256 |
+
[
|
257 |
+
in_tensor["image"],
|
258 |
+
in_tensor["T_normal_F"],
|
259 |
+
in_tensor["normal_F"],
|
260 |
+
diff_F_smpl / 2.0,
|
261 |
+
diff_S[:, :512].unsqueeze(
|
262 |
+
0).unsqueeze(0).repeat(1, 3, 1, 1),
|
263 |
+
]
|
264 |
+
)
|
265 |
+
per_loop_lst.extend(
|
266 |
+
[
|
267 |
+
in_tensor["image"],
|
268 |
+
in_tensor["T_normal_B"],
|
269 |
+
in_tensor["normal_B"],
|
270 |
+
diff_B_smpl / 2.0,
|
271 |
+
diff_S[:, 512:].unsqueeze(
|
272 |
+
0).unsqueeze(0).repeat(1, 3, 1, 1),
|
273 |
+
]
|
274 |
+
)
|
275 |
+
per_data_lst.append(
|
276 |
+
get_optim_grid_image(
|
277 |
+
per_loop_lst, None, nrow=5, type="smpl")
|
278 |
+
)
|
279 |
+
|
280 |
+
smpl_loss.backward()
|
281 |
+
optimizer_smpl.step()
|
282 |
+
scheduler_smpl.step(smpl_loss)
|
283 |
+
in_tensor["smpl_verts"] = smpl_verts * \
|
284 |
+
torch.tensor([1.0, 1.0, -1.0]).to(device)
|
285 |
+
|
286 |
+
# visualize the optimization process
|
287 |
+
# 1. SMPL Fitting
|
288 |
+
# 2. Clothes Refinement
|
289 |
+
|
290 |
+
os.makedirs(os.path.join(args.out_dir, cfg.name,
|
291 |
+
"refinement"), exist_ok=True)
|
292 |
+
|
293 |
+
# visualize the final results in self-rotation mode
|
294 |
+
os.makedirs(os.path.join(args.out_dir, cfg.name, "vid"), exist_ok=True)
|
295 |
+
|
296 |
+
# final results rendered as image
|
297 |
+
# 1. Render the final fitted SMPL (xxx_smpl.png)
|
298 |
+
# 2. Render the final reconstructed clothed human (xxx_cloth.png)
|
299 |
+
# 3. Blend the original image with predicted cloth normal (xxx_overlap.png)
|
300 |
+
|
301 |
+
os.makedirs(os.path.join(args.out_dir, cfg.name, "png"), exist_ok=True)
|
302 |
+
|
303 |
+
# final reconstruction meshes
|
304 |
+
# 1. SMPL mesh (xxx_smpl.obj)
|
305 |
+
# 2. SMPL params (xxx_smpl.npy)
|
306 |
+
# 3. clohted mesh (xxx_recon.obj)
|
307 |
+
# 4. remeshed clothed mesh (xxx_remesh.obj)
|
308 |
+
# 5. refined clothed mesh (xxx_refine.obj)
|
309 |
+
|
310 |
+
os.makedirs(os.path.join(args.out_dir, cfg.name, "obj"), exist_ok=True)
|
311 |
+
|
312 |
+
if cfg.net.prior_type != "pifu":
|
313 |
+
|
314 |
+
per_data_lst[0].save(
|
315 |
+
os.path.join(
|
316 |
+
args.out_dir, cfg.name, f"refinement/{data['name']}_smpl.gif"
|
317 |
+
),
|
318 |
+
save_all=True,
|
319 |
+
append_images=per_data_lst[1:],
|
320 |
+
duration=500,
|
321 |
+
loop=0,
|
322 |
+
)
|
323 |
+
|
324 |
+
if args.vis_freq == 1:
|
325 |
+
image2vid(
|
326 |
+
per_data_lst,
|
327 |
+
os.path.join(
|
328 |
+
args.out_dir, cfg.name, f"refinement/{data['name']}_smpl.avi"
|
329 |
+
),
|
330 |
+
)
|
331 |
+
|
332 |
+
per_data_lst[-1].save(
|
333 |
+
os.path.join(args.out_dir, cfg.name,
|
334 |
+
f"png/{data['name']}_smpl.png")
|
335 |
+
)
|
336 |
+
|
337 |
+
norm_pred = (
|
338 |
+
((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0)
|
339 |
+
.detach()
|
340 |
+
.cpu()
|
341 |
+
.numpy()
|
342 |
+
.astype(np.uint8)
|
343 |
+
)
|
344 |
+
|
345 |
+
norm_orig = unwrap(norm_pred, data)
|
346 |
+
mask_orig = unwrap(
|
347 |
+
np.repeat(
|
348 |
+
data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2
|
349 |
+
).astype(np.uint8),
|
350 |
+
data,
|
351 |
+
)
|
352 |
+
rgb_norm = blend_rgb_norm(data["ori_image"], norm_orig, mask_orig)
|
353 |
+
|
354 |
+
Image.fromarray(
|
355 |
+
np.concatenate(
|
356 |
+
[data["ori_image"].astype(np.uint8), rgb_norm], axis=1)
|
357 |
+
).save(os.path.join(args.out_dir, cfg.name, f"png/{data['name']}_overlap.png"))
|
358 |
+
|
359 |
+
smpl_obj = trimesh.Trimesh(
|
360 |
+
in_tensor["smpl_verts"].detach().cpu()[0] *
|
361 |
+
torch.tensor([1.0, -1.0, 1.0]),
|
362 |
+
in_tensor['smpl_faces'].detach().cpu()[0],
|
363 |
+
process=False,
|
364 |
+
maintains_order=True
|
365 |
+
)
|
366 |
+
smpl_obj.export(
|
367 |
+
f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl.obj")
|
368 |
+
|
369 |
+
smpl_info = {'betas': optimed_betas,
|
370 |
+
'pose': optimed_pose,
|
371 |
+
'orient': optimed_orient,
|
372 |
+
'trans': optimed_trans}
|
373 |
+
|
374 |
+
np.save(
|
375 |
+
f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True)
|
376 |
+
|
377 |
+
# ------------------------------------------------------------------------------------------------------------------
|
378 |
+
|
379 |
+
# cloth optimization
|
380 |
+
|
381 |
+
per_data_lst = []
|
382 |
+
|
383 |
+
# cloth recon
|
384 |
+
in_tensor.update(
|
385 |
+
dataset.compute_vis_cmap(
|
386 |
+
in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0]
|
387 |
+
)
|
388 |
+
)
|
389 |
+
|
390 |
+
if cfg.net.prior_type == "pamir":
|
391 |
+
in_tensor.update(
|
392 |
+
dataset.compute_voxel_verts(
|
393 |
+
optimed_pose,
|
394 |
+
optimed_orient,
|
395 |
+
optimed_betas,
|
396 |
+
optimed_trans,
|
397 |
+
data["scale"],
|
398 |
+
)
|
399 |
+
)
|
400 |
+
|
401 |
+
with torch.no_grad():
|
402 |
+
verts_pr, faces_pr, _ = model.test_single(in_tensor)
|
403 |
+
|
404 |
+
recon_obj = trimesh.Trimesh(
|
405 |
+
verts_pr, faces_pr, process=False, maintains_order=True
|
406 |
+
)
|
407 |
+
recon_obj.export(
|
408 |
+
os.path.join(args.out_dir, cfg.name,
|
409 |
+
f"obj/{data['name']}_recon.obj")
|
410 |
+
)
|
411 |
+
|
412 |
+
# Isotropic Explicit Remeshing for better geometry topology
|
413 |
+
verts_refine, faces_refine = remesh(os.path.join(args.out_dir, cfg.name,
|
414 |
+
f"obj/{data['name']}_recon.obj"), 0.5, device)
|
415 |
+
|
416 |
+
# define local_affine deform verts
|
417 |
+
mesh_pr = Meshes(verts_refine, faces_refine).to(device)
|
418 |
+
local_affine_model = LocalAffine(
|
419 |
+
mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device)
|
420 |
+
optimizer_cloth = torch.optim.Adam(
|
421 |
+
[{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True)
|
422 |
+
|
423 |
+
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
424 |
+
optimizer_cloth,
|
425 |
+
mode="min",
|
426 |
+
factor=0.1,
|
427 |
+
verbose=0,
|
428 |
+
min_lr=1e-5,
|
429 |
+
patience=args.patience,
|
430 |
+
)
|
431 |
+
|
432 |
+
with torch.no_grad():
|
433 |
+
per_loop_lst = []
|
434 |
+
rotate_recon_lst = dataset.render.get_rgb_image(cam_ids=[
|
435 |
+
0, 1, 2, 3])
|
436 |
+
per_loop_lst.extend(rotate_recon_lst)
|
437 |
+
per_data_lst.append(get_optim_grid_image(
|
438 |
+
per_loop_lst, None, type="cloth"))
|
439 |
+
|
440 |
+
final = None
|
441 |
+
|
442 |
+
if args.loop_cloth > 0:
|
443 |
+
|
444 |
+
loop_cloth = tqdm(range(args.loop_cloth))
|
445 |
+
|
446 |
+
for i in loop_cloth:
|
447 |
+
|
448 |
+
per_loop_lst = []
|
449 |
+
|
450 |
+
optimizer_cloth.zero_grad()
|
451 |
+
|
452 |
+
deformed_verts, stiffness, rigid = local_affine_model(
|
453 |
+
verts_refine.to(device), return_stiff=True)
|
454 |
+
mesh_pr = mesh_pr.update_padded(deformed_verts)
|
455 |
+
|
456 |
+
# losses for laplacian, edge, normal consistency
|
457 |
+
update_mesh_shape_prior_losses(mesh_pr, losses)
|
458 |
+
|
459 |
+
in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal(
|
460 |
+
mesh_pr.verts_padded(), mesh_pr.faces_padded())
|
461 |
+
|
462 |
+
diff_F_cloth = torch.abs(
|
463 |
+
in_tensor["P_normal_F"] - in_tensor["normal_F"])
|
464 |
+
diff_B_cloth = torch.abs(
|
465 |
+
in_tensor["P_normal_B"] - in_tensor["normal_B"])
|
466 |
+
|
467 |
+
losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean()
|
468 |
+
losses["stiffness"]["value"] = torch.mean(stiffness)
|
469 |
+
losses["rigid"]["value"] = torch.mean(rigid)
|
470 |
+
|
471 |
+
# Weighted sum of the losses
|
472 |
+
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
|
473 |
+
pbar_desc = "Cloth Refinement --- "
|
474 |
+
|
475 |
+
for k in losses.keys():
|
476 |
+
if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0:
|
477 |
+
cloth_loss = cloth_loss + \
|
478 |
+
losses[k]["value"] * losses[k]["weight"]
|
479 |
+
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | "
|
480 |
+
|
481 |
+
pbar_desc += f"Total: {cloth_loss:.5f}"
|
482 |
+
loop_cloth.set_description(pbar_desc)
|
483 |
+
|
484 |
+
# update params
|
485 |
+
cloth_loss.backward(retain_graph=True)
|
486 |
+
optimizer_cloth.step()
|
487 |
+
scheduler_cloth.step(cloth_loss)
|
488 |
+
|
489 |
+
# for vis
|
490 |
+
with torch.no_grad():
|
491 |
+
if i % args.vis_freq == 0:
|
492 |
+
|
493 |
+
rotate_recon_lst = dataset.render.get_rgb_image(cam_ids=[
|
494 |
+
0, 1, 2, 3])
|
495 |
+
|
496 |
+
per_loop_lst.extend(
|
497 |
+
[
|
498 |
+
in_tensor["image"],
|
499 |
+
in_tensor["P_normal_F"],
|
500 |
+
in_tensor["normal_F"],
|
501 |
+
diff_F_cloth / 2.0,
|
502 |
+
]
|
503 |
+
)
|
504 |
+
per_loop_lst.extend(
|
505 |
+
[
|
506 |
+
in_tensor["image"],
|
507 |
+
in_tensor["P_normal_B"],
|
508 |
+
in_tensor["normal_B"],
|
509 |
+
diff_B_cloth / 2.0,
|
510 |
+
]
|
511 |
+
)
|
512 |
+
per_loop_lst.extend(rotate_recon_lst)
|
513 |
+
per_data_lst.append(
|
514 |
+
get_optim_grid_image(
|
515 |
+
per_loop_lst, None, type="cloth")
|
516 |
+
)
|
517 |
+
|
518 |
+
# gif for optimization
|
519 |
+
per_data_lst[1].save(
|
520 |
+
os.path.join(
|
521 |
+
args.out_dir, cfg.name, f"refinement/{data['name']}_cloth.gif"
|
522 |
+
),
|
523 |
+
save_all=True,
|
524 |
+
append_images=per_data_lst[2:],
|
525 |
+
duration=500,
|
526 |
+
loop=0,
|
527 |
+
)
|
528 |
+
|
529 |
+
if args.vis_freq == 1:
|
530 |
+
image2vid(
|
531 |
+
per_data_lst,
|
532 |
+
os.path.join(
|
533 |
+
args.out_dir, cfg.name, f"refinement/{data['name']}_cloth.avi"
|
534 |
+
),
|
535 |
+
)
|
536 |
+
|
537 |
+
final = trimesh.Trimesh(
|
538 |
+
mesh_pr.verts_packed().detach().squeeze(0).cpu(),
|
539 |
+
mesh_pr.faces_packed().detach().squeeze(0).cpu(),
|
540 |
+
process=False, maintains_order=True
|
541 |
+
)
|
542 |
+
final_colors = query_color(
|
543 |
+
mesh_pr.verts_packed().detach().squeeze(0).cpu(),
|
544 |
+
mesh_pr.faces_packed().detach().squeeze(0).cpu(),
|
545 |
+
in_tensor["image"],
|
546 |
+
device=device,
|
547 |
+
)
|
548 |
+
final.visual.vertex_colors = final_colors
|
549 |
+
final.export(
|
550 |
+
f"{args.out_dir}/{cfg.name}/obj/{data['name']}_refine.obj")
|
551 |
+
|
552 |
+
# always export visualized png regardless of the cloth refinment
|
553 |
+
per_data_lst[-1].save(
|
554 |
+
os.path.join(args.out_dir, cfg.name,
|
555 |
+
f"png/{data['name']}_cloth.png")
|
556 |
+
)
|
557 |
+
|
558 |
+
# always export visualized video regardless of the cloth refinment
|
559 |
+
if args.export_video:
|
560 |
+
if final is not None:
|
561 |
+
verts_lst = [verts_pr, final.vertices]
|
562 |
+
faces_lst = [faces_pr, final.faces]
|
563 |
+
else:
|
564 |
+
verts_lst = [verts_pr]
|
565 |
+
faces_lst = [faces_pr]
|
566 |
+
|
567 |
+
# self-rotated video
|
568 |
+
dataset.render.load_meshes(
|
569 |
+
verts_lst, faces_lst)
|
570 |
+
dataset.render.get_rendered_video(
|
571 |
+
[data["ori_image"], rgb_norm],
|
572 |
+
os.path.join(args.out_dir, cfg.name,
|
573 |
+
f"vid/{data['name']}_cloth.mp4"),
|
574 |
+
)
|
575 |
+
|
576 |
+
# garment extraction from deepfashion images
|
577 |
+
if not (args.seg_dir is None):
|
578 |
+
if final is not None:
|
579 |
+
recon_obj = final.copy()
|
580 |
+
|
581 |
+
os.makedirs(os.path.join(
|
582 |
+
args.out_dir, cfg.name, "clothes"), exist_ok=True)
|
583 |
+
os.makedirs(os.path.join(args.out_dir, cfg.name,
|
584 |
+
"clothes", "info"), exist_ok=True)
|
585 |
+
for seg in data['segmentations']:
|
586 |
+
# These matrices work for PyMaf, not sure about the other hps type
|
587 |
+
K = np.array([[1.0000, 0.0000, 0.0000, 0.0000],
|
588 |
+
[0.0000, 1.0000, 0.0000, 0.0000],
|
589 |
+
[0.0000, 0.0000, -0.5000, 0.0000],
|
590 |
+
[-0.0000, -0.0000, 0.5000, 1.0000]]).T
|
591 |
+
|
592 |
+
R = np.array([[-1., 0., 0.],
|
593 |
+
[0., 1., 0.],
|
594 |
+
[0., 0., -1.]])
|
595 |
+
|
596 |
+
t = np.array([[-0., -0., 100.]])
|
597 |
+
clothing_obj = extract_cloth(recon_obj, seg, K, R, t, smpl_obj)
|
598 |
+
if clothing_obj is not None:
|
599 |
+
cloth_type = seg['type'].replace(' ', '_')
|
600 |
+
cloth_info = {
|
601 |
+
'betas': optimed_betas,
|
602 |
+
'body_pose': optimed_pose,
|
603 |
+
'global_orient': optimed_orient,
|
604 |
+
'pose2rot': False,
|
605 |
+
'clothing_type': cloth_type,
|
606 |
+
}
|
607 |
+
|
608 |
+
file_id = f"{data['name']}_{cloth_type}"
|
609 |
+
with open(os.path.join(args.out_dir, cfg.name, "clothes", "info", f"{file_id}_info.pkl"), 'wb') as fp:
|
610 |
+
pickle.dump(cloth_info, fp)
|
611 |
+
|
612 |
+
clothing_obj.export(os.path.join(
|
613 |
+
args.out_dir, cfg.name, "clothes", f"{file_id}.obj"))
|
614 |
+
else:
|
615 |
+
print(
|
616 |
+
f"Unable to extract clothing of type {seg['type']} from image {data['name']}")
|
assets/garment_teaser.png
ADDED
assets/intermediate_results.png
ADDED
assets/teaser.gif
ADDED
configs/icon-filter.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: icon-filter
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "./data/ckpt/icon-filter.ckpt"
|
4 |
+
normal_path: "./data/ckpt/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "icon" # icon/pamir/icon
|
14 |
+
use_filter: True
|
15 |
+
in_geo: (('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
|
18 |
+
gtype: 'HGPIFuNet'
|
19 |
+
norm_mlp: 'batch'
|
20 |
+
hourglass_dim: 6
|
21 |
+
smpl_dim: 7
|
22 |
+
|
23 |
+
# user defined
|
24 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
25 |
+
clean_mesh: False # if True, will remove floating pieces
|
configs/icon-nofilter.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: icon-nofilter
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "./data/ckpt/icon-nofilter.ckpt"
|
4 |
+
normal_path: "./data/ckpt/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "icon" # icon/pamir/icon
|
14 |
+
use_filter: False
|
15 |
+
in_geo: (('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
smpl_feats: ['sdf', 'norm', 'vis', 'cmap']
|
18 |
+
gtype: 'HGPIFuNet'
|
19 |
+
norm_mlp: 'batch'
|
20 |
+
hourglass_dim: 6
|
21 |
+
smpl_dim: 7
|
22 |
+
|
23 |
+
# user defined
|
24 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
25 |
+
clean_mesh: False # if True, will remove floating pieces
|
configs/pamir.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pamir
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "./data/ckpt/pamir.ckpt"
|
4 |
+
normal_path: "./data/ckpt/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "pamir" # icon/pamir/icon
|
14 |
+
use_filter: True
|
15 |
+
in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
gtype: 'HGPIFuNet'
|
18 |
+
norm_mlp: 'batch'
|
19 |
+
hourglass_dim: 6
|
20 |
+
voxel_dim: 7
|
21 |
+
|
22 |
+
# user defined
|
23 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
24 |
+
clean_mesh: False # if True, will remove floating pieces
|
configs/pifu.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: pifu
|
2 |
+
ckpt_dir: "./data/ckpt/"
|
3 |
+
resume_path: "./data/ckpt/pifu.ckpt"
|
4 |
+
normal_path: "./data/ckpt/normal.ckpt"
|
5 |
+
|
6 |
+
test_mode: True
|
7 |
+
batch_size: 1
|
8 |
+
|
9 |
+
net:
|
10 |
+
mlp_dim: [256, 512, 256, 128, 1]
|
11 |
+
res_layers: [2,3,4]
|
12 |
+
num_stack: 2
|
13 |
+
prior_type: "pifu" # icon/pamir/icon
|
14 |
+
use_filter: True
|
15 |
+
in_geo: (('image',3), ('normal_F',3), ('normal_B',3))
|
16 |
+
in_nml: (('image',3), ('T_normal_F',3), ('T_normal_B',3))
|
17 |
+
gtype: 'HGPIFuNet'
|
18 |
+
norm_mlp: 'batch'
|
19 |
+
hourglass_dim: 12
|
20 |
+
|
21 |
+
|
22 |
+
# user defined
|
23 |
+
mcube_res: 512 # occupancy field resolution, higher --> more details
|
24 |
+
clean_mesh: False # if True, will remove floating pieces
|
environment.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: icon
|
2 |
+
channels:
|
3 |
+
- pytorch-lts
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- fvcore
|
7 |
+
- iopath
|
8 |
+
- bottler
|
9 |
+
- defaults
|
10 |
+
dependencies:
|
11 |
+
- pytorch
|
12 |
+
- torchvision
|
13 |
+
- fvcore
|
14 |
+
- iopath
|
15 |
+
- nvidiacub
|
16 |
+
- pyembree
|
examples/22097467bffc92d4a5c4246f7d4edb75.png
ADDED
examples/44c0f84c957b6b9bdf77662af5bb7078.png
ADDED
examples/5a6a25963db2f667441d5076972c207c.png
ADDED
examples/8da7ceb94669c2f65cbd28022e1f9876.png
ADDED
examples/923d65f767c85a42212cae13fba3750b.png
ADDED
examples/959c4c726a69901ce71b93a9242ed900.png
ADDED
examples/c9856a2bc31846d684cbb965457fad59.png
ADDED
examples/e1e7622af7074a022f5d96dc16672517.png
ADDED
examples/fb9d20fdb93750584390599478ecf86e.png
ADDED
examples/segmentation/003883.jpg
ADDED
examples/segmentation/003883.json
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"item2": {
|
3 |
+
"segmentation": [
|
4 |
+
[
|
5 |
+
232.29572649572654, 34.447388414055126, 237.0364672364673,
|
6 |
+
40.57084520417861, 244.9377018043686, 47.089363722697165,
|
7 |
+
252.04881291547974, 49.65726495726508, 262.5179487179489,
|
8 |
+
51.43504273504287, 269.233998100665, 50.447388414055204,
|
9 |
+
277.5446343779678, 49.12725546058881, 285.64339981006657,
|
10 |
+
46.16429249762584, 294.9273504273506, 41.22602089268754,
|
11 |
+
299.9377967711301, 36.514245014245084, 304.67853751187084,
|
12 |
+
30.588319088319132, 306.0612535612536, 25.65004748338083,
|
13 |
+
307.64150047483383, 23.477207977207982, 311.19705603038943,
|
14 |
+
24.859924026590704, 317.12298195631536, 28.020417853751216,
|
15 |
+
323.04890788224134, 29.008072174738874, 331.34520417853764,
|
16 |
+
30.193257359924065, 339.4439696106365, 34.7364672364673,
|
17 |
+
346.75261158594515, 39.279677113010536, 350.11063627730323,
|
18 |
+
44.61301044634389, 355.00541310541314, 61.422317188983875,
|
19 |
+
358.9560303893638, 77.6198480531815, 362.1165242165243,
|
20 |
+
90.26182336182353, 364.88195631528976, 103.29886039886063,
|
21 |
+
367.6473884140552, 118.11367521367552, 369.42516619183294,
|
22 |
+
129.37293447293484, 369.2324786324788, 132.60550807217476,
|
23 |
+
365.6769230769232, 134.77834757834762, 359.15840455840464,
|
24 |
+
138.3339031339032, 353.43000949667623, 140.70427350427357,
|
25 |
+
351.4547008547009, 141.4943969610637, 351.25716999050337,
|
26 |
+
138.5314339981007, 351.05963912630585, 136.75365622032294,
|
27 |
+
345.7263057929725, 137.34624881291552, 337.8250712250712,
|
28 |
+
139.51908831908838, 331.5040835707502, 141.09933523266864,
|
29 |
+
324.7880341880341, 143.66723646723653, 322.2201329534662,
|
30 |
+
146.43266856600198, 322.2201329534662, 151.5684710351378,
|
31 |
+
323.0102564102563, 160.6548907882243, 324.95185185185176,
|
32 |
+
173.44615384615395, 325.34691358024685, 190.23627730294416,
|
33 |
+
325.93950617283946, 205.64368471035164, 325.93950617283946,
|
34 |
+
215.71775878442577, 325.93950617283946, 220.06343779677147,
|
35 |
+
322.7790123456789, 223.22393162393197, 315.0753086419752,
|
36 |
+
228.55726495726532, 309.34691358024673, 230.53257359924066,
|
37 |
+
290.1866096866098, 230.87929724596398, 263.91500474833805,
|
38 |
+
229.6941120607788, 236.45821462488112, 229.29905033238373,
|
39 |
+
218.48290598290572, 226.73114909781583, 202.65650522317188,
|
40 |
+
224.82811016144353, 197.71823361823357, 221.07502374169044,
|
41 |
+
195.15033238366567, 214.55650522317188, 195.74292497625825,
|
42 |
+
200.53181386514711, 197.125641025641, 180.5811965811964,
|
43 |
+
197.33285849952523, 164.68736942070285, 198.51804368471042,
|
44 |
+
154.21823361823365, 198.51804368471042, 138.61329534662863,
|
45 |
+
193.5797720797721, 136.4404558404558, 185.08594491927823,
|
46 |
+
133.08243114909774, 177.77730294396957, 128.73675213675205,
|
47 |
+
174.41927825261152, 128.53922127255453, 173.82668566001894,
|
48 |
+
133.2799620132953, 174.02421652421646, 136.24292497625825,
|
49 |
+
172.83903133903127, 137.03304843304838, 167.11063627730283,
|
50 |
+
134.86020892687554, 159.9995251661917, 130.51452991452985,
|
51 |
+
159.01187084520404, 129.1318138651471, 159.60446343779662,
|
52 |
+
123.60094966761622, 162.6012345679013, 111.57578347578357,
|
53 |
+
165.95925925925934, 98.53874643874646, 170.30493827160504,
|
54 |
+
82.7362773029439, 173.92307692307693, 70.05584045584048,
|
55 |
+
177.08357075023744, 54.84596391263053, 180.58129154795822,
|
56 |
+
41.73190883190885, 183.14919278252614, 34.423266856600165,
|
57 |
+
188.51623931623936, 30.279962013295354, 195.6273504273505,
|
58 |
+
25.539221272554588, 201.75080721747398, 22.971320037986676,
|
59 |
+
211.23228869895553, 22.37872744539408, 221.10883190883212,
|
60 |
+
20.996011396011355, 224.8619183285852, 20.996011396011355,
|
61 |
+
226.04710351377042, 23.56391263057927, 229.01006647673339,
|
62 |
+
30.279962013295354
|
63 |
+
]
|
64 |
+
],
|
65 |
+
"category_id": 1,
|
66 |
+
"category_name": "short sleeve top"
|
67 |
+
},
|
68 |
+
"item1": {
|
69 |
+
"segmentation": [
|
70 |
+
[
|
71 |
+
201.51804815682925, 224.7401022799914, 218.41555508203712,
|
72 |
+
227.23317707223518, 236.42109524824218, 228.89522693373104,
|
73 |
+
256.91971020669104, 229.44924355422967, 280.188408267633,
|
74 |
+
230.2802684849776, 296.53189857234224, 230.2802684849776,
|
75 |
+
313.7064138077994, 229.72625186447897, 315.32667803111013,
|
76 |
+
236.8076070743661, 317.8197528233539, 240.96273172810572,
|
77 |
+
318.65077775410185, 246.2258896228426, 321.4208608565949,
|
78 |
+
253.15109737907534, 322.8059024078415, 265.0624547197956,
|
79 |
+
324.74496057958663, 273.6497123375242, 325.9612827615598,
|
80 |
+
284.4076070743661, 325.40726614106114, 299.9200724483274,
|
81 |
+
324.29923290006394, 316.8175793735353, 322.0831664180694,
|
82 |
+
325.9588536117625, 320.16803750266354, 336.5366716386107,
|
83 |
+
316.0129128489239, 344.01589601534204, 315.18188791817596,
|
84 |
+
357.86631152780745, 312.4118048156829, 368.1156190070319,
|
85 |
+
308.5336884721926, 378.64193479650567, 306.31762199019806,
|
86 |
+
385.29013424248905, 305.76360536969946, 398.3095248242066,
|
87 |
+
305.48659705945016, 409.6668655444283, 304.94393777967184,
|
88 |
+
419.3418708715109, 302.7278712976774, 427.0981035584915,
|
89 |
+
301.3428297464308, 433.74630300447495, 301.3428297464308,
|
90 |
+
445.3806520349459, 300.5118048156829, 461.72414233965515,
|
91 |
+
299.89735776688684, 467.352311953974, 297.9582995951417,
|
92 |
+
477.60161943319844, 295.1882164926486, 491.7290432559132,
|
93 |
+
293.52616663115276, 497.2692094608994, 291.8641167696569,
|
94 |
+
503.36339228638417, 291.3101001491583, 510.8426166631155,
|
95 |
+
289.37104197741314, 513.8897080758579, 287.4433411463882,
|
96 |
+
519.2043682079693, 283.0112081823993, 519.7583848284679,
|
97 |
+
275.5319838056679, 519.4813765182186, 270.26882591093107,
|
98 |
+
518.096334966972, 265.8366929469421, 513.6642020029831,
|
99 |
+
263.62062646494763, 509.78608565949276, 264.7286597059449,
|
100 |
+
498.9827615597697, 265.2826763264435, 478.76115491157015,
|
101 |
+
266.1137012571914, 467.1268058810992, 266.1137012571914,
|
102 |
+
454.6614319198803, 264.17464308544623, 441.64204133816276,
|
103 |
+
263.06660984444903, 424.19051779245626, 261.5834221180482,
|
104 |
+
407.2581504368212, 259.92137225655233, 396.45482633709815,
|
105 |
+
257.1512891540592, 380.1113360323889, 257.42829746430857,
|
106 |
+
359.05870445344146, 256.8742808438099, 338.56008949499255,
|
107 |
+
256.8742808438099, 321.3855742595354, 254.10419774131685,
|
108 |
+
320.5545493287875, 251.05710632857443, 326.6487321542723,
|
109 |
+
249.39505646707858, 339.1141061154912, 249.11804815682927,
|
110 |
+
356.28862135094835, 248.28702322608135, 372.3551033454083,
|
111 |
+
245.23993181333896, 387.59056040912026, 243.5766673769444,
|
112 |
+
409.1404219049649, 241.91461751544855, 424.92989558917554,
|
113 |
+
240.52957596420202, 440.4423609631369, 238.86752610270617,
|
114 |
+
455.40080971659955, 238.86752610270617, 470.91327509056083,
|
115 |
+
238.31350948220754, 486.42574046452216, 238.81966759002768,
|
116 |
+
501.19639889196685, 239.6506925207756, 511.168698060942,
|
117 |
+
236.0495844875346, 515.6008310249309, 229.40138504155118,
|
118 |
+
519.4789473684212, 221.6451523545705, 520.3099722991692,
|
119 |
+
216.65900277008296, 517.2628808864267, 213.33490304709125,
|
120 |
+
509.50664819944615, 208.3487534626037, 491.50110803324105,
|
121 |
+
205.8556786703599, 475.1576177285318, 203.63961218836545,
|
122 |
+
460.75318559556774, 203.63961218836545, 443.3016620498613,
|
123 |
+
203.63961218836545, 421.9720221606645, 200.59252077562303,
|
124 |
+
415.60083102493036, 197.5052844662264, 406.9847858512679,
|
125 |
+
195.28921798423193, 392.0263370978052, 193.35015981248677,
|
126 |
+
370.97370551885774, 190.857085020243, 343.82689111442545,
|
127 |
+
187.8099936075006, 322.77425953547794, 187.0028979330919,
|
128 |
+
309.89237161730256, 186.17187300234397, 291.33281483059886,
|
129 |
+
188.11093117408916, 266.67907521841033, 191.15802258683155,
|
130 |
+
250.3355849137011, 196.69818879181773, 234.82311953973982
|
131 |
+
]
|
132 |
+
],
|
133 |
+
"category_id": 8,
|
134 |
+
"category_name": "trousers"
|
135 |
+
}
|
136 |
+
}
|
examples/segmentation/028009.jpg
ADDED
examples/segmentation/028009.json
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"item2": {
|
3 |
+
"segmentation": [
|
4 |
+
[
|
5 |
+
314.7474747474744, 204.84848484848482, 328.9696969696967,
|
6 |
+
209.7373737373737, 342.74747474747454, 211.95959595959593,
|
7 |
+
360.0808080808079, 211.07070707070704, 375.19191919191906,
|
8 |
+
210.18181818181816, 384.5252525252524, 207.07070707070704,
|
9 |
+
390.30303030303025, 204.84848484848482, 396.080808080808,
|
10 |
+
201.29292929292924, 402.3030303030303, 204.40404040404036,
|
11 |
+
412.969696969697, 203.9595959595959, 425.8585858585859,
|
12 |
+
206.18181818181813, 434.3030303030304, 211.95959595959593,
|
13 |
+
439.63636363636374, 223.0707070707071, 444.0808080808082,
|
14 |
+
234.18181818181824, 448.52525252525265, 250.62626262626276,
|
15 |
+
449.41414141414157, 260.848484848485, 452.08080808080825,
|
16 |
+
279.0707070707073, 456.08080808080825, 300.84848484848516,
|
17 |
+
457.858585858586, 308.40404040404076, 460.5252525252526,
|
18 |
+
315.7575757575756, 460.96969696969705, 329.97979797979787,
|
19 |
+
460.5252525252526, 345.9797979797979, 456.969696969697,
|
20 |
+
363.75757575757575, 453.41414141414145, 373.5353535353536,
|
21 |
+
450.3030303030303, 385.97979797979804, 447.1919191919192,
|
22 |
+
393.9797979797981, 443.6363636363636, 401.9797979797981,
|
23 |
+
438.3030303030303, 403.7575757575759, 433.85858585858585,
|
24 |
+
401.09090909090924, 430.7474747474747, 393.0909090909092,
|
25 |
+
426.7474747474747, 383.3131313131314, 424.9696969696969,
|
26 |
+
374.8686868686869, 424.9696969696969, 369.0909090909091,
|
27 |
+
423.63636363636357, 363.3131313131313, 423.63636363636357,
|
28 |
+
359.3131313131313, 423.63636363636357, 352.6464646464646,
|
29 |
+
420.9696969696969, 350.86868686868684, 422.74747474747466,
|
30 |
+
345.53535353535347, 422.74747474747466, 340.64646464646455,
|
31 |
+
422.74747474747466, 332.2020202020201, 421.8585858585858,
|
32 |
+
321.53535353535335, 418.74747474747466, 313.0909090909089,
|
33 |
+
416.5252525252524, 306.4242424242422, 412.9696969696969,
|
34 |
+
314.8686868686867, 410.3030303030302, 320.20202020202004,
|
35 |
+
411.6363636363635, 327.3131313131312, 414.74747474747466,
|
36 |
+
336.2020202020201, 418.74747474747466, 351.7575757575757,
|
37 |
+
420.9696969696969, 365.0909090909091, 423.1919191919191,
|
38 |
+
377.0909090909091, 423.1919191919191, 385.0909090909092,
|
39 |
+
424.5252525252525, 398.42424242424255, 396.0808080808079,
|
40 |
+
398.42424242424255, 374.7474747474745, 400.6464646464648,
|
41 |
+
354.7474747474744, 400.6464646464648, 331.6363636363632,
|
42 |
+
400.6464646464648, 313.41414141414094, 400.6464646464648,
|
43 |
+
305.4141414141409, 399.3131313131314, 297.4141414141409,
|
44 |
+
396.6464646464648, 284.525252525252, 396.2020202020203,
|
45 |
+
282.8686868686866, 391.59595959595964, 282.42424242424215,
|
46 |
+
373.81818181818176, 282.42424242424215, 358.26262626262616,
|
47 |
+
281.09090909090884, 334.70707070707056, 281.5353535353533,
|
48 |
+
313.37373737373713, 283.31313131313107, 297.3737373737371,
|
49 |
+
282.8686868686866, 283.1515151515148, 280.6464646464644,
|
50 |
+
266.7070707070703, 271.313131313131, 253.3737373737369,
|
51 |
+
264.6464646464643, 246.70707070707022, 257.5353535353532,
|
52 |
+
239.59595959595907, 249.9797979797976, 228.9292929292924,
|
53 |
+
242.42424242424204, 220.92929292929236, 233.17171717171723,
|
54 |
+
209.01010101010093, 225.1717171717172, 194.78787878787867,
|
55 |
+
222.06060606060606, 185.4545454545453, 224.2828282828283,
|
56 |
+
179.6767676767675, 230.0606060606061, 171.67676767676747,
|
57 |
+
232.72727272727278, 169.89898989898967, 243.83838383838392,
|
58 |
+
167.67676767676744, 256.2828282828284, 165.4545454545452,
|
59 |
+
274.06060606060623, 165.4545454545452, 291.8383838383841,
|
60 |
+
167.67676767676744, 302.5050505050508, 168.1212121212119,
|
61 |
+
310.94949494949526, 177.0101010101008, 314.0606060606064,
|
62 |
+
181.45454545454527, 314.94949494949526, 187.2323232323231,
|
63 |
+
312.7272727272731, 193.01010101010087, 307.8383838383842,
|
64 |
+
191.2323232323231, 302.94949494949526, 193.45454545454533,
|
65 |
+
292.727272727273, 193.45454545454533, 290.50505050505075,
|
66 |
+
195.67676767676755, 287.39393939393966, 197.45454545454533,
|
67 |
+
285.61616161616183, 197.45454545454533, 283.3939393939396,
|
68 |
+
193.89898989898978, 278.94949494949515, 197.45454545454533,
|
69 |
+
274.94949494949515, 199.67676767676755, 279.83838383838406,
|
70 |
+
201.45454545454535, 286.50505050505075, 201.45454545454535,
|
71 |
+
291.8383838383841, 201.8989898989898, 296.2828282828286,
|
72 |
+
202.7878787878787, 303.3939393939397, 202.34343434343424
|
73 |
+
]
|
74 |
+
],
|
75 |
+
"category_id": 2,
|
76 |
+
"category_name": "long sleeve top"
|
77 |
+
},
|
78 |
+
"item1": {
|
79 |
+
"segmentation": [
|
80 |
+
[
|
81 |
+
346.9494949494949, 660.6868686868687, 397.6161616161618,
|
82 |
+
661.5757575757576, 398.06060606060623, 674.0202020202021,
|
83 |
+
398.94949494949515, 691.3535353535356, 397.6161616161618,
|
84 |
+
710.0202020202022, 395.838383838384, 726.0202020202023,
|
85 |
+
393.1717171717173, 742.0202020202023, 346.9494949494949,
|
86 |
+
738.9090909090912, 346.50505050505046, 724.2424242424245,
|
87 |
+
347.3939393939394, 713.5757575757578, 348.72727272727275,
|
88 |
+
706.0202020202022, 349.17171717171715, 686.0202020202022,
|
89 |
+
348.72727272727275, 675.7979797979799, 347.3939393939394,
|
90 |
+
667.7979797979799
|
91 |
+
],
|
92 |
+
[
|
93 |
+
283.71717171717165, 396.68686868686876, 289.9393939393939,
|
94 |
+
396.68686868686876, 303.27272727272725, 397.1313131313132,
|
95 |
+
312.16161616161617, 399.7979797979799, 334.3838383838385,
|
96 |
+
400.68686868686876, 351.7171717171719, 400.68686868686876,
|
97 |
+
361.93939393939417, 401.5757575757577, 376.60606060606085,
|
98 |
+
401.5757575757577, 390.82828282828314, 398.46464646464653,
|
99 |
+
410.3838383838388, 397.5757575757577, 425.0505050505055,
|
100 |
+
394.46464646464653, 431.71717171717216, 422.9090909090911,
|
101 |
+
434.38383838383885, 447.79797979798, 430.38383838383885,
|
102 |
+
478.0202020202024, 423.2727272727277, 507.79797979798025,
|
103 |
+
418.3838383838388, 530.0202020202025, 411.8787878787878,
|
104 |
+
557.3333333333333, 403.43434343434336, 590.6666666666666,
|
105 |
+
400.7676767676767, 611.5555555555557, 399.8787878787878,
|
106 |
+
619.1111111111112, 399.8787878787878, 630.6666666666669,
|
107 |
+
398.10101010101, 635.1111111111113, 399.43434343434336,
|
108 |
+
641.7777777777779, 399.43434343434336, 656.4444444444447,
|
109 |
+
398.10101010101, 662.666666666667, 347.4343434343432, 662.666666666667,
|
110 |
+
346.1010101010098, 637.7777777777779, 347.4343434343432,
|
111 |
+
610.6666666666667, 349.21212121212096, 576.4444444444445,
|
112 |
+
350.98989898989873, 556.4444444444443, 349.6565656565654,
|
113 |
+
541.3333333333331, 348.32323232323205, 535.9999999999998,
|
114 |
+
348.32323232323205, 523.5555555555553, 349.21212121212096,
|
115 |
+
505.33333333333303, 342.5454545454543, 511.5555555555553,
|
116 |
+
338.9898989898987, 516.8888888888887, 334.5454545454542,
|
117 |
+
523.5555555555553, 325.6565656565653, 543.111111111111,
|
118 |
+
319.87878787878753, 556.4444444444443, 314.1010101010097,
|
119 |
+
568.4444444444443, 307.8787878787875, 583.1111111111111,
|
120 |
+
300.3232323232319, 608.0000000000001, 298.10101010100965,
|
121 |
+
617.7777777777778, 298.5454545454541, 624.0000000000001,
|
122 |
+
295.43434343434296, 628.0000000000001, 293.2121212121208,
|
123 |
+
628.0000000000001, 293.6565656565652, 632.4444444444446,
|
124 |
+
291.43434343434296, 638.6666666666669, 290.54545454545405,
|
125 |
+
644.4444444444447, 292.3232323232319, 648.8888888888891,
|
126 |
+
303.8787878787875, 667.1111111111114, 313.65656565656525,
|
127 |
+
684.0000000000003, 319.87878787878753, 700.8888888888893,
|
128 |
+
322.54545454545416, 712.8888888888894, 324.323232323232,
|
129 |
+
720.0000000000005, 327.87878787878753, 731.5555555555561,
|
130 |
+
330.9898989898987, 738.6666666666672, 331.87878787878753,
|
131 |
+
743.1111111111117, 334.5454545454542, 745.7777777777783,
|
132 |
+
336.3232323232325, 749.1313131313133, 338.54545454545473,
|
133 |
+
754.0202020202022, 338.54545454545473, 757.5757575757577,
|
134 |
+
341.6565656565658, 760.6868686868688, 344.76767676767696,
|
135 |
+
767.3535353535356, 345.2121212121214, 770.9090909090911,
|
136 |
+
346.9898989898992, 754.0202020202022, 347.43434343434365,
|
137 |
+
738.909090909091, 393.2121212121216, 740.6868686868687,
|
138 |
+
389.65656565656604, 764.6868686868688, 386.5454545454549,
|
139 |
+
784.2424242424245, 384.3232323232327, 806.9090909090912,
|
140 |
+
382.54545454545485, 812.686868686869, 381.13131313131316,
|
141 |
+
818.7070707070708, 378.020202020202, 828.4848484848485,
|
142 |
+
375.35353535353534, 839.5959595959597, 374.9090909090909,
|
143 |
+
854.2626262626264, 373.1313131313131, 856.9292929292931,
|
144 |
+
376.24242424242425, 864.9292929292931, 372.24242424242425,
|
145 |
+
874.2626262626264, 366.4646464646464, 880.9292929292932,
|
146 |
+
357.13131313131305, 872.9292929292932, 345.13131313131305,
|
147 |
+
868.0404040404043, 337.131313131313, 867.1515151515154,
|
148 |
+
337.131313131313, 856.0404040404042, 338.4646464646463,
|
149 |
+
850.7070707070709, 336.2424242424241, 846.2626262626264,
|
150 |
+
335.3535353535352, 841.3737373737375, 338.4646464646463,
|
151 |
+
827.5959595959597, 342.0202020202019, 815.5959595959596,
|
152 |
+
344.6868686868686, 809.3737373737374, 344.6868686868686,
|
153 |
+
796.4848484848484, 344.6868686868686, 786.7070707070707,
|
154 |
+
346.0202020202019, 779.151515151515, 344.24242424242414,
|
155 |
+
776.0404040404039, 343.3535353535352, 786.2626262626262,
|
156 |
+
342.0202020202019, 796.0404040404039, 338.90909090909076,
|
157 |
+
801.8181818181818, 333.57575757575745, 809.3737373737374,
|
158 |
+
326.02020202020185, 813.8181818181819, 320.242424242424,
|
159 |
+
812.4848484848485, 318.02020202020185, 810.7070707070707,
|
160 |
+
317.13131313131294, 807.1515151515151, 315.79797979797956,
|
161 |
+
803.5959595959596, 313.57575757575734, 799.5959595959596,
|
162 |
+
311.3535353535351, 793.8181818181818, 306.90909090909065,
|
163 |
+
791.1515151515151, 305.57575757575734, 787.5959595959595,
|
164 |
+
304.242424242424, 782.7070707070706, 302.02020202020174,
|
165 |
+
776.4848484848484, 298.90909090909065, 773.8181818181816,
|
166 |
+
294.90909090909065, 771.151515151515, 290.34343434343435,
|
167 |
+
758.909090909091, 284.5656565656566, 742.020202020202,
|
168 |
+
278.78787878787875, 729.5757575757575, 270.3434343434343,
|
169 |
+
713.131313131313, 257.8989898989898, 689.1313131313129,
|
170 |
+
247.2323232323231, 669.1313131313128, 239.23232323232307,
|
171 |
+
657.5757575757573, 233.89898989898973, 642.9090909090905,
|
172 |
+
233.0101010101008, 634.0202020202016, 233.45454545454527,
|
173 |
+
630.0202020202016, 235.23232323232304, 611.7979797979793,
|
174 |
+
241.93939393939402, 583.0707070707073, 245.93939393939405,
|
175 |
+
567.5151515151516, 251.2727272727274, 540.4040404040404,
|
176 |
+
256.1616161616163, 518.6262626262626, 260.60606060606074,
|
177 |
+
501.2929292929292, 263.7171717171719, 493.7373737373736,
|
178 |
+
268.16161616161634, 481.73737373737356, 270.38383838383857,
|
179 |
+
469.73737373737356, 272.6060606060608, 462.18181818181796,
|
180 |
+
276.1616161616164, 457.7373737373735, 276.1616161616164,
|
181 |
+
454.1818181818179, 277.05050505050525, 450.1818181818179,
|
182 |
+
278.828282828283, 433.292929292929, 278.3838383838386,
|
183 |
+
419.0707070707067, 278.828282828283, 417.29292929292893,
|
184 |
+
281.0505050505053, 414.1818181818178, 281.93939393939417,
|
185 |
+
404.8484848484844, 283.71717171717194, 401.2929292929289
|
186 |
+
]
|
187 |
+
],
|
188 |
+
"category_id": 8,
|
189 |
+
"category_name": "trousers"
|
190 |
+
}
|
191 |
+
}
|
examples/slack_trial2-000150.png
ADDED
fetch_data.sh
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; }
|
3 |
+
|
4 |
+
mkdir -p data/smpl_related/models
|
5 |
+
|
6 |
+
# username and password input
|
7 |
+
echo -e "\nYou need to register at https://icon.is.tue.mpg.de/, according to Installation Instruction."
|
8 |
+
read -p "Username (ICON):" username
|
9 |
+
read -p "Password (ICON):" password
|
10 |
+
username=$(urle $username)
|
11 |
+
password=$(urle $password)
|
12 |
+
|
13 |
+
# SMPL (Male, Female)
|
14 |
+
echo -e "\nDownloading SMPL..."
|
15 |
+
wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.0.0.zip&resume=1' -O './data/smpl_related/models/SMPL_python_v.1.0.0.zip' --no-check-certificate --continue
|
16 |
+
unzip data/smpl_related/models/SMPL_python_v.1.0.0.zip -d data/smpl_related/models
|
17 |
+
mv data/smpl_related/models/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl data/smpl_related/models/smpl/SMPL_FEMALE.pkl
|
18 |
+
mv data/smpl_related/models/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl data/smpl_related/models/smpl/SMPL_MALE.pkl
|
19 |
+
cd data/smpl_related/models
|
20 |
+
rm -rf *.zip __MACOSX smpl/models smpl/smpl_webuser
|
21 |
+
cd ../../..
|
22 |
+
|
23 |
+
# SMPL (Neutral, from SMPLIFY)
|
24 |
+
echo -e "\nDownloading SMPLify..."
|
25 |
+
wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplify&sfile=mpips_smplify_public_v2.zip&resume=1' -O './data/smpl_related/models/mpips_smplify_public_v2.zip' --no-check-certificate --continue
|
26 |
+
unzip data/smpl_related/models/mpips_smplify_public_v2.zip -d data/smpl_related/models
|
27 |
+
mv data/smpl_related/models/smplify_public/code/models/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl data/smpl_related/models/smpl/SMPL_NEUTRAL.pkl
|
28 |
+
cd data/smpl_related/models
|
29 |
+
rm -rf *.zip smplify_public
|
30 |
+
cd ../../..
|
31 |
+
|
32 |
+
# ICON
|
33 |
+
echo -e "\nDownloading ICON..."
|
34 |
+
wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=icon&sfile=icon_data.zip&resume=1' -O './data/icon_data.zip' --no-check-certificate --continue
|
35 |
+
cd data && unzip icon_data.zip
|
36 |
+
mv smpl_data smpl_related/
|
37 |
+
rm -f icon_data.zip
|
38 |
+
cd ..
|
39 |
+
|
40 |
+
function download_for_training () {
|
41 |
+
|
42 |
+
# SMPL-X (optional)
|
43 |
+
echo -e "\nDownloading SMPL-X..."
|
44 |
+
wget --post-data "username=$1&password=$2" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=models_smplx_v1_1.zip&resume=1' -O './data/smpl_related/models/models_smplx_v1_1.zip' --no-check-certificate --continue
|
45 |
+
unzip data/smpl_related/models/models_smplx_v1_1.zip -d data/smpl_related
|
46 |
+
rm -f data/smpl_related/models/models_smplx_v1_1.zip
|
47 |
+
|
48 |
+
# SMIL (optional)
|
49 |
+
echo -e "\nDownloading SMIL..."
|
50 |
+
wget --post-data "username=$1&password=$2" 'https://download.is.tue.mpg.de/download.php?domain=agora&sfile=smpl_kid_template.npy&resume=1' -O './data/smpl_related/models/smpl/smpl_kid_template.npy' --no-check-certificate --continue
|
51 |
+
wget --post-data "username=$1&password=$2" 'https://download.is.tue.mpg.de/download.php?domain=agora&sfile=smplx_kid_template.npy&resume=1' -O './data/smpl_related/models/smplx/smplx_kid_template.npy' --no-check-certificate --continue
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
read -p "(optional) Download models used for training (y/n)?" choice
|
56 |
+
case "$choice" in
|
57 |
+
y|Y ) download_for_training $username $password;;
|
58 |
+
n|N ) echo "Great job! Try the demo for now!";;
|
59 |
+
* ) echo "Invalid input! Please use y|Y or n|N";;
|
60 |
+
esac
|
install.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # conda installation
|
2 |
+
# wget https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.3-Linux-x86_64.sh
|
3 |
+
# chmod +x Miniconda3-py38_4.10.3-Linux-x86_64.sh
|
4 |
+
# bash Miniconda3-py38_4.10.3-Linux-x86_64.sh -b -f -p /home/user/.local
|
5 |
+
# rm Miniconda3-py38_4.10.3-Linux-x86_64.sh
|
6 |
+
# conda config --env --set always_yes true
|
7 |
+
# conda update -n base -c defaults conda -y
|
8 |
+
|
9 |
+
# # conda environment setup
|
10 |
+
# conda env create -f environment.yaml
|
11 |
+
# conda init bash
|
12 |
+
# source /home/user/.bashrc
|
13 |
+
# source activate icon
|
14 |
+
nvidia-smi
|
15 |
+
pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111
|
16 |
+
pip install -r requirement.txt
|
lib/__init__.py
ADDED
File without changes
|
lib/common/__init__.py
ADDED
File without changes
|
lib/common/cloth_extraction.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import itertools
|
5 |
+
import trimesh
|
6 |
+
from matplotlib.path import Path
|
7 |
+
from collections import Counter
|
8 |
+
from sklearn.neighbors import KNeighborsClassifier
|
9 |
+
|
10 |
+
|
11 |
+
def load_segmentation(path, shape):
|
12 |
+
"""
|
13 |
+
Get a segmentation mask for a given image
|
14 |
+
Arguments:
|
15 |
+
path: path to the segmentation json file
|
16 |
+
shape: shape of the output mask
|
17 |
+
Returns:
|
18 |
+
Returns a segmentation mask
|
19 |
+
"""
|
20 |
+
with open(path) as json_file:
|
21 |
+
dict = json.load(json_file)
|
22 |
+
segmentations = []
|
23 |
+
for key, val in dict.items():
|
24 |
+
if not key.startswith('item'):
|
25 |
+
continue
|
26 |
+
|
27 |
+
# Each item can have multiple polygons. Combine them to one
|
28 |
+
# segmentation_coord = list(itertools.chain.from_iterable(val['segmentation']))
|
29 |
+
# segmentation_coord = np.round(np.array(segmentation_coord)).astype(int)
|
30 |
+
|
31 |
+
coordinates = []
|
32 |
+
for segmentation_coord in val['segmentation']:
|
33 |
+
# The format before is [x1,y1, x2, y2, ....]
|
34 |
+
x = segmentation_coord[::2]
|
35 |
+
y = segmentation_coord[1::2]
|
36 |
+
xy = np.vstack((x, y)).T
|
37 |
+
coordinates.append(xy)
|
38 |
+
|
39 |
+
segmentations.append(
|
40 |
+
{'type': val['category_name'], 'type_id': val['category_id'], 'coordinates': coordinates})
|
41 |
+
|
42 |
+
return segmentations
|
43 |
+
|
44 |
+
|
45 |
+
def smpl_to_recon_labels(recon, smpl, k=1):
|
46 |
+
"""
|
47 |
+
Get the bodypart labels for the recon object by using the labels from the corresponding smpl object
|
48 |
+
Arguments:
|
49 |
+
recon: trimesh object (fully clothed model)
|
50 |
+
shape: trimesh object (smpl model)
|
51 |
+
k: number of nearest neighbours to use
|
52 |
+
Returns:
|
53 |
+
Returns a dictionary containing the bodypart and the corresponding indices
|
54 |
+
"""
|
55 |
+
smpl_vert_segmentation = json.load(
|
56 |
+
open(os.path.join(os.path.dirname(__file__), 'smpl_vert_segmentation.json')))
|
57 |
+
n = smpl.vertices.shape[0]
|
58 |
+
y = np.array([None] * n)
|
59 |
+
for key, val in smpl_vert_segmentation.items():
|
60 |
+
y[val] = key
|
61 |
+
|
62 |
+
classifier = KNeighborsClassifier(n_neighbors=1)
|
63 |
+
classifier.fit(smpl.vertices, y)
|
64 |
+
|
65 |
+
y_pred = classifier.predict(recon.vertices)
|
66 |
+
|
67 |
+
recon_labels = {}
|
68 |
+
for key in smpl_vert_segmentation.keys():
|
69 |
+
recon_labels[key] = list(np.argwhere(
|
70 |
+
y_pred == key).flatten().astype(int))
|
71 |
+
|
72 |
+
return recon_labels
|
73 |
+
|
74 |
+
|
75 |
+
def extract_cloth(recon, segmentation, K, R, t, smpl=None):
|
76 |
+
"""
|
77 |
+
Extract a portion of a mesh using 2d segmentation coordinates
|
78 |
+
Arguments:
|
79 |
+
recon: fully clothed mesh
|
80 |
+
seg_coord: segmentation coordinates in 2D (NDC)
|
81 |
+
K: intrinsic matrix of the projection
|
82 |
+
R: rotation matrix of the projection
|
83 |
+
t: translation vector of the projection
|
84 |
+
Returns:
|
85 |
+
Returns a submesh using the segmentation coordinates
|
86 |
+
"""
|
87 |
+
seg_coord = segmentation['coord_normalized']
|
88 |
+
mesh = trimesh.Trimesh(recon.vertices, recon.faces)
|
89 |
+
extrinsic = np.zeros((3, 4))
|
90 |
+
extrinsic[:3, :3] = R
|
91 |
+
extrinsic[:, 3] = t
|
92 |
+
P = K[:3, :3] @ extrinsic
|
93 |
+
|
94 |
+
P_inv = np.linalg.pinv(P)
|
95 |
+
|
96 |
+
# Each segmentation can contain multiple polygons
|
97 |
+
# We need to check them separately
|
98 |
+
points_so_far = []
|
99 |
+
faces = recon.faces
|
100 |
+
for polygon in seg_coord:
|
101 |
+
n = len(polygon)
|
102 |
+
coords_h = np.hstack((polygon, np.ones((n, 1))))
|
103 |
+
# Apply the inverse projection on homogeneus 2D coordinates to get the corresponding 3d Coordinates
|
104 |
+
XYZ = P_inv @ coords_h[:, :, None]
|
105 |
+
XYZ = XYZ.reshape((XYZ.shape[0], XYZ.shape[1]))
|
106 |
+
XYZ = XYZ[:, :3] / XYZ[:, 3, None]
|
107 |
+
|
108 |
+
p = Path(XYZ[:, :2])
|
109 |
+
|
110 |
+
grid = p.contains_points(recon.vertices[:, :2])
|
111 |
+
indeces = np.argwhere(grid == True)
|
112 |
+
points_so_far += list(indeces.flatten())
|
113 |
+
|
114 |
+
if smpl is not None:
|
115 |
+
num_verts = recon.vertices.shape[0]
|
116 |
+
recon_labels = smpl_to_recon_labels(recon, smpl)
|
117 |
+
body_parts_to_remove = ['rightHand', 'leftToeBase', 'leftFoot', 'rightFoot', 'head',
|
118 |
+
'leftHandIndex1', 'rightHandIndex1', 'rightToeBase', 'leftHand', 'rightHand']
|
119 |
+
type = segmentation['type_id']
|
120 |
+
|
121 |
+
# Remove additional bodyparts that are most likely not part of the segmentation but might intersect (e.g. hand in front of torso)
|
122 |
+
# https://github.com/switchablenorms/DeepFashion2
|
123 |
+
# Short sleeve clothes
|
124 |
+
if type == 1 or type == 3 or type == 10:
|
125 |
+
body_parts_to_remove += ['leftForeArm', 'rightForeArm']
|
126 |
+
# No sleeves at all or lower body clothes
|
127 |
+
elif type == 5 or type == 6 or type == 12 or type == 13 or type == 8 or type == 9:
|
128 |
+
body_parts_to_remove += ['leftForeArm',
|
129 |
+
'rightForeArm', 'leftArm', 'rightArm']
|
130 |
+
# Shorts
|
131 |
+
elif type == 7:
|
132 |
+
body_parts_to_remove += ['leftLeg', 'rightLeg',
|
133 |
+
'leftForeArm', 'rightForeArm', 'leftArm', 'rightArm']
|
134 |
+
|
135 |
+
verts_to_remove = list(itertools.chain.from_iterable(
|
136 |
+
[recon_labels[part] for part in body_parts_to_remove]))
|
137 |
+
|
138 |
+
label_mask = np.zeros(num_verts, dtype=bool)
|
139 |
+
label_mask[verts_to_remove] = True
|
140 |
+
|
141 |
+
seg_mask = np.zeros(num_verts, dtype=bool)
|
142 |
+
seg_mask[points_so_far] = True
|
143 |
+
|
144 |
+
# Remove points that belong to other bodyparts
|
145 |
+
# If a vertice in pointsSoFar is included in the bodyparts to remove, then these points should be removed
|
146 |
+
extra_verts_to_remove = np.array(list(seg_mask) and list(label_mask))
|
147 |
+
|
148 |
+
combine_mask = np.zeros(num_verts, dtype=bool)
|
149 |
+
combine_mask[points_so_far] = True
|
150 |
+
combine_mask[extra_verts_to_remove] = False
|
151 |
+
|
152 |
+
all_indices = np.argwhere(combine_mask == True).flatten()
|
153 |
+
|
154 |
+
i_x = np.where(np.in1d(faces[:, 0], all_indices))[0]
|
155 |
+
i_y = np.where(np.in1d(faces[:, 1], all_indices))[0]
|
156 |
+
i_z = np.where(np.in1d(faces[:, 2], all_indices))[0]
|
157 |
+
|
158 |
+
faces_to_keep = np.array(list(set(i_x).union(i_y).union(i_z)))
|
159 |
+
mask = np.zeros(len(recon.faces), dtype=bool)
|
160 |
+
if len(faces_to_keep) > 0:
|
161 |
+
mask[faces_to_keep] = True
|
162 |
+
|
163 |
+
mesh.update_faces(mask)
|
164 |
+
mesh.remove_unreferenced_vertices()
|
165 |
+
|
166 |
+
# mesh.rezero()
|
167 |
+
|
168 |
+
return mesh
|
169 |
+
|
170 |
+
return None
|
lib/common/config.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
from yacs.config import CfgNode as CN
|
19 |
+
import os
|
20 |
+
|
21 |
+
_C = CN(new_allowed=True)
|
22 |
+
|
23 |
+
# needed by trainer
|
24 |
+
_C.name = 'default'
|
25 |
+
_C.gpus = [0]
|
26 |
+
_C.test_gpus = [1]
|
27 |
+
_C.root = "./data/"
|
28 |
+
_C.ckpt_dir = './data/ckpt/'
|
29 |
+
_C.resume_path = ''
|
30 |
+
_C.normal_path = ''
|
31 |
+
_C.corr_path = ''
|
32 |
+
_C.results_path = './data/results/'
|
33 |
+
_C.projection_mode = 'orthogonal'
|
34 |
+
_C.num_views = 1
|
35 |
+
_C.sdf = False
|
36 |
+
_C.sdf_clip = 5.0
|
37 |
+
|
38 |
+
_C.lr_G = 1e-3
|
39 |
+
_C.lr_C = 1e-3
|
40 |
+
_C.lr_N = 2e-4
|
41 |
+
_C.weight_decay = 0.0
|
42 |
+
_C.momentum = 0.0
|
43 |
+
_C.optim = 'RMSprop'
|
44 |
+
_C.schedule = [5, 10, 15]
|
45 |
+
_C.gamma = 0.1
|
46 |
+
|
47 |
+
_C.overfit = False
|
48 |
+
_C.resume = False
|
49 |
+
_C.test_mode = False
|
50 |
+
_C.test_uv = False
|
51 |
+
_C.draw_geo_thres = 0.60
|
52 |
+
_C.num_sanity_val_steps = 2
|
53 |
+
_C.fast_dev = 0
|
54 |
+
_C.get_fit = False
|
55 |
+
_C.agora = False
|
56 |
+
_C.optim_cloth = False
|
57 |
+
_C.optim_body = False
|
58 |
+
_C.mcube_res = 256
|
59 |
+
_C.clean_mesh = True
|
60 |
+
_C.remesh = False
|
61 |
+
|
62 |
+
_C.batch_size = 4
|
63 |
+
_C.num_threads = 8
|
64 |
+
|
65 |
+
_C.num_epoch = 10
|
66 |
+
_C.freq_plot = 0.01
|
67 |
+
_C.freq_show_train = 0.1
|
68 |
+
_C.freq_show_val = 0.2
|
69 |
+
_C.freq_eval = 0.5
|
70 |
+
_C.accu_grad_batch = 4
|
71 |
+
|
72 |
+
_C.test_items = ['sv', 'mv', 'mv-fusion', 'hybrid', 'dc-pred', 'gt']
|
73 |
+
|
74 |
+
_C.net = CN()
|
75 |
+
_C.net.gtype = 'HGPIFuNet'
|
76 |
+
_C.net.ctype = 'resnet18'
|
77 |
+
_C.net.classifierIMF = 'MultiSegClassifier'
|
78 |
+
_C.net.netIMF = 'resnet18'
|
79 |
+
_C.net.norm = 'group'
|
80 |
+
_C.net.norm_mlp = 'group'
|
81 |
+
_C.net.norm_color = 'group'
|
82 |
+
_C.net.hg_down = 'ave_pool'
|
83 |
+
_C.net.num_views = 1
|
84 |
+
|
85 |
+
# kernel_size, stride, dilation, padding
|
86 |
+
|
87 |
+
_C.net.conv1 = [7, 2, 1, 3]
|
88 |
+
_C.net.conv3x3 = [3, 1, 1, 1]
|
89 |
+
|
90 |
+
_C.net.num_stack = 4
|
91 |
+
_C.net.num_hourglass = 2
|
92 |
+
_C.net.hourglass_dim = 256
|
93 |
+
_C.net.voxel_dim = 32
|
94 |
+
_C.net.resnet_dim = 120
|
95 |
+
_C.net.mlp_dim = [320, 1024, 512, 256, 128, 1]
|
96 |
+
_C.net.mlp_dim_knn = [320, 1024, 512, 256, 128, 3]
|
97 |
+
_C.net.mlp_dim_color = [513, 1024, 512, 256, 128, 3]
|
98 |
+
_C.net.mlp_dim_multiseg = [1088, 2048, 1024, 500]
|
99 |
+
_C.net.res_layers = [2, 3, 4]
|
100 |
+
_C.net.filter_dim = 256
|
101 |
+
_C.net.smpl_dim = 3
|
102 |
+
|
103 |
+
_C.net.cly_dim = 3
|
104 |
+
_C.net.soft_dim = 64
|
105 |
+
_C.net.z_size = 200.0
|
106 |
+
_C.net.N_freqs = 10
|
107 |
+
_C.net.geo_w = 0.1
|
108 |
+
_C.net.norm_w = 0.1
|
109 |
+
_C.net.dc_w = 0.1
|
110 |
+
_C.net.C_cat_to_G = False
|
111 |
+
|
112 |
+
_C.net.skip_hourglass = True
|
113 |
+
_C.net.use_tanh = True
|
114 |
+
_C.net.soft_onehot = True
|
115 |
+
_C.net.no_residual = True
|
116 |
+
_C.net.use_attention = False
|
117 |
+
|
118 |
+
_C.net.prior_type = "sdf"
|
119 |
+
_C.net.smpl_feats = ['sdf', 'cmap', 'norm', 'vis']
|
120 |
+
_C.net.use_filter = True
|
121 |
+
_C.net.use_cc = False
|
122 |
+
_C.net.use_PE = False
|
123 |
+
_C.net.use_IGR = False
|
124 |
+
_C.net.in_geo = ()
|
125 |
+
_C.net.in_nml = ()
|
126 |
+
|
127 |
+
_C.dataset = CN()
|
128 |
+
_C.dataset.root = ''
|
129 |
+
_C.dataset.set_splits = [0.95, 0.04]
|
130 |
+
_C.dataset.types = [
|
131 |
+
"3dpeople", "axyz", "renderpeople", "renderpeople_p27", "humanalloy"
|
132 |
+
]
|
133 |
+
_C.dataset.scales = [1.0, 100.0, 1.0, 1.0, 100.0 / 39.37]
|
134 |
+
_C.dataset.rp_type = "pifu900"
|
135 |
+
_C.dataset.th_type = 'train'
|
136 |
+
_C.dataset.input_size = 512
|
137 |
+
_C.dataset.rotation_num = 3
|
138 |
+
_C.dataset.num_precomp = 10 # Number of segmentation classifiers
|
139 |
+
_C.dataset.num_multiseg = 500 # Number of categories per classifier
|
140 |
+
_C.dataset.num_knn = 10 # for loss/error
|
141 |
+
_C.dataset.num_knn_dis = 20 # for accuracy
|
142 |
+
_C.dataset.num_verts_max = 20000
|
143 |
+
_C.dataset.zray_type = False
|
144 |
+
_C.dataset.online_smpl = False
|
145 |
+
_C.dataset.noise_type = ['z-trans', 'pose', 'beta']
|
146 |
+
_C.dataset.noise_scale = [0.0, 0.0, 0.0]
|
147 |
+
_C.dataset.num_sample_geo = 10000
|
148 |
+
_C.dataset.num_sample_color = 0
|
149 |
+
_C.dataset.num_sample_seg = 0
|
150 |
+
_C.dataset.num_sample_knn = 10000
|
151 |
+
|
152 |
+
_C.dataset.sigma_geo = 5.0
|
153 |
+
_C.dataset.sigma_color = 0.10
|
154 |
+
_C.dataset.sigma_seg = 0.10
|
155 |
+
_C.dataset.thickness_threshold = 20.0
|
156 |
+
_C.dataset.ray_sample_num = 2
|
157 |
+
_C.dataset.semantic_p = False
|
158 |
+
_C.dataset.remove_outlier = False
|
159 |
+
|
160 |
+
_C.dataset.train_bsize = 1.0
|
161 |
+
_C.dataset.val_bsize = 1.0
|
162 |
+
_C.dataset.test_bsize = 1.0
|
163 |
+
|
164 |
+
|
165 |
+
def get_cfg_defaults():
|
166 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
167 |
+
# Return a clone so that the defaults will not be altered
|
168 |
+
# This is for the "local variable" use pattern
|
169 |
+
return _C.clone()
|
170 |
+
|
171 |
+
|
172 |
+
# Alternatively, provide a way to import the defaults as
|
173 |
+
# a global singleton:
|
174 |
+
cfg = _C # users can `from config import cfg`
|
175 |
+
|
176 |
+
# cfg = get_cfg_defaults()
|
177 |
+
# cfg.merge_from_file('./configs/example.yaml')
|
178 |
+
|
179 |
+
# # Now override from a list (opts could come from the command line)
|
180 |
+
# opts = ['dataset.root', './data/XXXX', 'learning_rate', '1e-2']
|
181 |
+
# cfg.merge_from_list(opts)
|
182 |
+
|
183 |
+
|
184 |
+
def update_cfg(cfg_file):
|
185 |
+
# cfg = get_cfg_defaults()
|
186 |
+
_C.merge_from_file(cfg_file)
|
187 |
+
# return cfg.clone()
|
188 |
+
return _C
|
189 |
+
|
190 |
+
|
191 |
+
def parse_args(args):
|
192 |
+
cfg_file = args.cfg_file
|
193 |
+
if args.cfg_file is not None:
|
194 |
+
cfg = update_cfg(args.cfg_file)
|
195 |
+
else:
|
196 |
+
cfg = get_cfg_defaults()
|
197 |
+
|
198 |
+
# if args.misc is not None:
|
199 |
+
# cfg.merge_from_list(args.misc)
|
200 |
+
|
201 |
+
return cfg
|
202 |
+
|
203 |
+
|
204 |
+
def parse_args_extend(args):
|
205 |
+
if args.resume:
|
206 |
+
if not os.path.exists(args.log_dir):
|
207 |
+
raise ValueError(
|
208 |
+
'Experiment are set to resume mode, but log directory does not exist.'
|
209 |
+
)
|
210 |
+
|
211 |
+
# load log's cfg
|
212 |
+
cfg_file = os.path.join(args.log_dir, 'cfg.yaml')
|
213 |
+
cfg = update_cfg(cfg_file)
|
214 |
+
|
215 |
+
if args.misc is not None:
|
216 |
+
cfg.merge_from_list(args.misc)
|
217 |
+
else:
|
218 |
+
parse_args(args)
|
lib/common/render.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
from pytorch3d.renderer import (
|
18 |
+
BlendParams,
|
19 |
+
blending,
|
20 |
+
look_at_view_transform,
|
21 |
+
FoVOrthographicCameras,
|
22 |
+
PointLights,
|
23 |
+
RasterizationSettings,
|
24 |
+
PointsRasterizationSettings,
|
25 |
+
PointsRenderer,
|
26 |
+
AlphaCompositor,
|
27 |
+
PointsRasterizer,
|
28 |
+
MeshRenderer,
|
29 |
+
MeshRasterizer,
|
30 |
+
SoftPhongShader,
|
31 |
+
SoftSilhouetteShader,
|
32 |
+
TexturesVertex,
|
33 |
+
)
|
34 |
+
from pytorch3d.renderer.mesh import TexturesVertex
|
35 |
+
from pytorch3d.structures import Meshes
|
36 |
+
from lib.dataset.mesh_util import SMPLX, get_visibility
|
37 |
+
|
38 |
+
import lib.common.render_utils as util
|
39 |
+
import torch
|
40 |
+
import numpy as np
|
41 |
+
from PIL import Image
|
42 |
+
from tqdm import tqdm
|
43 |
+
import os
|
44 |
+
import cv2
|
45 |
+
import math
|
46 |
+
from termcolor import colored
|
47 |
+
|
48 |
+
|
49 |
+
def image2vid(images, vid_path):
|
50 |
+
|
51 |
+
w, h = images[0].size
|
52 |
+
videodims = (w, h)
|
53 |
+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
54 |
+
video = cv2.VideoWriter(vid_path, fourcc, 30, videodims)
|
55 |
+
for image in images:
|
56 |
+
video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
|
57 |
+
video.release()
|
58 |
+
|
59 |
+
|
60 |
+
def query_color(verts, faces, image, device):
|
61 |
+
"""query colors from points and image
|
62 |
+
|
63 |
+
Args:
|
64 |
+
verts ([B, 3]): [query verts]
|
65 |
+
faces ([M, 3]): [query faces]
|
66 |
+
image ([B, 3, H, W]): [full image]
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
[np.float]: [return colors]
|
70 |
+
"""
|
71 |
+
|
72 |
+
verts = verts.float().to(device)
|
73 |
+
faces = faces.long().to(device)
|
74 |
+
|
75 |
+
(xy, z) = verts.split([2, 1], dim=1)
|
76 |
+
visibility = get_visibility(xy, z, faces[:, [0, 2, 1]]).flatten()
|
77 |
+
uv = xy.unsqueeze(0).unsqueeze(2) # [B, N, 2]
|
78 |
+
uv = uv * torch.tensor([1.0, -1.0]).type_as(uv)
|
79 |
+
colors = (torch.nn.functional.grid_sample(image, uv, align_corners=True)[
|
80 |
+
0, :, :, 0].permute(1, 0) + 1.0) * 0.5 * 255.0
|
81 |
+
colors[visibility == 0.0] = ((Meshes(verts.unsqueeze(0), faces.unsqueeze(
|
82 |
+
0)).verts_normals_padded().squeeze(0) + 1.0) * 0.5 * 255.0)[visibility == 0.0]
|
83 |
+
|
84 |
+
return colors.detach().cpu()
|
85 |
+
|
86 |
+
|
87 |
+
class cleanShader(torch.nn.Module):
|
88 |
+
def __init__(self, device="cpu", cameras=None, blend_params=None):
|
89 |
+
super().__init__()
|
90 |
+
self.cameras = cameras
|
91 |
+
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
92 |
+
|
93 |
+
def forward(self, fragments, meshes, **kwargs):
|
94 |
+
cameras = kwargs.get("cameras", self.cameras)
|
95 |
+
if cameras is None:
|
96 |
+
msg = "Cameras must be specified either at initialization \
|
97 |
+
or in the forward pass of TexturedSoftPhongShader"
|
98 |
+
|
99 |
+
raise ValueError(msg)
|
100 |
+
|
101 |
+
# get renderer output
|
102 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
103 |
+
texels = meshes.sample_textures(fragments)
|
104 |
+
images = blending.softmax_rgb_blend(
|
105 |
+
texels, fragments, blend_params, znear=-256, zfar=256
|
106 |
+
)
|
107 |
+
|
108 |
+
return images
|
109 |
+
|
110 |
+
|
111 |
+
class Render:
|
112 |
+
def __init__(self, size=512, device=torch.device("cuda:0")):
|
113 |
+
self.device = device
|
114 |
+
self.mesh_y_center = 100.0
|
115 |
+
self.dis = 100.0
|
116 |
+
self.scale = 1.0
|
117 |
+
self.size = size
|
118 |
+
self.cam_pos = [(0, 100, 100)]
|
119 |
+
|
120 |
+
self.mesh = None
|
121 |
+
self.deform_mesh = None
|
122 |
+
self.pcd = None
|
123 |
+
self.renderer = None
|
124 |
+
self.meshRas = None
|
125 |
+
self.type = None
|
126 |
+
self.knn = None
|
127 |
+
self.knn_inverse = None
|
128 |
+
|
129 |
+
self.smpl_seg = None
|
130 |
+
self.smpl_cmap = None
|
131 |
+
|
132 |
+
self.smplx = SMPLX()
|
133 |
+
|
134 |
+
self.uv_rasterizer = util.Pytorch3dRasterizer(self.size)
|
135 |
+
|
136 |
+
def get_camera(self, cam_id):
|
137 |
+
|
138 |
+
R, T = look_at_view_transform(
|
139 |
+
eye=[self.cam_pos[cam_id]],
|
140 |
+
at=((0, self.mesh_y_center, 0),),
|
141 |
+
up=((0, 1, 0),),
|
142 |
+
)
|
143 |
+
|
144 |
+
camera = FoVOrthographicCameras(
|
145 |
+
device=self.device,
|
146 |
+
R=R,
|
147 |
+
T=T,
|
148 |
+
znear=100.0,
|
149 |
+
zfar=-100.0,
|
150 |
+
max_y=100.0,
|
151 |
+
min_y=-100.0,
|
152 |
+
max_x=100.0,
|
153 |
+
min_x=-100.0,
|
154 |
+
scale_xyz=(self.scale * np.ones(3),),
|
155 |
+
)
|
156 |
+
|
157 |
+
return camera
|
158 |
+
|
159 |
+
def init_renderer(self, camera, type="clean_mesh", bg="gray"):
|
160 |
+
|
161 |
+
if "mesh" in type:
|
162 |
+
|
163 |
+
# rasterizer
|
164 |
+
self.raster_settings_mesh = RasterizationSettings(
|
165 |
+
image_size=self.size,
|
166 |
+
blur_radius=np.log(1.0 / 1e-4) * 1e-7,
|
167 |
+
faces_per_pixel=30,
|
168 |
+
)
|
169 |
+
self.meshRas = MeshRasterizer(
|
170 |
+
cameras=camera, raster_settings=self.raster_settings_mesh
|
171 |
+
)
|
172 |
+
|
173 |
+
if bg == "black":
|
174 |
+
blendparam = BlendParams(1e-4, 1e-4, (0.0, 0.0, 0.0))
|
175 |
+
elif bg == "white":
|
176 |
+
blendparam = BlendParams(1e-4, 1e-8, (1.0, 1.0, 1.0))
|
177 |
+
elif bg == "gray":
|
178 |
+
blendparam = BlendParams(1e-4, 1e-8, (0.5, 0.5, 0.5))
|
179 |
+
|
180 |
+
if type == "ori_mesh":
|
181 |
+
|
182 |
+
lights = PointLights(
|
183 |
+
device=self.device,
|
184 |
+
ambient_color=((0.8, 0.8, 0.8),),
|
185 |
+
diffuse_color=((0.2, 0.2, 0.2),),
|
186 |
+
specular_color=((0.0, 0.0, 0.0),),
|
187 |
+
location=[[0.0, 200.0, 0.0]],
|
188 |
+
)
|
189 |
+
|
190 |
+
self.renderer = MeshRenderer(
|
191 |
+
rasterizer=self.meshRas,
|
192 |
+
shader=SoftPhongShader(
|
193 |
+
device=self.device,
|
194 |
+
cameras=camera,
|
195 |
+
lights=lights,
|
196 |
+
blend_params=blendparam,
|
197 |
+
),
|
198 |
+
)
|
199 |
+
|
200 |
+
if type == "silhouette":
|
201 |
+
self.raster_settings_silhouette = RasterizationSettings(
|
202 |
+
image_size=self.size,
|
203 |
+
blur_radius=np.log(1.0 / 1e-4 - 1.0) * 5e-5,
|
204 |
+
faces_per_pixel=50,
|
205 |
+
cull_backfaces=True,
|
206 |
+
)
|
207 |
+
|
208 |
+
self.silhouetteRas = MeshRasterizer(
|
209 |
+
cameras=camera, raster_settings=self.raster_settings_silhouette
|
210 |
+
)
|
211 |
+
self.renderer = MeshRenderer(
|
212 |
+
rasterizer=self.silhouetteRas, shader=SoftSilhouetteShader()
|
213 |
+
)
|
214 |
+
|
215 |
+
if type == "pointcloud":
|
216 |
+
self.raster_settings_pcd = PointsRasterizationSettings(
|
217 |
+
image_size=self.size, radius=0.006, points_per_pixel=10
|
218 |
+
)
|
219 |
+
|
220 |
+
self.pcdRas = PointsRasterizer(
|
221 |
+
cameras=camera, raster_settings=self.raster_settings_pcd
|
222 |
+
)
|
223 |
+
self.renderer = PointsRenderer(
|
224 |
+
rasterizer=self.pcdRas,
|
225 |
+
compositor=AlphaCompositor(background_color=(0, 0, 0)),
|
226 |
+
)
|
227 |
+
|
228 |
+
if type == "clean_mesh":
|
229 |
+
|
230 |
+
self.renderer = MeshRenderer(
|
231 |
+
rasterizer=self.meshRas,
|
232 |
+
shader=cleanShader(
|
233 |
+
device=self.device, cameras=camera, blend_params=blendparam
|
234 |
+
),
|
235 |
+
)
|
236 |
+
|
237 |
+
def VF2Mesh(self, verts, faces):
|
238 |
+
|
239 |
+
if not torch.is_tensor(verts):
|
240 |
+
verts = torch.tensor(verts)
|
241 |
+
if not torch.is_tensor(faces):
|
242 |
+
faces = torch.tensor(faces)
|
243 |
+
|
244 |
+
if verts.ndimension() == 2:
|
245 |
+
verts = verts.unsqueeze(0).float()
|
246 |
+
if faces.ndimension() == 2:
|
247 |
+
faces = faces.unsqueeze(0).long()
|
248 |
+
|
249 |
+
verts = verts.to(self.device)
|
250 |
+
faces = faces.to(self.device)
|
251 |
+
|
252 |
+
mesh = Meshes(verts, faces).to(self.device)
|
253 |
+
|
254 |
+
mesh.textures = TexturesVertex(
|
255 |
+
verts_features=(mesh.verts_normals_padded() + 1.0) * 0.5
|
256 |
+
)
|
257 |
+
|
258 |
+
return mesh
|
259 |
+
|
260 |
+
def load_meshes(self, verts, faces):
|
261 |
+
"""load mesh into the pytorch3d renderer
|
262 |
+
|
263 |
+
Args:
|
264 |
+
verts ([N,3]): verts
|
265 |
+
faces ([N,3]): faces
|
266 |
+
offset ([N,3]): offset
|
267 |
+
"""
|
268 |
+
|
269 |
+
# camera setting
|
270 |
+
self.scale = 100.0
|
271 |
+
self.mesh_y_center = 0.0
|
272 |
+
|
273 |
+
self.cam_pos = [
|
274 |
+
(0, self.mesh_y_center, 100.0),
|
275 |
+
(100.0, self.mesh_y_center, 0),
|
276 |
+
(0, self.mesh_y_center, -100.0),
|
277 |
+
(-100.0, self.mesh_y_center, 0),
|
278 |
+
]
|
279 |
+
|
280 |
+
self.type = "color"
|
281 |
+
|
282 |
+
if isinstance(verts, list):
|
283 |
+
self.meshes = []
|
284 |
+
for V, F in zip(verts, faces):
|
285 |
+
self.meshes.append(self.VF2Mesh(V, F))
|
286 |
+
else:
|
287 |
+
self.meshes = [self.VF2Mesh(verts, faces)]
|
288 |
+
|
289 |
+
def get_depth_map(self, cam_ids=[0, 2]):
|
290 |
+
|
291 |
+
depth_maps = []
|
292 |
+
for cam_id in cam_ids:
|
293 |
+
self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
|
294 |
+
fragments = self.meshRas(self.meshes[0])
|
295 |
+
depth_map = fragments.zbuf[..., 0].squeeze(0)
|
296 |
+
if cam_id == 2:
|
297 |
+
depth_map = torch.fliplr(depth_map)
|
298 |
+
depth_maps.append(depth_map)
|
299 |
+
|
300 |
+
return depth_maps
|
301 |
+
|
302 |
+
def get_rgb_image(self, cam_ids=[0, 2]):
|
303 |
+
|
304 |
+
images = []
|
305 |
+
for cam_id in range(len(self.cam_pos)):
|
306 |
+
if cam_id in cam_ids:
|
307 |
+
self.init_renderer(self.get_camera(
|
308 |
+
cam_id), "clean_mesh", "gray")
|
309 |
+
if len(cam_ids) == 4:
|
310 |
+
rendered_img = (
|
311 |
+
self.renderer(self.meshes[0])[
|
312 |
+
0:1, :, :, :3].permute(0, 3, 1, 2)
|
313 |
+
- 0.5
|
314 |
+
) * 2.0
|
315 |
+
else:
|
316 |
+
rendered_img = (
|
317 |
+
self.renderer(self.meshes[0])[
|
318 |
+
0:1, :, :, :3].permute(0, 3, 1, 2)
|
319 |
+
- 0.5
|
320 |
+
) * 2.0
|
321 |
+
if cam_id == 2 and len(cam_ids) == 2:
|
322 |
+
rendered_img = torch.flip(rendered_img, dims=[3])
|
323 |
+
images.append(rendered_img)
|
324 |
+
|
325 |
+
return images
|
326 |
+
|
327 |
+
def get_rendered_video(self, images, save_path):
|
328 |
+
|
329 |
+
self.cam_pos = []
|
330 |
+
for angle in range(360):
|
331 |
+
self.cam_pos.append(
|
332 |
+
(
|
333 |
+
100.0 * math.cos(np.pi / 180 * angle),
|
334 |
+
self.mesh_y_center,
|
335 |
+
100.0 * math.sin(np.pi / 180 * angle),
|
336 |
+
)
|
337 |
+
)
|
338 |
+
|
339 |
+
old_shape = np.array(images[0].shape[:2])
|
340 |
+
new_shape = np.around(
|
341 |
+
(self.size / old_shape[0]) * old_shape).astype(np.int)
|
342 |
+
|
343 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
344 |
+
video = cv2.VideoWriter(
|
345 |
+
save_path, fourcc, 30, (self.size * len(self.meshes) +
|
346 |
+
new_shape[1] * len(images), self.size)
|
347 |
+
)
|
348 |
+
|
349 |
+
pbar = tqdm(range(len(self.cam_pos)))
|
350 |
+
pbar.set_description(colored(f"exporting video {os.path.basename(save_path)}...", "blue"))
|
351 |
+
for cam_id in pbar:
|
352 |
+
self.init_renderer(self.get_camera(cam_id), "clean_mesh", "gray")
|
353 |
+
|
354 |
+
img_lst = [
|
355 |
+
np.array(Image.fromarray(img).resize(new_shape[::-1])).astype(np.uint8)[
|
356 |
+
:, :, [2, 1, 0]
|
357 |
+
]
|
358 |
+
for img in images
|
359 |
+
]
|
360 |
+
|
361 |
+
for mesh in self.meshes:
|
362 |
+
rendered_img = (
|
363 |
+
(self.renderer(mesh)[0, :, :, :3] * 255.0)
|
364 |
+
.detach()
|
365 |
+
.cpu()
|
366 |
+
.numpy()
|
367 |
+
.astype(np.uint8)
|
368 |
+
)
|
369 |
+
|
370 |
+
img_lst.append(rendered_img)
|
371 |
+
final_img = np.concatenate(img_lst, axis=1)
|
372 |
+
video.write(final_img)
|
373 |
+
|
374 |
+
video.release()
|
375 |
+
|
376 |
+
def get_silhouette_image(self, cam_ids=[0, 2]):
|
377 |
+
|
378 |
+
images = []
|
379 |
+
for cam_id in range(len(self.cam_pos)):
|
380 |
+
if cam_id in cam_ids:
|
381 |
+
self.init_renderer(self.get_camera(cam_id), "silhouette")
|
382 |
+
rendered_img = self.renderer(self.meshes[0])[0:1, :, :, 3]
|
383 |
+
if cam_id == 2 and len(cam_ids) == 2:
|
384 |
+
rendered_img = torch.flip(rendered_img, dims=[2])
|
385 |
+
images.append(rendered_img)
|
386 |
+
|
387 |
+
return images
|
lib/common/render_utils.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
import trimesh
|
21 |
+
import math
|
22 |
+
from typing import NewType
|
23 |
+
from pytorch3d.structures import Meshes
|
24 |
+
from pytorch3d.renderer.mesh import rasterize_meshes
|
25 |
+
|
26 |
+
Tensor = NewType('Tensor', torch.Tensor)
|
27 |
+
|
28 |
+
|
29 |
+
def solid_angles(points: Tensor,
|
30 |
+
triangles: Tensor,
|
31 |
+
thresh: float = 1e-8) -> Tensor:
|
32 |
+
''' Compute solid angle between the input points and triangles
|
33 |
+
Follows the method described in:
|
34 |
+
The Solid Angle of a Plane Triangle
|
35 |
+
A. VAN OOSTEROM AND J. STRACKEE
|
36 |
+
IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING,
|
37 |
+
VOL. BME-30, NO. 2, FEBRUARY 1983
|
38 |
+
Parameters
|
39 |
+
-----------
|
40 |
+
points: BxQx3
|
41 |
+
Tensor of input query points
|
42 |
+
triangles: BxFx3x3
|
43 |
+
Target triangles
|
44 |
+
thresh: float
|
45 |
+
float threshold
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
solid_angles: BxQxF
|
49 |
+
A tensor containing the solid angle between all query points
|
50 |
+
and input triangles
|
51 |
+
'''
|
52 |
+
# Center the triangles on the query points. Size should be BxQxFx3x3
|
53 |
+
centered_tris = triangles[:, None] - points[:, :, None, None]
|
54 |
+
|
55 |
+
# BxQxFx3
|
56 |
+
norms = torch.norm(centered_tris, dim=-1)
|
57 |
+
|
58 |
+
# Should be BxQxFx3
|
59 |
+
cross_prod = torch.cross(centered_tris[:, :, :, 1],
|
60 |
+
centered_tris[:, :, :, 2],
|
61 |
+
dim=-1)
|
62 |
+
# Should be BxQxF
|
63 |
+
numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
|
64 |
+
del cross_prod
|
65 |
+
|
66 |
+
dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1)
|
67 |
+
dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1)
|
68 |
+
dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
|
69 |
+
del centered_tris
|
70 |
+
|
71 |
+
denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] +
|
72 |
+
dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0])
|
73 |
+
del dot01, dot12, dot02, norms
|
74 |
+
|
75 |
+
# Should be BxQ
|
76 |
+
solid_angle = torch.atan2(numerator, denominator)
|
77 |
+
del numerator, denominator
|
78 |
+
|
79 |
+
torch.cuda.empty_cache()
|
80 |
+
|
81 |
+
return 2 * solid_angle
|
82 |
+
|
83 |
+
|
84 |
+
def winding_numbers(points: Tensor,
|
85 |
+
triangles: Tensor,
|
86 |
+
thresh: float = 1e-8) -> Tensor:
|
87 |
+
''' Uses winding_numbers to compute inside/outside
|
88 |
+
Robust inside-outside segmentation using generalized winding numbers
|
89 |
+
Alec Jacobson,
|
90 |
+
Ladislav Kavan,
|
91 |
+
Olga Sorkine-Hornung
|
92 |
+
Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018
|
93 |
+
Gavin Barill
|
94 |
+
NEIL G. Dickson
|
95 |
+
Ryan Schmidt
|
96 |
+
David I.W. Levin
|
97 |
+
and Alec Jacobson
|
98 |
+
Parameters
|
99 |
+
-----------
|
100 |
+
points: BxQx3
|
101 |
+
Tensor of input query points
|
102 |
+
triangles: BxFx3x3
|
103 |
+
Target triangles
|
104 |
+
thresh: float
|
105 |
+
float threshold
|
106 |
+
Returns
|
107 |
+
-------
|
108 |
+
winding_numbers: BxQ
|
109 |
+
A tensor containing the Generalized winding numbers
|
110 |
+
'''
|
111 |
+
# The generalized winding number is the sum of solid angles of the point
|
112 |
+
# with respect to all triangles.
|
113 |
+
return 1 / (4 * math.pi) * solid_angles(points, triangles,
|
114 |
+
thresh=thresh).sum(dim=-1)
|
115 |
+
|
116 |
+
|
117 |
+
def batch_contains(verts, faces, points):
|
118 |
+
|
119 |
+
B = verts.shape[0]
|
120 |
+
N = points.shape[1]
|
121 |
+
|
122 |
+
verts = verts.detach().cpu()
|
123 |
+
faces = faces.detach().cpu()
|
124 |
+
points = points.detach().cpu()
|
125 |
+
contains = torch.zeros(B, N)
|
126 |
+
|
127 |
+
for i in range(B):
|
128 |
+
contains[i] = torch.as_tensor(
|
129 |
+
trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
|
130 |
+
|
131 |
+
return 2.0 * (contains - 0.5)
|
132 |
+
|
133 |
+
|
134 |
+
def dict2obj(d):
|
135 |
+
# if isinstance(d, list):
|
136 |
+
# d = [dict2obj(x) for x in d]
|
137 |
+
if not isinstance(d, dict):
|
138 |
+
return d
|
139 |
+
|
140 |
+
class C(object):
|
141 |
+
pass
|
142 |
+
|
143 |
+
o = C()
|
144 |
+
for k in d:
|
145 |
+
o.__dict__[k] = dict2obj(d[k])
|
146 |
+
return o
|
147 |
+
|
148 |
+
|
149 |
+
def face_vertices(vertices, faces):
|
150 |
+
"""
|
151 |
+
:param vertices: [batch size, number of vertices, 3]
|
152 |
+
:param faces: [batch size, number of faces, 3]
|
153 |
+
:return: [batch size, number of faces, 3, 3]
|
154 |
+
"""
|
155 |
+
|
156 |
+
bs, nv = vertices.shape[:2]
|
157 |
+
bs, nf = faces.shape[:2]
|
158 |
+
device = vertices.device
|
159 |
+
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
|
160 |
+
nv)[:, None, None]
|
161 |
+
vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
|
162 |
+
|
163 |
+
return vertices[faces.long()]
|
164 |
+
|
165 |
+
|
166 |
+
class Pytorch3dRasterizer(nn.Module):
|
167 |
+
""" Borrowed from https://github.com/facebookresearch/pytorch3d
|
168 |
+
Notice:
|
169 |
+
x,y,z are in image space, normalized
|
170 |
+
can only render squared image now
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, image_size=224):
|
174 |
+
"""
|
175 |
+
use fixed raster_settings for rendering faces
|
176 |
+
"""
|
177 |
+
super().__init__()
|
178 |
+
raster_settings = {
|
179 |
+
'image_size': image_size,
|
180 |
+
'blur_radius': 0.0,
|
181 |
+
'faces_per_pixel': 1,
|
182 |
+
'bin_size': None,
|
183 |
+
'max_faces_per_bin': None,
|
184 |
+
'perspective_correct': True,
|
185 |
+
'cull_backfaces': True,
|
186 |
+
}
|
187 |
+
raster_settings = dict2obj(raster_settings)
|
188 |
+
self.raster_settings = raster_settings
|
189 |
+
|
190 |
+
def forward(self, vertices, faces, attributes=None):
|
191 |
+
fixed_vertices = vertices.clone()
|
192 |
+
fixed_vertices[..., :2] = -fixed_vertices[..., :2]
|
193 |
+
meshes_screen = Meshes(verts=fixed_vertices.float(),
|
194 |
+
faces=faces.long())
|
195 |
+
raster_settings = self.raster_settings
|
196 |
+
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
197 |
+
meshes_screen,
|
198 |
+
image_size=raster_settings.image_size,
|
199 |
+
blur_radius=raster_settings.blur_radius,
|
200 |
+
faces_per_pixel=raster_settings.faces_per_pixel,
|
201 |
+
bin_size=raster_settings.bin_size,
|
202 |
+
max_faces_per_bin=raster_settings.max_faces_per_bin,
|
203 |
+
perspective_correct=raster_settings.perspective_correct,
|
204 |
+
)
|
205 |
+
vismask = (pix_to_face > -1).float()
|
206 |
+
D = attributes.shape[-1]
|
207 |
+
attributes = attributes.clone()
|
208 |
+
attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
|
209 |
+
3, attributes.shape[-1])
|
210 |
+
N, H, W, K, _ = bary_coords.shape
|
211 |
+
mask = pix_to_face == -1
|
212 |
+
pix_to_face = pix_to_face.clone()
|
213 |
+
pix_to_face[mask] = 0
|
214 |
+
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
|
215 |
+
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
|
216 |
+
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
|
217 |
+
pixel_vals[mask] = 0 # Replace masked values in output.
|
218 |
+
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
|
219 |
+
pixel_vals = torch.cat(
|
220 |
+
[pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
|
221 |
+
return pixel_vals
|
lib/common/seg3d_lossless.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
|
19 |
+
from .seg3d_utils import (
|
20 |
+
create_grid3D,
|
21 |
+
plot_mask3D,
|
22 |
+
SmoothConv3D,
|
23 |
+
)
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import numpy as np
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import mcubes
|
30 |
+
from kaolin.ops.conversions import voxelgrids_to_trianglemeshes
|
31 |
+
import logging
|
32 |
+
|
33 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
34 |
+
|
35 |
+
|
36 |
+
class Seg3dLossless(nn.Module):
|
37 |
+
def __init__(self,
|
38 |
+
query_func,
|
39 |
+
b_min,
|
40 |
+
b_max,
|
41 |
+
resolutions,
|
42 |
+
channels=1,
|
43 |
+
balance_value=0.5,
|
44 |
+
align_corners=False,
|
45 |
+
visualize=False,
|
46 |
+
debug=False,
|
47 |
+
use_cuda_impl=False,
|
48 |
+
faster=False,
|
49 |
+
use_shadow=False,
|
50 |
+
**kwargs):
|
51 |
+
"""
|
52 |
+
align_corners: same with how you process gt. (grid_sample / interpolate)
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
self.query_func = query_func
|
56 |
+
self.register_buffer(
|
57 |
+
'b_min',
|
58 |
+
torch.tensor(b_min).float().unsqueeze(1)) # [bz, 1, 3]
|
59 |
+
self.register_buffer(
|
60 |
+
'b_max',
|
61 |
+
torch.tensor(b_max).float().unsqueeze(1)) # [bz, 1, 3]
|
62 |
+
|
63 |
+
# ti.init(arch=ti.cuda)
|
64 |
+
# self.mciso_taichi = MCISO(dim=3, N=resolutions[-1]-1)
|
65 |
+
|
66 |
+
if type(resolutions[0]) is int:
|
67 |
+
resolutions = torch.tensor([(res, res, res)
|
68 |
+
for res in resolutions])
|
69 |
+
else:
|
70 |
+
resolutions = torch.tensor(resolutions)
|
71 |
+
self.register_buffer('resolutions', resolutions)
|
72 |
+
self.batchsize = self.b_min.size(0)
|
73 |
+
assert self.batchsize == 1
|
74 |
+
self.balance_value = balance_value
|
75 |
+
self.channels = channels
|
76 |
+
assert self.channels == 1
|
77 |
+
self.align_corners = align_corners
|
78 |
+
self.visualize = visualize
|
79 |
+
self.debug = debug
|
80 |
+
self.use_cuda_impl = use_cuda_impl
|
81 |
+
self.faster = faster
|
82 |
+
self.use_shadow = use_shadow
|
83 |
+
|
84 |
+
for resolution in resolutions:
|
85 |
+
assert resolution[0] % 2 == 1 and resolution[1] % 2 == 1, \
|
86 |
+
f"resolution {resolution} need to be odd becuase of align_corner."
|
87 |
+
|
88 |
+
# init first resolution
|
89 |
+
init_coords = create_grid3D(0,
|
90 |
+
resolutions[-1] - 1,
|
91 |
+
steps=resolutions[0]) # [N, 3]
|
92 |
+
init_coords = init_coords.unsqueeze(0).repeat(self.batchsize, 1,
|
93 |
+
1) # [bz, N, 3]
|
94 |
+
self.register_buffer('init_coords', init_coords)
|
95 |
+
|
96 |
+
# some useful tensors
|
97 |
+
calculated = torch.zeros(
|
98 |
+
(self.resolutions[-1][2], self.resolutions[-1][1],
|
99 |
+
self.resolutions[-1][0]),
|
100 |
+
dtype=torch.bool)
|
101 |
+
self.register_buffer('calculated', calculated)
|
102 |
+
|
103 |
+
gird8_offsets = torch.stack(
|
104 |
+
torch.meshgrid([
|
105 |
+
torch.tensor([-1, 0, 1]),
|
106 |
+
torch.tensor([-1, 0, 1]),
|
107 |
+
torch.tensor([-1, 0, 1])
|
108 |
+
])).int().view(3, -1).t() # [27, 3]
|
109 |
+
self.register_buffer('gird8_offsets', gird8_offsets)
|
110 |
+
|
111 |
+
# smooth convs
|
112 |
+
self.smooth_conv3x3 = SmoothConv3D(in_channels=1,
|
113 |
+
out_channels=1,
|
114 |
+
kernel_size=3)
|
115 |
+
self.smooth_conv5x5 = SmoothConv3D(in_channels=1,
|
116 |
+
out_channels=1,
|
117 |
+
kernel_size=5)
|
118 |
+
self.smooth_conv7x7 = SmoothConv3D(in_channels=1,
|
119 |
+
out_channels=1,
|
120 |
+
kernel_size=7)
|
121 |
+
self.smooth_conv9x9 = SmoothConv3D(in_channels=1,
|
122 |
+
out_channels=1,
|
123 |
+
kernel_size=9)
|
124 |
+
|
125 |
+
def batch_eval(self, coords, **kwargs):
|
126 |
+
"""
|
127 |
+
coords: in the coordinates of last resolution
|
128 |
+
**kwargs: for query_func
|
129 |
+
"""
|
130 |
+
coords = coords.detach()
|
131 |
+
# normalize coords to fit in [b_min, b_max]
|
132 |
+
if self.align_corners:
|
133 |
+
coords2D = coords.float() / (self.resolutions[-1] - 1)
|
134 |
+
else:
|
135 |
+
step = 1.0 / self.resolutions[-1].float()
|
136 |
+
coords2D = coords.float() / self.resolutions[-1] + step / 2
|
137 |
+
coords2D = coords2D * (self.b_max - self.b_min) + self.b_min
|
138 |
+
# query function
|
139 |
+
occupancys = self.query_func(**kwargs, points=coords2D)
|
140 |
+
if type(occupancys) is list:
|
141 |
+
occupancys = torch.stack(occupancys) # [bz, C, N]
|
142 |
+
assert len(occupancys.size()) == 3, \
|
143 |
+
"query_func should return a occupancy with shape of [bz, C, N]"
|
144 |
+
return occupancys
|
145 |
+
|
146 |
+
def forward(self, **kwargs):
|
147 |
+
if self.faster:
|
148 |
+
return self._forward_faster(**kwargs)
|
149 |
+
else:
|
150 |
+
return self._forward(**kwargs)
|
151 |
+
|
152 |
+
def _forward_faster(self, **kwargs):
|
153 |
+
"""
|
154 |
+
In faster mode, we make following changes to exchange accuracy for speed:
|
155 |
+
1. no conflict checking: 4.88 fps -> 6.56 fps
|
156 |
+
2. smooth_conv9x9 ~ smooth_conv3x3 for different resolution
|
157 |
+
3. last step no examine
|
158 |
+
"""
|
159 |
+
final_W = self.resolutions[-1][0]
|
160 |
+
final_H = self.resolutions[-1][1]
|
161 |
+
final_D = self.resolutions[-1][2]
|
162 |
+
|
163 |
+
for resolution in self.resolutions:
|
164 |
+
W, H, D = resolution
|
165 |
+
stride = (self.resolutions[-1] - 1) / (resolution - 1)
|
166 |
+
|
167 |
+
# first step
|
168 |
+
if torch.equal(resolution, self.resolutions[0]):
|
169 |
+
coords = self.init_coords.clone() # torch.long
|
170 |
+
occupancys = self.batch_eval(coords, **kwargs)
|
171 |
+
occupancys = occupancys.view(self.batchsize, self.channels, D,
|
172 |
+
H, W)
|
173 |
+
if (occupancys > 0.5).sum() == 0:
|
174 |
+
# return F.interpolate(
|
175 |
+
# occupancys, size=(final_D, final_H, final_W),
|
176 |
+
# mode="linear", align_corners=True)
|
177 |
+
return None
|
178 |
+
|
179 |
+
if self.visualize:
|
180 |
+
self.plot(occupancys, coords, final_D, final_H, final_W)
|
181 |
+
|
182 |
+
with torch.no_grad():
|
183 |
+
coords_accum = coords / stride
|
184 |
+
|
185 |
+
# last step
|
186 |
+
elif torch.equal(resolution, self.resolutions[-1]):
|
187 |
+
|
188 |
+
with torch.no_grad():
|
189 |
+
# here true is correct!
|
190 |
+
valid = F.interpolate(
|
191 |
+
(occupancys > self.balance_value).float(),
|
192 |
+
size=(D, H, W),
|
193 |
+
mode="trilinear",
|
194 |
+
align_corners=True)
|
195 |
+
|
196 |
+
# here true is correct!
|
197 |
+
occupancys = F.interpolate(occupancys.float(),
|
198 |
+
size=(D, H, W),
|
199 |
+
mode="trilinear",
|
200 |
+
align_corners=True)
|
201 |
+
|
202 |
+
# is_boundary = (valid > 0.0) & (valid < 1.0)
|
203 |
+
is_boundary = valid == 0.5
|
204 |
+
|
205 |
+
# next steps
|
206 |
+
else:
|
207 |
+
coords_accum *= 2
|
208 |
+
|
209 |
+
with torch.no_grad():
|
210 |
+
# here true is correct!
|
211 |
+
valid = F.interpolate(
|
212 |
+
(occupancys > self.balance_value).float(),
|
213 |
+
size=(D, H, W),
|
214 |
+
mode="trilinear",
|
215 |
+
align_corners=True)
|
216 |
+
|
217 |
+
# here true is correct!
|
218 |
+
occupancys = F.interpolate(occupancys.float(),
|
219 |
+
size=(D, H, W),
|
220 |
+
mode="trilinear",
|
221 |
+
align_corners=True)
|
222 |
+
|
223 |
+
is_boundary = (valid > 0.0) & (valid < 1.0)
|
224 |
+
|
225 |
+
with torch.no_grad():
|
226 |
+
if torch.equal(resolution, self.resolutions[1]):
|
227 |
+
is_boundary = (self.smooth_conv9x9(is_boundary.float())
|
228 |
+
> 0)[0, 0]
|
229 |
+
elif torch.equal(resolution, self.resolutions[2]):
|
230 |
+
is_boundary = (self.smooth_conv7x7(is_boundary.float())
|
231 |
+
> 0)[0, 0]
|
232 |
+
else:
|
233 |
+
is_boundary = (self.smooth_conv3x3(is_boundary.float())
|
234 |
+
> 0)[0, 0]
|
235 |
+
|
236 |
+
coords_accum = coords_accum.long()
|
237 |
+
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
238 |
+
coords_accum[0, :, 0]] = False
|
239 |
+
point_coords = is_boundary.permute(
|
240 |
+
2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
241 |
+
point_indices = (point_coords[:, :, 2] * H * W +
|
242 |
+
point_coords[:, :, 1] * W +
|
243 |
+
point_coords[:, :, 0])
|
244 |
+
|
245 |
+
R, C, D, H, W = occupancys.shape
|
246 |
+
|
247 |
+
# inferred value
|
248 |
+
coords = point_coords * stride
|
249 |
+
|
250 |
+
if coords.size(1) == 0:
|
251 |
+
continue
|
252 |
+
occupancys_topk = self.batch_eval(coords, **kwargs)
|
253 |
+
|
254 |
+
# put mask point predictions to the right places on the upsampled grid.
|
255 |
+
R, C, D, H, W = occupancys.shape
|
256 |
+
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
257 |
+
occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
|
258 |
+
2, point_indices, occupancys_topk).view(R, C, D, H, W))
|
259 |
+
|
260 |
+
with torch.no_grad():
|
261 |
+
voxels = coords / stride
|
262 |
+
coords_accum = torch.cat([voxels, coords_accum],
|
263 |
+
dim=1).unique(dim=1)
|
264 |
+
|
265 |
+
return occupancys[0, 0]
|
266 |
+
|
267 |
+
def _forward(self, **kwargs):
|
268 |
+
"""
|
269 |
+
output occupancy field would be:
|
270 |
+
(bz, C, res, res)
|
271 |
+
"""
|
272 |
+
final_W = self.resolutions[-1][0]
|
273 |
+
final_H = self.resolutions[-1][1]
|
274 |
+
final_D = self.resolutions[-1][2]
|
275 |
+
|
276 |
+
calculated = self.calculated.clone()
|
277 |
+
|
278 |
+
for resolution in self.resolutions:
|
279 |
+
W, H, D = resolution
|
280 |
+
stride = (self.resolutions[-1] - 1) / (resolution - 1)
|
281 |
+
|
282 |
+
if self.visualize:
|
283 |
+
this_stage_coords = []
|
284 |
+
|
285 |
+
# first step
|
286 |
+
if torch.equal(resolution, self.resolutions[0]):
|
287 |
+
coords = self.init_coords.clone() # torch.long
|
288 |
+
occupancys = self.batch_eval(coords, **kwargs)
|
289 |
+
occupancys = occupancys.view(self.batchsize, self.channels, D,
|
290 |
+
H, W)
|
291 |
+
|
292 |
+
if self.visualize:
|
293 |
+
self.plot(occupancys, coords, final_D, final_H, final_W)
|
294 |
+
|
295 |
+
with torch.no_grad():
|
296 |
+
coords_accum = coords / stride
|
297 |
+
calculated[coords[0, :, 2], coords[0, :, 1],
|
298 |
+
coords[0, :, 0]] = True
|
299 |
+
|
300 |
+
# next steps
|
301 |
+
else:
|
302 |
+
coords_accum *= 2
|
303 |
+
|
304 |
+
with torch.no_grad():
|
305 |
+
# here true is correct!
|
306 |
+
valid = F.interpolate(
|
307 |
+
(occupancys > self.balance_value).float(),
|
308 |
+
size=(D, H, W),
|
309 |
+
mode="trilinear",
|
310 |
+
align_corners=True)
|
311 |
+
|
312 |
+
# here true is correct!
|
313 |
+
occupancys = F.interpolate(occupancys.float(),
|
314 |
+
size=(D, H, W),
|
315 |
+
mode="trilinear",
|
316 |
+
align_corners=True)
|
317 |
+
|
318 |
+
is_boundary = (valid > 0.0) & (valid < 1.0)
|
319 |
+
|
320 |
+
with torch.no_grad():
|
321 |
+
# TODO
|
322 |
+
if self.use_shadow and torch.equal(resolution,
|
323 |
+
self.resolutions[-1]):
|
324 |
+
# larger z means smaller depth here
|
325 |
+
depth_res = resolution[2].item()
|
326 |
+
depth_index = torch.linspace(0,
|
327 |
+
depth_res - 1,
|
328 |
+
steps=depth_res).type_as(
|
329 |
+
occupancys.device)
|
330 |
+
depth_index_max = torch.max(
|
331 |
+
(occupancys > self.balance_value) *
|
332 |
+
(depth_index + 1),
|
333 |
+
dim=-1,
|
334 |
+
keepdim=True)[0] - 1
|
335 |
+
shadow = depth_index < depth_index_max
|
336 |
+
is_boundary[shadow] = False
|
337 |
+
is_boundary = is_boundary[0, 0]
|
338 |
+
else:
|
339 |
+
is_boundary = (self.smooth_conv3x3(is_boundary.float())
|
340 |
+
> 0)[0, 0]
|
341 |
+
# is_boundary = is_boundary[0, 0]
|
342 |
+
|
343 |
+
is_boundary[coords_accum[0, :, 2], coords_accum[0, :, 1],
|
344 |
+
coords_accum[0, :, 0]] = False
|
345 |
+
point_coords = is_boundary.permute(
|
346 |
+
2, 1, 0).nonzero(as_tuple=False).unsqueeze(0)
|
347 |
+
point_indices = (point_coords[:, :, 2] * H * W +
|
348 |
+
point_coords[:, :, 1] * W +
|
349 |
+
point_coords[:, :, 0])
|
350 |
+
|
351 |
+
R, C, D, H, W = occupancys.shape
|
352 |
+
# interpolated value
|
353 |
+
occupancys_interp = torch.gather(
|
354 |
+
occupancys.reshape(R, C, D * H * W), 2,
|
355 |
+
point_indices.unsqueeze(1))
|
356 |
+
|
357 |
+
# inferred value
|
358 |
+
coords = point_coords * stride
|
359 |
+
|
360 |
+
if coords.size(1) == 0:
|
361 |
+
continue
|
362 |
+
occupancys_topk = self.batch_eval(coords, **kwargs)
|
363 |
+
if self.visualize:
|
364 |
+
this_stage_coords.append(coords)
|
365 |
+
|
366 |
+
# put mask point predictions to the right places on the upsampled grid.
|
367 |
+
R, C, D, H, W = occupancys.shape
|
368 |
+
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
|
369 |
+
occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
|
370 |
+
2, point_indices, occupancys_topk).view(R, C, D, H, W))
|
371 |
+
|
372 |
+
with torch.no_grad():
|
373 |
+
# conflicts
|
374 |
+
conflicts = ((occupancys_interp - self.balance_value) *
|
375 |
+
(occupancys_topk - self.balance_value) < 0)[0,
|
376 |
+
0]
|
377 |
+
|
378 |
+
if self.visualize:
|
379 |
+
self.plot(occupancys, coords, final_D, final_H,
|
380 |
+
final_W)
|
381 |
+
|
382 |
+
voxels = coords / stride
|
383 |
+
coords_accum = torch.cat([voxels, coords_accum],
|
384 |
+
dim=1).unique(dim=1)
|
385 |
+
calculated[coords[0, :, 2], coords[0, :, 1],
|
386 |
+
coords[0, :, 0]] = True
|
387 |
+
|
388 |
+
while conflicts.sum() > 0:
|
389 |
+
if self.use_shadow and torch.equal(resolution,
|
390 |
+
self.resolutions[-1]):
|
391 |
+
break
|
392 |
+
|
393 |
+
with torch.no_grad():
|
394 |
+
conflicts_coords = coords[0, conflicts, :]
|
395 |
+
|
396 |
+
if self.debug:
|
397 |
+
self.plot(occupancys,
|
398 |
+
conflicts_coords.unsqueeze(0),
|
399 |
+
final_D,
|
400 |
+
final_H,
|
401 |
+
final_W,
|
402 |
+
title='conflicts')
|
403 |
+
|
404 |
+
conflicts_boundary = (conflicts_coords.int() +
|
405 |
+
self.gird8_offsets.unsqueeze(1) *
|
406 |
+
stride.int()).reshape(
|
407 |
+
-1, 3).long().unique(dim=0)
|
408 |
+
conflicts_boundary[:, 0] = (
|
409 |
+
conflicts_boundary[:, 0].clamp(
|
410 |
+
0,
|
411 |
+
calculated.size(2) - 1))
|
412 |
+
conflicts_boundary[:, 1] = (
|
413 |
+
conflicts_boundary[:, 1].clamp(
|
414 |
+
0,
|
415 |
+
calculated.size(1) - 1))
|
416 |
+
conflicts_boundary[:, 2] = (
|
417 |
+
conflicts_boundary[:, 2].clamp(
|
418 |
+
0,
|
419 |
+
calculated.size(0) - 1))
|
420 |
+
|
421 |
+
coords = conflicts_boundary[calculated[
|
422 |
+
conflicts_boundary[:, 2], conflicts_boundary[:, 1],
|
423 |
+
conflicts_boundary[:, 0]] == False]
|
424 |
+
|
425 |
+
if self.debug:
|
426 |
+
self.plot(occupancys,
|
427 |
+
coords.unsqueeze(0),
|
428 |
+
final_D,
|
429 |
+
final_H,
|
430 |
+
final_W,
|
431 |
+
title='coords')
|
432 |
+
|
433 |
+
coords = coords.unsqueeze(0)
|
434 |
+
point_coords = coords / stride
|
435 |
+
point_indices = (point_coords[:, :, 2] * H * W +
|
436 |
+
point_coords[:, :, 1] * W +
|
437 |
+
point_coords[:, :, 0])
|
438 |
+
|
439 |
+
R, C, D, H, W = occupancys.shape
|
440 |
+
# interpolated value
|
441 |
+
occupancys_interp = torch.gather(
|
442 |
+
occupancys.reshape(R, C, D * H * W), 2,
|
443 |
+
point_indices.unsqueeze(1))
|
444 |
+
|
445 |
+
# inferred value
|
446 |
+
coords = point_coords * stride
|
447 |
+
|
448 |
+
if coords.size(1) == 0:
|
449 |
+
break
|
450 |
+
occupancys_topk = self.batch_eval(coords, **kwargs)
|
451 |
+
if self.visualize:
|
452 |
+
this_stage_coords.append(coords)
|
453 |
+
|
454 |
+
with torch.no_grad():
|
455 |
+
# conflicts
|
456 |
+
conflicts = ((occupancys_interp - self.balance_value) *
|
457 |
+
(occupancys_topk - self.balance_value) <
|
458 |
+
0)[0, 0]
|
459 |
+
|
460 |
+
# put mask point predictions to the right places on the upsampled grid.
|
461 |
+
point_indices = point_indices.unsqueeze(1).expand(
|
462 |
+
-1, C, -1)
|
463 |
+
occupancys = (occupancys.reshape(R, C, D * H * W).scatter_(
|
464 |
+
2, point_indices, occupancys_topk).view(R, C, D, H, W))
|
465 |
+
|
466 |
+
with torch.no_grad():
|
467 |
+
voxels = coords / stride
|
468 |
+
coords_accum = torch.cat([voxels, coords_accum],
|
469 |
+
dim=1).unique(dim=1)
|
470 |
+
calculated[coords[0, :, 2], coords[0, :, 1],
|
471 |
+
coords[0, :, 0]] = True
|
472 |
+
|
473 |
+
if self.visualize:
|
474 |
+
this_stage_coords = torch.cat(this_stage_coords, dim=1)
|
475 |
+
self.plot(occupancys, this_stage_coords, final_D, final_H,
|
476 |
+
final_W)
|
477 |
+
|
478 |
+
return occupancys[0, 0]
|
479 |
+
|
480 |
+
def plot(self,
|
481 |
+
occupancys,
|
482 |
+
coords,
|
483 |
+
final_D,
|
484 |
+
final_H,
|
485 |
+
final_W,
|
486 |
+
title='',
|
487 |
+
**kwargs):
|
488 |
+
final = F.interpolate(occupancys.float(),
|
489 |
+
size=(final_D, final_H, final_W),
|
490 |
+
mode="trilinear",
|
491 |
+
align_corners=True) # here true is correct!
|
492 |
+
x = coords[0, :, 0].to("cpu")
|
493 |
+
y = coords[0, :, 1].to("cpu")
|
494 |
+
z = coords[0, :, 2].to("cpu")
|
495 |
+
|
496 |
+
plot_mask3D(final[0, 0].to("cpu"), title, (x, y, z), **kwargs)
|
497 |
+
|
498 |
+
def find_vertices(self, sdf, direction="front"):
|
499 |
+
'''
|
500 |
+
- direction: "front" | "back" | "left" | "right"
|
501 |
+
'''
|
502 |
+
resolution = sdf.size(2)
|
503 |
+
if direction == "front":
|
504 |
+
pass
|
505 |
+
elif direction == "left":
|
506 |
+
sdf = sdf.permute(2, 1, 0)
|
507 |
+
elif direction == "back":
|
508 |
+
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
|
509 |
+
sdf = sdf[inv_idx, :, :]
|
510 |
+
elif direction == "right":
|
511 |
+
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
|
512 |
+
sdf = sdf[:, :, inv_idx]
|
513 |
+
sdf = sdf.permute(2, 1, 0)
|
514 |
+
|
515 |
+
inv_idx = torch.arange(sdf.size(2) - 1, -1, -1).long()
|
516 |
+
sdf = sdf[inv_idx, :, :]
|
517 |
+
sdf_all = sdf.permute(2, 1, 0)
|
518 |
+
|
519 |
+
# shadow
|
520 |
+
grad_v = (sdf_all > 0.5) * torch.linspace(
|
521 |
+
resolution, 1, steps=resolution).to(sdf.device)
|
522 |
+
grad_c = torch.ones_like(sdf_all) * torch.linspace(
|
523 |
+
0, resolution - 1, steps=resolution).to(sdf.device)
|
524 |
+
max_v, max_c = grad_v.max(dim=2)
|
525 |
+
shadow = grad_c > max_c.view(resolution, resolution, 1)
|
526 |
+
keep = (sdf_all > 0.5) & (~shadow)
|
527 |
+
|
528 |
+
p1 = keep.nonzero(as_tuple=False).t() # [3, N]
|
529 |
+
p2 = p1.clone() # z
|
530 |
+
p2[2, :] = (p2[2, :] - 2).clamp(0, resolution)
|
531 |
+
p3 = p1.clone() # y
|
532 |
+
p3[1, :] = (p3[1, :] - 2).clamp(0, resolution)
|
533 |
+
p4 = p1.clone() # x
|
534 |
+
p4[0, :] = (p4[0, :] - 2).clamp(0, resolution)
|
535 |
+
|
536 |
+
v1 = sdf_all[p1[0, :], p1[1, :], p1[2, :]]
|
537 |
+
v2 = sdf_all[p2[0, :], p2[1, :], p2[2, :]]
|
538 |
+
v3 = sdf_all[p3[0, :], p3[1, :], p3[2, :]]
|
539 |
+
v4 = sdf_all[p4[0, :], p4[1, :], p4[2, :]]
|
540 |
+
|
541 |
+
X = p1[0, :].long() # [N,]
|
542 |
+
Y = p1[1, :].long() # [N,]
|
543 |
+
Z = p2[2, :].float() * (0.5 - v1) / (v2 - v1) + \
|
544 |
+
p1[2, :].float() * (v2 - 0.5) / (v2 - v1) # [N,]
|
545 |
+
Z = Z.clamp(0, resolution)
|
546 |
+
|
547 |
+
# normal
|
548 |
+
norm_z = v2 - v1
|
549 |
+
norm_y = v3 - v1
|
550 |
+
norm_x = v4 - v1
|
551 |
+
# print (v2.min(dim=0)[0], v2.max(dim=0)[0], v3.min(dim=0)[0], v3.max(dim=0)[0])
|
552 |
+
|
553 |
+
norm = torch.stack([norm_x, norm_y, norm_z], dim=1)
|
554 |
+
norm = norm / torch.norm(norm, p=2, dim=1, keepdim=True)
|
555 |
+
|
556 |
+
return X, Y, Z, norm
|
557 |
+
|
558 |
+
def render_normal(self, resolution, X, Y, Z, norm):
|
559 |
+
image = torch.ones((1, 3, resolution, resolution),
|
560 |
+
dtype=torch.float32).to(norm.device)
|
561 |
+
color = (norm + 1) / 2.0
|
562 |
+
color = color.clamp(0, 1)
|
563 |
+
image[0, :, Y, X] = color.t()
|
564 |
+
return image
|
565 |
+
|
566 |
+
def display(self, sdf):
|
567 |
+
|
568 |
+
# render
|
569 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="front")
|
570 |
+
image1 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
571 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="left")
|
572 |
+
image2 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
573 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="right")
|
574 |
+
image3 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
575 |
+
X, Y, Z, norm = self.find_vertices(sdf, direction="back")
|
576 |
+
image4 = self.render_normal(self.resolutions[-1, -1], X, Y, Z, norm)
|
577 |
+
|
578 |
+
image = torch.cat([image1, image2, image3, image4], axis=3)
|
579 |
+
image = image.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.0
|
580 |
+
|
581 |
+
return np.uint8(image)
|
582 |
+
|
583 |
+
def export_mesh(self, occupancys):
|
584 |
+
|
585 |
+
final = occupancys[1:, 1:, 1:].contiguous()
|
586 |
+
|
587 |
+
if final.shape[0] > 256:
|
588 |
+
# for voxelgrid larger than 256^3, the required GPU memory will be > 9GB
|
589 |
+
# thus we use CPU marching_cube to avoid "CUDA out of memory"
|
590 |
+
occu_arr = final.detach().cpu().numpy() # non-smooth surface
|
591 |
+
# occu_arr = mcubes.smooth(final.detach().cpu().numpy()) # smooth surface
|
592 |
+
vertices, triangles = mcubes.marching_cubes(
|
593 |
+
occu_arr, self.balance_value)
|
594 |
+
verts = torch.as_tensor(vertices[:, [2, 1, 0]])
|
595 |
+
faces = torch.as_tensor(triangles.astype(
|
596 |
+
np.long), dtype=torch.long)[:, [0, 2, 1]]
|
597 |
+
else:
|
598 |
+
torch.cuda.empty_cache()
|
599 |
+
vertices, triangles = voxelgrids_to_trianglemeshes(
|
600 |
+
final.unsqueeze(0))
|
601 |
+
verts = vertices[0][:, [2, 1, 0]].cpu()
|
602 |
+
faces = triangles[0][:, [0, 2, 1]].cpu()
|
603 |
+
|
604 |
+
return verts, faces
|
lib/common/seg3d_utils.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
|
23 |
+
|
24 |
+
def plot_mask2D(mask,
|
25 |
+
title="",
|
26 |
+
point_coords=None,
|
27 |
+
figsize=10,
|
28 |
+
point_marker_size=5):
|
29 |
+
'''
|
30 |
+
Simple plotting tool to show intermediate mask predictions and points
|
31 |
+
where PointRend is applied.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
mask (Tensor): mask prediction of shape HxW
|
35 |
+
title (str): title for the plot
|
36 |
+
point_coords ((Tensor, Tensor)): x and y point coordinates
|
37 |
+
figsize (int): size of the figure to plot
|
38 |
+
point_marker_size (int): marker size for points
|
39 |
+
'''
|
40 |
+
|
41 |
+
H, W = mask.shape
|
42 |
+
plt.figure(figsize=(figsize, figsize))
|
43 |
+
if title:
|
44 |
+
title += ", "
|
45 |
+
plt.title("{}resolution {}x{}".format(title, H, W), fontsize=30)
|
46 |
+
plt.ylabel(H, fontsize=30)
|
47 |
+
plt.xlabel(W, fontsize=30)
|
48 |
+
plt.xticks([], [])
|
49 |
+
plt.yticks([], [])
|
50 |
+
plt.imshow(mask.detach(),
|
51 |
+
interpolation="nearest",
|
52 |
+
cmap=plt.get_cmap('gray'))
|
53 |
+
if point_coords is not None:
|
54 |
+
plt.scatter(x=point_coords[0],
|
55 |
+
y=point_coords[1],
|
56 |
+
color="red",
|
57 |
+
s=point_marker_size,
|
58 |
+
clip_on=True)
|
59 |
+
plt.xlim(-0.5, W - 0.5)
|
60 |
+
plt.ylim(H - 0.5, -0.5)
|
61 |
+
plt.show()
|
62 |
+
|
63 |
+
|
64 |
+
def plot_mask3D(mask=None,
|
65 |
+
title="",
|
66 |
+
point_coords=None,
|
67 |
+
figsize=1500,
|
68 |
+
point_marker_size=8,
|
69 |
+
interactive=True):
|
70 |
+
'''
|
71 |
+
Simple plotting tool to show intermediate mask predictions and points
|
72 |
+
where PointRend is applied.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
mask (Tensor): mask prediction of shape DxHxW
|
76 |
+
title (str): title for the plot
|
77 |
+
point_coords ((Tensor, Tensor, Tensor)): x and y and z point coordinates
|
78 |
+
figsize (int): size of the figure to plot
|
79 |
+
point_marker_size (int): marker size for points
|
80 |
+
'''
|
81 |
+
import trimesh
|
82 |
+
import vtkplotter
|
83 |
+
from skimage import measure
|
84 |
+
|
85 |
+
vp = vtkplotter.Plotter(title=title, size=(figsize, figsize))
|
86 |
+
vis_list = []
|
87 |
+
|
88 |
+
if mask is not None:
|
89 |
+
mask = mask.detach().to("cpu").numpy()
|
90 |
+
mask = mask.transpose(2, 1, 0)
|
91 |
+
|
92 |
+
# marching cube to find surface
|
93 |
+
verts, faces, normals, values = measure.marching_cubes_lewiner(
|
94 |
+
mask, 0.5, gradient_direction='ascent')
|
95 |
+
|
96 |
+
# create a mesh
|
97 |
+
mesh = trimesh.Trimesh(verts, faces)
|
98 |
+
mesh.visual.face_colors = [200, 200, 250, 100]
|
99 |
+
vis_list.append(mesh)
|
100 |
+
|
101 |
+
if point_coords is not None:
|
102 |
+
point_coords = torch.stack(point_coords, 1).to("cpu").numpy()
|
103 |
+
|
104 |
+
# import numpy as np
|
105 |
+
# select_x = np.logical_and(point_coords[:, 0] >= 16, point_coords[:, 0] <= 112)
|
106 |
+
# select_y = np.logical_and(point_coords[:, 1] >= 48, point_coords[:, 1] <= 272)
|
107 |
+
# select_z = np.logical_and(point_coords[:, 2] >= 16, point_coords[:, 2] <= 112)
|
108 |
+
# select = np.logical_and(np.logical_and(select_x, select_y), select_z)
|
109 |
+
# point_coords = point_coords[select, :]
|
110 |
+
|
111 |
+
pc = vtkplotter.Points(point_coords, r=point_marker_size, c='red')
|
112 |
+
vis_list.append(pc)
|
113 |
+
|
114 |
+
vp.show(*vis_list,
|
115 |
+
bg="white",
|
116 |
+
axes=1,
|
117 |
+
interactive=interactive,
|
118 |
+
azimuth=30,
|
119 |
+
elevation=30)
|
120 |
+
|
121 |
+
|
122 |
+
def create_grid3D(min, max, steps):
|
123 |
+
if type(min) is int:
|
124 |
+
min = (min, min, min) # (x, y, z)
|
125 |
+
if type(max) is int:
|
126 |
+
max = (max, max, max) # (x, y)
|
127 |
+
if type(steps) is int:
|
128 |
+
steps = (steps, steps, steps) # (x, y, z)
|
129 |
+
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
130 |
+
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
131 |
+
arrangeZ = torch.linspace(min[2], max[2], steps[2]).long()
|
132 |
+
gridD, girdH, gridW = torch.meshgrid([arrangeZ, arrangeY, arrangeX])
|
133 |
+
coords = torch.stack([gridW, girdH,
|
134 |
+
gridD]) # [2, steps[0], steps[1], steps[2]]
|
135 |
+
coords = coords.view(3, -1).t() # [N, 3]
|
136 |
+
return coords
|
137 |
+
|
138 |
+
|
139 |
+
def create_grid2D(min, max, steps):
|
140 |
+
if type(min) is int:
|
141 |
+
min = (min, min) # (x, y)
|
142 |
+
if type(max) is int:
|
143 |
+
max = (max, max) # (x, y)
|
144 |
+
if type(steps) is int:
|
145 |
+
steps = (steps, steps) # (x, y)
|
146 |
+
arrangeX = torch.linspace(min[0], max[0], steps[0]).long()
|
147 |
+
arrangeY = torch.linspace(min[1], max[1], steps[1]).long()
|
148 |
+
girdH, gridW = torch.meshgrid([arrangeY, arrangeX])
|
149 |
+
coords = torch.stack([gridW, girdH]) # [2, steps[0], steps[1]]
|
150 |
+
coords = coords.view(2, -1).t() # [N, 2]
|
151 |
+
return coords
|
152 |
+
|
153 |
+
|
154 |
+
class SmoothConv2D(nn.Module):
|
155 |
+
def __init__(self, in_channels, out_channels, kernel_size=3):
|
156 |
+
super().__init__()
|
157 |
+
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
158 |
+
self.padding = (kernel_size - 1) // 2
|
159 |
+
|
160 |
+
weight = torch.ones(
|
161 |
+
(in_channels, out_channels, kernel_size, kernel_size),
|
162 |
+
dtype=torch.float32) / (kernel_size**2)
|
163 |
+
self.register_buffer('weight', weight)
|
164 |
+
|
165 |
+
def forward(self, input):
|
166 |
+
return F.conv2d(input, self.weight, padding=self.padding)
|
167 |
+
|
168 |
+
|
169 |
+
class SmoothConv3D(nn.Module):
|
170 |
+
def __init__(self, in_channels, out_channels, kernel_size=3):
|
171 |
+
super().__init__()
|
172 |
+
assert kernel_size % 2 == 1, "kernel_size for smooth_conv must be odd: {3, 5, ...}"
|
173 |
+
self.padding = (kernel_size - 1) // 2
|
174 |
+
|
175 |
+
weight = torch.ones(
|
176 |
+
(in_channels, out_channels, kernel_size, kernel_size, kernel_size),
|
177 |
+
dtype=torch.float32) / (kernel_size**3)
|
178 |
+
self.register_buffer('weight', weight)
|
179 |
+
|
180 |
+
def forward(self, input):
|
181 |
+
return F.conv3d(input, self.weight, padding=self.padding)
|
182 |
+
|
183 |
+
|
184 |
+
def build_smooth_conv3D(in_channels=1,
|
185 |
+
out_channels=1,
|
186 |
+
kernel_size=3,
|
187 |
+
padding=1):
|
188 |
+
smooth_conv = torch.nn.Conv3d(in_channels=in_channels,
|
189 |
+
out_channels=out_channels,
|
190 |
+
kernel_size=kernel_size,
|
191 |
+
padding=padding)
|
192 |
+
smooth_conv.weight.data = torch.ones(
|
193 |
+
(in_channels, out_channels, kernel_size, kernel_size, kernel_size),
|
194 |
+
dtype=torch.float32) / (kernel_size**3)
|
195 |
+
smooth_conv.bias.data = torch.zeros(out_channels)
|
196 |
+
return smooth_conv
|
197 |
+
|
198 |
+
|
199 |
+
def build_smooth_conv2D(in_channels=1,
|
200 |
+
out_channels=1,
|
201 |
+
kernel_size=3,
|
202 |
+
padding=1):
|
203 |
+
smooth_conv = torch.nn.Conv2d(in_channels=in_channels,
|
204 |
+
out_channels=out_channels,
|
205 |
+
kernel_size=kernel_size,
|
206 |
+
padding=padding)
|
207 |
+
smooth_conv.weight.data = torch.ones(
|
208 |
+
(in_channels, out_channels, kernel_size, kernel_size),
|
209 |
+
dtype=torch.float32) / (kernel_size**2)
|
210 |
+
smooth_conv.bias.data = torch.zeros(out_channels)
|
211 |
+
return smooth_conv
|
212 |
+
|
213 |
+
|
214 |
+
def get_uncertain_point_coords_on_grid3D(uncertainty_map, num_points,
|
215 |
+
**kwargs):
|
216 |
+
"""
|
217 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
218 |
+
Args:
|
219 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
|
220 |
+
values for a set of points on a regular H x W x D grid.
|
221 |
+
num_points (int): The number of points P to select.
|
222 |
+
Returns:
|
223 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
224 |
+
[0, H x W x D) of the most uncertain points.
|
225 |
+
point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
|
226 |
+
coordinates of the most uncertain points from the H x W x D grid.
|
227 |
+
"""
|
228 |
+
R, _, D, H, W = uncertainty_map.shape
|
229 |
+
# h_step = 1.0 / float(H)
|
230 |
+
# w_step = 1.0 / float(W)
|
231 |
+
# d_step = 1.0 / float(D)
|
232 |
+
|
233 |
+
num_points = min(D * H * W, num_points)
|
234 |
+
point_scores, point_indices = torch.topk(uncertainty_map.view(
|
235 |
+
R, D * H * W),
|
236 |
+
k=num_points,
|
237 |
+
dim=1)
|
238 |
+
point_coords = torch.zeros(R,
|
239 |
+
num_points,
|
240 |
+
3,
|
241 |
+
dtype=torch.float,
|
242 |
+
device=uncertainty_map.device)
|
243 |
+
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
244 |
+
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
245 |
+
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
246 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
|
247 |
+
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
|
248 |
+
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
|
249 |
+
print(f"resolution {D} x {H} x {W}", point_scores.min(),
|
250 |
+
point_scores.max())
|
251 |
+
return point_indices, point_coords
|
252 |
+
|
253 |
+
|
254 |
+
def get_uncertain_point_coords_on_grid3D_faster(uncertainty_map, num_points,
|
255 |
+
clip_min):
|
256 |
+
"""
|
257 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
258 |
+
Args:
|
259 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W, D) that contains uncertainty
|
260 |
+
values for a set of points on a regular H x W x D grid.
|
261 |
+
num_points (int): The number of points P to select.
|
262 |
+
Returns:
|
263 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
264 |
+
[0, H x W x D) of the most uncertain points.
|
265 |
+
point_coords (Tensor): A tensor of shape (N, P, 3) that contains [0, 1] x [0, 1] normalized
|
266 |
+
coordinates of the most uncertain points from the H x W x D grid.
|
267 |
+
"""
|
268 |
+
R, _, D, H, W = uncertainty_map.shape
|
269 |
+
# h_step = 1.0 / float(H)
|
270 |
+
# w_step = 1.0 / float(W)
|
271 |
+
# d_step = 1.0 / float(D)
|
272 |
+
|
273 |
+
assert R == 1, "batchsize > 1 is not implemented!"
|
274 |
+
uncertainty_map = uncertainty_map.view(D * H * W)
|
275 |
+
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
276 |
+
num_points = min(num_points, indices.size(0))
|
277 |
+
point_scores, point_indices = torch.topk(uncertainty_map[indices],
|
278 |
+
k=num_points,
|
279 |
+
dim=0)
|
280 |
+
point_indices = indices[point_indices].unsqueeze(0)
|
281 |
+
|
282 |
+
point_coords = torch.zeros(R,
|
283 |
+
num_points,
|
284 |
+
3,
|
285 |
+
dtype=torch.float,
|
286 |
+
device=uncertainty_map.device)
|
287 |
+
# point_coords[:, :, 0] = h_step / 2.0 + (point_indices // (W * D)).to(torch.float) * h_step
|
288 |
+
# point_coords[:, :, 1] = w_step / 2.0 + (point_indices % (W * D) // D).to(torch.float) * w_step
|
289 |
+
# point_coords[:, :, 2] = d_step / 2.0 + (point_indices % D).to(torch.float) * d_step
|
290 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.float) # x
|
291 |
+
point_coords[:, :, 1] = (point_indices % (H * W) // W).to(torch.float) # y
|
292 |
+
point_coords[:, :, 2] = (point_indices // (H * W)).to(torch.float) # z
|
293 |
+
# print (f"resolution {D} x {H} x {W}", point_scores.min(), point_scores.max())
|
294 |
+
return point_indices, point_coords
|
295 |
+
|
296 |
+
|
297 |
+
def get_uncertain_point_coords_on_grid2D(uncertainty_map, num_points,
|
298 |
+
**kwargs):
|
299 |
+
"""
|
300 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
301 |
+
Args:
|
302 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
|
303 |
+
values for a set of points on a regular H x W grid.
|
304 |
+
num_points (int): The number of points P to select.
|
305 |
+
Returns:
|
306 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
307 |
+
[0, H x W) of the most uncertain points.
|
308 |
+
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
|
309 |
+
coordinates of the most uncertain points from the H x W grid.
|
310 |
+
"""
|
311 |
+
R, _, H, W = uncertainty_map.shape
|
312 |
+
# h_step = 1.0 / float(H)
|
313 |
+
# w_step = 1.0 / float(W)
|
314 |
+
|
315 |
+
num_points = min(H * W, num_points)
|
316 |
+
point_scores, point_indices = torch.topk(uncertainty_map.view(R, H * W),
|
317 |
+
k=num_points,
|
318 |
+
dim=1)
|
319 |
+
point_coords = torch.zeros(R,
|
320 |
+
num_points,
|
321 |
+
2,
|
322 |
+
dtype=torch.long,
|
323 |
+
device=uncertainty_map.device)
|
324 |
+
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
325 |
+
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
326 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
327 |
+
point_coords[:, :, 1] = (point_indices // W).to(torch.long)
|
328 |
+
# print (point_scores.min(), point_scores.max())
|
329 |
+
return point_indices, point_coords
|
330 |
+
|
331 |
+
|
332 |
+
def get_uncertain_point_coords_on_grid2D_faster(uncertainty_map, num_points,
|
333 |
+
clip_min):
|
334 |
+
"""
|
335 |
+
Find `num_points` most uncertain points from `uncertainty_map` grid.
|
336 |
+
Args:
|
337 |
+
uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty
|
338 |
+
values for a set of points on a regular H x W grid.
|
339 |
+
num_points (int): The number of points P to select.
|
340 |
+
Returns:
|
341 |
+
point_indices (Tensor): A tensor of shape (N, P) that contains indices from
|
342 |
+
[0, H x W) of the most uncertain points.
|
343 |
+
point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized
|
344 |
+
coordinates of the most uncertain points from the H x W grid.
|
345 |
+
"""
|
346 |
+
R, _, H, W = uncertainty_map.shape
|
347 |
+
# h_step = 1.0 / float(H)
|
348 |
+
# w_step = 1.0 / float(W)
|
349 |
+
|
350 |
+
assert R == 1, "batchsize > 1 is not implemented!"
|
351 |
+
uncertainty_map = uncertainty_map.view(H * W)
|
352 |
+
indices = (uncertainty_map >= clip_min).nonzero().squeeze(1)
|
353 |
+
num_points = min(num_points, indices.size(0))
|
354 |
+
point_scores, point_indices = torch.topk(uncertainty_map[indices],
|
355 |
+
k=num_points,
|
356 |
+
dim=0)
|
357 |
+
point_indices = indices[point_indices].unsqueeze(0)
|
358 |
+
|
359 |
+
point_coords = torch.zeros(R,
|
360 |
+
num_points,
|
361 |
+
2,
|
362 |
+
dtype=torch.long,
|
363 |
+
device=uncertainty_map.device)
|
364 |
+
# point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step
|
365 |
+
# point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step
|
366 |
+
point_coords[:, :, 0] = (point_indices % W).to(torch.long)
|
367 |
+
point_coords[:, :, 1] = (point_indices // W).to(torch.long)
|
368 |
+
# print (point_scores.min(), point_scores.max())
|
369 |
+
return point_indices, point_coords
|
370 |
+
|
371 |
+
|
372 |
+
def calculate_uncertainty(logits, classes=None, balance_value=0.5):
|
373 |
+
"""
|
374 |
+
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
|
375 |
+
foreground class in `classes`.
|
376 |
+
Args:
|
377 |
+
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or
|
378 |
+
class-agnostic, where R is the total number of predicted masks in all images and C is
|
379 |
+
the number of foreground classes. The values are logits.
|
380 |
+
classes (list): A list of length R that contains either predicted of ground truth class
|
381 |
+
for eash predicted mask.
|
382 |
+
Returns:
|
383 |
+
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
384 |
+
the most uncertain locations having the highest uncertainty score.
|
385 |
+
"""
|
386 |
+
if logits.shape[1] == 1:
|
387 |
+
gt_class_logits = logits
|
388 |
+
else:
|
389 |
+
gt_class_logits = logits[
|
390 |
+
torch.arange(logits.shape[0], device=logits.device),
|
391 |
+
classes].unsqueeze(1)
|
392 |
+
return -torch.abs(gt_class_logits - balance_value)
|
lib/common/smpl_vert_segmentation.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lib/common/train_util.py
ADDED
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import yaml
|
19 |
+
import os.path as osp
|
20 |
+
import torch
|
21 |
+
import numpy as np
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from ..dataset.mesh_util import *
|
24 |
+
from ..net.geometry import orthogonal
|
25 |
+
from pytorch3d.renderer.mesh import rasterize_meshes
|
26 |
+
from .render_utils import Pytorch3dRasterizer
|
27 |
+
from pytorch3d.structures import Meshes
|
28 |
+
import cv2
|
29 |
+
from PIL import Image
|
30 |
+
from tqdm import tqdm
|
31 |
+
import os
|
32 |
+
from termcolor import colored
|
33 |
+
|
34 |
+
|
35 |
+
def reshape_sample_tensor(sample_tensor, num_views):
|
36 |
+
if num_views == 1:
|
37 |
+
return sample_tensor
|
38 |
+
# Need to repeat sample_tensor along the batch dim num_views times
|
39 |
+
sample_tensor = sample_tensor.unsqueeze(dim=1)
|
40 |
+
sample_tensor = sample_tensor.repeat(1, num_views, 1, 1)
|
41 |
+
sample_tensor = sample_tensor.view(
|
42 |
+
sample_tensor.shape[0] * sample_tensor.shape[1],
|
43 |
+
sample_tensor.shape[2], sample_tensor.shape[3])
|
44 |
+
return sample_tensor
|
45 |
+
|
46 |
+
|
47 |
+
def gen_mesh_eval(opt, net, cuda, data, resolution=None):
|
48 |
+
resolution = opt.resolution if resolution is None else resolution
|
49 |
+
image_tensor = data['img'].to(device=cuda)
|
50 |
+
calib_tensor = data['calib'].to(device=cuda)
|
51 |
+
|
52 |
+
net.filter(image_tensor)
|
53 |
+
|
54 |
+
b_min = data['b_min']
|
55 |
+
b_max = data['b_max']
|
56 |
+
try:
|
57 |
+
verts, faces, _, _ = reconstruction_faster(net,
|
58 |
+
cuda,
|
59 |
+
calib_tensor,
|
60 |
+
resolution,
|
61 |
+
b_min,
|
62 |
+
b_max,
|
63 |
+
use_octree=False)
|
64 |
+
|
65 |
+
except Exception as e:
|
66 |
+
print(e)
|
67 |
+
print('Can not create marching cubes at this time.')
|
68 |
+
verts, faces = None, None
|
69 |
+
return verts, faces
|
70 |
+
|
71 |
+
|
72 |
+
def gen_mesh(opt, net, cuda, data, save_path, resolution=None):
|
73 |
+
resolution = opt.resolution if resolution is None else resolution
|
74 |
+
image_tensor = data['img'].to(device=cuda)
|
75 |
+
calib_tensor = data['calib'].to(device=cuda)
|
76 |
+
|
77 |
+
net.filter(image_tensor)
|
78 |
+
|
79 |
+
b_min = data['b_min']
|
80 |
+
b_max = data['b_max']
|
81 |
+
try:
|
82 |
+
save_img_path = save_path[:-4] + '.png'
|
83 |
+
save_img_list = []
|
84 |
+
for v in range(image_tensor.shape[0]):
|
85 |
+
save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
|
86 |
+
(1, 2, 0)) * 0.5 +
|
87 |
+
0.5)[:, :, ::-1] * 255.0
|
88 |
+
save_img_list.append(save_img)
|
89 |
+
save_img = np.concatenate(save_img_list, axis=1)
|
90 |
+
Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
|
91 |
+
|
92 |
+
verts, faces, _, _ = reconstruction_faster(net, cuda, calib_tensor,
|
93 |
+
resolution, b_min, b_max)
|
94 |
+
verts_tensor = torch.from_numpy(
|
95 |
+
verts.T).unsqueeze(0).to(device=cuda).float()
|
96 |
+
xyz_tensor = net.projection(verts_tensor, calib_tensor[:1])
|
97 |
+
uv = xyz_tensor[:, :2, :]
|
98 |
+
color = netG.index(image_tensor[:1], uv).detach().cpu().numpy()[0].T
|
99 |
+
color = color * 0.5 + 0.5
|
100 |
+
save_obj_mesh_with_color(save_path, verts, faces, color)
|
101 |
+
except Exception as e:
|
102 |
+
print(e)
|
103 |
+
print('Can not create marching cubes at this time.')
|
104 |
+
verts, faces, color = None, None, None
|
105 |
+
return verts, faces, color
|
106 |
+
|
107 |
+
|
108 |
+
def gen_mesh_color(opt, netG, netC, cuda, data, save_path, use_octree=True):
|
109 |
+
image_tensor = data['img'].to(device=cuda)
|
110 |
+
calib_tensor = data['calib'].to(device=cuda)
|
111 |
+
|
112 |
+
netG.filter(image_tensor)
|
113 |
+
netC.filter(image_tensor)
|
114 |
+
netC.attach(netG.get_im_feat())
|
115 |
+
|
116 |
+
b_min = data['b_min']
|
117 |
+
b_max = data['b_max']
|
118 |
+
try:
|
119 |
+
save_img_path = save_path[:-4] + '.png'
|
120 |
+
save_img_list = []
|
121 |
+
for v in range(image_tensor.shape[0]):
|
122 |
+
save_img = (np.transpose(image_tensor[v].detach().cpu().numpy(),
|
123 |
+
(1, 2, 0)) * 0.5 +
|
124 |
+
0.5)[:, :, ::-1] * 255.0
|
125 |
+
save_img_list.append(save_img)
|
126 |
+
save_img = np.concatenate(save_img_list, axis=1)
|
127 |
+
Image.fromarray(np.uint8(save_img[:, :, ::-1])).save(save_img_path)
|
128 |
+
|
129 |
+
verts, faces, _, _ = reconstruction_faster(netG,
|
130 |
+
cuda,
|
131 |
+
calib_tensor,
|
132 |
+
opt.resolution,
|
133 |
+
b_min,
|
134 |
+
b_max,
|
135 |
+
use_octree=use_octree)
|
136 |
+
|
137 |
+
# Now Getting colors
|
138 |
+
verts_tensor = torch.from_numpy(
|
139 |
+
verts.T).unsqueeze(0).to(device=cuda).float()
|
140 |
+
verts_tensor = reshape_sample_tensor(verts_tensor, opt.num_views)
|
141 |
+
color = np.zeros(verts.shape)
|
142 |
+
interval = 10000
|
143 |
+
for i in range(len(color) // interval):
|
144 |
+
left = i * interval
|
145 |
+
right = i * interval + interval
|
146 |
+
if i == len(color) // interval - 1:
|
147 |
+
right = -1
|
148 |
+
netC.query(verts_tensor[:, :, left:right], calib_tensor)
|
149 |
+
rgb = netC.get_preds()[0].detach().cpu().numpy() * 0.5 + 0.5
|
150 |
+
color[left:right] = rgb.T
|
151 |
+
|
152 |
+
save_obj_mesh_with_color(save_path, verts, faces, color)
|
153 |
+
except Exception as e:
|
154 |
+
print(e)
|
155 |
+
print('Can not create marching cubes at this time.')
|
156 |
+
verts, faces, color = None, None, None
|
157 |
+
return verts, faces, color
|
158 |
+
|
159 |
+
|
160 |
+
def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma):
|
161 |
+
"""Sets the learning rate to the initial LR decayed by schedule"""
|
162 |
+
if epoch in schedule:
|
163 |
+
lr *= gamma
|
164 |
+
for param_group in optimizer.param_groups:
|
165 |
+
param_group['lr'] = lr
|
166 |
+
return lr
|
167 |
+
|
168 |
+
|
169 |
+
def compute_acc(pred, gt, thresh=0.5):
|
170 |
+
'''
|
171 |
+
return:
|
172 |
+
IOU, precision, and recall
|
173 |
+
'''
|
174 |
+
with torch.no_grad():
|
175 |
+
vol_pred = pred > thresh
|
176 |
+
vol_gt = gt > thresh
|
177 |
+
|
178 |
+
union = vol_pred | vol_gt
|
179 |
+
inter = vol_pred & vol_gt
|
180 |
+
|
181 |
+
true_pos = inter.sum().float()
|
182 |
+
|
183 |
+
union = union.sum().float()
|
184 |
+
if union == 0:
|
185 |
+
union = 1
|
186 |
+
vol_pred = vol_pred.sum().float()
|
187 |
+
if vol_pred == 0:
|
188 |
+
vol_pred = 1
|
189 |
+
vol_gt = vol_gt.sum().float()
|
190 |
+
if vol_gt == 0:
|
191 |
+
vol_gt = 1
|
192 |
+
return true_pos / union, true_pos / vol_pred, true_pos / vol_gt
|
193 |
+
|
194 |
+
|
195 |
+
# def calc_metrics(opt, net, cuda, dataset, num_tests,
|
196 |
+
# resolution=128, sampled_points=1000, use_kaolin=True):
|
197 |
+
# if num_tests > len(dataset):
|
198 |
+
# num_tests = len(dataset)
|
199 |
+
# with torch.no_grad():
|
200 |
+
# chamfer_arr, p2s_arr = [], []
|
201 |
+
# for idx in tqdm(range(num_tests)):
|
202 |
+
# data = dataset[idx * len(dataset) // num_tests]
|
203 |
+
|
204 |
+
# verts, faces = gen_mesh_eval(opt, net, cuda, data, resolution)
|
205 |
+
# if verts is None:
|
206 |
+
# continue
|
207 |
+
|
208 |
+
# mesh_gt = trimesh.load(data['mesh_path'])
|
209 |
+
# mesh_gt = mesh_gt.split(only_watertight=False)
|
210 |
+
# comp_num = [mesh.vertices.shape[0] for mesh in mesh_gt]
|
211 |
+
# mesh_gt = mesh_gt[comp_num.index(max(comp_num))]
|
212 |
+
|
213 |
+
# mesh_pred = trimesh.Trimesh(verts, faces)
|
214 |
+
|
215 |
+
# gt_surface_pts, _ = trimesh.sample.sample_surface_even(
|
216 |
+
# mesh_gt, sampled_points)
|
217 |
+
# pred_surface_pts, _ = trimesh.sample.sample_surface_even(
|
218 |
+
# mesh_pred, sampled_points)
|
219 |
+
|
220 |
+
# if use_kaolin and has_kaolin:
|
221 |
+
# kal_mesh_gt = kal.rep.TriangleMesh.from_tensors(
|
222 |
+
# torch.tensor(mesh_gt.vertices).float().to(device=cuda),
|
223 |
+
# torch.tensor(mesh_gt.faces).long().to(device=cuda))
|
224 |
+
# kal_mesh_pred = kal.rep.TriangleMesh.from_tensors(
|
225 |
+
# torch.tensor(mesh_pred.vertices).float().to(device=cuda),
|
226 |
+
# torch.tensor(mesh_pred.faces).long().to(device=cuda))
|
227 |
+
|
228 |
+
# kal_distance_0 = kal.metrics.mesh.point_to_surface(
|
229 |
+
# torch.tensor(pred_surface_pts).float().to(device=cuda), kal_mesh_gt)
|
230 |
+
# kal_distance_1 = kal.metrics.mesh.point_to_surface(
|
231 |
+
# torch.tensor(gt_surface_pts).float().to(device=cuda), kal_mesh_pred)
|
232 |
+
|
233 |
+
# dist_gt_pred = torch.sqrt(kal_distance_0).cpu().numpy()
|
234 |
+
# dist_pred_gt = torch.sqrt(kal_distance_1).cpu().numpy()
|
235 |
+
# else:
|
236 |
+
# try:
|
237 |
+
# _, dist_pred_gt, _ = trimesh.proximity.closest_point(mesh_pred, gt_surface_pts)
|
238 |
+
# _, dist_gt_pred, _ = trimesh.proximity.closest_point(mesh_gt, pred_surface_pts)
|
239 |
+
# except Exception as e:
|
240 |
+
# print (e)
|
241 |
+
# continue
|
242 |
+
|
243 |
+
# chamfer_dist = 0.5 * (dist_pred_gt.mean() + dist_gt_pred.mean())
|
244 |
+
# p2s_dist = dist_pred_gt.mean()
|
245 |
+
|
246 |
+
# chamfer_arr.append(chamfer_dist)
|
247 |
+
# p2s_arr.append(p2s_dist)
|
248 |
+
|
249 |
+
# return np.average(chamfer_arr), np.average(p2s_arr)
|
250 |
+
|
251 |
+
|
252 |
+
def calc_error(opt, net, cuda, dataset, num_tests):
|
253 |
+
if num_tests > len(dataset):
|
254 |
+
num_tests = len(dataset)
|
255 |
+
with torch.no_grad():
|
256 |
+
erorr_arr, IOU_arr, prec_arr, recall_arr = [], [], [], []
|
257 |
+
for idx in tqdm(range(num_tests)):
|
258 |
+
data = dataset[idx * len(dataset) // num_tests]
|
259 |
+
# retrieve the data
|
260 |
+
image_tensor = data['img'].to(device=cuda)
|
261 |
+
calib_tensor = data['calib'].to(device=cuda)
|
262 |
+
sample_tensor = data['samples'].to(device=cuda).unsqueeze(0)
|
263 |
+
if opt.num_views > 1:
|
264 |
+
sample_tensor = reshape_sample_tensor(sample_tensor,
|
265 |
+
opt.num_views)
|
266 |
+
label_tensor = data['labels'].to(device=cuda).unsqueeze(0)
|
267 |
+
|
268 |
+
res, error = net.forward(image_tensor,
|
269 |
+
sample_tensor,
|
270 |
+
calib_tensor,
|
271 |
+
labels=label_tensor)
|
272 |
+
|
273 |
+
IOU, prec, recall = compute_acc(res, label_tensor)
|
274 |
+
|
275 |
+
# print(
|
276 |
+
# '{0}/{1} | Error: {2:06f} IOU: {3:06f} prec: {4:06f} recall: {5:06f}'
|
277 |
+
# .format(idx, num_tests, error.item(), IOU.item(), prec.item(), recall.item()))
|
278 |
+
erorr_arr.append(error.item())
|
279 |
+
IOU_arr.append(IOU.item())
|
280 |
+
prec_arr.append(prec.item())
|
281 |
+
recall_arr.append(recall.item())
|
282 |
+
|
283 |
+
return np.average(erorr_arr), np.average(IOU_arr), np.average(
|
284 |
+
prec_arr), np.average(recall_arr)
|
285 |
+
|
286 |
+
|
287 |
+
def calc_error_color(opt, netG, netC, cuda, dataset, num_tests):
|
288 |
+
if num_tests > len(dataset):
|
289 |
+
num_tests = len(dataset)
|
290 |
+
with torch.no_grad():
|
291 |
+
error_color_arr = []
|
292 |
+
|
293 |
+
for idx in tqdm(range(num_tests)):
|
294 |
+
data = dataset[idx * len(dataset) // num_tests]
|
295 |
+
# retrieve the data
|
296 |
+
image_tensor = data['img'].to(device=cuda)
|
297 |
+
calib_tensor = data['calib'].to(device=cuda)
|
298 |
+
color_sample_tensor = data['color_samples'].to(
|
299 |
+
device=cuda).unsqueeze(0)
|
300 |
+
|
301 |
+
if opt.num_views > 1:
|
302 |
+
color_sample_tensor = reshape_sample_tensor(
|
303 |
+
color_sample_tensor, opt.num_views)
|
304 |
+
|
305 |
+
rgb_tensor = data['rgbs'].to(device=cuda).unsqueeze(0)
|
306 |
+
|
307 |
+
netG.filter(image_tensor)
|
308 |
+
_, errorC = netC.forward(image_tensor,
|
309 |
+
netG.get_im_feat(),
|
310 |
+
color_sample_tensor,
|
311 |
+
calib_tensor,
|
312 |
+
labels=rgb_tensor)
|
313 |
+
|
314 |
+
# print('{0}/{1} | Error inout: {2:06f} | Error color: {3:06f}'
|
315 |
+
# .format(idx, num_tests, errorG.item(), errorC.item()))
|
316 |
+
error_color_arr.append(errorC.item())
|
317 |
+
|
318 |
+
return np.average(error_color_arr)
|
319 |
+
|
320 |
+
|
321 |
+
# pytorch lightning training related fucntions
|
322 |
+
|
323 |
+
|
324 |
+
def query_func(opt, netG, features, points, proj_matrix=None):
|
325 |
+
'''
|
326 |
+
- points: size of (bz, N, 3)
|
327 |
+
- proj_matrix: size of (bz, 4, 4)
|
328 |
+
return: size of (bz, 1, N)
|
329 |
+
'''
|
330 |
+
assert len(points) == 1
|
331 |
+
samples = points.repeat(opt.num_views, 1, 1)
|
332 |
+
samples = samples.permute(0, 2, 1) # [bz, 3, N]
|
333 |
+
|
334 |
+
# view specific query
|
335 |
+
if proj_matrix is not None:
|
336 |
+
samples = orthogonal(samples, proj_matrix)
|
337 |
+
|
338 |
+
calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples)
|
339 |
+
|
340 |
+
preds = netG.query(features=features,
|
341 |
+
points=samples,
|
342 |
+
calibs=calib_tensor,
|
343 |
+
regressor=netG.if_regressor)
|
344 |
+
|
345 |
+
if type(preds) is list:
|
346 |
+
preds = preds[0]
|
347 |
+
|
348 |
+
return preds
|
349 |
+
|
350 |
+
|
351 |
+
def isin(ar1, ar2):
|
352 |
+
return (ar1[..., None] == ar2).any(-1)
|
353 |
+
|
354 |
+
|
355 |
+
def in1d(ar1, ar2):
|
356 |
+
mask = ar2.new_zeros((max(ar1.max(), ar2.max()) + 1, ), dtype=torch.bool)
|
357 |
+
mask[ar2.unique()] = True
|
358 |
+
return mask[ar1]
|
359 |
+
|
360 |
+
|
361 |
+
def get_visibility(xy, z, faces):
|
362 |
+
"""get the visibility of vertices
|
363 |
+
|
364 |
+
Args:
|
365 |
+
xy (torch.tensor): [N,2]
|
366 |
+
z (torch.tensor): [N,1]
|
367 |
+
faces (torch.tensor): [N,3]
|
368 |
+
size (int): resolution of rendered image
|
369 |
+
"""
|
370 |
+
|
371 |
+
xyz = torch.cat((xy, -z), dim=1)
|
372 |
+
xyz = (xyz + 1.0) / 2.0
|
373 |
+
faces = faces.long()
|
374 |
+
|
375 |
+
rasterizer = Pytorch3dRasterizer(image_size=2**12)
|
376 |
+
meshes_screen = Meshes(verts=xyz[None, ...], faces=faces[None, ...])
|
377 |
+
raster_settings = rasterizer.raster_settings
|
378 |
+
|
379 |
+
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
380 |
+
meshes_screen,
|
381 |
+
image_size=raster_settings.image_size,
|
382 |
+
blur_radius=raster_settings.blur_radius,
|
383 |
+
faces_per_pixel=raster_settings.faces_per_pixel,
|
384 |
+
bin_size=raster_settings.bin_size,
|
385 |
+
max_faces_per_bin=raster_settings.max_faces_per_bin,
|
386 |
+
perspective_correct=raster_settings.perspective_correct,
|
387 |
+
cull_backfaces=raster_settings.cull_backfaces,
|
388 |
+
)
|
389 |
+
|
390 |
+
vis_vertices_id = torch.unique(faces[torch.unique(pix_to_face), :])
|
391 |
+
vis_mask = torch.zeros(size=(z.shape[0], 1))
|
392 |
+
vis_mask[vis_vertices_id] = 1.0
|
393 |
+
|
394 |
+
# print("------------------------\n")
|
395 |
+
# print(f"keep points : {vis_mask.sum()/len(vis_mask)}")
|
396 |
+
|
397 |
+
return vis_mask
|
398 |
+
|
399 |
+
|
400 |
+
def batch_mean(res, key):
|
401 |
+
# recursive mean for multilevel dicts
|
402 |
+
return torch.stack([
|
403 |
+
x[key] if isinstance(x, dict) else batch_mean(x, key) for x in res
|
404 |
+
]).mean()
|
405 |
+
|
406 |
+
|
407 |
+
def tf_log_convert(log_dict):
|
408 |
+
new_log_dict = log_dict.copy()
|
409 |
+
for k, v in log_dict.items():
|
410 |
+
new_log_dict[k.replace("_", "/")] = v
|
411 |
+
del new_log_dict[k]
|
412 |
+
|
413 |
+
return new_log_dict
|
414 |
+
|
415 |
+
|
416 |
+
def bar_log_convert(log_dict, name=None, rot=None):
|
417 |
+
from decimal import Decimal
|
418 |
+
|
419 |
+
new_log_dict = {}
|
420 |
+
|
421 |
+
if name is not None:
|
422 |
+
new_log_dict['name'] = name[0]
|
423 |
+
if rot is not None:
|
424 |
+
new_log_dict['rot'] = rot[0]
|
425 |
+
|
426 |
+
for k, v in log_dict.items():
|
427 |
+
color = "yellow"
|
428 |
+
if 'loss' in k:
|
429 |
+
color = "red"
|
430 |
+
k = k.replace("loss", "L")
|
431 |
+
elif 'acc' in k:
|
432 |
+
color = "green"
|
433 |
+
k = k.replace("acc", "A")
|
434 |
+
elif 'iou' in k:
|
435 |
+
color = "green"
|
436 |
+
k = k.replace("iou", "I")
|
437 |
+
elif 'prec' in k:
|
438 |
+
color = "green"
|
439 |
+
k = k.replace("prec", "P")
|
440 |
+
elif 'recall' in k:
|
441 |
+
color = "green"
|
442 |
+
k = k.replace("recall", "R")
|
443 |
+
|
444 |
+
if 'lr' not in k:
|
445 |
+
new_log_dict[colored(k.split("_")[1],
|
446 |
+
color)] = colored(f"{v:.3f}", color)
|
447 |
+
else:
|
448 |
+
new_log_dict[colored(k.split("_")[1],
|
449 |
+
color)] = colored(f"{Decimal(str(v)):.1E}",
|
450 |
+
color)
|
451 |
+
|
452 |
+
if 'loss' in new_log_dict.keys():
|
453 |
+
del new_log_dict['loss']
|
454 |
+
|
455 |
+
return new_log_dict
|
456 |
+
|
457 |
+
|
458 |
+
def accumulate(outputs, rot_num, split):
|
459 |
+
|
460 |
+
hparam_log_dict = {}
|
461 |
+
|
462 |
+
metrics = outputs[0].keys()
|
463 |
+
datasets = split.keys()
|
464 |
+
|
465 |
+
for dataset in datasets:
|
466 |
+
for metric in metrics:
|
467 |
+
keyword = f"hparam/{dataset}-{metric}"
|
468 |
+
if keyword not in hparam_log_dict.keys():
|
469 |
+
hparam_log_dict[keyword] = 0
|
470 |
+
for idx in range(split[dataset][0] * rot_num,
|
471 |
+
split[dataset][1] * rot_num):
|
472 |
+
hparam_log_dict[keyword] += outputs[idx][metric]
|
473 |
+
hparam_log_dict[keyword] /= (split[dataset][1] -
|
474 |
+
split[dataset][0]) * rot_num
|
475 |
+
|
476 |
+
print(colored(hparam_log_dict, "green"))
|
477 |
+
|
478 |
+
return hparam_log_dict
|
479 |
+
|
480 |
+
|
481 |
+
def calc_error_N(outputs, targets):
|
482 |
+
"""calculate the error of normal (IGR)
|
483 |
+
|
484 |
+
Args:
|
485 |
+
outputs (torch.tensor): [B, 3, N]
|
486 |
+
target (torch.tensor): [B, N, 3]
|
487 |
+
|
488 |
+
# manifold loss and grad_loss in IGR paper
|
489 |
+
grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()
|
490 |
+
normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
|
491 |
+
|
492 |
+
Returns:
|
493 |
+
torch.tensor: error of valid normals on the surface
|
494 |
+
"""
|
495 |
+
# outputs = torch.tanh(-outputs.permute(0,2,1).reshape(-1,3))
|
496 |
+
outputs = -outputs.permute(0, 2, 1).reshape(-1, 1)
|
497 |
+
targets = targets.reshape(-1, 3)[:, 2:3]
|
498 |
+
with_normals = targets.sum(dim=1).abs() > 0.0
|
499 |
+
|
500 |
+
# eikonal loss
|
501 |
+
grad_loss = ((outputs[with_normals].norm(2, dim=-1) - 1)**2).mean()
|
502 |
+
# normals loss
|
503 |
+
normal_loss = (outputs - targets)[with_normals].abs().norm(2, dim=1).mean()
|
504 |
+
|
505 |
+
return grad_loss * 0.0 + normal_loss
|
506 |
+
|
507 |
+
|
508 |
+
def calc_knn_acc(preds, carn_verts, labels, pick_num):
|
509 |
+
"""calculate knn accuracy
|
510 |
+
|
511 |
+
Args:
|
512 |
+
preds (torch.tensor): [B, 3, N]
|
513 |
+
carn_verts (torch.tensor): [SMPLX_V_num, 3]
|
514 |
+
labels (torch.tensor): [B, N_knn, N]
|
515 |
+
"""
|
516 |
+
N_knn_full = labels.shape[1]
|
517 |
+
preds = preds.permute(0, 2, 1).reshape(-1, 3)
|
518 |
+
labels = labels.permute(0, 2, 1).reshape(-1, N_knn_full) # [BxN, num_knn]
|
519 |
+
labels = labels[:, :pick_num]
|
520 |
+
|
521 |
+
dist = torch.cdist(preds, carn_verts, p=2) # [BxN, SMPL_V_num]
|
522 |
+
knn = dist.topk(k=pick_num, dim=1, largest=False)[1] # [BxN, num_knn]
|
523 |
+
cat_mat = torch.sort(torch.cat((knn, labels), dim=1))[0]
|
524 |
+
bool_col = torch.zeros_like(cat_mat)[:, 0]
|
525 |
+
for i in range(pick_num * 2 - 1):
|
526 |
+
bool_col += cat_mat[:, i] == cat_mat[:, i + 1]
|
527 |
+
acc = (bool_col > 0).sum() / len(bool_col)
|
528 |
+
|
529 |
+
return acc
|
530 |
+
|
531 |
+
|
532 |
+
def calc_acc_seg(output, target, num_multiseg):
|
533 |
+
from pytorch_lightning.metrics import Accuracy
|
534 |
+
return Accuracy()(output.reshape(-1, num_multiseg).cpu(),
|
535 |
+
target.flatten().cpu())
|
536 |
+
|
537 |
+
|
538 |
+
def add_watermark(imgs, titles):
|
539 |
+
|
540 |
+
# Write some Text
|
541 |
+
|
542 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
543 |
+
bottomLeftCornerOfText = (350, 50)
|
544 |
+
bottomRightCornerOfText = (800, 50)
|
545 |
+
fontScale = 1
|
546 |
+
fontColor = (1.0, 1.0, 1.0)
|
547 |
+
lineType = 2
|
548 |
+
|
549 |
+
for i in range(len(imgs)):
|
550 |
+
|
551 |
+
title = titles[i + 1]
|
552 |
+
cv2.putText(imgs[i], title, bottomLeftCornerOfText, font, fontScale,
|
553 |
+
fontColor, lineType)
|
554 |
+
|
555 |
+
if i == 0:
|
556 |
+
cv2.putText(imgs[i], str(titles[i][0]), bottomRightCornerOfText,
|
557 |
+
font, fontScale, fontColor, lineType)
|
558 |
+
|
559 |
+
result = np.concatenate(imgs, axis=0).transpose(2, 0, 1)
|
560 |
+
|
561 |
+
return result
|
562 |
+
|
563 |
+
|
564 |
+
def make_test_gif(img_dir):
|
565 |
+
|
566 |
+
if img_dir is not None and len(os.listdir(img_dir)) > 0:
|
567 |
+
for dataset in os.listdir(img_dir):
|
568 |
+
for subject in sorted(os.listdir(osp.join(img_dir, dataset))):
|
569 |
+
img_lst = []
|
570 |
+
im1 = None
|
571 |
+
for file in sorted(
|
572 |
+
os.listdir(osp.join(img_dir, dataset, subject))):
|
573 |
+
if file[-3:] not in ['obj', 'gif']:
|
574 |
+
img_path = os.path.join(img_dir, dataset, subject,
|
575 |
+
file)
|
576 |
+
if im1 == None:
|
577 |
+
im1 = Image.open(img_path)
|
578 |
+
else:
|
579 |
+
img_lst.append(Image.open(img_path))
|
580 |
+
|
581 |
+
print(os.path.join(img_dir, dataset, subject, "out.gif"))
|
582 |
+
im1.save(os.path.join(img_dir, dataset, subject, "out.gif"),
|
583 |
+
save_all=True,
|
584 |
+
append_images=img_lst,
|
585 |
+
duration=500,
|
586 |
+
loop=0)
|
587 |
+
|
588 |
+
|
589 |
+
def export_cfg(logger, cfg):
|
590 |
+
|
591 |
+
cfg_export_file = osp.join(logger.save_dir, logger.name,
|
592 |
+
f"version_{logger.version}", "cfg.yaml")
|
593 |
+
|
594 |
+
if not osp.exists(cfg_export_file):
|
595 |
+
os.makedirs(osp.dirname(cfg_export_file), exist_ok=True)
|
596 |
+
with open(cfg_export_file, "w+") as file:
|
597 |
+
_ = yaml.dump(cfg, file)
|
lib/dataloader_demo.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from lib.common.config import get_cfg_defaults
|
3 |
+
from lib.dataset.PIFuDataset import PIFuDataset
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('-v',
|
9 |
+
'--show',
|
10 |
+
action='store_true',
|
11 |
+
help='vis sampler 3D')
|
12 |
+
parser.add_argument('-s',
|
13 |
+
'--speed',
|
14 |
+
action='store_true',
|
15 |
+
help='vis sampler 3D')
|
16 |
+
parser.add_argument('-l',
|
17 |
+
'--list',
|
18 |
+
action='store_true',
|
19 |
+
help='vis sampler 3D')
|
20 |
+
parser.add_argument('-c',
|
21 |
+
'--config',
|
22 |
+
default='./configs/train/icon-filter.yaml',
|
23 |
+
help='vis sampler 3D')
|
24 |
+
args_c = parser.parse_args()
|
25 |
+
|
26 |
+
args = get_cfg_defaults()
|
27 |
+
args.merge_from_file(args_c.config)
|
28 |
+
|
29 |
+
dataset = PIFuDataset(args, split='train', vis=args_c.show)
|
30 |
+
print(f"Number of subjects :{len(dataset.subject_list)}")
|
31 |
+
data_dict = dataset[0]
|
32 |
+
|
33 |
+
if args_c.list:
|
34 |
+
for k in data_dict.keys():
|
35 |
+
if not hasattr(data_dict[k], "shape"):
|
36 |
+
print(f"{k}: {data_dict[k]}")
|
37 |
+
else:
|
38 |
+
print(f"{k}: {data_dict[k].shape}")
|
39 |
+
|
40 |
+
if args_c.show:
|
41 |
+
# for item in dataset:
|
42 |
+
item = dataset[0]
|
43 |
+
dataset.visualize_sampling3D(item, mode='occ')
|
44 |
+
|
45 |
+
if args_c.speed:
|
46 |
+
# original: 2 it/s
|
47 |
+
# smpl online compute: 2 it/s
|
48 |
+
# normal online compute: 1.5 it/s
|
49 |
+
from tqdm import tqdm
|
50 |
+
for item in tqdm(dataset):
|
51 |
+
# pass
|
52 |
+
for k in item.keys():
|
53 |
+
if 'voxel' in k:
|
54 |
+
if not hasattr(item[k], "shape"):
|
55 |
+
print(f"{k}: {item[k]}")
|
56 |
+
else:
|
57 |
+
print(f"{k}: {item[k].shape}")
|
58 |
+
print("--------------------")
|
lib/dataset/Evaluator.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
from lib.renderer.gl.normal_render import NormalRender
|
19 |
+
from lib.dataset.mesh_util import projection
|
20 |
+
from lib.common.render import Render
|
21 |
+
from PIL import Image
|
22 |
+
import os
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
import trimesh
|
27 |
+
import os.path as osp
|
28 |
+
from PIL import Image
|
29 |
+
|
30 |
+
|
31 |
+
class Evaluator:
|
32 |
+
|
33 |
+
_normal_render = None
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def init_gl():
|
37 |
+
Evaluator._normal_render = NormalRender(width=512, height=512)
|
38 |
+
|
39 |
+
def __init__(self, device):
|
40 |
+
self.device = device
|
41 |
+
self.render = Render(size=512, device=self.device)
|
42 |
+
self.error_term = nn.MSELoss()
|
43 |
+
|
44 |
+
self.offset = 0.0
|
45 |
+
self.scale_factor = None
|
46 |
+
|
47 |
+
def set_mesh(self, result_dict, scale_factor=1.0, offset=0.0):
|
48 |
+
|
49 |
+
for key in result_dict.keys():
|
50 |
+
if torch.is_tensor(result_dict[key]):
|
51 |
+
result_dict[key] = result_dict[key].detach().cpu().numpy()
|
52 |
+
|
53 |
+
for k, v in result_dict.items():
|
54 |
+
setattr(self, k, v)
|
55 |
+
|
56 |
+
self.scale_factor = scale_factor
|
57 |
+
self.offset = offset
|
58 |
+
|
59 |
+
def _render_normal(self, mesh, deg, norms=None):
|
60 |
+
view_mat = np.identity(4)
|
61 |
+
rz = deg / 180.0 * np.pi
|
62 |
+
model_mat = np.identity(4)
|
63 |
+
model_mat[:3, :3] = self._normal_render.euler_to_rot_mat(0, rz, 0)
|
64 |
+
model_mat[1, 3] = self.offset
|
65 |
+
view_mat[2, 2] *= -1
|
66 |
+
|
67 |
+
self._normal_render.set_matrices(view_mat, model_mat)
|
68 |
+
if norms is None:
|
69 |
+
norms = mesh.vertex_normals
|
70 |
+
self._normal_render.set_normal_mesh(self.scale_factor * mesh.vertices,
|
71 |
+
mesh.faces, norms, mesh.faces)
|
72 |
+
self._normal_render.draw()
|
73 |
+
normal_img = self._normal_render.get_color()
|
74 |
+
return normal_img
|
75 |
+
|
76 |
+
def render_mesh_list(self, mesh_lst):
|
77 |
+
|
78 |
+
self.offset = 0.0
|
79 |
+
self.scale_factor = 1.0
|
80 |
+
|
81 |
+
full_list = []
|
82 |
+
for mesh in mesh_lst:
|
83 |
+
row_lst = []
|
84 |
+
for deg in np.arange(0, 360, 90):
|
85 |
+
normal = self._render_normal(mesh, deg)
|
86 |
+
row_lst.append(normal)
|
87 |
+
full_list.append(np.concatenate(row_lst, axis=1))
|
88 |
+
|
89 |
+
res_array = np.concatenate(full_list, axis=0)
|
90 |
+
|
91 |
+
return res_array
|
92 |
+
|
93 |
+
def _get_reproj_normal_error(self, deg):
|
94 |
+
|
95 |
+
tgt_normal = self._render_normal(self.tgt_mesh, deg)
|
96 |
+
src_normal = self._render_normal(self.src_mesh, deg)
|
97 |
+
error = (((src_normal[:, :, :3] -
|
98 |
+
tgt_normal[:, :, :3])**2).sum(axis=2).mean(axis=(0, 1)))
|
99 |
+
|
100 |
+
return error, [src_normal, tgt_normal]
|
101 |
+
|
102 |
+
def render_normal(self, verts, faces):
|
103 |
+
|
104 |
+
verts = verts[0].detach().cpu().numpy()
|
105 |
+
faces = faces[0].detach().cpu().numpy()
|
106 |
+
|
107 |
+
mesh_F = trimesh.Trimesh(verts * np.array([1.0, -1.0, 1.0]), faces)
|
108 |
+
mesh_B = trimesh.Trimesh(verts * np.array([1.0, -1.0, -1.0]), faces)
|
109 |
+
|
110 |
+
self.scale_factor = 1.0
|
111 |
+
|
112 |
+
normal_F = self._render_normal(mesh_F, 0)
|
113 |
+
normal_B = self._render_normal(mesh_B,
|
114 |
+
0,
|
115 |
+
norms=mesh_B.vertex_normals *
|
116 |
+
np.array([-1.0, -1.0, 1.0]))
|
117 |
+
|
118 |
+
mask = normal_F[:, :, 3:4]
|
119 |
+
normal_F = (torch.as_tensor(2.0 * (normal_F - 0.5) * mask).permute(
|
120 |
+
2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
|
121 |
+
normal_B = (torch.as_tensor(2.0 * (normal_B - 0.5) * mask).permute(
|
122 |
+
2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
|
123 |
+
|
124 |
+
return {"T_normal_F": normal_F, "T_normal_B": normal_B}
|
125 |
+
|
126 |
+
def calculate_normal_consist(
|
127 |
+
self,
|
128 |
+
frontal=True,
|
129 |
+
back=True,
|
130 |
+
left=True,
|
131 |
+
right=True,
|
132 |
+
save_demo_img=None,
|
133 |
+
return_demo=False,
|
134 |
+
):
|
135 |
+
|
136 |
+
# reproj error
|
137 |
+
# if save_demo_img is not None, save a visualization at the given path (etc, "./test.png")
|
138 |
+
if self._normal_render is None:
|
139 |
+
print(
|
140 |
+
"In order to use normal render, "
|
141 |
+
"you have to call init_gl() before initialing any evaluator objects."
|
142 |
+
)
|
143 |
+
return -1
|
144 |
+
|
145 |
+
side_cnt = 0
|
146 |
+
total_error = 0
|
147 |
+
demo_list = []
|
148 |
+
|
149 |
+
if frontal:
|
150 |
+
side_cnt += 1
|
151 |
+
error, normal_lst = self._get_reproj_normal_error(0)
|
152 |
+
total_error += error
|
153 |
+
demo_list.append(np.concatenate(normal_lst, axis=0))
|
154 |
+
if back:
|
155 |
+
side_cnt += 1
|
156 |
+
error, normal_lst = self._get_reproj_normal_error(180)
|
157 |
+
total_error += error
|
158 |
+
demo_list.append(np.concatenate(normal_lst, axis=0))
|
159 |
+
if left:
|
160 |
+
side_cnt += 1
|
161 |
+
error, normal_lst = self._get_reproj_normal_error(90)
|
162 |
+
total_error += error
|
163 |
+
demo_list.append(np.concatenate(normal_lst, axis=0))
|
164 |
+
if right:
|
165 |
+
side_cnt += 1
|
166 |
+
error, normal_lst = self._get_reproj_normal_error(270)
|
167 |
+
total_error += error
|
168 |
+
demo_list.append(np.concatenate(normal_lst, axis=0))
|
169 |
+
if save_demo_img is not None:
|
170 |
+
res_array = np.concatenate(demo_list, axis=1)
|
171 |
+
res_img = Image.fromarray((res_array * 255).astype(np.uint8))
|
172 |
+
res_img.save(save_demo_img)
|
173 |
+
|
174 |
+
if return_demo:
|
175 |
+
res_array = np.concatenate(demo_list, axis=1)
|
176 |
+
return res_array
|
177 |
+
else:
|
178 |
+
return total_error
|
179 |
+
|
180 |
+
def space_transfer(self):
|
181 |
+
|
182 |
+
# convert from GT to SDF
|
183 |
+
self.verts_pr -= self.recon_size / 2.0
|
184 |
+
self.verts_pr /= self.recon_size / 2.0
|
185 |
+
|
186 |
+
self.verts_gt = projection(self.verts_gt, self.calib)
|
187 |
+
self.verts_gt[:, 1] *= -1
|
188 |
+
|
189 |
+
self.tgt_mesh = trimesh.Trimesh(self.verts_gt, self.faces_gt)
|
190 |
+
self.src_mesh = trimesh.Trimesh(self.verts_pr, self.faces_pr)
|
191 |
+
|
192 |
+
# (self.tgt_mesh+self.src_mesh).show()
|
193 |
+
|
194 |
+
def export_mesh(self, dir, name):
|
195 |
+
self.tgt_mesh.visual.vertex_colors = np.array([255, 0, 0])
|
196 |
+
self.src_mesh.visual.vertex_colors = np.array([0, 255, 0])
|
197 |
+
|
198 |
+
(self.tgt_mesh + self.src_mesh).export(
|
199 |
+
osp.join(dir, f"{name}_gt_pr.obj"))
|
200 |
+
|
201 |
+
def calculate_chamfer_p2s(self, sampled_points=1000):
|
202 |
+
"""calculate the geometry metrics [chamfer, p2s, chamfer_H, p2s_H]
|
203 |
+
|
204 |
+
Args:
|
205 |
+
verts_gt (torch.cuda.tensor): [N, 3]
|
206 |
+
faces_gt (torch.cuda.tensor): [M, 3]
|
207 |
+
verts_pr (torch.cuda.tensor): [N', 3]
|
208 |
+
faces_pr (torch.cuda.tensor): [M', 3]
|
209 |
+
sampled_points (int, optional): use smaller number for faster testing. Defaults to 1000.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
tuple: chamfer, p2s, chamfer_H, p2s_H
|
213 |
+
"""
|
214 |
+
|
215 |
+
gt_surface_pts, _ = trimesh.sample.sample_surface_even(
|
216 |
+
self.tgt_mesh, sampled_points)
|
217 |
+
pred_surface_pts, _ = trimesh.sample.sample_surface_even(
|
218 |
+
self.src_mesh, sampled_points)
|
219 |
+
|
220 |
+
_, dist_pred_gt, _ = trimesh.proximity.closest_point(
|
221 |
+
self.src_mesh, gt_surface_pts)
|
222 |
+
_, dist_gt_pred, _ = trimesh.proximity.closest_point(
|
223 |
+
self.tgt_mesh, pred_surface_pts)
|
224 |
+
|
225 |
+
dist_pred_gt[np.isnan(dist_pred_gt)] = 0
|
226 |
+
dist_gt_pred[np.isnan(dist_gt_pred)] = 0
|
227 |
+
chamfer_dist = 0.5 * (dist_pred_gt.mean() +
|
228 |
+
dist_gt_pred.mean()).item() * 100
|
229 |
+
p2s_dist = dist_pred_gt.mean().item() * 100
|
230 |
+
|
231 |
+
return chamfer_dist, p2s_dist
|
232 |
+
|
233 |
+
def calc_acc(self, output, target, thres=0.5, use_sdf=False):
|
234 |
+
|
235 |
+
# # remove the surface points with thres
|
236 |
+
# non_surf_ids = (target != thres)
|
237 |
+
# output = output[non_surf_ids]
|
238 |
+
# target = target[non_surf_ids]
|
239 |
+
|
240 |
+
with torch.no_grad():
|
241 |
+
output = output.masked_fill(output < thres, 0.0)
|
242 |
+
output = output.masked_fill(output > thres, 1.0)
|
243 |
+
|
244 |
+
if use_sdf:
|
245 |
+
target = target.masked_fill(target < thres, 0.0)
|
246 |
+
target = target.masked_fill(target > thres, 1.0)
|
247 |
+
|
248 |
+
acc = output.eq(target).float().mean()
|
249 |
+
|
250 |
+
# iou, precison, recall
|
251 |
+
output = output > thres
|
252 |
+
target = target > thres
|
253 |
+
|
254 |
+
union = output | target
|
255 |
+
inter = output & target
|
256 |
+
|
257 |
+
_max = torch.tensor(1.0).to(output.device)
|
258 |
+
|
259 |
+
union = max(union.sum().float(), _max)
|
260 |
+
true_pos = max(inter.sum().float(), _max)
|
261 |
+
vol_pred = max(output.sum().float(), _max)
|
262 |
+
vol_gt = max(target.sum().float(), _max)
|
263 |
+
|
264 |
+
return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt
|
lib/dataset/NormalDataset.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import os.path as osp
|
19 |
+
import numpy as np
|
20 |
+
from PIL import Image
|
21 |
+
import torchvision.transforms as transforms
|
22 |
+
|
23 |
+
|
24 |
+
class NormalDataset():
|
25 |
+
def __init__(self, cfg, split='train'):
|
26 |
+
|
27 |
+
self.split = split
|
28 |
+
self.root = cfg.root
|
29 |
+
self.overfit = cfg.overfit
|
30 |
+
|
31 |
+
self.opt = cfg.dataset
|
32 |
+
self.datasets = self.opt.types
|
33 |
+
self.input_size = self.opt.input_size
|
34 |
+
self.set_splits = self.opt.set_splits
|
35 |
+
self.scales = self.opt.scales
|
36 |
+
self.pifu = self.opt.pifu
|
37 |
+
|
38 |
+
# input data types and dimensions
|
39 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
40 |
+
self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
|
41 |
+
self.in_total = self.in_nml + ['normal_F', 'normal_B']
|
42 |
+
self.in_total_dim = self.in_nml_dim + [3, 3]
|
43 |
+
|
44 |
+
if self.split != 'train':
|
45 |
+
self.rotations = range(0, 360, 120)
|
46 |
+
else:
|
47 |
+
self.rotations = np.arange(0, 360, 360 /
|
48 |
+
self.opt.rotation_num).astype(np.int)
|
49 |
+
|
50 |
+
self.datasets_dict = {}
|
51 |
+
for dataset_id, dataset in enumerate(self.datasets):
|
52 |
+
dataset_dir = osp.join(self.root, dataset, "smplx")
|
53 |
+
self.datasets_dict[dataset] = {
|
54 |
+
"subjects":
|
55 |
+
np.loadtxt(osp.join(self.root, dataset, "all.txt"), dtype=str),
|
56 |
+
"path":
|
57 |
+
dataset_dir,
|
58 |
+
"scale":
|
59 |
+
self.scales[dataset_id]
|
60 |
+
}
|
61 |
+
|
62 |
+
self.subject_list = self.get_subject_list(split)
|
63 |
+
|
64 |
+
# PIL to tensor
|
65 |
+
self.image_to_tensor = transforms.Compose([
|
66 |
+
transforms.Resize(self.input_size),
|
67 |
+
transforms.ToTensor(),
|
68 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
69 |
+
])
|
70 |
+
|
71 |
+
# PIL to tensor
|
72 |
+
self.mask_to_tensor = transforms.Compose([
|
73 |
+
transforms.Resize(self.input_size),
|
74 |
+
transforms.ToTensor(),
|
75 |
+
transforms.Normalize((0.0, ), (1.0, ))
|
76 |
+
])
|
77 |
+
|
78 |
+
def get_subject_list(self, split):
|
79 |
+
|
80 |
+
subject_list = []
|
81 |
+
|
82 |
+
for dataset in self.datasets:
|
83 |
+
|
84 |
+
if self.pifu:
|
85 |
+
txt = osp.join(self.root, dataset, f'{split}_pifu.txt')
|
86 |
+
else:
|
87 |
+
txt = osp.join(self.root, dataset, f'{split}.txt')
|
88 |
+
|
89 |
+
if osp.exists(txt):
|
90 |
+
print(f"load from {txt}")
|
91 |
+
subject_list += sorted(np.loadtxt(txt, dtype=str).tolist())
|
92 |
+
|
93 |
+
if self.pifu:
|
94 |
+
miss_pifu = sorted(
|
95 |
+
np.loadtxt(osp.join(self.root, dataset,
|
96 |
+
"miss_pifu.txt"),
|
97 |
+
dtype=str).tolist())
|
98 |
+
subject_list = [
|
99 |
+
subject for subject in subject_list
|
100 |
+
if subject not in miss_pifu
|
101 |
+
]
|
102 |
+
subject_list = [
|
103 |
+
"renderpeople/" + subject for subject in subject_list
|
104 |
+
]
|
105 |
+
|
106 |
+
else:
|
107 |
+
train_txt = osp.join(self.root, dataset, 'train.txt')
|
108 |
+
val_txt = osp.join(self.root, dataset, 'val.txt')
|
109 |
+
test_txt = osp.join(self.root, dataset, 'test.txt')
|
110 |
+
|
111 |
+
print(
|
112 |
+
f"generate lists of [train, val, test] \n {train_txt} \n {val_txt} \n {test_txt} \n"
|
113 |
+
)
|
114 |
+
|
115 |
+
split_txt = osp.join(self.root, dataset, f'{split}.txt')
|
116 |
+
|
117 |
+
subjects = self.datasets_dict[dataset]['subjects']
|
118 |
+
train_split = int(len(subjects) * self.set_splits[0])
|
119 |
+
val_split = int(
|
120 |
+
len(subjects) * self.set_splits[1]) + train_split
|
121 |
+
|
122 |
+
with open(train_txt, "w") as f:
|
123 |
+
f.write("\n".join(dataset + "/" + item
|
124 |
+
for item in subjects[:train_split]))
|
125 |
+
with open(val_txt, "w") as f:
|
126 |
+
f.write("\n".join(
|
127 |
+
dataset + "/" + item
|
128 |
+
for item in subjects[train_split:val_split]))
|
129 |
+
with open(test_txt, "w") as f:
|
130 |
+
f.write("\n".join(dataset + "/" + item
|
131 |
+
for item in subjects[val_split:]))
|
132 |
+
|
133 |
+
subject_list += sorted(
|
134 |
+
np.loadtxt(split_txt, dtype=str).tolist())
|
135 |
+
|
136 |
+
bug_list = sorted(
|
137 |
+
np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
|
138 |
+
|
139 |
+
subject_list = [
|
140 |
+
subject for subject in subject_list if (subject not in bug_list)
|
141 |
+
]
|
142 |
+
|
143 |
+
return subject_list
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
return len(self.subject_list) * len(self.rotations)
|
147 |
+
|
148 |
+
def __getitem__(self, index):
|
149 |
+
|
150 |
+
# only pick the first data if overfitting
|
151 |
+
if self.overfit:
|
152 |
+
index = 0
|
153 |
+
|
154 |
+
rid = index % len(self.rotations)
|
155 |
+
mid = index // len(self.rotations)
|
156 |
+
|
157 |
+
rotation = self.rotations[rid]
|
158 |
+
|
159 |
+
# choose specific test sets
|
160 |
+
subject = self.subject_list[mid]
|
161 |
+
|
162 |
+
subject_render = "/".join(
|
163 |
+
[subject.split("/")[0] + "_12views",
|
164 |
+
subject.split("/")[1]])
|
165 |
+
|
166 |
+
# setup paths
|
167 |
+
data_dict = {
|
168 |
+
'dataset':
|
169 |
+
subject.split("/")[0],
|
170 |
+
'subject':
|
171 |
+
subject,
|
172 |
+
'rotation':
|
173 |
+
rotation,
|
174 |
+
'image_path':
|
175 |
+
osp.join(self.root, subject_render, 'render',
|
176 |
+
f'{rotation:03d}.png')
|
177 |
+
}
|
178 |
+
|
179 |
+
# image/normal/depth loader
|
180 |
+
for name, channel in zip(self.in_total, self.in_total_dim):
|
181 |
+
|
182 |
+
if name != 'image':
|
183 |
+
data_dict.update({
|
184 |
+
f'{name}_path':
|
185 |
+
osp.join(self.root, subject_render, name,
|
186 |
+
f'{rotation:03d}.png')
|
187 |
+
})
|
188 |
+
data_dict.update({
|
189 |
+
name:
|
190 |
+
self.imagepath2tensor(data_dict[f'{name}_path'],
|
191 |
+
channel,
|
192 |
+
inv='depth_B' in name)
|
193 |
+
})
|
194 |
+
|
195 |
+
path_keys = [
|
196 |
+
key for key in data_dict.keys() if '_path' in key or '_dir' in key
|
197 |
+
]
|
198 |
+
for key in path_keys:
|
199 |
+
del data_dict[key]
|
200 |
+
|
201 |
+
return data_dict
|
202 |
+
|
203 |
+
def imagepath2tensor(self, path, channel=3, inv=False):
|
204 |
+
|
205 |
+
rgba = Image.open(path).convert('RGBA')
|
206 |
+
mask = rgba.split()[-1]
|
207 |
+
image = rgba.convert('RGB')
|
208 |
+
image = self.image_to_tensor(image)
|
209 |
+
mask = self.mask_to_tensor(mask)
|
210 |
+
image = (image * mask)[:channel]
|
211 |
+
|
212 |
+
return (image * (0.5 - inv) * 2.0).float()
|
lib/dataset/NormalModule.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
from .NormalDataset import NormalDataset
|
21 |
+
|
22 |
+
# pytorch lightning related libs
|
23 |
+
import pytorch_lightning as pl
|
24 |
+
|
25 |
+
|
26 |
+
class NormalModule(pl.LightningDataModule):
|
27 |
+
def __init__(self, cfg):
|
28 |
+
super(NormalModule, self).__init__()
|
29 |
+
self.cfg = cfg
|
30 |
+
self.overfit = self.cfg.overfit
|
31 |
+
|
32 |
+
if self.overfit:
|
33 |
+
self.batch_size = 1
|
34 |
+
else:
|
35 |
+
self.batch_size = self.cfg.batch_size
|
36 |
+
|
37 |
+
self.data_size = {}
|
38 |
+
|
39 |
+
def prepare_data(self):
|
40 |
+
|
41 |
+
pass
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def worker_init_fn(worker_id):
|
45 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
46 |
+
|
47 |
+
def setup(self, stage):
|
48 |
+
|
49 |
+
if stage == 'fit' or stage is None:
|
50 |
+
self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
|
51 |
+
self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
|
52 |
+
self.data_size = {
|
53 |
+
'train': len(self.train_dataset),
|
54 |
+
'val': len(self.val_dataset)
|
55 |
+
}
|
56 |
+
|
57 |
+
if stage == 'test' or stage is None:
|
58 |
+
self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
|
59 |
+
|
60 |
+
def train_dataloader(self):
|
61 |
+
|
62 |
+
train_data_loader = DataLoader(self.train_dataset,
|
63 |
+
batch_size=self.batch_size,
|
64 |
+
shuffle=not self.overfit,
|
65 |
+
num_workers=self.cfg.num_threads,
|
66 |
+
pin_memory=True,
|
67 |
+
worker_init_fn=self.worker_init_fn)
|
68 |
+
|
69 |
+
return train_data_loader
|
70 |
+
|
71 |
+
def val_dataloader(self):
|
72 |
+
|
73 |
+
if self.overfit:
|
74 |
+
current_dataset = self.train_dataset
|
75 |
+
else:
|
76 |
+
current_dataset = self.val_dataset
|
77 |
+
|
78 |
+
val_data_loader = DataLoader(current_dataset,
|
79 |
+
batch_size=self.batch_size,
|
80 |
+
shuffle=False,
|
81 |
+
num_workers=self.cfg.num_threads,
|
82 |
+
pin_memory=True)
|
83 |
+
|
84 |
+
return val_data_loader
|
85 |
+
|
86 |
+
def test_dataloader(self):
|
87 |
+
|
88 |
+
test_data_loader = DataLoader(self.test_dataset,
|
89 |
+
batch_size=1,
|
90 |
+
shuffle=False,
|
91 |
+
num_workers=self.cfg.num_threads,
|
92 |
+
pin_memory=True)
|
93 |
+
|
94 |
+
return test_data_loader
|
lib/dataset/PIFuDataModule.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from .PIFuDataset import PIFuDataset
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
|
7 |
+
class PIFuDataModule(pl.LightningDataModule):
|
8 |
+
def __init__(self, cfg):
|
9 |
+
super(PIFuDataModule, self).__init__()
|
10 |
+
self.cfg = cfg
|
11 |
+
self.overfit = self.cfg.overfit
|
12 |
+
|
13 |
+
if self.overfit:
|
14 |
+
self.batch_size = 1
|
15 |
+
else:
|
16 |
+
self.batch_size = self.cfg.batch_size
|
17 |
+
|
18 |
+
self.data_size = {}
|
19 |
+
|
20 |
+
def prepare_data(self):
|
21 |
+
|
22 |
+
pass
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def worker_init_fn(worker_id):
|
26 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
27 |
+
|
28 |
+
def setup(self, stage):
|
29 |
+
|
30 |
+
if stage == 'fit':
|
31 |
+
self.train_dataset = PIFuDataset(cfg=self.cfg, split="train")
|
32 |
+
self.val_dataset = PIFuDataset(cfg=self.cfg, split="val")
|
33 |
+
self.data_size = {'train': len(self.train_dataset),
|
34 |
+
'val': len(self.val_dataset)}
|
35 |
+
|
36 |
+
if stage == 'test':
|
37 |
+
self.test_dataset = PIFuDataset(cfg=self.cfg, split="test")
|
38 |
+
|
39 |
+
def train_dataloader(self):
|
40 |
+
|
41 |
+
train_data_loader = DataLoader(
|
42 |
+
self.train_dataset,
|
43 |
+
batch_size=self.batch_size, shuffle=True,
|
44 |
+
num_workers=self.cfg.num_threads, pin_memory=True,
|
45 |
+
worker_init_fn=self.worker_init_fn)
|
46 |
+
|
47 |
+
return train_data_loader
|
48 |
+
|
49 |
+
def val_dataloader(self):
|
50 |
+
|
51 |
+
if self.overfit:
|
52 |
+
current_dataset = self.train_dataset
|
53 |
+
else:
|
54 |
+
current_dataset = self.val_dataset
|
55 |
+
|
56 |
+
val_data_loader = DataLoader(
|
57 |
+
current_dataset,
|
58 |
+
batch_size=1, shuffle=False,
|
59 |
+
num_workers=self.cfg.num_threads, pin_memory=True,
|
60 |
+
worker_init_fn=self.worker_init_fn)
|
61 |
+
|
62 |
+
return val_data_loader
|
63 |
+
|
64 |
+
def test_dataloader(self):
|
65 |
+
|
66 |
+
test_data_loader = DataLoader(
|
67 |
+
self.test_dataset,
|
68 |
+
batch_size=1, shuffle=False,
|
69 |
+
num_workers=self.cfg.num_threads, pin_memory=True)
|
70 |
+
|
71 |
+
return test_data_loader
|
lib/dataset/PIFuDataset.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lib.renderer.mesh import load_fit_body
|
2 |
+
from lib.dataset.hoppeMesh import HoppeMesh
|
3 |
+
from lib.dataset.body_model import TetraSMPLModel
|
4 |
+
from lib.common.render import Render
|
5 |
+
from lib.dataset.mesh_util import SMPLX, projection, cal_sdf_batch, get_visibility
|
6 |
+
from lib.pare.pare.utils.geometry import rotation_matrix_to_angle_axis
|
7 |
+
from termcolor import colored
|
8 |
+
import os.path as osp
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import random
|
12 |
+
import trimesh
|
13 |
+
import torch
|
14 |
+
import vedo
|
15 |
+
from kaolin.ops.mesh import check_sign
|
16 |
+
import torchvision.transforms as transforms
|
17 |
+
from ipdb import set_trace
|
18 |
+
|
19 |
+
|
20 |
+
class PIFuDataset():
|
21 |
+
def __init__(self, cfg, split='train', vis=False):
|
22 |
+
|
23 |
+
self.split = split
|
24 |
+
self.root = cfg.root
|
25 |
+
self.bsize = cfg.batch_size
|
26 |
+
self.overfit = cfg.overfit
|
27 |
+
|
28 |
+
# for debug, only used in visualize_sampling3D
|
29 |
+
self.vis = vis
|
30 |
+
|
31 |
+
self.opt = cfg.dataset
|
32 |
+
self.datasets = self.opt.types
|
33 |
+
self.input_size = self.opt.input_size
|
34 |
+
self.scales = self.opt.scales
|
35 |
+
self.workers = cfg.num_threads
|
36 |
+
self.prior_type = cfg.net.prior_type
|
37 |
+
|
38 |
+
self.noise_type = self.opt.noise_type
|
39 |
+
self.noise_scale = self.opt.noise_scale
|
40 |
+
|
41 |
+
noise_joints = [4, 5, 7, 8, 13, 14, 16, 17, 18, 19, 20, 21]
|
42 |
+
|
43 |
+
self.noise_smpl_idx = []
|
44 |
+
self.noise_smplx_idx = []
|
45 |
+
|
46 |
+
for idx in noise_joints:
|
47 |
+
self.noise_smpl_idx.append(idx * 3)
|
48 |
+
self.noise_smpl_idx.append(idx * 3 + 1)
|
49 |
+
self.noise_smpl_idx.append(idx * 3 + 2)
|
50 |
+
|
51 |
+
self.noise_smplx_idx.append((idx-1) * 3)
|
52 |
+
self.noise_smplx_idx.append((idx-1) * 3 + 1)
|
53 |
+
self.noise_smplx_idx.append((idx-1) * 3 + 2)
|
54 |
+
|
55 |
+
self.use_sdf = cfg.sdf
|
56 |
+
self.sdf_clip = cfg.sdf_clip
|
57 |
+
|
58 |
+
# [(feat_name, channel_num),...]
|
59 |
+
self.in_geo = [item[0] for item in cfg.net.in_geo]
|
60 |
+
self.in_nml = [item[0] for item in cfg.net.in_nml]
|
61 |
+
|
62 |
+
self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
|
63 |
+
self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
|
64 |
+
|
65 |
+
self.in_total = self.in_geo + self.in_nml
|
66 |
+
self.in_total_dim = self.in_geo_dim + self.in_nml_dim
|
67 |
+
|
68 |
+
if self.split == 'train':
|
69 |
+
self.rotations = np.arange(
|
70 |
+
0, 360, 360 / self.opt.rotation_num).astype(np.int32)
|
71 |
+
else:
|
72 |
+
self.rotations = range(0, 360, 120)
|
73 |
+
|
74 |
+
self.datasets_dict = {}
|
75 |
+
|
76 |
+
for dataset_id, dataset in enumerate(self.datasets):
|
77 |
+
|
78 |
+
mesh_dir = None
|
79 |
+
smplx_dir = None
|
80 |
+
|
81 |
+
dataset_dir = osp.join(self.root, dataset)
|
82 |
+
|
83 |
+
if dataset in ['thuman2']:
|
84 |
+
mesh_dir = osp.join(dataset_dir, "scans")
|
85 |
+
smplx_dir = osp.join(dataset_dir, "fits")
|
86 |
+
smpl_dir = osp.join(dataset_dir, "smpl")
|
87 |
+
|
88 |
+
self.datasets_dict[dataset] = {
|
89 |
+
"subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str),
|
90 |
+
"smplx_dir": smplx_dir,
|
91 |
+
"smpl_dir": smpl_dir,
|
92 |
+
"mesh_dir": mesh_dir,
|
93 |
+
"scale": self.scales[dataset_id]
|
94 |
+
}
|
95 |
+
|
96 |
+
self.subject_list = self.get_subject_list(split)
|
97 |
+
self.smplx = SMPLX()
|
98 |
+
|
99 |
+
# PIL to tensor
|
100 |
+
self.image_to_tensor = transforms.Compose([
|
101 |
+
transforms.Resize(self.input_size),
|
102 |
+
transforms.ToTensor(),
|
103 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
104 |
+
])
|
105 |
+
|
106 |
+
# PIL to tensor
|
107 |
+
self.mask_to_tensor = transforms.Compose([
|
108 |
+
transforms.Resize(self.input_size),
|
109 |
+
transforms.ToTensor(),
|
110 |
+
transforms.Normalize((0.0, ), (1.0, ))
|
111 |
+
])
|
112 |
+
|
113 |
+
self.device = torch.device(f"cuda:{cfg.gpus[0]}")
|
114 |
+
self.render = Render(size=512, device=self.device)
|
115 |
+
|
116 |
+
def render_normal(self, verts, faces):
|
117 |
+
|
118 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
119 |
+
self.render.load_meshes(verts, faces)
|
120 |
+
return self.render.get_rgb_image()
|
121 |
+
|
122 |
+
def get_subject_list(self, split):
|
123 |
+
|
124 |
+
subject_list = []
|
125 |
+
|
126 |
+
for dataset in self.datasets:
|
127 |
+
|
128 |
+
split_txt = osp.join(self.root, dataset, f'{split}.txt')
|
129 |
+
|
130 |
+
if osp.exists(split_txt):
|
131 |
+
print(f"load from {split_txt}")
|
132 |
+
subject_list += np.loadtxt(split_txt, dtype=str).tolist()
|
133 |
+
else:
|
134 |
+
full_txt = osp.join(self.root, dataset, 'all.txt')
|
135 |
+
print(f"split {full_txt} into train/val/test")
|
136 |
+
|
137 |
+
full_lst = np.loadtxt(full_txt, dtype=str)
|
138 |
+
full_lst = [dataset+"/"+item for item in full_lst]
|
139 |
+
[train_lst, test_lst, val_lst] = np.split(
|
140 |
+
full_lst, [500, 500+5, ])
|
141 |
+
|
142 |
+
np.savetxt(full_txt.replace(
|
143 |
+
"all", "train"), train_lst, fmt="%s")
|
144 |
+
np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s")
|
145 |
+
np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s")
|
146 |
+
|
147 |
+
print(f"load from {split_txt}")
|
148 |
+
subject_list += np.loadtxt(split_txt, dtype=str).tolist()
|
149 |
+
|
150 |
+
if self.split != 'test':
|
151 |
+
subject_list += subject_list[:self.bsize -
|
152 |
+
len(subject_list) % self.bsize]
|
153 |
+
print(colored(f"total: {len(subject_list)}", "yellow"))
|
154 |
+
random.shuffle(subject_list)
|
155 |
+
|
156 |
+
# subject_list = ["thuman2/0008"]
|
157 |
+
return subject_list
|
158 |
+
|
159 |
+
def __len__(self):
|
160 |
+
return len(self.subject_list) * len(self.rotations)
|
161 |
+
|
162 |
+
def __getitem__(self, index):
|
163 |
+
|
164 |
+
# only pick the first data if overfitting
|
165 |
+
if self.overfit:
|
166 |
+
index = 0
|
167 |
+
|
168 |
+
rid = index % len(self.rotations)
|
169 |
+
mid = index // len(self.rotations)
|
170 |
+
|
171 |
+
rotation = self.rotations[rid]
|
172 |
+
subject = self.subject_list[mid].split("/")[1]
|
173 |
+
dataset = self.subject_list[mid].split("/")[0]
|
174 |
+
render_folder = "/".join([dataset +
|
175 |
+
f"_{self.opt.rotation_num}views", subject])
|
176 |
+
|
177 |
+
# setup paths
|
178 |
+
data_dict = {
|
179 |
+
'dataset': dataset,
|
180 |
+
'subject': subject,
|
181 |
+
'rotation': rotation,
|
182 |
+
'scale': self.datasets_dict[dataset]["scale"],
|
183 |
+
'mesh_path': osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}/{subject}.obj"),
|
184 |
+
'smplx_path': osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}/smplx_param.pkl"),
|
185 |
+
'smpl_path': osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.pkl"),
|
186 |
+
'calib_path': osp.join(self.root, render_folder, 'calib', f'{rotation:03d}.txt'),
|
187 |
+
'vis_path': osp.join(self.root, render_folder, 'vis', f'{rotation:03d}.pt'),
|
188 |
+
'image_path': osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png')
|
189 |
+
}
|
190 |
+
|
191 |
+
# load training data
|
192 |
+
data_dict.update(self.load_calib(data_dict))
|
193 |
+
|
194 |
+
# image/normal/depth loader
|
195 |
+
for name, channel in zip(self.in_total, self.in_total_dim):
|
196 |
+
|
197 |
+
if f'{name}_path' not in data_dict.keys():
|
198 |
+
data_dict.update({
|
199 |
+
f'{name}_path': osp.join(self.root, render_folder, name, f'{rotation:03d}.png')
|
200 |
+
})
|
201 |
+
|
202 |
+
# tensor update
|
203 |
+
data_dict.update({
|
204 |
+
name: self.imagepath2tensor(
|
205 |
+
data_dict[f'{name}_path'], channel, inv=False)
|
206 |
+
})
|
207 |
+
|
208 |
+
data_dict.update(self.load_mesh(data_dict))
|
209 |
+
data_dict.update(self.get_sampling_geo(
|
210 |
+
data_dict, is_valid=self.split == "val", is_sdf=self.use_sdf))
|
211 |
+
data_dict.update(self.load_smpl(data_dict, self.vis))
|
212 |
+
|
213 |
+
if self.prior_type == 'pamir':
|
214 |
+
data_dict.update(self.load_smpl_voxel(data_dict))
|
215 |
+
|
216 |
+
if (self.split != 'test') and (not self.vis):
|
217 |
+
|
218 |
+
del data_dict['verts']
|
219 |
+
del data_dict['faces']
|
220 |
+
|
221 |
+
if not self.vis:
|
222 |
+
del data_dict['mesh']
|
223 |
+
|
224 |
+
path_keys = [
|
225 |
+
key for key in data_dict.keys() if '_path' in key or '_dir' in key
|
226 |
+
]
|
227 |
+
for key in path_keys:
|
228 |
+
del data_dict[key]
|
229 |
+
|
230 |
+
return data_dict
|
231 |
+
|
232 |
+
def imagepath2tensor(self, path, channel=3, inv=False):
|
233 |
+
|
234 |
+
rgba = Image.open(path).convert('RGBA')
|
235 |
+
mask = rgba.split()[-1]
|
236 |
+
image = rgba.convert('RGB')
|
237 |
+
image = self.image_to_tensor(image)
|
238 |
+
mask = self.mask_to_tensor(mask)
|
239 |
+
image = (image * mask)[:channel]
|
240 |
+
|
241 |
+
return (image * (0.5 - inv) * 2.0).float()
|
242 |
+
|
243 |
+
def load_calib(self, data_dict):
|
244 |
+
calib_data = np.loadtxt(data_dict['calib_path'], dtype=float)
|
245 |
+
extrinsic = calib_data[:4, :4]
|
246 |
+
intrinsic = calib_data[4:8, :4]
|
247 |
+
calib_mat = np.matmul(intrinsic, extrinsic)
|
248 |
+
calib_mat = torch.from_numpy(calib_mat).float()
|
249 |
+
return {'calib': calib_mat}
|
250 |
+
|
251 |
+
def load_mesh(self, data_dict):
|
252 |
+
mesh_path = data_dict['mesh_path']
|
253 |
+
scale = data_dict['scale']
|
254 |
+
|
255 |
+
mesh_ori = trimesh.load(mesh_path,
|
256 |
+
skip_materials=True,
|
257 |
+
process=False,
|
258 |
+
maintain_order=True)
|
259 |
+
verts = mesh_ori.vertices * scale
|
260 |
+
faces = mesh_ori.faces
|
261 |
+
|
262 |
+
vert_normals = np.array(mesh_ori.vertex_normals)
|
263 |
+
face_normals = np.array(mesh_ori.face_normals)
|
264 |
+
|
265 |
+
mesh = HoppeMesh(verts, faces, vert_normals, face_normals)
|
266 |
+
|
267 |
+
return {
|
268 |
+
'mesh': mesh,
|
269 |
+
'verts': torch.as_tensor(mesh.verts).float(),
|
270 |
+
'faces': torch.as_tensor(mesh.faces).long()
|
271 |
+
}
|
272 |
+
|
273 |
+
def add_noise(self,
|
274 |
+
beta_num,
|
275 |
+
smpl_pose,
|
276 |
+
smpl_betas,
|
277 |
+
noise_type,
|
278 |
+
noise_scale,
|
279 |
+
type,
|
280 |
+
hashcode):
|
281 |
+
|
282 |
+
np.random.seed(hashcode)
|
283 |
+
|
284 |
+
if type == 'smplx':
|
285 |
+
noise_idx = self.noise_smplx_idx
|
286 |
+
else:
|
287 |
+
noise_idx = self.noise_smpl_idx
|
288 |
+
|
289 |
+
if 'beta' in noise_type and noise_scale[noise_type.index("beta")] > 0.0:
|
290 |
+
smpl_betas += (np.random.rand(beta_num) -
|
291 |
+
0.5) * 2.0 * noise_scale[noise_type.index("beta")]
|
292 |
+
smpl_betas = smpl_betas.astype(np.float32)
|
293 |
+
|
294 |
+
if 'pose' in noise_type and noise_scale[noise_type.index("pose")] > 0.0:
|
295 |
+
smpl_pose[noise_idx] += (
|
296 |
+
np.random.rand(len(noise_idx)) -
|
297 |
+
0.5) * 2.0 * np.pi * noise_scale[noise_type.index("pose")]
|
298 |
+
smpl_pose = smpl_pose.astype(np.float32)
|
299 |
+
if type == 'smplx':
|
300 |
+
return torch.as_tensor(smpl_pose[None, ...]), torch.as_tensor(smpl_betas[None, ...])
|
301 |
+
else:
|
302 |
+
return smpl_pose, smpl_betas
|
303 |
+
|
304 |
+
def compute_smpl_verts(self, data_dict, noise_type=None, noise_scale=None):
|
305 |
+
|
306 |
+
dataset = data_dict['dataset']
|
307 |
+
smplx_dict = {}
|
308 |
+
|
309 |
+
smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
|
310 |
+
smplx_pose = smplx_param["body_pose"] # [1,63]
|
311 |
+
smplx_betas = smplx_param["betas"] # [1,10]
|
312 |
+
smplx_pose, smplx_betas = self.add_noise(
|
313 |
+
smplx_betas.shape[1],
|
314 |
+
smplx_pose[0],
|
315 |
+
smplx_betas[0],
|
316 |
+
noise_type,
|
317 |
+
noise_scale,
|
318 |
+
type='smplx',
|
319 |
+
hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
|
320 |
+
|
321 |
+
smplx_out, _ = load_fit_body(fitted_path=data_dict['smplx_path'],
|
322 |
+
scale=self.datasets_dict[dataset]['scale'],
|
323 |
+
smpl_type='smplx',
|
324 |
+
smpl_gender='male',
|
325 |
+
noise_dict=dict(betas=smplx_betas, body_pose=smplx_pose))
|
326 |
+
|
327 |
+
smplx_dict.update({"type": "smplx",
|
328 |
+
"gender": 'male',
|
329 |
+
"body_pose": torch.as_tensor(smplx_pose),
|
330 |
+
"betas": torch.as_tensor(smplx_betas)})
|
331 |
+
|
332 |
+
return smplx_out.vertices, smplx_dict
|
333 |
+
|
334 |
+
def compute_voxel_verts(self,
|
335 |
+
data_dict,
|
336 |
+
noise_type=None,
|
337 |
+
noise_scale=None):
|
338 |
+
|
339 |
+
smpl_param = np.load(data_dict['smpl_path'], allow_pickle=True)
|
340 |
+
smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
|
341 |
+
|
342 |
+
smpl_pose = rotation_matrix_to_angle_axis(
|
343 |
+
torch.as_tensor(smpl_param['full_pose'][0])).numpy()
|
344 |
+
smpl_betas = smpl_param["betas"]
|
345 |
+
|
346 |
+
smpl_path = osp.join(self.smplx.model_dir, "smpl/SMPL_MALE.pkl")
|
347 |
+
tetra_path = osp.join(self.smplx.tedra_dir,
|
348 |
+
"tetra_male_adult_smpl.npz")
|
349 |
+
|
350 |
+
smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
|
351 |
+
|
352 |
+
smpl_pose, smpl_betas = self.add_noise(
|
353 |
+
smpl_model.beta_shape[0],
|
354 |
+
smpl_pose.flatten(),
|
355 |
+
smpl_betas[0],
|
356 |
+
noise_type,
|
357 |
+
noise_scale,
|
358 |
+
type='smpl',
|
359 |
+
hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
|
360 |
+
|
361 |
+
smpl_model.set_params(pose=smpl_pose.reshape(-1, 3),
|
362 |
+
beta=smpl_betas,
|
363 |
+
trans=smpl_param["transl"])
|
364 |
+
|
365 |
+
verts = (np.concatenate([smpl_model.verts, smpl_model.verts_added],
|
366 |
+
axis=0) * smplx_param["scale"] + smplx_param["translation"]
|
367 |
+
) * self.datasets_dict[data_dict['dataset']]['scale']
|
368 |
+
faces = np.loadtxt(osp.join(self.smplx.tedra_dir, "tetrahedrons_male_adult.txt"),
|
369 |
+
dtype=np.int32) - 1
|
370 |
+
|
371 |
+
pad_v_num = int(8000 - verts.shape[0])
|
372 |
+
pad_f_num = int(25100 - faces.shape[0])
|
373 |
+
|
374 |
+
verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
|
375 |
+
mode='constant',
|
376 |
+
constant_values=0.0).astype(np.float32)
|
377 |
+
faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
|
378 |
+
mode='constant',
|
379 |
+
constant_values=0.0).astype(np.int32)
|
380 |
+
|
381 |
+
|
382 |
+
return verts, faces, pad_v_num, pad_f_num
|
383 |
+
|
384 |
+
def load_smpl(self, data_dict, vis=False):
|
385 |
+
|
386 |
+
smplx_verts, smplx_dict = self.compute_smpl_verts(
|
387 |
+
data_dict, self.noise_type,
|
388 |
+
self.noise_scale) # compute using smpl model
|
389 |
+
|
390 |
+
smplx_verts = projection(smplx_verts, data_dict['calib']).float()
|
391 |
+
smplx_faces = torch.as_tensor(self.smplx.faces).long()
|
392 |
+
smplx_vis = torch.load(data_dict['vis_path']).float()
|
393 |
+
smplx_cmap = torch.as_tensor(
|
394 |
+
np.load(self.smplx.cmap_vert_path)).float()
|
395 |
+
|
396 |
+
# get smpl_signs
|
397 |
+
query_points = projection(data_dict['samples_geo'],
|
398 |
+
data_dict['calib']).float()
|
399 |
+
|
400 |
+
pts_signs = 2.0 * (check_sign(smplx_verts.unsqueeze(0),
|
401 |
+
smplx_faces,
|
402 |
+
query_points.unsqueeze(0)).float() - 0.5).squeeze(0)
|
403 |
+
|
404 |
+
return_dict = {
|
405 |
+
'smpl_verts': smplx_verts,
|
406 |
+
'smpl_faces': smplx_faces,
|
407 |
+
'smpl_vis': smplx_vis,
|
408 |
+
'smpl_cmap': smplx_cmap,
|
409 |
+
'pts_signs': pts_signs
|
410 |
+
}
|
411 |
+
if smplx_dict is not None:
|
412 |
+
return_dict.update(smplx_dict)
|
413 |
+
|
414 |
+
if vis:
|
415 |
+
|
416 |
+
(xy, z) = torch.as_tensor(smplx_verts).to(
|
417 |
+
self.device).split([2, 1], dim=1)
|
418 |
+
smplx_vis = get_visibility(xy, z, torch.as_tensor(
|
419 |
+
smplx_faces).to(self.device).long())
|
420 |
+
|
421 |
+
T_normal_F, T_normal_B = self.render_normal(
|
422 |
+
(smplx_verts*torch.tensor([1.0, -1.0, 1.0])).to(self.device),
|
423 |
+
smplx_faces.to(self.device))
|
424 |
+
|
425 |
+
return_dict.update({"T_normal_F": T_normal_F.squeeze(0),
|
426 |
+
"T_normal_B": T_normal_B.squeeze(0)})
|
427 |
+
query_points = projection(data_dict['samples_geo'],
|
428 |
+
data_dict['calib']).float()
|
429 |
+
|
430 |
+
smplx_sdf, smplx_norm, smplx_cmap, smplx_vis = cal_sdf_batch(
|
431 |
+
smplx_verts.unsqueeze(0).to(self.device),
|
432 |
+
smplx_faces.unsqueeze(0).to(self.device),
|
433 |
+
smplx_cmap.unsqueeze(0).to(self.device),
|
434 |
+
smplx_vis.unsqueeze(0).to(self.device),
|
435 |
+
query_points.unsqueeze(0).contiguous().to(self.device))
|
436 |
+
|
437 |
+
return_dict.update({
|
438 |
+
'smpl_feat':
|
439 |
+
torch.cat(
|
440 |
+
(smplx_sdf[0].detach().cpu(),
|
441 |
+
smplx_cmap[0].detach().cpu(),
|
442 |
+
smplx_norm[0].detach().cpu(),
|
443 |
+
smplx_vis[0].detach().cpu()),
|
444 |
+
dim=1)
|
445 |
+
})
|
446 |
+
|
447 |
+
return return_dict
|
448 |
+
|
449 |
+
def load_smpl_voxel(self, data_dict):
|
450 |
+
|
451 |
+
smpl_verts, smpl_faces, pad_v_num, pad_f_num = self.compute_voxel_verts(
|
452 |
+
data_dict, self.noise_type,
|
453 |
+
self.noise_scale) # compute using smpl model
|
454 |
+
smpl_verts = projection(smpl_verts, data_dict['calib'])
|
455 |
+
|
456 |
+
smpl_verts *= 0.5
|
457 |
+
|
458 |
+
return {
|
459 |
+
'voxel_verts': smpl_verts,
|
460 |
+
'voxel_faces': smpl_faces,
|
461 |
+
'pad_v_num': pad_v_num,
|
462 |
+
'pad_f_num': pad_f_num
|
463 |
+
}
|
464 |
+
|
465 |
+
def get_sampling_geo(self, data_dict, is_valid=False, is_sdf=False):
|
466 |
+
|
467 |
+
mesh = data_dict['mesh']
|
468 |
+
calib = data_dict['calib']
|
469 |
+
|
470 |
+
# Samples are around the true surface with an offset
|
471 |
+
n_samples_surface = 4 * self.opt.num_sample_geo
|
472 |
+
vert_ids = np.arange(mesh.verts.shape[0])
|
473 |
+
thickness_sample_ratio = np.ones_like(vert_ids).astype(np.float32)
|
474 |
+
|
475 |
+
thickness_sample_ratio /= thickness_sample_ratio.sum()
|
476 |
+
|
477 |
+
samples_surface_ids = np.random.choice(vert_ids,
|
478 |
+
n_samples_surface,
|
479 |
+
replace=True,
|
480 |
+
p=thickness_sample_ratio)
|
481 |
+
|
482 |
+
samples_normal_ids = np.random.choice(vert_ids,
|
483 |
+
self.opt.num_sample_geo // 2,
|
484 |
+
replace=False,
|
485 |
+
p=thickness_sample_ratio)
|
486 |
+
|
487 |
+
surf_samples = mesh.verts[samples_normal_ids, :]
|
488 |
+
surf_normals = mesh.vert_normals[samples_normal_ids, :]
|
489 |
+
|
490 |
+
samples_surface = mesh.verts[samples_surface_ids, :]
|
491 |
+
|
492 |
+
# Sampling offsets are random noise with constant scale (15cm - 20cm)
|
493 |
+
offset = np.random.normal(scale=self.opt.sigma_geo,
|
494 |
+
size=(n_samples_surface, 1))
|
495 |
+
samples_surface += mesh.vert_normals[samples_surface_ids, :] * offset
|
496 |
+
|
497 |
+
# Uniform samples in [-1, 1]
|
498 |
+
calib_inv = np.linalg.inv(calib)
|
499 |
+
n_samples_space = self.opt.num_sample_geo // 4
|
500 |
+
samples_space_img = 2.0 * np.random.rand(n_samples_space, 3) - 1.0
|
501 |
+
samples_space = projection(samples_space_img, calib_inv)
|
502 |
+
|
503 |
+
# z-ray direction samples
|
504 |
+
if self.opt.zray_type and not is_valid:
|
505 |
+
n_samples_rayz = self.opt.ray_sample_num
|
506 |
+
samples_surface_cube = projection(samples_surface, calib)
|
507 |
+
samples_surface_cube_repeat = np.repeat(samples_surface_cube,
|
508 |
+
n_samples_rayz,
|
509 |
+
axis=0)
|
510 |
+
|
511 |
+
thickness_repeat = np.repeat(0.5 *
|
512 |
+
np.ones_like(samples_surface_ids),
|
513 |
+
n_samples_rayz,
|
514 |
+
axis=0)
|
515 |
+
|
516 |
+
noise_repeat = np.random.normal(scale=0.40,
|
517 |
+
size=(n_samples_surface *
|
518 |
+
n_samples_rayz, ))
|
519 |
+
samples_surface_cube_repeat[:,
|
520 |
+
-1] += thickness_repeat * noise_repeat
|
521 |
+
samples_surface_rayz = projection(samples_surface_cube_repeat,
|
522 |
+
calib_inv)
|
523 |
+
|
524 |
+
samples = np.concatenate(
|
525 |
+
[samples_surface, samples_space, samples_surface_rayz], 0)
|
526 |
+
else:
|
527 |
+
samples = np.concatenate([samples_surface, samples_space], 0)
|
528 |
+
|
529 |
+
np.random.shuffle(samples)
|
530 |
+
|
531 |
+
# labels: in->1.0; out->0.0.
|
532 |
+
if is_sdf:
|
533 |
+
sdfs = mesh.get_sdf(samples)
|
534 |
+
inside_samples = samples[sdfs < 0]
|
535 |
+
outside_samples = samples[sdfs >= 0]
|
536 |
+
|
537 |
+
inside_sdfs = sdfs[sdfs < 0]
|
538 |
+
outside_sdfs = sdfs[sdfs >= 0]
|
539 |
+
else:
|
540 |
+
inside = mesh.contains(samples)
|
541 |
+
inside_samples = samples[inside >= 0.5]
|
542 |
+
outside_samples = samples[inside < 0.5]
|
543 |
+
|
544 |
+
nin = inside_samples.shape[0]
|
545 |
+
|
546 |
+
if nin > self.opt.num_sample_geo // 2:
|
547 |
+
inside_samples = inside_samples[:self.opt.num_sample_geo // 2]
|
548 |
+
outside_samples = outside_samples[:self.opt.num_sample_geo // 2]
|
549 |
+
if is_sdf:
|
550 |
+
inside_sdfs = inside_sdfs[:self.opt.num_sample_geo // 2]
|
551 |
+
outside_sdfs = outside_sdfs[:self.opt.num_sample_geo // 2]
|
552 |
+
else:
|
553 |
+
outside_samples = outside_samples[:(self.opt.num_sample_geo - nin)]
|
554 |
+
if is_sdf:
|
555 |
+
outside_sdfs = outside_sdfs[:(self.opt.num_sample_geo - nin)]
|
556 |
+
|
557 |
+
if is_sdf:
|
558 |
+
samples = np.concatenate(
|
559 |
+
[inside_samples, outside_samples, surf_samples], 0)
|
560 |
+
|
561 |
+
labels = np.concatenate([
|
562 |
+
inside_sdfs, outside_sdfs, 0.0 * np.ones(surf_samples.shape[0])
|
563 |
+
])
|
564 |
+
|
565 |
+
normals = np.zeros_like(samples)
|
566 |
+
normals[-self.opt.num_sample_geo // 2:, :] = surf_normals
|
567 |
+
|
568 |
+
# convert sdf from [-14, 130] to [0, 1]
|
569 |
+
# outside: 0, inside: 1
|
570 |
+
# Note: Marching cubes is defined on occupancy space (inside=1.0, outside=0.0)
|
571 |
+
|
572 |
+
labels = -labels.clip(min=-self.sdf_clip, max=self.sdf_clip)
|
573 |
+
labels += self.sdf_clip
|
574 |
+
labels /= (self.sdf_clip * 2)
|
575 |
+
|
576 |
+
else:
|
577 |
+
samples = np.concatenate([inside_samples, outside_samples])
|
578 |
+
labels = np.concatenate([
|
579 |
+
np.ones(inside_samples.shape[0]),
|
580 |
+
np.zeros(outside_samples.shape[0])
|
581 |
+
])
|
582 |
+
|
583 |
+
normals = np.zeros_like(samples)
|
584 |
+
|
585 |
+
samples = torch.from_numpy(samples).float()
|
586 |
+
labels = torch.from_numpy(labels).float()
|
587 |
+
normals = torch.from_numpy(normals).float()
|
588 |
+
|
589 |
+
return {'samples_geo': samples, 'labels_geo': labels}
|
590 |
+
|
591 |
+
def visualize_sampling3D(self, data_dict, mode='vis'):
|
592 |
+
|
593 |
+
# create plot
|
594 |
+
vp = vedo.Plotter(title="", size=(1500, 1500), axes=0, bg='white')
|
595 |
+
vis_list = []
|
596 |
+
|
597 |
+
assert mode in ['vis', 'sdf', 'normal', 'cmap', 'occ']
|
598 |
+
|
599 |
+
# sdf-1 cmap-3 norm-3 vis-1
|
600 |
+
if mode == 'vis':
|
601 |
+
labels = data_dict[f'smpl_feat'][:, [-1]] # visibility
|
602 |
+
colors = np.concatenate([labels, labels, labels], axis=1)
|
603 |
+
elif mode == 'occ':
|
604 |
+
labels = data_dict[f'labels_geo'][..., None] # occupancy
|
605 |
+
colors = np.concatenate([labels, labels, labels], axis=1)
|
606 |
+
elif mode == 'sdf':
|
607 |
+
labels = data_dict[f'smpl_feat'][:, [0]] # sdf
|
608 |
+
labels -= labels.min()
|
609 |
+
labels /= labels.max()
|
610 |
+
colors = np.concatenate([labels, labels, labels], axis=1)
|
611 |
+
elif mode == 'normal':
|
612 |
+
labels = data_dict[f'smpl_feat'][:, -4:-1] # normal
|
613 |
+
colors = (labels + 1.0) * 0.5
|
614 |
+
elif mode == 'cmap':
|
615 |
+
labels = data_dict[f'smpl_feat'][:, -7:-4] # colormap
|
616 |
+
colors = np.array(labels)
|
617 |
+
|
618 |
+
points = projection(data_dict['samples_geo'], data_dict['calib'])
|
619 |
+
verts = projection(data_dict['verts'], data_dict['calib'])
|
620 |
+
points[:, 1] *= -1
|
621 |
+
verts[:, 1] *= -1
|
622 |
+
|
623 |
+
# create a mesh
|
624 |
+
mesh = trimesh.Trimesh(verts, data_dict['faces'], process=True)
|
625 |
+
mesh.visual.vertex_colors = [128.0, 128.0, 128.0, 255.0]
|
626 |
+
vis_list.append(mesh)
|
627 |
+
|
628 |
+
if 'voxel_verts' in data_dict.keys():
|
629 |
+
print(colored("voxel verts", "green"))
|
630 |
+
voxel_verts = data_dict['voxel_verts'] * 2.0
|
631 |
+
voxel_faces = data_dict['voxel_faces']
|
632 |
+
voxel_verts[:, 1] *= -1
|
633 |
+
voxel = trimesh.Trimesh(
|
634 |
+
voxel_verts, voxel_faces[:, [0, 2, 1]], process=False, maintain_order=True)
|
635 |
+
voxel.visual.vertex_colors = [0.0, 128.0, 0.0, 255.0]
|
636 |
+
vis_list.append(voxel)
|
637 |
+
|
638 |
+
if 'smpl_verts' in data_dict.keys():
|
639 |
+
print(colored("smpl verts", "green"))
|
640 |
+
smplx_verts = data_dict['smpl_verts']
|
641 |
+
smplx_faces = data_dict['smpl_faces']
|
642 |
+
smplx_verts[:, 1] *= -1
|
643 |
+
smplx = trimesh.Trimesh(
|
644 |
+
smplx_verts, smplx_faces[:, [0, 2, 1]], process=False, maintain_order=True)
|
645 |
+
smplx.visual.vertex_colors = [128.0, 128.0, 0.0, 255.0]
|
646 |
+
vis_list.append(smplx)
|
647 |
+
|
648 |
+
# create a picure
|
649 |
+
img_pos = [1.0, 0.0, -1.0]
|
650 |
+
for img_id, img_key in enumerate(['normal_F', 'image', 'T_normal_B']):
|
651 |
+
image_arr = (data_dict[img_key].detach().cpu().permute(
|
652 |
+
1, 2, 0).numpy() + 1.0) * 0.5 * 255.0
|
653 |
+
image_dim = image_arr.shape[0]
|
654 |
+
image = vedo.Picture(image_arr).scale(
|
655 |
+
2.0 / image_dim).pos(-1.0, -1.0, img_pos[img_id])
|
656 |
+
vis_list.append(image)
|
657 |
+
|
658 |
+
# create a pointcloud
|
659 |
+
pc = vedo.Points(points, r=15, c=np.float32(colors))
|
660 |
+
vis_list.append(pc)
|
661 |
+
|
662 |
+
vp.show(*vis_list, bg="white", axes=1.0, interactive=True)
|
lib/dataset/TestDataset.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
5 |
+
# holder of all proprietary rights on this computer program.
|
6 |
+
# You can only use this computer program if you have closed
|
7 |
+
# a license agreement with MPG or you get the right to use the computer
|
8 |
+
# program from someone who is authorized to grant you that right.
|
9 |
+
# Any use of the computer program without a valid license is prohibited and
|
10 |
+
# liable to prosecution.
|
11 |
+
#
|
12 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
13 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
14 |
+
# for Intelligent Systems. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: [email protected]
|
17 |
+
|
18 |
+
import smplx
|
19 |
+
from lib.pymaf.utils.geometry import rotation_matrix_to_angle_axis, batch_rodrigues
|
20 |
+
from lib.pymaf.utils.imutils import process_image
|
21 |
+
from lib.pymaf.core import path_config
|
22 |
+
from lib.pymaf.models import pymaf_net
|
23 |
+
from lib.common.config import cfg
|
24 |
+
from lib.common.render import Render
|
25 |
+
from lib.dataset.body_model import TetraSMPLModel
|
26 |
+
from lib.dataset.mesh_util import get_visibility, SMPLX
|
27 |
+
import os.path as osp
|
28 |
+
import os
|
29 |
+
import torch
|
30 |
+
import glob
|
31 |
+
import numpy as np
|
32 |
+
import random
|
33 |
+
import human_det
|
34 |
+
from termcolor import colored
|
35 |
+
from PIL import ImageFile
|
36 |
+
|
37 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
38 |
+
|
39 |
+
|
40 |
+
class TestDataset():
|
41 |
+
def __init__(self, cfg, device):
|
42 |
+
|
43 |
+
random.seed(1993)
|
44 |
+
|
45 |
+
self.image_dir = cfg['image_dir']
|
46 |
+
self.seg_dir = cfg['seg_dir']
|
47 |
+
self.has_det = cfg['has_det']
|
48 |
+
self.hps_type = cfg['hps_type']
|
49 |
+
self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx'
|
50 |
+
self.smpl_gender = 'neutral'
|
51 |
+
|
52 |
+
self.device = device
|
53 |
+
|
54 |
+
if self.has_det:
|
55 |
+
self.det = human_det.Detection()
|
56 |
+
else:
|
57 |
+
self.det = None
|
58 |
+
|
59 |
+
keep_lst = sorted(glob.glob(f"{self.image_dir}/*"))
|
60 |
+
img_fmts = ['jpg', 'png', 'jpeg', "JPG", 'bmp']
|
61 |
+
keep_lst = [
|
62 |
+
item for item in keep_lst if item.split(".")[-1] in img_fmts
|
63 |
+
]
|
64 |
+
|
65 |
+
self.subject_list = sorted(
|
66 |
+
[item for item in keep_lst if item.split(".")[-1] in img_fmts])
|
67 |
+
|
68 |
+
# smpl related
|
69 |
+
self.smpl_data = SMPLX()
|
70 |
+
|
71 |
+
self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(
|
72 |
+
model_path=self.smpl_data.model_dir,
|
73 |
+
gender=smpl_gender,
|
74 |
+
model_type=smpl_type,
|
75 |
+
ext='npz')
|
76 |
+
|
77 |
+
# Load SMPL model
|
78 |
+
self.smpl_model = self.get_smpl_model(
|
79 |
+
self.smpl_type, self.smpl_gender).to(self.device)
|
80 |
+
self.faces = self.smpl_model.faces
|
81 |
+
|
82 |
+
self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS,
|
83 |
+
pretrained=True).to(self.device)
|
84 |
+
self.hps.load_state_dict(torch.load(
|
85 |
+
path_config.CHECKPOINT_FILE)['model'],
|
86 |
+
strict=True)
|
87 |
+
self.hps.eval()
|
88 |
+
|
89 |
+
print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green"))
|
90 |
+
|
91 |
+
self.render = Render(size=512, device=device)
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.subject_list)
|
95 |
+
|
96 |
+
def compute_vis_cmap(self, smpl_verts, smpl_faces):
|
97 |
+
|
98 |
+
(xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
|
99 |
+
smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
|
100 |
+
if self.smpl_type == 'smpl':
|
101 |
+
smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
|
102 |
+
else:
|
103 |
+
smplx_ind = np.arange(smpl_vis.shape[0])
|
104 |
+
smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
|
105 |
+
|
106 |
+
return {
|
107 |
+
'smpl_vis': smpl_vis.unsqueeze(0).to(self.device),
|
108 |
+
'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device),
|
109 |
+
'smpl_verts': smpl_verts.unsqueeze(0)
|
110 |
+
}
|
111 |
+
|
112 |
+
def compute_voxel_verts(self, body_pose, global_orient, betas, trans,
|
113 |
+
scale):
|
114 |
+
|
115 |
+
smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl")
|
116 |
+
tetra_path = osp.join(self.smpl_data.tedra_dir,
|
117 |
+
'tetra_neutral_adult_smpl.npz')
|
118 |
+
smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
|
119 |
+
|
120 |
+
pose = torch.cat([global_orient[0], body_pose[0]], dim=0)
|
121 |
+
smpl_model.set_params(rotation_matrix_to_angle_axis(pose),
|
122 |
+
beta=betas[0])
|
123 |
+
|
124 |
+
verts = np.concatenate(
|
125 |
+
[smpl_model.verts, smpl_model.verts_added],
|
126 |
+
axis=0) * scale.item() + trans.detach().cpu().numpy()
|
127 |
+
faces = np.loadtxt(osp.join(self.smpl_data.tedra_dir,
|
128 |
+
'tetrahedrons_neutral_adult.txt'),
|
129 |
+
dtype=np.int32) - 1
|
130 |
+
|
131 |
+
pad_v_num = int(8000 - verts.shape[0])
|
132 |
+
pad_f_num = int(25100 - faces.shape[0])
|
133 |
+
|
134 |
+
verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
|
135 |
+
mode='constant',
|
136 |
+
constant_values=0.0).astype(np.float32) * 0.5
|
137 |
+
faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
|
138 |
+
mode='constant',
|
139 |
+
constant_values=0.0).astype(np.int32)
|
140 |
+
|
141 |
+
verts[:, 2] *= -1.0
|
142 |
+
|
143 |
+
voxel_dict = {
|
144 |
+
'voxel_verts':
|
145 |
+
torch.from_numpy(verts).to(self.device).unsqueeze(0).float(),
|
146 |
+
'voxel_faces':
|
147 |
+
torch.from_numpy(faces).to(self.device).unsqueeze(0).long(),
|
148 |
+
'pad_v_num':
|
149 |
+
torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(),
|
150 |
+
'pad_f_num':
|
151 |
+
torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long()
|
152 |
+
}
|
153 |
+
|
154 |
+
return voxel_dict
|
155 |
+
|
156 |
+
def __getitem__(self, index):
|
157 |
+
|
158 |
+
img_path = self.subject_list[index]
|
159 |
+
img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
|
160 |
+
|
161 |
+
if self.seg_dir is None:
|
162 |
+
img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image(
|
163 |
+
img_path, self.det, self.hps_type, 512, self.device)
|
164 |
+
|
165 |
+
data_dict = {
|
166 |
+
'name': img_name,
|
167 |
+
'image': img_icon.to(self.device).unsqueeze(0),
|
168 |
+
'ori_image': img_ori,
|
169 |
+
'mask': img_mask,
|
170 |
+
'uncrop_param': uncrop_param
|
171 |
+
}
|
172 |
+
|
173 |
+
else:
|
174 |
+
img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image(
|
175 |
+
img_path, self.det, self.hps_type, 512, self.device,
|
176 |
+
seg_path=os.path.join(self.seg_dir, f'{img_name}.json'))
|
177 |
+
data_dict = {
|
178 |
+
'name': img_name,
|
179 |
+
'image': img_icon.to(self.device).unsqueeze(0),
|
180 |
+
'ori_image': img_ori,
|
181 |
+
'mask': img_mask,
|
182 |
+
'uncrop_param': uncrop_param,
|
183 |
+
'segmentations': segmentations
|
184 |
+
}
|
185 |
+
|
186 |
+
with torch.no_grad():
|
187 |
+
# import ipdb; ipdb.set_trace()
|
188 |
+
preds_dict = self.hps.forward(img_hps)
|
189 |
+
|
190 |
+
data_dict['smpl_faces'] = torch.Tensor(
|
191 |
+
self.faces.astype(np.int16)).long().unsqueeze(0).to(
|
192 |
+
self.device)
|
193 |
+
|
194 |
+
if self.hps_type == 'pymaf':
|
195 |
+
output = preds_dict['smpl_out'][-1]
|
196 |
+
scale, tranX, tranY = output['theta'][0, :3]
|
197 |
+
data_dict['betas'] = output['pred_shape']
|
198 |
+
data_dict['body_pose'] = output['rotmat'][:, 1:]
|
199 |
+
data_dict['global_orient'] = output['rotmat'][:, 0:1]
|
200 |
+
data_dict['smpl_verts'] = output['verts']
|
201 |
+
|
202 |
+
elif self.hps_type == 'pare':
|
203 |
+
data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:]
|
204 |
+
data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1]
|
205 |
+
data_dict['betas'] = preds_dict['pred_shape']
|
206 |
+
data_dict['smpl_verts'] = preds_dict['smpl_vertices']
|
207 |
+
scale, tranX, tranY = preds_dict['pred_cam'][0, :3]
|
208 |
+
|
209 |
+
elif self.hps_type == 'pixie':
|
210 |
+
data_dict.update(preds_dict)
|
211 |
+
data_dict['body_pose'] = preds_dict['body_pose']
|
212 |
+
data_dict['global_orient'] = preds_dict['global_pose']
|
213 |
+
data_dict['betas'] = preds_dict['shape']
|
214 |
+
data_dict['smpl_verts'] = preds_dict['vertices']
|
215 |
+
scale, tranX, tranY = preds_dict['cam'][0, :3]
|
216 |
+
|
217 |
+
elif self.hps_type == 'hybrik':
|
218 |
+
data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:]
|
219 |
+
data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]]
|
220 |
+
data_dict['betas'] = preds_dict['pred_shape']
|
221 |
+
data_dict['smpl_verts'] = preds_dict['pred_vertices']
|
222 |
+
scale, tranX, tranY = preds_dict['pred_camera'][0, :3]
|
223 |
+
scale = scale * 2
|
224 |
+
|
225 |
+
elif self.hps_type == 'bev':
|
226 |
+
data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[
|
227 |
+
[0], :10].to(self.device).float()
|
228 |
+
pred_thetas = batch_rodrigues(torch.from_numpy(
|
229 |
+
preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float()
|
230 |
+
data_dict['body_pose'] = pred_thetas[1:][None].to(self.device)
|
231 |
+
data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device)
|
232 |
+
data_dict['smpl_verts'] = torch.from_numpy(
|
233 |
+
preds_dict['verts'][[0]]).to(self.device).float()
|
234 |
+
tranX = preds_dict['cam_trans'][0, 0]
|
235 |
+
tranY = preds_dict['cam'][0, 1] + 0.28
|
236 |
+
scale = preds_dict['cam'][0, 0] * 1.1
|
237 |
+
|
238 |
+
data_dict['scale'] = scale
|
239 |
+
data_dict['trans'] = torch.tensor(
|
240 |
+
[tranX, tranY, 0.0]).to(self.device).float()
|
241 |
+
|
242 |
+
# data_dict info (key-shape):
|
243 |
+
# scale, tranX, tranY - tensor.float
|
244 |
+
# betas - [1,10] / [1, 200]
|
245 |
+
# body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
|
246 |
+
# global_orient - [1, 1, 3, 3]
|
247 |
+
# smpl_verts - [1, 6890, 3] / [1, 10475, 3]
|
248 |
+
|
249 |
+
return data_dict
|
250 |
+
|
251 |
+
def render_normal(self, verts, faces):
|
252 |
+
|
253 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
254 |
+
self.render.load_meshes(verts, faces)
|
255 |
+
return self.render.get_rgb_image()
|
256 |
+
|
257 |
+
def render_depth(self, verts, faces):
|
258 |
+
|
259 |
+
# render optimized mesh (normal, T_normal, image [-1,1])
|
260 |
+
self.render.load_meshes(verts, faces)
|
261 |
+
return self.render.get_depth_map(cam_ids=[0, 2])
|
262 |
+
|
263 |
+
def visualize_alignment(self, data):
|
264 |
+
|
265 |
+
import vedo
|
266 |
+
import trimesh
|
267 |
+
|
268 |
+
if self.hps_type != 'pixie':
|
269 |
+
smpl_out = self.smpl_model(betas=data['betas'],
|
270 |
+
body_pose=data['body_pose'],
|
271 |
+
global_orient=data['global_orient'],
|
272 |
+
pose2rot=False)
|
273 |
+
smpl_verts = (
|
274 |
+
(smpl_out.vertices + data['trans']) * data['scale']).detach().cpu().numpy()[0]
|
275 |
+
else:
|
276 |
+
smpl_verts, _, _ = self.smpl_model(shape_params=data['betas'],
|
277 |
+
expression_params=data['exp'],
|
278 |
+
body_pose=data['body_pose'],
|
279 |
+
global_pose=data['global_orient'],
|
280 |
+
jaw_pose=data['jaw_pose'],
|
281 |
+
left_hand_pose=data['left_hand_pose'],
|
282 |
+
right_hand_pose=data['right_hand_pose'])
|
283 |
+
|
284 |
+
smpl_verts = (
|
285 |
+
(smpl_verts + data['trans']) * data['scale']).detach().cpu().numpy()[0]
|
286 |
+
|
287 |
+
smpl_verts *= np.array([1.0, -1.0, -1.0])
|
288 |
+
faces = data['smpl_faces'][0].detach().cpu().numpy()
|
289 |
+
|
290 |
+
image_P = data['image']
|
291 |
+
image_F, image_B = self.render_normal(smpl_verts, faces)
|
292 |
+
|
293 |
+
# create plot
|
294 |
+
vp = vedo.Plotter(title="", size=(1500, 1500))
|
295 |
+
vis_list = []
|
296 |
+
|
297 |
+
image_F = (
|
298 |
+
0.5 * (1.0 + image_F[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
|
299 |
+
image_B = (
|
300 |
+
0.5 * (1.0 + image_B[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
|
301 |
+
image_P = (
|
302 |
+
0.5 * (1.0 + image_P[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0)
|
303 |
+
|
304 |
+
vis_list.append(vedo.Picture(image_P*0.5+image_F *
|
305 |
+
0.5).scale(2.0/image_P.shape[0]).pos(-1.0, -1.0, 1.0))
|
306 |
+
vis_list.append(vedo.Picture(image_F).scale(
|
307 |
+
2.0/image_F.shape[0]).pos(-1.0, -1.0, -0.5))
|
308 |
+
vis_list.append(vedo.Picture(image_B).scale(
|
309 |
+
2.0/image_B.shape[0]).pos(-1.0, -1.0, -1.0))
|
310 |
+
|
311 |
+
# create a mesh
|
312 |
+
mesh = trimesh.Trimesh(smpl_verts, faces, process=False)
|
313 |
+
mesh.visual.vertex_colors = [200, 200, 0]
|
314 |
+
vis_list.append(mesh)
|
315 |
+
|
316 |
+
vp.show(*vis_list, bg="white", axes=1, interactive=True)
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == '__main__':
|
320 |
+
|
321 |
+
cfg.merge_from_file("./configs/icon-filter.yaml")
|
322 |
+
cfg.merge_from_file('./lib/pymaf/configs/pymaf_config.yaml')
|
323 |
+
|
324 |
+
cfg_show_list = [
|
325 |
+
'test_gpus', ['0'], 'mcube_res', 512, 'clean_mesh', False
|
326 |
+
]
|
327 |
+
|
328 |
+
cfg.merge_from_list(cfg_show_list)
|
329 |
+
cfg.freeze()
|
330 |
+
|
331 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
332 |
+
device = torch.device('cuda:0')
|
333 |
+
|
334 |
+
dataset = TestDataset(
|
335 |
+
{
|
336 |
+
'image_dir': "./examples",
|
337 |
+
'has_det': True, # w/ or w/o detection
|
338 |
+
'hps_type': 'bev' # pymaf/pare/pixie/hybrik/bev
|
339 |
+
}, device)
|
340 |
+
|
341 |
+
for i in range(len(dataset)):
|
342 |
+
dataset.visualize_alignment(dataset[i])
|
lib/dataset/__init__.py
ADDED
File without changes
|