diff --git a/.gitattributes b/.gitattributes
index 271605e9cbc6bf6a8cc45433a9902bcceb29ffdf..c4cf3ca026764cf36f4e3b479d40e3de165373f2 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
*.obj filter=lfs diff=lfs merge=lfs -text
+*.ply filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index e16244a21186b6529c86d110a7a969f5a6115eb9..c834a7d92efe274d91e28a171c24a8cc6d2a6a37 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,7 @@ emoji: 🗿
colorFrom: gray
colorTo: blue
sdk: gradio
+sdk_version: 4.44.1
pinned: true
license: mit
suggested_hardware: a10g-small
diff --git a/app.py b/app.py
index 85760c084fb06e30eb63b9ea328e622417620c24..116dc06619117d27352f5b292a8b2197958488a7 100644
--- a/app.py
+++ b/app.py
@@ -91,8 +91,7 @@ def main():
subprocess.run(['pip', 'list'])
- description_header = '# PPSurf: Combining Patches and Point Convolutions for Detailed Surface Reconstruction\n
- Note: Hugginface disabled docker support for Zero-GPU without notification and no solution. I will fix this space when I find the time.'
+ description_header = '# PPSurf: Combining Patches and Point Convolutions for Detailed Surface Reconstruction'
description_col0 = '''## [Github](https://github.com/cg-tuwien/ppsurf)
Supported input file formats:
diff --git a/ppsurf/.gitignore b/ppsurf/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..04cbef4d4223e88a910c37728e488bf818ee2712
--- /dev/null
+++ b/ppsurf/.gitignore
@@ -0,0 +1,145 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+.idea/
+
+*.zip
+*.pth
+
+results/
+models/
+logs/
+datasets/*
+!datasets/abc_minimal/
+logs_framework/
+debug/
+.vscode/
+backups/
+lightning_logs/
+pl_models/
diff --git a/ppsurf/LICENSE b/ppsurf/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f779382cb7201eaa170849f6ad1a9fb000e0a37a
--- /dev/null
+++ b/ppsurf/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Philipp Erler
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/ppsurf/README.md b/ppsurf/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7336fe733c515270811a6661aa423de727eac8ec
--- /dev/null
+++ b/ppsurf/README.md
@@ -0,0 +1,323 @@
+# PPSurf
+Combining Patches and Point Convolutions for Detailed Surface Reconstruction
+
+This is our implementation of [PPSurf](https://www.cg.tuwien.ac.at/research/publications/2024/erler_2024_ppsurf/),
+a network that estimates a signed distance function from point clouds. This SDF is turned into a mesh with Marching Cubes.
+
+
+
+This is our follow-up work for [Points2Surf](https://www.cg.tuwien.ac.at/research/publications/2020/erler-p2s/).
+It uses parts of [POCO](https://github.com/valeoai/POCO), mainly the network and mesh extraction.
+This work was published in [Computer Graphics Forum (Jan 2024)](https://onlinelibrary.wiley.com/doi/10.1111/cgf.15000).
+
+
+## Setup
+
+We tested this repository on these systems:
+* Windows 10/11 and Ubuntu 22.04 LTS
+* CUDA 11.7, 11.8 and 12.1
+
+To manage the Python environments, we recommend using [Micromamba](https://github.com/mamba-org/mamba),
+a much faster Anaconda alternative.
+To install it, [follow this guide](https://mamba.readthedocs.io/en/latest/micromamba-installation.html#umamba-install).
+
+Alternatively, you can install the required packages with conda by simply replacing the 'mamba' calls with 'conda'.
+Finally, you can use Pip with the requirements.txt.
+
+``` bash
+# clone this repo, a minimal dataset is included
+git clone https://github.com/ErlerPhilipp/ppsurf.git
+
+# go into the cloned dir
+cd ppsurf
+
+# create the environment with the required packages
+mamba env create --file pps{_win}.yml
+
+# activate the new environment
+mamba activate pps
+```
+Use `pps_win.yml` for Windows and `pps.yml` for other OS.
+
+Test the setup with the minimal dataset included in the repo:
+``` bash
+python full_run_pps_mini.py
+```
+
+## Datasets, Model and Results
+
+Datesets:
+``` bash
+# download the ABC training and validation set
+python datasets/download_abc_training.py
+
+# download the test datasets
+python datasets/download_testsets.py
+```
+
+Model:
+``` bash
+python models/download_ppsurf_50nn.py
+```
+Let us know in case you need the other models from the ablation.
+They were trained using old, unclean code and are not directly compatible with this repo.
+
+Results:
+
+Download the results used for the paper from [here](https://www.cg.tuwien.ac.at/research/publications/2024/erler_2024_ppsurf/).
+This includes meshes and metrics for the 50NN variant.
+
+## Reconstruct single Point Clouds
+
+After the setup, you can reconstruct a point cloud with this simple command:
+``` bash
+python pps.py rec {in_file} {out_dir} {extra params}
+
+# example
+python pps.py rec "datasets/abc_minimal/04_pts_vis/00010009_d97409455fa543b3a224250f_trimesh_000.xyz.ply" "results/my clouds/" --model.init_args.gen_resolution_global 129
+```
+Where *in_file* is the path to the point cloud and *out_dir* is the path to the output directory.
+This will download our pre-trained 50NN model if necessary and reconstruct the point cloud with it.
+You can append additional parameters as described in [Command Line Interface Section](#Command-Line-Interface).
+
+*rec* is actually not a sub-command but is converted to *predict* with the default parameters before parsing.
+You can use the *predict* sub-command directly for more control over the reconstruction:
+``` bash
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt --trainer.logger False --trainer.devices 1 \
+ --data.init_args.in_file {in_file} --model.init_args.results_dir {out_dir}
+```
+
+Using the *predict* sub-command will **not** download our pre-trained model. You can download it manually:
+``` bash
+python models/download_ppsurf_50nn.py
+```
+
+Supported file formats are:
+- PLY, STL, OBJ and other mesh files loaded by [trimesh](https://github.com/mikedh/trimesh).
+- XYZ as whitespace-separated text file, read by [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.loadtxt.html).
+Load first 3 columns as XYZ coordinates. All other columns will be ignored.
+- NPY and NPZ, read by [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.load.html).
+NPZ assumes default key='arr_0'. All columns after the first 3 columns will be ignored.
+- LAS and LAZ (version 1.0-1.4), COPC and CRS loaded by [Laspy](https://github.com/laspy/laspy).
+You may want to sub-sample large point clouds to ~250k points to avoid speed and memory issues.
+For detailed reconstruction, you'll need to extract parts of large point clouds.
+
+
+## Replicate Results
+
+Train, reconstruct and evaluate to replicate the main results (PPSurf 50NN) from the paper
+``` bash
+python full_run_pps.py
+```
+
+Training takes about 5 hours on 4 A40 GPUs. By default, training will use all available GPUs and CPUs.
+Reconstructing one object takes about 1 minute on a single A40. The test sets have almost 1000 objects in total.
+
+Logging during training with Tensorboard is enabled by default.
+We log the loss, accuracy, recall and F1 score for the sign prediction.
+You can start a Tensorboard server with:
+``` bash
+tensorboard --logdir models
+```
+
+
+## Command Line Interface
+
+PPSurf uses the Pytorch-Lightning [CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html).
+The basic structure is:
+``` bash
+# CLI command template
+python {CLI entry point} {sub-command} {configs} {extra params}
+```
+Where the *CLI entry point* is either `pps.py` or `poco.py` and *sub-command* can be one of *[fit, test, predict]*.
+*Fit* trains a model, *test* evaluates it and *predict* reconstructs a whole dataset or a single point cloud.
+
+*Configs* can be any number of YAML files. Later ones override values from earlier ones.
+This example adapts the default POCO parameters to PPSurf needs and uses all GPUs of our training server:
+``` bash
+-c configs/poco.yaml -c configs/ppsurf.yaml -c configs/device_server.yaml
+```
+
+You can override any available parameter explicitly:
+``` bash
+--model.init_args.gen_resolution_global 129 --debug True
+```
+
+When running *test*, *predict* or *rec*, you need to consider a few more things.
+Make sure to specify the model checkpoint!
+Also, you need to specify a dataset, since the default is the training set.
+Finally, you should disable the logger, or it will create empty folders and logs.
+``` bash
+--ckpt_path 'models/{name}/version_{version}/checkpoints/last.ckpt' --data.init_args.in_file 'datasets/abc_minimal/testset.txt' --trainer.logger False
+```
+where *name* is e.g. ppsurf and *version* is usually 0.
+If you run the training multiple times, you need to increment the version number.
+
+Appending this will print the assembled config without running anything:
+``` bash
+--print_config
+```
+
+These are the commands called by full_run_pps.py to reproduce our results *PPSurf 50 NN*:
+``` bash
+# train
+python pps.py fit -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/device_server.yaml -c configs/ppsurf_50nn.yaml
+
+# test
+python pps.py test -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/abc/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+
+# predict all ABC datasets
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/abc/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/abc_extra_noisy/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/abc_noisefree/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+
+# predict all Famous datasets
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/famous_original/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/famous_noisefree/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/famous_sparse/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/famous_dense/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/famous_extra_noisy/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+
+# predict all Thingi10k datasets
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/thingi10k_scans_original/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/thingi10k_scans_noisefree/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/thingi10k_scans_sparse/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/thingi10k_scans_dense/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/thingi10k_scans_extra_noisy/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+
+# predict the real-world dataset
+python pps.py predict -c configs/poco.yaml -c configs/ppsurf.yaml -c configs/ppsurf_50nn.yaml \
+ --data.init_args.in_file datasets/real_world/testset.txt --ckpt_path models/ppsurf_50nn/version_0/checkpoints/last.ckpt \
+ --trainer.logger False --trainer.devices 1
+
+ # create comparison tables (will have only 50NN column)
+ python source/figures/comp_all.py
+```
+
+
+## Outputs and Evaluation
+
+**Training**:
+Model checkpoints, hyperparameters and logs are stored in `models/{model}/version_{version}/`.
+The version number is incremented with each training run.
+The checkpoint for further use is `models/{model}/version_{version}/checkpoints/last.ckpt`.
+
+**Testing**:
+Test results are stored in `results/{model}/{dataset}/metrics_{model}.xlsx`.
+This is like the validation but on all data of the test/val set with additional metrics.
+
+**Reconstruction**:
+Reconstructed meshes are stored in `results/{model}/{dataset}/meshes`.
+After reconstruction, metrics are computed and stored in `results/{model}/{dataset}/{metric}_{model}.xlsx`,
+where *metric* is one of *[chamfer_distance, f1, iou, normal_error]*.
+
+**Metrics**:
+You can (re-)run the metrics, e.g. for other methods, with:
+``` bash
+python source/make_evaluation.py
+```
+You may need to adjust *model_names* and *dataset_names* in this script.
+This supports the results of other methods if they are in the same structure as ours.
+
+**Comparisons**:
+We provide scripts to generate comparisons in `source/figures`:
+``` bash
+python source/figures/comp_{comp_name}.py
+```
+This will:
+- assemble the per-shape metrics spreadsheets of all relevant methods in `results/comp/{dataset}/{metric}.xlsx`.
+- compute and visualize the Chamfer distance, encoded as vertex colors in
+ `results/comp/{dataset}/{method}/mesh_cd_vis` as PLY.
+- render the reconstructed mesh with and without distance colors in `results/comp/{dataset}/{method}/mesh_rend` and
+ `results/comp/{dataset}/{method}/cd_vis_rend` as PNG.
+- render the GT mesh in `results/comp/{dataset}/mesh_gt_rend` as PNG. Note that this does only work if a real screen is attached.
+- assemble the per-method mean, median and stddev for all metrics in `results/comp/{comp_name}.xlsx`.
+- assemble all renderings as a qualitative report in `results/comp/{comp_name}.html`.
+- assemble per-dataset mean for all relevant datasets, methods and metrics in `results/comp/reports/{comp_name}`
+ as spreadsheet and LaTex table.
+
+**Figures**:
+You can prepare Chamfer distance data and render the results with Blender using these scripts:
+``` bash
+python source/figures/prepare_figures.py
+python source/figures/render_meshes_blender.py
+```
+This requires some manual camera adjustment in Blender for some objects.
+Please don't ask for support on this messy last-minute code.
+
+
+## Trouble Shooting
+
+On Windows, you might run into DLL load issues. If so, try re-installing intel-openmp:
+``` bash
+mamba install -c defaults intel-openmp --force-reinstall
+```
+
+Conda/Mamba might run into a compile error while installing the environment. If so, try updating conda:
+``` bash
+conda update -n base -c defaults conda
+```
+
+Pip might fail when creating the environment. If so, try installing the Pip packages from the `pps.yml` manually.
+
+On Windows, Pip install may raise a
+"Microsoft Visual C++ 14.0 or greater is required.
+Get it with "Microsoft C++ Build Tools" error.
+In this case, install the MS Visual Studio build tools,
+as described on [Stackoverflow](https://stackoverflow.com/questions/64261546/how-to-solve-error-microsoft-visual-c-14-0-or-greater-is-required-when-inst).
+
+
+## Updates
+
+### 2023-10-13
+
+Improved speed by using [pykdtree](https://github.com/storpipfugl/pykdtree)
+instead of [Scipy KDTree](https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.KDTree.html) for k-NN queries
+
+
+## Citation
+If you use our work, please cite our paper:
+```
+@article{ppsurf2024,
+author = {Erler, Philipp and Fuentes-Perez, Lizeth and Hermosilla, Pedro and Guerrero, Paul and Pajarola, Renato and Wimmer, Michael},
+title = {PPSurf: Combining Patches and Point Convolutions for Detailed Surface Reconstruction},
+journal = {Computer Graphics Forum},
+volume = {n/a},
+number = {n/a},
+pages = {e15000},
+keywords = {modeling, surface reconstruction},
+doi = {https://doi.org/10.1111/cgf.15000},
+url = {https://onlinelibrary.wiley.com/doi/abs/10.1111/cgf.15000},
+eprint = {https://onlinelibrary.wiley.com/doi/pdf/10.1111/cgf.15000},
+abstract = {Abstract 3D surface reconstruction from point clouds is a key step in areas such as content creation, archaeology, digital cultural heritage and engineering. Current approaches either try to optimize a non-data-driven surface representation to fit the points, or learn a data-driven prior over the distribution of commonly occurring surfaces and how they correlate with potentially noisy point clouds. Data-driven methods enable robust handling of noise and typically either focus on a global or a local prior, which trade-off between robustness to noise on the global end and surface detail preservation on the local end. We propose PPSurf as a method that combines a global prior based on point convolutions and a local prior based on processing local point cloud patches. We show that this approach is robust to noise while recovering surface details more accurately than the current state-of-the-art. Our source code, pre-trained model and dataset are available at https://github.com/cg-tuwien/ppsurf.}
+}
+```
diff --git a/ppsurf/configs/device_server.yaml b/ppsurf/configs/device_server.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b0007474e894090b9a79b5aeb7a1008da483f041
--- /dev/null
+++ b/ppsurf/configs/device_server.yaml
@@ -0,0 +1,13 @@
+trainer:
+ strategy: ddp
+ # strategy: ddp_find_unused_parameters_true
+
+model:
+ init_args:
+ workers: &num_workers 48
+
+data:
+ init_args:
+ use_ddp: True
+ workers: *num_workers
+ batch_size: 12 # 50 / 4 = 12.5
\ No newline at end of file
diff --git a/ppsurf/configs/poco.yaml b/ppsurf/configs/poco.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ce9214a9e253d126236041cd7ce9af48b4fee943
--- /dev/null
+++ b/ppsurf/configs/poco.yaml
@@ -0,0 +1,77 @@
+debug: False
+seed_everything: 42
+
+trainer:
+ max_epochs: 150
+ default_root_dir: 'models/poco'
+ strategy: auto
+ accelerator: gpu
+ devices: -1
+ precision: 16-mixed
+ num_sanity_val_steps: 0
+ log_every_n_steps: 1
+ logger:
+ class_path: pytorch_lightning.loggers.TensorBoardLogger
+ init_args:
+ save_dir: 'models'
+ callbacks:
+ - class_path: source.cli.PPSProgressBar
+ - class_path: LearningRateMonitor
+ init_args:
+ logging_interval: step
+ - class_path: ModelCheckpoint
+ init_args:
+ save_last: True
+ save_top_k: 0
+
+data:
+ class_path: source.poco_data_loader.PocoDataModule
+ init_args:
+ use_ddp: False
+ in_file: datasets/abc_train/testset.txt
+ padding_factor: 0.05
+ seed: 42
+ manifold_points: 10000
+ patches_per_shape: -1
+ do_data_augmentation: True
+ batch_size: 10
+ workers: 8
+
+model:
+ class_path: source.poco_model.PocoModel
+ init_args:
+ output_names:
+ - 'imp_surf_sign'
+ in_channels: 3
+ out_channels: 2
+ k: 64
+ network_latent_size: 32
+ gen_subsample_manifold_iter: 10
+ gen_subsample_manifold: 10000
+ gen_resolution_global: 257
+ rec_batch_size: 50000
+ gen_refine_iter: 10
+ workers: 8
+ lambda_l1: 0.0
+ results_dir: 'results'
+ name: 'poco'
+ debug: False
+
+optimizer:
+ class_path: torch.optim.AdamW
+ init_args:
+ lr: 0.001
+ betas:
+ - 0.9
+ - 0.999
+ eps: 1e-5
+ weight_decay: 1e-2
+ amsgrad: False
+
+lr_scheduler:
+ class_path: torch.optim.lr_scheduler.MultiStepLR
+ init_args:
+ milestones:
+ - 75
+ - 125
+ gamma: 0.1
diff --git a/ppsurf/configs/poco_mini.yaml b/ppsurf/configs/poco_mini.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d7d306f272ccc6a472a7c70301b9521779966d70
--- /dev/null
+++ b/ppsurf/configs/poco_mini.yaml
@@ -0,0 +1,13 @@
+trainer:
+ default_root_dir: 'models/poco_mini'
+
+model:
+ init_args:
+ name: 'poco_mini'
+ gen_resolution_global: 129 # half resolution
+ rec_batch_size: 25000 # half memory
+
+data:
+ init_args:
+ in_file: datasets/abc_minimal/testset.txt # small dataset
+ batch_size: 10 # 16 GB GPU memory
\ No newline at end of file
diff --git a/ppsurf/configs/ppsurf.yaml b/ppsurf/configs/ppsurf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e060dce0695064a87615d2c7ecbd67ef6e58fcd
--- /dev/null
+++ b/ppsurf/configs/ppsurf.yaml
@@ -0,0 +1,13 @@
+trainer:
+ default_root_dir: 'models/ppsurf'
+
+model:
+ class_path: source.ppsurf_model.PPSurfModel
+ init_args:
+ network_latent_size: 256
+ num_pts_local: 50
+ pointnet_latent_size: 256
+ debug: False
+
+data:
+ class_path: source.ppsurf_data_loader.PPSurfDataModule
\ No newline at end of file
diff --git a/ppsurf/configs/ppsurf_100nn.yaml b/ppsurf/configs/ppsurf_100nn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..90868099faacd7192c0919443f13f8419b8fd93c
--- /dev/null
+++ b/ppsurf/configs/ppsurf_100nn.yaml
@@ -0,0 +1,8 @@
+trainer:
+ default_root_dir: 'models/ppsurf_100nn'
+
+model:
+ init_args:
+ name: 'ppsurf_100nn'
+ num_pts_local: 100
+
diff --git a/ppsurf/configs/ppsurf_10nn.yaml b/ppsurf/configs/ppsurf_10nn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..73492af3d5dd15f7da3b36271f0c403d26bc8c4b
--- /dev/null
+++ b/ppsurf/configs/ppsurf_10nn.yaml
@@ -0,0 +1,8 @@
+trainer:
+ default_root_dir: 'models/ppsurf_10nn'
+
+model:
+ init_args:
+ name: 'ppsurf_10nn'
+ num_pts_local: 10
+
diff --git a/ppsurf/configs/ppsurf_200nn.yaml b/ppsurf/configs/ppsurf_200nn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e25b4ada1984b8e2d43be76e21a77f43a4da7564
--- /dev/null
+++ b/ppsurf/configs/ppsurf_200nn.yaml
@@ -0,0 +1,9 @@
+trainer:
+ default_root_dir: 'models/ppsurf_200nn'
+
+model:
+ init_args:
+ name: 'ppsurf_200nn'
+ num_pts_local: 200
+ rec_batch_size: 25000
+
diff --git a/ppsurf/configs/ppsurf_25nn.yaml b/ppsurf/configs/ppsurf_25nn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f555ad95b3889362880981a9f32656f298d442de
--- /dev/null
+++ b/ppsurf/configs/ppsurf_25nn.yaml
@@ -0,0 +1,8 @@
+trainer:
+ default_root_dir: 'models/ppsurf_25nn'
+
+model:
+ init_args:
+ name: 'ppsurf_25nn'
+ num_pts_local: 25
+
diff --git a/ppsurf/configs/ppsurf_50nn.yaml b/ppsurf/configs/ppsurf_50nn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c559163055ead23e619304205445bb2e0ef2f844
--- /dev/null
+++ b/ppsurf/configs/ppsurf_50nn.yaml
@@ -0,0 +1,8 @@
+trainer:
+ default_root_dir: 'models/ppsurf_50nn'
+
+model:
+ init_args:
+ name: 'ppsurf_50nn'
+ num_pts_local: 50
+
diff --git a/ppsurf/configs/ppsurf_mini.yaml b/ppsurf/configs/ppsurf_mini.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e67e25e3551aec5b5c4b4456b2d158f90df4d43a
--- /dev/null
+++ b/ppsurf/configs/ppsurf_mini.yaml
@@ -0,0 +1,13 @@
+trainer:
+ default_root_dir: 'models/ppsurf_mini'
+
+model:
+ init_args:
+ name: 'ppsurf_mini'
+ gen_resolution_global: 129 # half resolution
+ rec_batch_size: 25000 # half memory
+
+data:
+ init_args:
+ in_file: datasets/abc_minimal/testset.txt # small dataset
+ batch_size: 10 # 16 GB GPU memory
\ No newline at end of file
diff --git a/ppsurf/configs/profiler.yaml b/ppsurf/configs/profiler.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1c77a623802d18cd33cc6250cb132b1cdfc455b0
--- /dev/null
+++ b/ppsurf/configs/profiler.yaml
@@ -0,0 +1,8 @@
+trainer:
+ profiler:
+ class_path: source.cli.PPSProfiler
+ init_args:
+ export_to_chrome: False
+ emit_nvtx: False
+ with_stack: False
+
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00010009_d97409455fa543b3a224250f_trimesh_000.ply b/ppsurf/datasets/abc_minimal/03_meshes/00010009_d97409455fa543b3a224250f_trimesh_000.ply
new file mode 100644
index 0000000000000000000000000000000000000000..147e0d70b8898ba6bac714ca4260f213fb4166b1
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00010009_d97409455fa543b3a224250f_trimesh_000.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d0e0aae89613fc83b37e886a060da3d2b224ee7f6788a9fdc85127a3fa2d5ae
+size 366326
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00010039_75f31cb4dff84986aadc622b_trimesh_000.ply b/ppsurf/datasets/abc_minimal/03_meshes/00010039_75f31cb4dff84986aadc622b_trimesh_000.ply
new file mode 100644
index 0000000000000000000000000000000000000000..71821bc1136b398268206b2ffbfedd64c7e1cd6c
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00010039_75f31cb4dff84986aadc622b_trimesh_000.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aedcc0a61e16c8af1c3972bb9fd455107aa6bbaa4cc970e5cb6192465534fd8c
+size 145139
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00010045_75f31cb4dff84986aadc622b_trimesh_006.ply b/ppsurf/datasets/abc_minimal/03_meshes/00010045_75f31cb4dff84986aadc622b_trimesh_006.ply
new file mode 100644
index 0000000000000000000000000000000000000000..d7274c6559d4bfe52b52bbf81c25b4b63375caa5
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00010045_75f31cb4dff84986aadc622b_trimesh_006.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4df8a5d100893dc2f941101ef2a726ae5f36c0673b591b5f4e5404d575a1a63
+size 104689
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00010071_493cf58028d24a5b97528c11_trimesh_001.ply b/ppsurf/datasets/abc_minimal/03_meshes/00010071_493cf58028d24a5b97528c11_trimesh_001.ply
new file mode 100644
index 0000000000000000000000000000000000000000..b4c157d75af650b4034ddc5d32b8430ded64ae2f
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00010071_493cf58028d24a5b97528c11_trimesh_001.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:141c70ef2dec82550a4f71565e80d7d1dafe44745ab1b9a139a4fbe49096a500
+size 119909
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00010074_493cf58028d24a5b97528c11_trimesh_004.ply b/ppsurf/datasets/abc_minimal/03_meshes/00010074_493cf58028d24a5b97528c11_trimesh_004.ply
new file mode 100644
index 0000000000000000000000000000000000000000..0584cc7df9764d30f668b97429856233880b2f25
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00010074_493cf58028d24a5b97528c11_trimesh_004.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c7a6526052d923223c335a254f96e6ea197a583a79afbd31a89e273e0d8616df
+size 58733
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00010089_5ae1ee45b583467fa009adc4_trimesh_000.ply b/ppsurf/datasets/abc_minimal/03_meshes/00010089_5ae1ee45b583467fa009adc4_trimesh_000.ply
new file mode 100644
index 0000000000000000000000000000000000000000..aad92a4ce34714634a90532fa728552bdd7be2b1
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00010089_5ae1ee45b583467fa009adc4_trimesh_000.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2f0beae183bc8a3251af07f9bc4b431eb5dd89a95261704a11b3a2df84ac7268
+size 176173
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00010098_1f6110e499fb41c582c50527_trimesh_001.ply b/ppsurf/datasets/abc_minimal/03_meshes/00010098_1f6110e499fb41c582c50527_trimesh_001.ply
new file mode 100644
index 0000000000000000000000000000000000000000..cd96e185b84f9d37c17d6e1ca6a1339702db6ff9
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00010098_1f6110e499fb41c582c50527_trimesh_001.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:376e1efe00155feed1f1e1dbd49d56157beaa5428e1000525b2ab405de784de1
+size 88089
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply b/ppsurf/datasets/abc_minimal/03_meshes/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply
new file mode 100644
index 0000000000000000000000000000000000000000..0fe5ac2550bf8cf420e28dde995c7dfc20f16b8a
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8c9009b48a4a61fbb7e5180d6876a4ee4df202cc5823fa2ca93a001d3df380a
+size 307116
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply b/ppsurf/datasets/abc_minimal/03_meshes/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply
new file mode 100644
index 0000000000000000000000000000000000000000..a6a6f1ad2e8b5f469399244b3ba44394a90eedf4
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5eebfc53813899cdbad4e05a02ff39a4a2472f0d23b83c74acc225d57d595f43
+size 100913
diff --git a/ppsurf/datasets/abc_minimal/03_meshes/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply b/ppsurf/datasets/abc_minimal/03_meshes/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply
new file mode 100644
index 0000000000000000000000000000000000000000..161e9f4b7596725cced1195d2a3dc8b52ea90636
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/03_meshes/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65a87b1b08c4e75beea98e47786c8d33c61d9a07d78197627ad0c17f808a4851
+size 54681
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00010009_d97409455fa543b3a224250f_trimesh_000.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00010009_d97409455fa543b3a224250f_trimesh_000.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..15a79fdf3a400fff099623fd27223fa99e4873f7
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00010009_d97409455fa543b3a224250f_trimesh_000.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b5f040b9396998a99f70f597b771e5e4d62acd03c3201784e51c3ba5506e62c
+size 720587
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00010039_75f31cb4dff84986aadc622b_trimesh_000.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00010039_75f31cb4dff84986aadc622b_trimesh_000.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..7d2bd27938ffd80dbb5558d6d91b131e9bb3cb91
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00010039_75f31cb4dff84986aadc622b_trimesh_000.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f821d95c3185e42274707c1f19387d018d89d94f9191e498f14680a5ce5e756
+size 83446
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00010045_75f31cb4dff84986aadc622b_trimesh_006.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00010045_75f31cb4dff84986aadc622b_trimesh_006.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..214039154848173dc9480c324995840937a92d7f
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00010045_75f31cb4dff84986aadc622b_trimesh_006.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db0d39e6373d6052b3344fe11f983e01cdd857c683051f1ee0bb88d17dd0cf53
+size 230699
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00010071_493cf58028d24a5b97528c11_trimesh_001.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00010071_493cf58028d24a5b97528c11_trimesh_001.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..76ed11d6c343211aa388cded367c2f8c1d9ebbd7
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00010071_493cf58028d24a5b97528c11_trimesh_001.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e96621d17d9d10535d5149f9560a43701f701ab2856381f297886bbaa6c78944
+size 807491
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00010074_493cf58028d24a5b97528c11_trimesh_004.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00010074_493cf58028d24a5b97528c11_trimesh_004.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..f13e01a2161289addd9c72c85864f70422c1a3c6
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00010074_493cf58028d24a5b97528c11_trimesh_004.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8456c94b3dd6060f7406387063b8e47c60a47eca0582bdf1873e089a17d88a09
+size 152183
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00010089_5ae1ee45b583467fa009adc4_trimesh_000.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00010089_5ae1ee45b583467fa009adc4_trimesh_000.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..63a9690c0ee237361906b1320cda6ee3fea0d516
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00010089_5ae1ee45b583467fa009adc4_trimesh_000.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:acfe34a146d7400566157045626b4cec26b6fb11e95448c3925b67b6330eca33
+size 137495
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00010098_1f6110e499fb41c582c50527_trimesh_001.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00010098_1f6110e499fb41c582c50527_trimesh_001.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..e5194109abcc48d3e95244fbf38daa6d2b880ff9
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00010098_1f6110e499fb41c582c50527_trimesh_001.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e367d04a8af767eb902da9e4047ba41db080d8dbcea958781e2b22e8259f90c
+size 147875
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00011084_fddd53ce45f640f3ab922328_trimesh_019.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00011084_fddd53ce45f640f3ab922328_trimesh_019.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..0c337e61426a9e2f267877d72a7db883f2bb500b
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00011084_fddd53ce45f640f3ab922328_trimesh_019.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dff0fca9f7e9cebfcea4142bafd860ca0ecb7a916a852d2f96e99f4c2d35b06d
+size 719963
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..6570a55ea0557b46158ed6500d7c12d5f449f019
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8ddfcb74e2a20c3086116bc18dd7cfc6501585e2fa483b6431997150cac9685
+size 1039991
diff --git a/ppsurf/datasets/abc_minimal/04_pts_vis/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.xyz.ply b/ppsurf/datasets/abc_minimal/04_pts_vis/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.xyz.ply
new file mode 100644
index 0000000000000000000000000000000000000000..8113845061e757749540d3acf6875f8b88720503
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/04_pts_vis/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.xyz.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c5cf81d7822dcabbf89d5699929670a0c3fdfa89041a2bad88114bad6c7a5b08
+size 416531
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00010009_d97409455fa543b3a224250f_trimesh_000.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00010009_d97409455fa543b3a224250f_trimesh_000.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..7efe68bd7e80c2ad4210a27ba69de1a616de5942
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00010009_d97409455fa543b3a224250f_trimesh_000.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5ac5cbe1f8a37ceac430dbca46372641835264bb9705f3303756413297ef9b4
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00010039_75f31cb4dff84986aadc622b_trimesh_000.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00010039_75f31cb4dff84986aadc622b_trimesh_000.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..796383d4111d288a820e9480066897167ddbb8db
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00010039_75f31cb4dff84986aadc622b_trimesh_000.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:91f79d74bc3dca2b959c5569fd8099578ae5381570151a72030549ee36b91748
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00010045_75f31cb4dff84986aadc622b_trimesh_006.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00010045_75f31cb4dff84986aadc622b_trimesh_006.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..10f855992eb9c3338326e0b4abe9c101472ac17b
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00010045_75f31cb4dff84986aadc622b_trimesh_006.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de26d05ad83c10705ad8e1125f2141b9b99decc431c7503f6b174b40a39f9d1b
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00010071_493cf58028d24a5b97528c11_trimesh_001.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00010071_493cf58028d24a5b97528c11_trimesh_001.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..5f360204f4f894701eff3afd8eb6c41d27cfde58
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00010071_493cf58028d24a5b97528c11_trimesh_001.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8a931c4ef1917b02a68c9c226a208d940e4f92dcec64a1d5d5602377b6b7634
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00010074_493cf58028d24a5b97528c11_trimesh_004.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00010074_493cf58028d24a5b97528c11_trimesh_004.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..47b6aad7aea9e0f796fe6030a26a52160e9d0a0b
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00010074_493cf58028d24a5b97528c11_trimesh_004.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:35f5ae18831517514f8d2e4a3ae6f9d9ca684d09775ed6a4b5ce1cd2ef2622f1
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00010089_5ae1ee45b583467fa009adc4_trimesh_000.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00010089_5ae1ee45b583467fa009adc4_trimesh_000.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..b455918efdc430f3d112f92b8424b110e67feece
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00010089_5ae1ee45b583467fa009adc4_trimesh_000.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b36dc2750703fb29a23dc19b3d3862020b8d31c293d5a50a8a117f1a6fa8be43
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00010098_1f6110e499fb41c582c50527_trimesh_001.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00010098_1f6110e499fb41c582c50527_trimesh_001.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..3f7305e229a15d36495b647aa736955833969f2d
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00010098_1f6110e499fb41c582c50527_trimesh_001.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33a7ada18b00f5859b72961d9eec834735efa7657bcb649dc6f2f6dd90438eb4
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..50ca2b25bf3d5ee49db3e0aaa6fd78d353be8470
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db4342b68b37276a04473869b173ed022c3a760771e0dfe2defcc6595ecf2818
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..5293a99a12ca6e9d368235bf93338cc343db6e1c
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:67abf7b61b07dfd4cf9b36a3419015e62d40c06d2134a27ffe5f38a28c152178
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_dist/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply.npy b/ppsurf/datasets/abc_minimal/05_query_dist/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..46ec4a02b7315a97082e46de023287b2baccb8f2
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_dist/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d987baa6647b4683bea3693263658eb35cc9a0324eae0a9782bbb9a75511ea76
+size 8128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00010009_d97409455fa543b3a224250f_trimesh_000.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00010009_d97409455fa543b3a224250f_trimesh_000.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..ec2746236b1511103dcbe8ad949bda61186fa042
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00010009_d97409455fa543b3a224250f_trimesh_000.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbf87f437ca8fbd595c7fd6357f0eb3922998fa142f50023ff5864be96787ad5
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00010039_75f31cb4dff84986aadc622b_trimesh_000.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00010039_75f31cb4dff84986aadc622b_trimesh_000.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..7069fa9ce15aa26e2a939e6a5a56096492a4e422
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00010039_75f31cb4dff84986aadc622b_trimesh_000.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:572874bc63bdda88b5afd0c5dd3c0d7dbeb276db8ca63f3bf065e23bb09ef8c5
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00010045_75f31cb4dff84986aadc622b_trimesh_006.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00010045_75f31cb4dff84986aadc622b_trimesh_006.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..bd6132911f6d4008a2c0bbd4beb2842d5dcdfb64
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00010045_75f31cb4dff84986aadc622b_trimesh_006.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:14d224b89581c3c000f6b8963b0991fcb54f5ac1b1fe245e7f74fa52b852d9c2
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00010071_493cf58028d24a5b97528c11_trimesh_001.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00010071_493cf58028d24a5b97528c11_trimesh_001.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..f26132650b6642afafa1b17a80f3291dbd9eae32
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00010071_493cf58028d24a5b97528c11_trimesh_001.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e10c9c8eec830436accb2a9ecfc875995bbf8bf8285527bd3e68d26c4d5b722f
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00010074_493cf58028d24a5b97528c11_trimesh_004.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00010074_493cf58028d24a5b97528c11_trimesh_004.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..4f4c8043e9999ca9946ed4ee642d98e2d9078ea5
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00010074_493cf58028d24a5b97528c11_trimesh_004.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1dfb63ee40bdba4edfca68454a63b0a0bed0edeac3ed3fe1eb4a6873fa4f05dd
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00010089_5ae1ee45b583467fa009adc4_trimesh_000.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00010089_5ae1ee45b583467fa009adc4_trimesh_000.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..6fa65959edd7a930c1918d6c04deb793e47e6674
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00010089_5ae1ee45b583467fa009adc4_trimesh_000.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bcd36101b53bd8c6f405e784eed08920fb8ce563f604750942d5ae1b0bbe225f
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00010098_1f6110e499fb41c582c50527_trimesh_001.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00010098_1f6110e499fb41c582c50527_trimesh_001.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..4af65468e99dd3cfbeeaa2860ebe5ec850a27cf5
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00010098_1f6110e499fb41c582c50527_trimesh_001.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63a838eaaca572ce08f3c9f39a8441f050a53a212ceb5fdbb7dd6a84ccb504e0
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..f339305e84a584f23ceb737e0169e219af8aafb9
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:58cd55593b459db924a600183579b803c675a627d1abcd0c7f0f8d2603f5ca10
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..06c2563ce1163f934a02b70fdf1c65ec9930f714
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4bfc1f45badf01cdbb30aa725e7364fd9e1d6744fbeb8d0ce4fd729cb7713c84
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_pts/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply.npy b/ppsurf/datasets/abc_minimal/05_query_pts/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply.npy
new file mode 100644
index 0000000000000000000000000000000000000000..8d32c4ef82f0a15bd2aeaf28a04ec9a5788cfac9
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_pts/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:40228fe5ab4c33dc66ef8b283e1c13807f1f1718923b17cea4d379b570c49dc1
+size 24128
diff --git a/ppsurf/datasets/abc_minimal/05_query_vis/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply.ply b/ppsurf/datasets/abc_minimal/05_query_vis/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply.ply
new file mode 100644
index 0000000000000000000000000000000000000000..96003b09ff358c2a0eadfbe4fa8be091cf19bdfc
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_vis/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:55d5a12025587c6b9e1f78e358ffc16abc6dfa595024123f376eb9995222aa98
+size 32287
diff --git a/ppsurf/datasets/abc_minimal/05_query_vis/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply.ply b/ppsurf/datasets/abc_minimal/05_query_vis/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply.ply
new file mode 100644
index 0000000000000000000000000000000000000000..1e2c8c514ccd769f02a62702171257e41d28ad4a
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_vis/00016513_3d6966cd42eb44ab8f4224f2_trimesh_053.ply.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8038e5f033111cfe51b95a8cd0cda265179c8ed3e536331a8b899d1d0e343a75
+size 32287
diff --git a/ppsurf/datasets/abc_minimal/05_query_vis/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply.ply b/ppsurf/datasets/abc_minimal/05_query_vis/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply.ply
new file mode 100644
index 0000000000000000000000000000000000000000..b7c44fc2798d27c9d8e9d177d0f7855980489747
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/05_query_vis/00994122_57d9d4755722f9d2d7436f0a_trimesh_000.ply.ply
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f3c6501eae6ded00eb5645b1163445ceae0e0dcef356636ba64828502bcfe2ba
+size 32287
diff --git a/ppsurf/datasets/abc_minimal/settings.ini b/ppsurf/datasets/abc_minimal/settings.ini
new file mode 100644
index 0000000000000000000000000000000000000000..dd28007aaaaef835b5d210c04252b3ea3b02bc72
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/settings.ini
@@ -0,0 +1,8 @@
+[general]
+only_for_evaluation = 0
+grid_resolution = 256
+epsilon = 5
+num_scans_per_mesh_min = 5
+num_scans_per_mesh_max = 30
+scanner_noise_sigma_min = 0.0
+scanner_noise_sigma_max = 0.05
\ No newline at end of file
diff --git a/ppsurf/datasets/abc_minimal/testset.txt b/ppsurf/datasets/abc_minimal/testset.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1d303e17ded2bdc4334420c80dba57828326bed6
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/testset.txt
@@ -0,0 +1,3 @@
+00010009_d97409455fa543b3a224250f_trimesh_000
+00010074_493cf58028d24a5b97528c11_trimesh_004
+00994122_57d9d4755722f9d2d7436f0a_trimesh_000
\ No newline at end of file
diff --git a/ppsurf/datasets/abc_minimal/trainset.txt b/ppsurf/datasets/abc_minimal/trainset.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5309c65848d2c1df93fe313231abe0404e5e6a5c
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/trainset.txt
@@ -0,0 +1,7 @@
+00010039_75f31cb4dff84986aadc622b_trimesh_000
+00010045_75f31cb4dff84986aadc622b_trimesh_006
+00010071_493cf58028d24a5b97528c11_trimesh_001
+00010089_5ae1ee45b583467fa009adc4_trimesh_000
+00010098_1f6110e499fb41c582c50527_trimesh_001
+00011084_fddd53ce45f640f3ab922328_trimesh_019
+00016513_3d6966cd42eb44ab8f4224f2_trimesh_053
\ No newline at end of file
diff --git a/ppsurf/datasets/abc_minimal/valset.txt b/ppsurf/datasets/abc_minimal/valset.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1d303e17ded2bdc4334420c80dba57828326bed6
--- /dev/null
+++ b/ppsurf/datasets/abc_minimal/valset.txt
@@ -0,0 +1,3 @@
+00010009_d97409455fa543b3a224250f_trimesh_000
+00010074_493cf58028d24a5b97528c11_trimesh_004
+00994122_57d9d4755722f9d2d7436f0a_trimesh_000
\ No newline at end of file
diff --git a/ppsurf/full_run_poco.py b/ppsurf/full_run_poco.py
new file mode 100644
index 0000000000000000000000000000000000000000..f17c9515870d913b205f1ddd9089a30078b1df14
--- /dev/null
+++ b/ppsurf/full_run_poco.py
@@ -0,0 +1,75 @@
+# this script runs a model through training, testing and prediction of all datasets
+
+# profiling with tree visualization
+# pip install snakeviz
+# https://jiffyclub.github.io/snakeviz/
+# python -m cProfile -o poco.prof poco.py
+# snakeviz poco.prof
+
+import os
+from source.base.mp import get_multi_gpu_params
+
+if __name__ == '__main__':
+ python_call = 'python'
+ main_cmd = 'poco.py'
+ name = 'poco'
+ version = '0'
+ # on_server = False
+
+ debug = ''
+ print_config = ''
+
+ # uncomment for debugging
+ # debug += '--debug True'
+ # print_config += '--print_config'
+
+ # python_call += ' -m cProfile -o poco.prof' # uncomment for profiling
+
+ main_cmd = python_call + ' ' + main_cmd
+
+ cmd_template = '{main_cmd} {sub_cmd} {configs} {debug} {print_config}'
+ configs = '{server} -c configs/{name}.yaml'
+
+ # training
+
+ # configs_train = configs.format(server='-c configs/device_server.yaml' if on_server else '', name=name)
+ configs_train = configs.format(server=' '.join(get_multi_gpu_params()), name=name)
+ cmd_train = cmd_template.format(main_cmd=main_cmd, sub_cmd='fit', configs=configs_train, debug=debug, print_config=print_config)
+ cmd_train += ' --data.init_args.in_file datasets/abc_train/trainset.txt'
+ os.system(cmd_train)
+
+ args_no_train = (
+ '--ckpt_path models/{name}/version_{version}/checkpoints/last.ckpt '
+ '--trainer.logger False ' # comment for profiling
+ '--trainer.devices 1'
+ ).format(name=name, version=version)
+ configs_no_train = configs.format(server='', name=name)
+ cmd_template_no_train = cmd_template + ' --data.init_args.in_file {dataset}/testset.txt ' + args_no_train
+
+ # testing
+ cmd_test = cmd_template_no_train.format(main_cmd=main_cmd, sub_cmd='test', configs=configs_no_train,
+ dataset='datasets/abc', debug=debug, print_config=print_config)
+ os.system(cmd_test)
+
+ # prediction
+ datasets = [
+ # 'abc_minimal',
+ 'abc',
+ 'abc_extra_noisy',
+ 'abc_noisefree',
+ 'real_world',
+ 'famous_original', 'famous_noisefree', 'famous_sparse', 'famous_dense', 'famous_extra_noisy',
+ 'thingi10k_scans_original', 'thingi10k_scans_noisefree', 'thingi10k_scans_sparse',
+ 'thingi10k_scans_dense', 'thingi10k_scans_extra_noisy'
+ ]
+ for ds in datasets:
+ cmd_pred = cmd_template_no_train.format(main_cmd=main_cmd, sub_cmd='predict', configs=configs_no_train,
+ dataset='datasets/' + ds, debug=debug, print_config=print_config)
+ # cmd_pred += ' -c configs/profiler.yaml'
+ # cmd_pred += ' --model.init_args.gen_resolution_global 129'
+ os.system(cmd_pred)
+
+ # make comparison
+ os.system('python source/figures/comp_all.py')
+
+ print('All done. You should find the results in results/comp/reports/comp_all.xlsx.')
diff --git a/ppsurf/full_run_poco_mini.py b/ppsurf/full_run_poco_mini.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc35702d13f7638bbff31a6f671bea23fbaa3fbd
--- /dev/null
+++ b/ppsurf/full_run_poco_mini.py
@@ -0,0 +1,69 @@
+# a short run of POCO for testing, debugging and profiling
+
+# profiling with tree visualization
+# pip install snakeviz
+# https://jiffyclub.github.io/snakeviz/
+# python -m cProfile -o poco.prof poco.py
+# snakeviz poco.prof
+
+import os
+from source.base.mp import get_multi_gpu_params
+
+if __name__ == '__main__':
+ python_call = 'python'
+ main_cmd = 'poco.py'
+ name = 'poco_mini'
+ version = '0'
+ # on_server = False
+
+ debug = ''
+ print_config = ''
+
+ # uncomment for debugging
+ # debug += '--debug True'
+ # print_config += '--print_config'
+
+ # python_call += ' -m cProfile -o poco.prof' # uncomment for profiling
+
+ main_cmd = python_call + ' ' + main_cmd
+
+ cmd_template = '{main_cmd} {sub_cmd} {configs} {debug} {print_config}'
+ configs = '-c configs/poco.yaml {server} -c configs/{name}.yaml'
+
+ # training
+ # configs_train = configs.format(server='-c configs/device_server.yaml' if on_server else '', name=name)
+ configs_train = configs.format(server=' '.join(get_multi_gpu_params()), name=name)
+ cmd_train = cmd_template.format(main_cmd=main_cmd, sub_cmd='fit', configs=configs_train, debug=debug, print_config=print_config)
+ os.system(cmd_train)
+
+ args_no_train = (
+ '--ckpt_path models/{name}/version_{version}/checkpoints/last.ckpt '
+ '--trainer.logger False ' # comment for tensorboard profiling
+ '--trainer.devices 1'
+ ).format(name=name, version=version)
+ configs_no_train = configs.format(server='', name=name)
+ cmd_template_no_train = cmd_template + ' --data.init_args.in_file {dataset}/testset.txt ' + args_no_train
+
+ # testing
+ cmd_test = cmd_template_no_train.format(main_cmd=main_cmd, sub_cmd='test', configs=configs_no_train,
+ dataset='datasets/abc_minimal', debug=debug, print_config=print_config)
+ os.system(cmd_test)
+
+ # prediction
+ datasets = [
+ 'abc_minimal',
+ # 'abc',
+ # 'abc_extra_noisy',
+ # 'abc_noisefree',
+ # 'real_world',
+ # 'famous_original', 'famous_noisefree', 'famous_sparse', 'famous_dense', 'famous_extra_noisy',
+ # 'thingi10k_scans_original', 'thingi10k_scans_noisefree', 'thingi10k_scans_sparse',
+ # 'thingi10k_scans_dense', 'thingi10k_scans_extra_noisy'
+ ]
+ # configs_no_train += ' --model.init_args.rec_batch_size 100'
+ for ds in datasets:
+ cmd_pred = cmd_template_no_train.format(main_cmd=main_cmd, sub_cmd='predict', configs=configs_no_train,
+ dataset='datasets/' + ds, debug=debug, print_config=print_config)
+ # cmd_pred += ' -c configs/profiler.yaml'
+ cmd_pred += ' --model.init_args.gen_resolution_global 129'
+ os.system(cmd_pred)
diff --git a/ppsurf/full_run_pps.py b/ppsurf/full_run_pps.py
new file mode 100644
index 0000000000000000000000000000000000000000..90e123ab78440f71141f8bea534aec972696e760
--- /dev/null
+++ b/ppsurf/full_run_pps.py
@@ -0,0 +1,53 @@
+# this script runs a model through training, testing and prediction of all datasets
+
+import os
+from source.base.mp import get_multi_gpu_params
+
+if __name__ == '__main__':
+ python_call = 'python'
+ main_cmd = 'pps.py'
+ name = 'ppsurf_50nn'
+ version = '0'
+ # on_server = True
+
+ main_cmd = python_call + ' ' + main_cmd
+
+ cmd_template = '{main_cmd} {sub_cmd} {configs}'
+ configs = '-c configs/poco.yaml -c configs/ppsurf.yaml {server} -c configs/{name}.yaml'
+
+ # training
+ # configs_train = configs.format(server='-c configs/device_server.yaml' if on_server else '', name=name)
+ configs_train = configs.format(server=' '.join(get_multi_gpu_params()), name=name)
+ cmd_train = cmd_template.format(main_cmd=main_cmd, sub_cmd='fit', configs=configs_train)
+ os.system(cmd_train)
+
+ args_no_train = ('--ckpt_path models/{name}/version_{version}/checkpoints/last.ckpt '
+ '--trainer.logger False --trainer.devices 1').format(name=name, version=version)
+ configs_no_train = configs.format(server='', name=name)
+ cmd_template_no_train = cmd_template + ' --data.init_args.in_file {dataset}/testset.txt ' + args_no_train
+
+ # testing
+ cmd_test = cmd_template_no_train.format(main_cmd=main_cmd, sub_cmd='test', configs=configs_no_train,
+ dataset='datasets/abc_train')
+ os.system(cmd_test)
+
+ # prediction
+ datasets = [
+ # 'abc_minimal',
+ 'abc',
+ 'abc_extra_noisy',
+ 'abc_noisefree',
+ 'real_world',
+ 'famous_original', 'famous_noisefree', 'famous_sparse', 'famous_dense', 'famous_extra_noisy',
+ 'thingi10k_scans_original', 'thingi10k_scans_noisefree', 'thingi10k_scans_sparse',
+ 'thingi10k_scans_dense', 'thingi10k_scans_extra_noisy'
+ ]
+ for ds in datasets:
+ cmd_pred = cmd_template_no_train.format(main_cmd=main_cmd, sub_cmd='predict', configs=configs_no_train,
+ dataset='datasets/' + ds)
+ os.system(cmd_pred)
+
+ # make comparison
+ os.system('python source/figures/comp_all.py')
+
+ print('All done. You should find the results in results/comp/reports/comp_all.xlsx.')
diff --git a/ppsurf/full_run_pps_mini.py b/ppsurf/full_run_pps_mini.py
new file mode 100644
index 0000000000000000000000000000000000000000..581e0b007920e2f3c50929edbbb2137cafcf5114
--- /dev/null
+++ b/ppsurf/full_run_pps_mini.py
@@ -0,0 +1,69 @@
+# a short run of PPSurf for testing, debugging and profiling
+
+# profiling with tree visualization
+# pip install snakeviz
+# https://jiffyclub.github.io/snakeviz/
+# python -m cProfile -o pps.prof pps.py
+# snakeviz pps.prof
+
+import os
+from source.base.mp import get_multi_gpu_params
+
+if __name__ == '__main__':
+ python_call = 'python'
+ main_cmd = 'pps.py'
+ name = 'ppsurf_mini'
+ version = '0'
+ # on_server = False
+
+ debug = ''
+ print_config = ''
+
+ # uncomment for debugging
+ # debug += '--debug True'
+ # print_config += '--print_config'
+
+ # python_call += ' -m cProfile -o pps.prof' # uncomment for profiling
+
+ main_cmd = python_call + ' ' + main_cmd
+
+ cmd_template = '{main_cmd} {sub_cmd} {configs} {debug} {print_config}'
+ configs = '-c configs/poco.yaml -c configs/ppsurf.yaml {server} -c configs/{name}.yaml'
+
+ # training
+ # configs_train = configs.format(server='-c configs/device_server.yaml' if on_server else '', name=name)
+ configs_train = configs.format(server=' '.join(get_multi_gpu_params()), name=name)
+ cmd_train = cmd_template.format(main_cmd=main_cmd, sub_cmd='fit',
+ configs=configs_train, debug=debug, print_config=print_config)
+ os.system(cmd_train)
+
+ args_no_train = ('--ckpt_path models/{name}/version_{version}/checkpoints/last.ckpt '
+ '--trainer.logger False ' # comment for tensorboard profiling
+ '--trainer.devices 1'
+ ).format(name=name, version=version)
+ configs_no_train = configs.format(server='', name=name)
+ cmd_template_no_train = cmd_template + ' --data.init_args.in_file {dataset}/testset.txt ' + args_no_train
+
+ # testing
+ cmd_test = cmd_template_no_train.format(main_cmd=main_cmd, sub_cmd='test', configs=configs_no_train,
+ dataset='datasets/abc_minimal', debug=debug, print_config=print_config)
+ os.system(cmd_test)
+
+ # prediction
+ datasets = [
+ 'abc_minimal',
+ # 'abc',
+ # 'abc_extra_noisy',
+ # 'abc_noisefree',
+ # 'real_world',
+ # 'famous_original', 'famous_noisefree', 'famous_sparse', 'famous_dense', 'famous_extra_noisy',
+ # 'thingi10k_scans_original', 'thingi10k_scans_noisefree', 'thingi10k_scans_sparse',
+ # 'thingi10k_scans_dense', 'thingi10k_scans_extra_noisy'
+ ]
+ # configs_no_train += ' --model.init_args.rec_batch_size 100'
+ for ds in datasets:
+ cmd_pred = cmd_template_no_train.format(main_cmd=main_cmd, sub_cmd='predict', configs=configs_no_train,
+ dataset='datasets/' + ds, debug=debug, print_config=print_config)
+ # cmd_pred += ' -c configs/profiler.yaml'
+ cmd_pred += ' --model.init_args.gen_resolution_global 129'
+ os.system(cmd_pred)
diff --git a/ppsurf/images/teaser.png b/ppsurf/images/teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..37e810e3c38f4d2a7b77e7042556627c3f299624
Binary files /dev/null and b/ppsurf/images/teaser.png differ
diff --git a/ppsurf/poco.py b/ppsurf/poco.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b8278e27d3dd45e1cefac87c56d30a5166b2ebd
--- /dev/null
+++ b/ppsurf/poco.py
@@ -0,0 +1,42 @@
+import typing
+
+from pytorch_lightning import cli
+
+from source.poco_model import PocoModel
+from source.occupancy_data_module import OccupancyDataModule
+
+from source.cli import CLI
+
+
+class PocoCLI(CLI):
+
+ def add_arguments_to_parser(self, parser: cli.LightningArgumentParser) -> None:
+ super().add_arguments_to_parser(parser)
+
+ parser.link_arguments('data.init_args.in_file', 'model.init_args.in_file')
+ parser.link_arguments('data.init_args.padding_factor', 'model.init_args.padding_factor')
+
+ # this direction because logger is not available for test/predict
+ parser.link_arguments('model.init_args.name', 'trainer.logger.init_args.name')
+
+ def handle_rec_subcommand(self, args: typing.List[str]) -> typing.List[str]:
+ """Replace 'rec' subcommand with predict and its default parameters.
+ Download model if necessary.
+ """
+ raise NotImplementedError()
+
+
+def cli_main():
+ PocoCLI(model_class=PocoModel, subclass_mode_model=True,
+ datamodule_class=OccupancyDataModule, subclass_mode_data=True)
+
+
+if __name__ == '__main__':
+ # for testing
+ # sys.argv = ['poco.py', 'fit',
+ # '-c', 'configs/poco.yaml',
+ # # '--print_config'
+ # ]
+
+ # Run PPS, run!
+ cli_main()
diff --git a/ppsurf/pps.py b/ppsurf/pps.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e21f9b1105c82008e7d2f332a594ae5a8b9bdad
--- /dev/null
+++ b/ppsurf/pps.py
@@ -0,0 +1,127 @@
+import sys
+import os
+import typing
+
+from pytorch_lightning import cli
+
+from source.poco_model import PocoModel
+from source.occupancy_data_module import OccupancyDataModule
+
+from poco import PocoCLI
+
+# run with:
+# python pps.py fit
+# python pps.py validate
+# python pps.py test
+# python pps.py predict
+# configs as below
+
+
+class PPSCLI(PocoCLI):
+
+ def add_arguments_to_parser(self, parser: cli.LightningArgumentParser) -> None:
+ super().add_arguments_to_parser(parser)
+
+ parser.link_arguments('model.init_args.num_pts_local', 'data.init_args.num_pts_local')
+
+ def handle_rec_subcommand(self, args: typing.List[str]) -> typing.List[str]:
+ """Replace 'rec' subcommand with predict and its default parameters.
+ Download model if necessary.
+ """
+
+ # no rec -> nothing to do
+ if len(args) <= 1 or args[1] != 'rec':
+ return args
+
+ # check syntax
+ if len(args) < 4 or args[0] != os.path.basename(__file__):
+ raise ValueError(
+ 'Invalid syntax for rec subcommand: {}\n'
+ 'Make sure that it matches this example: '
+ 'pps.py rec in_file.ply out_file.ply --model.init_args.rec_batch_size 50000'.format(' '.join(sys.argv)))
+
+ in_file = args[2]
+ if not os.path.exists(in_file):
+ raise ValueError('Input file does not exist: {}'.format(in_file))
+ out_dir = args[3]
+ os.makedirs(out_dir, exist_ok=True)
+ extra_params = args[4:]
+ model_path = os.path.join('models/ppsurf_50nn/version_0/checkpoints/last.ckpt')
+
+ # assemble predict subcommand
+ args_pred = args[:1]
+ args_pred += [
+ 'predict',
+ '-c', 'configs/poco.yaml',
+ '-c', 'configs/ppsurf.yaml',
+ '-c', 'configs/ppsurf_50nn.yaml',
+ '--ckpt_path', model_path,
+ '--data.init_args.in_file', in_file,
+ '--model.init_args.results_dir', out_dir,
+ '--trainer.logger', 'False',
+ '--trainer.devices', '1'
+ ]
+ args_pred += extra_params
+ print('Converted rec subcommand to predict subcommand: {}'.format(' '.join(args_pred)))
+
+ # download model if necessary
+ if not os.path.exists(model_path):
+ print('Model checkpoint not found at {}. Downloading...'.format(model_path))
+ os.system('python models/download_ppsurf_50nn.py')
+
+ return args_pred
+
+
+def cli_main():
+ PPSCLI(model_class=PocoModel, subclass_mode_model=True,
+ datamodule_class=OccupancyDataModule, subclass_mode_data=True)
+
+
+def fixed_cmd():
+ # for debugging
+
+ # train
+ sys.argv = ['pps.py',
+ 'fit',
+ '-c', 'configs/poco.yaml',
+ '-c', 'configs/ppsurf.yaml',
+ '-c', 'configs/ppsurf_mini.yaml',
+ # '--debug', 'True',
+ # '--print_config'
+ ]
+ cli_main()
+
+ # test
+ sys.argv = ['pps.py',
+ 'test',
+ '-c', 'configs/poco.yaml',
+ '-c', 'configs/ppsurf.yaml',
+ '-c', 'configs/ppsurf_mini.yaml',
+ '--ckpt_path', 'models/ppsurf_mini/version_0/checkpoints/last.ckpt', '--trainer.logger', 'False',
+ # '--print_config'
+ ]
+ cli_main()
+
+ # predict
+ sys.argv = ['pps.py',
+ 'predict',
+ '-c', 'configs/poco.yaml',
+ '-c', 'configs/ppsurf.yaml',
+ '-c', 'configs/ppsurf_mini.yaml',
+ '--ckpt_path', 'models/ppsurf_mini/version_0/checkpoints/last.ckpt', '--trainer.logger', 'False',
+ # '--print_config'
+ ]
+ cli_main()
+
+ # rec
+ sys.argv = ['pps.py',
+ 'rec',
+ 'datasets/abc_minimal/04_pts_vis/00011084_fddd53ce45f640f3ab922328_trimesh_019.xyz.ply',
+ 'results/rec/test/00011084_fddd53ce45f640f3ab922328_trimesh_019.ply',
+ ]
+ cli_main()
+
+
+if __name__ == '__main__':
+ # fixed_cmd()
+ cli_main()
diff --git a/ppsurf/pps.yml b/ppsurf/pps.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6d139b430c06e6c6150468482bd6affadc50f35e
--- /dev/null
+++ b/ppsurf/pps.yml
@@ -0,0 +1,35 @@
+name: pps
+channels:
+ - pytorch
+ - nvidia
+ - anaconda
+ - conda-forge
+ - pyg
+ - defaults
+dependencies:
+ - python>=3.11
+ - pytorch=2
+ - pytorch-cuda
+ - pytorch-lightning=2
+ - pyg=2
+ - pytorch-cluster=1
+ - scikit-learn-intelex
+ - numpy=1
+ - scikit-image=0
+ - scipy=1
+ - pandas=1
+ - openpyxl=
+ - overrides=7
+ - pykdtree=1
+ - laspy=2
+ - pip>=22
+ - pillow
+ - intel-openmp
+ - tqdm
+ - pyglet=1
+ - rtree=1
+ - pip:
+ - tensorboard>=2.14
+ - trimesh>=3.23
+ - pysdf>=0.1
+ - jsonargparse[signatures]>=4.7
diff --git a/ppsurf/pps_win.yml b/ppsurf/pps_win.yml
new file mode 100644
index 0000000000000000000000000000000000000000..41974c6a265441d96f8e4148c17ce996e8aef64c
--- /dev/null
+++ b/ppsurf/pps_win.yml
@@ -0,0 +1,35 @@
+name: pps
+channels:
+ - pytorch
+ - nvidia
+ - anaconda
+ - conda-forge
+ - pyg
+ - defaults
+dependencies:
+ - python>=3.10
+ - pytorch>=1.12
+ - pytorch-cuda=11
+ - pytorch-lightning=2
+ - pyg=2
+ - pytorch-cluster>=1.6.0 # no newer version, requires pytorch 1 and cuda 11
+ - scikit-learn-intelex
+ - numpy>=1.23.5
+ - scikit-image>=0.19.3
+ - scipy>=1.10.0
+ - pandas>=1.5.2
+ - openpyxl>=3.0.10
+ - pillow
+ - intel-openmp
+ - overrides>=7.4
+ - pykdtree>=1.3.9
+ - laspy>=2.5.1
+ - pip>=22.3.1
+ - tqdm
+ - pyglet=1
+ - rtree=1
+ - pip:
+ - tensorboard>=2.14.0
+ - trimesh>=3.23.5
+ - pysdf>=0.1.8
+ - jsonargparse[signatures]>=4.7.3
diff --git a/ppsurf/requirements.txt b/ppsurf/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..40b31cdf03b933f9630f68ff4d94446252421994
--- /dev/null
+++ b/ppsurf/requirements.txt
@@ -0,0 +1,23 @@
+torch>=2
+--index-url https://download.pytorch.org/whl/cu117
+pytorch-lightning>=2.0
+pyg>=2.3
+torch-cluster>=1.6.0
+scikit-learn-intelex
+jsonargparse>=4.7.3
+numpy>=1
+python>=3.10
+scikit-image>=0.19
+scipy>=1.10
+pandas>=1.5
+openpyxl>=3.0
+pillow
+intel-openmp
+overrides>=7.4
+pykdtree>=1.3
+tensorboard>=2.14
+trimesh>=3.23
+tqdm>=4
+pysdf>=0.1
+jsonargparse[signatures]>=4.7
+laspy>=2.5
\ No newline at end of file
diff --git a/ppsurf/source/__init__.py b/ppsurf/source/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ppsurf/source/base/__init__.py b/ppsurf/source/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ppsurf/source/base/container.py b/ppsurf/source/base/container.py
new file mode 100644
index 0000000000000000000000000000000000000000..108271f2e8f008cbce9e71bab3034332066a3d8c
--- /dev/null
+++ b/ppsurf/source/base/container.py
@@ -0,0 +1,151 @@
+import typing
+
+import numpy as np
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import torch
+
+
+def flatten_dicts(dicts: typing.Sequence[typing.Dict[typing.Any, typing.Any]]):
+ """
+ flatten dicts containing other dicts
+ :param dicts:
+ :return:
+ """
+ if len(dicts) == 0:
+ return dict()
+ elif len(dicts) == 1:
+ return dicts[0]
+
+ new_dicts = []
+ for d in dicts:
+ new_dict = {}
+ for k in d.keys():
+ value = d[k]
+ if isinstance(value, typing.Dict):
+ new_dict.update(value)
+ else:
+ new_dict[k] = value
+ new_dicts.append(new_dict)
+
+ return new_dicts
+
+
+def aggregate_dicts_np(dicts: typing.Sequence[typing.Mapping[typing.Any, np.ndarray]], method: str):
+ """
+
+ :param dicts:
+ :param method: one of ['mean', 'concat', 'stack']
+ :return:
+ """
+ import numbers
+
+ if len(dicts) == 0:
+ return dict()
+ elif len(dicts) == 1 and method != 'stack': # add singleton dimension for stack
+ return dicts[0]
+
+ valid_methods = ['mean', 'concat', 'stack']
+ if method not in valid_methods:
+ raise ValueError('Invalid method {} must be one of {}'.format(method, valid_methods))
+
+ dict_aggregated = dict()
+ for k in dicts[0].keys():
+ values = [d[k] for d in dicts]
+ if isinstance(values[0], numbers.Number):
+ values = [np.asarray(v) for v in values]
+
+ if method == 'concat':
+ values_np = np.concatenate(values)
+ elif method == 'stack':
+ values_np = np.stack(values)
+ elif method == 'mean':
+ values_np = np.array(values)
+ if values_np.dtype.type is np.str_:
+ values_np = values_np[0]
+ else:
+ values_np = np.nanmean(values_np)
+ else:
+ raise ValueError()
+ dict_aggregated[k] = values_np
+ return dict_aggregated
+
+
+def aggregate_dicts(dicts: typing.Sequence[typing.Mapping[typing.Any, 'torch.Tensor']], method: str):
+ """
+
+ :param dicts:
+ :param method: one of ['mean', 'concat', 'stack']
+ :return:
+ """
+ import torch
+ import numbers
+
+ if len(dicts) == 0:
+ return dict()
+ elif len(dicts) == 1:
+ return dicts[0]
+
+ valid_methods = ['mean', 'concat', 'stack']
+ if method not in valid_methods:
+ raise ValueError('Invalid method {} must be one of {}'.format(method, valid_methods))
+
+ dict_aggregated = dict()
+ for k in dicts[0].keys():
+ values = [d[k] for d in dicts]
+ if isinstance(values[0], numbers.Number):
+ values = [torch.as_tensor(v) for v in values]
+
+ if method == 'concat':
+ values = torch.cat(values)
+ elif method == 'stack':
+ if isinstance(values[0], str):
+ pass # keep list of strings
+ else:
+ values = torch.stack(values)
+ elif method == 'mean':
+ values = torch.tensor(values)
+ if values.dtype == str:
+ values = values[0]
+ else:
+ values = torch.nanmean(values).item()
+ else:
+ raise ValueError()
+ dict_aggregated[k] = values
+ return dict_aggregated
+
+
+def tensor_list_to_array(tensors: typing.Sequence['torch.Tensor']) -> np.ndarray:
+ """
+ use this if the tensors can be on different devices
+ :param tensors:
+ :return: array
+ """
+ import torch
+ tensors_cpu = [t.detach().cpu() for t in tensors]
+ arr = torch.concat(tensors_cpu).numpy()
+ return arr
+
+
+def dict_np_to_torch(patch_data: dict):
+ # convert values to tensors if necessary
+ from torch import from_numpy, Tensor, tensor
+ import numbers
+
+ for key in patch_data.keys():
+ val = patch_data[key]
+ if isinstance(val, np.ndarray):
+ patch_data[key] = from_numpy(val.copy())
+ elif isinstance(val, Tensor):
+ pass # nothing to do
+ elif np.isscalar(val):
+ if isinstance(val, numbers.Number):
+ patch_data[key] = tensor(val)
+ elif isinstance(val, str):
+ patch_data[key] = val
+ else: # try to keep type and let pytorch handle it
+ patch_data[key] = val
+ else:
+ raise NotImplementedError('Key: {}, Type:{}'.format(key, type(val)))
+ return patch_data
diff --git a/ppsurf/source/base/evaluation.py b/ppsurf/source/base/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0252a71dab7be9a3994574432488a86ac58cb54
--- /dev/null
+++ b/ppsurf/source/base/evaluation.py
@@ -0,0 +1,604 @@
+import math
+
+import numpy as np
+import os
+import typing
+
+from typing import TYPE_CHECKING
+
+from source.base.metrics import get_metric_meshes
+
+if TYPE_CHECKING:
+ import pandas as pd
+
+
+def make_excel_file_comparison(cd_pred_list, human_readable_results, output_file, val_set,
+ low_metrics_better: typing.Union[bool, typing.Sequence] = True):
+ import pandas as pd
+ # try https://realpython.com/openpyxl-excel-spreadsheets-python/
+
+ # one shape per line, dataset per column
+ cd_pred = np.array(cd_pred_list).transpose()
+ data_headers_human_readable = ['Shape'] + [hr for hr in human_readable_results]
+ # data_headers = [''] + [rft for rft in result_file_templates]
+ data_body = [[val_set[i]] + cd_pred[i].tolist() for i in range(len(val_set))]
+ df = pd.DataFrame(data=data_body, columns=data_headers_human_readable)
+ df = df.set_index('Shape')
+
+ export_xlsx(df=df, low_metrics_better=low_metrics_better, output_file=output_file, add_stats=True,
+ header=True, independent_cols=True)
+
+
+def make_quantitative_comparison(
+ shape_names: typing.Sequence[str], gt_mesh_files: typing.Sequence[str],
+ result_headers: typing.Sequence[str], result_file_templates: typing.Sequence[str],
+ comp_output_dir: str, num_samples=10000, num_processes=0):
+
+ cd_pred_list = get_metric_meshes(
+ result_file_template=result_file_templates, shape_list=shape_names, gt_mesh_files=gt_mesh_files,
+ num_samples=num_samples, metric='chamfer', num_processes=num_processes)
+ cd_output_file = os.path.join(comp_output_dir, 'chamfer_distance.xlsx')
+ make_excel_file_comparison(cd_pred_list, result_headers, cd_output_file, shape_names, low_metrics_better=[True])
+
+ f1_pred_list = get_metric_meshes(
+ result_file_template=result_file_templates, shape_list=shape_names, gt_mesh_files=gt_mesh_files,
+ num_samples=num_samples, metric='f1', num_processes=num_processes)
+ cd_output_file = os.path.join(comp_output_dir, 'f1.xlsx')
+ make_excel_file_comparison(f1_pred_list, result_headers, cd_output_file, shape_names, low_metrics_better=[False])
+
+ iou_pred_list = get_metric_meshes(
+ result_file_template=result_file_templates, shape_list=shape_names,
+ gt_mesh_files=gt_mesh_files, num_samples=num_samples, metric='iou', num_processes=num_processes)
+ iou_output_file = os.path.join(comp_output_dir, 'iou.xlsx')
+ make_excel_file_comparison(iou_pred_list, result_headers, iou_output_file, shape_names, low_metrics_better=[False])
+
+ nc_pred_list = get_metric_meshes(
+ result_file_template=result_file_templates, shape_list=shape_names,
+ gt_mesh_files=gt_mesh_files, num_samples=num_samples, metric='normals', num_processes=num_processes)
+ nc_output_file = os.path.join(comp_output_dir, 'normal_error.xlsx')
+ make_excel_file_comparison(nc_pred_list, result_headers, nc_output_file, shape_names, low_metrics_better=[True])
+
+
+def make_html_report(report_file_out, comp_name, pc_renders, gt_renders,
+ cd_vis_renders, dist_cut_off, metrics_cd, metrics_iou, metrics_nc):
+
+ num_rows = len(gt_renders)
+ num_recs = len(metrics_cd)
+ num_cols = len(metrics_cd) + 3
+
+ def clean_path_list(path_list: typing.Sequence[str]):
+ return [p.replace('\\', '/') for p in path_list]
+
+ gt_renders = clean_path_list(gt_renders)
+ pc_renders = clean_path_list(pc_renders)
+ cd_vis_renders = [clean_path_list(rec_shapes) for rec_shapes in cd_vis_renders]
+
+ def make_relative(path_list: typing.Sequence[str]):
+ import pathlib
+ return [pathlib.Path(*pathlib.Path(p).parts[3:]) for p in path_list]
+
+ gt_renders = make_relative(gt_renders)
+ pc_renders = make_relative(pc_renders)
+ cd_vis_renders = [make_relative(rec_shapes) for rec_shapes in cd_vis_renders]
+
+ def make_human_readable(path_list):
+ return [str(p).replace('_', ' ') for p in path_list]
+
+ shape_names_hr = make_human_readable([os.path.basename(p) for p in gt_renders])
+ rec_names = make_human_readable([os.path.split(os.path.dirname(os.path.dirname(cd_vis_renders[ir][0])))[1]
+ for ir in range(num_recs)])
+ gt_renders_hr = make_human_readable(gt_renders)
+ pc_renders_hr = make_human_readable(pc_renders)
+ cd_vis_renders_hr = [make_human_readable(rec_shapes) for rec_shapes in cd_vis_renders]
+
+ # template draft by Chat-GPT
+ html_template = r'''
+
+
+
+ PPSurf Comparison Results
+
+
+
+
+ Dataset: {title}
+
+
+
+{th_rec}
+
+
+
+{tr}
+
+
+
+
+ '''
+
+ table_header_rec_template = """
+ {rec_name} | """
+ table_row_template = '''
+
+ {file_name} |
+  |
+  |
+{recs}
+
'''
+ table_row_rec_template = '  {metrics} | '
+ table_row_rec_metrics_template = 'CD: {cd:.2f}, IoU: {iou:.2f}, NCE: {nc:.2f}'
+
+ img_size = 300
+ table_row_rec_metrics = [[table_row_rec_metrics_template.format(
+ cd=metrics_cd[ir][i] * 100.0, iou=metrics_iou[ir][i], nc=metrics_nc[ir][i])
+ for i in range(num_rows)]
+ for ir in range(num_recs)]
+ table_row_rec = [[table_row_rec_template.format(
+ rec_file=cd_vis_renders[ir][i], rec_file_hr=cd_vis_renders_hr[ir][i],
+ metrics=table_row_rec_metrics[ir][i], img_size=img_size)
+ for ir in range(num_recs)]
+ for i in range(num_rows)]
+
+ table_rows = [table_row_template.format(
+ file_name=shape_names_hr[i],
+ pc_file=pc_renders[i], pc_file_hr=pc_renders_hr[i],
+ gt_file=gt_renders[i], gt_file_hr=gt_renders_hr[i],
+ recs='\n'.join(table_row_rec[i]), img_size=img_size)
+ for i in range(num_rows)]
+
+ th_width = int(math.floor(100 / num_cols))
+ th_names = ['Shape Name', 'Point Cloud', 'GT Object'] + rec_names
+ th_sticky = [' class="sticky"'] * 3 + [''] * len(rec_names)
+ table_header_rec = ''.join([table_header_rec_template.format(th_sticky=th_sticky[ni], rec_name=n)
+ for ni, n in enumerate(th_names)])
+
+ html_text = html_template.format(
+ th_width=th_width,
+ num_rec=num_recs, title=comp_name,
+ th_rec=table_header_rec, tr=''.join(table_rows))
+
+ with open(report_file_out, 'w') as text_file:
+ text_file.write(html_text)
+
+
+def make_test_report(shape_names: list, results: typing.Union[list, dict],
+ output_file: str, output_names: list, is_dict=True):
+ import pandas as pd
+ from torch import stack
+
+ metrics_keys_to_log = ['abs_dist_rms', 'accuracy', 'precision', 'recall', 'f1_score']
+ headers = ['Shape', 'Loss total'] + output_names + metrics_keys_to_log
+ low_metrics_better = [True] * (1 + len(output_names)) + [True, False, False, False, False]
+
+ if not is_dict:
+ loss_total = [r[0] for r in results]
+ loss_components = [r[1] for r in results]
+ metrics_dicts = [r[2] for r in results]
+ metrics_lists = []
+ for m in metrics_keys_to_log:
+ metrics_list = [md[m] for md in metrics_dicts]
+ metrics_lists.append(metrics_list)
+ metrics = np.array(metrics_lists).transpose()
+ else:
+ loss_total = results['loss'].detach().cpu()
+ loss_components = results['loss_components_mean'].detach().cpu()
+ if len(loss_components.shape) == 1:
+ loss_components = loss_components.unsqueeze(1)
+ metrics = stack([results[k] for k in metrics_keys_to_log]).transpose(0, 1).detach().cpu()
+
+ if len(loss_total.shape) == 2: # DP -> squeeze
+ loss_total = loss_total.squeeze(-1)
+ metrics = metrics.squeeze(-1)
+
+ data = [[shape_names[i]] + [loss_total[i].item()] + loss_components[i].tolist() + metrics[i].tolist()
+ for i in range(len(loss_total))]
+ df = pd.DataFrame(data=data, columns=headers)
+ df = df.set_index('Shape')
+
+ export_xlsx(df=df, low_metrics_better=low_metrics_better, output_file=output_file,
+ add_stats=True, header=True, independent_cols=True)
+
+ loss_total_mean = np.mean(np.array(loss_total))
+ abs_dist_rms_mean = np.nanmean(metrics[:, 0])
+ f1_mean = np.nanmean(metrics[:, -1])
+ return loss_total_mean, abs_dist_rms_mean, f1_mean
+
+
+def export_xlsx(df: 'pd.DataFrame', low_metrics_better: typing.Union[None, typing.Sequence[bool], bool],
+ output_file: str, add_stats=True, header=True, independent_cols=True):
+ import datetime
+ from source.base import fs
+
+ # export with conditional formatting and average
+ from openpyxl import Workbook
+ from openpyxl.utils.dataframe import dataframe_to_rows
+ from openpyxl.utils.cell import get_column_letter
+ from openpyxl.formatting.rule import ColorScaleRule
+ wb = Workbook()
+ ws = wb.active
+
+ df_export = df.copy()
+ df_export.reset_index(inplace=True) # revert index to normal column to get rid of extra header row
+ for r in dataframe_to_rows(df_export, index=False, header=header):
+ ws.append(r)
+
+ # no direction given, assume near 0 or near 1 results
+ if low_metrics_better is None:
+ cols = df.to_numpy()
+ cols = np.vectorize(lambda x: x.timestamp() if isinstance(x, datetime.datetime) else x)(cols)
+ cols_mean = np.nanmean(cols, axis=0)
+ if not independent_cols:
+ cols_mean = np.mean(cols_mean) # scalar for dependent cols
+ low_metrics_better = np.logical_or(cols_mean > 1.0, cols_mean < 0.5)
+
+ top_row = 2
+ col_ids = df.index.shape[1] if len(df.index.shape) > 1 else 1
+ ws.freeze_panes = '{}{}'.format(get_column_letter(col_ids + 1), top_row)
+ bottom_row = df.shape[0] + top_row - 1
+ if add_stats:
+ for di in range(df.shape[1]):
+ column = col_ids + 1 + di
+ column_letter = get_column_letter(column)
+ ws.cell(row=bottom_row + 1, column=column).value = '=AVERAGE({}{}:{}{})'.format(
+ column_letter, top_row, column_letter, bottom_row)
+ ws.cell(row=bottom_row + 2, column=column).value = '=MEDIAN({}{}:{}{})'.format(
+ column_letter, top_row, column_letter, bottom_row)
+ # ws.cell(row=bottom_row + 3, column=column).value = '=STDEV.P({}{}:{}{})'.format( # strange '@' appears
+ ws.cell(row=bottom_row + 3, column=column).value = '=STDEV({}{}:{}{})'.format( # rely on compatibility
+ column_letter, top_row, column_letter, bottom_row)
+
+ # Stat names
+ ws.cell(row=bottom_row + 1, column=1).value = 'AVERAGE'
+ ws.cell(row=bottom_row + 2, column=1).value = 'MEDIAN'
+ ws.cell(row=bottom_row + 3, column=1).value = 'STDEV'
+
+ def add_formatting_rule(col_start_id, row_start_id, col_end_id, row_end_id, lower_is_better):
+ col_start_letter = get_column_letter(col_start_id)
+ col_end_letter = get_column_letter(col_end_id)
+ col_range_str = '{col_start}{row_start}:{col_end}{row_end}'.format(
+ col_start=col_start_letter, row_start=row_start_id, col_end=col_end_letter, row_end=row_end_id)
+ if lower_is_better: # error here means that this list has an invalid length
+ start_color = 'FF00AA00'
+ end_color = 'FFAA0000'
+ else:
+ end_color = 'FF00AA00'
+ start_color = 'FFAA0000'
+ rule = ColorScaleRule(start_type='percentile', start_value=0, start_color=start_color,
+ mid_type='percentile', mid_value=50, mid_color='FFFFFFFF',
+ end_type='percentile', end_value=100, end_color=end_color)
+ ws.conditional_formatting.add(col_range_str, rule)
+
+ # highlight optimum
+ from openpyxl.formatting.rule import FormulaRule
+ from openpyxl.styles import Font
+
+ asc_desc = 'MIN' if lower_is_better else 'MAX'
+ # should be like =H2=MIN(H$2:H$11)
+ formula = '={col_start}{row_start}={func}({col_start}${row_start}:{col_end}${row_end})'.format(
+ col_start=col_start_letter, row_start=row_start_id, func=asc_desc,
+ col_end=col_end_letter, row_end=row_end_id)
+ rule = FormulaRule(formula=(formula,), font=Font(underline='single'))
+ ws.conditional_formatting.add(col_range_str, rule)
+
+ # color scale over shapes
+ if independent_cols:
+ bottom_row_formatting = bottom_row + (2 if add_stats else 0) # not for STDEV
+ for col in range(df.shape[1]):
+ if not np.isnan(low_metrics_better[col]):
+ add_formatting_rule(col_start_id=col+col_ids+1, row_start_id=top_row,
+ col_end_id=col+col_ids+1, row_end_id=bottom_row_formatting,
+ lower_is_better=low_metrics_better[col])
+ else: # dependent cols
+ for shape_id in range(df.shape[0]):
+ row = top_row + shape_id
+ add_formatting_rule(col_start_id=col_ids+1, row_start_id=row,
+ col_end_id=df.shape[1]+col_ids+1, row_end_id=row,
+ lower_is_better=low_metrics_better)
+
+ # color scale over stats (horizontal)
+ lower_better = [low_metrics_better] * 2 + [True] # lower stdev is always better, mean and avg depend on metric
+ for stat_id in range(3):
+ row = bottom_row + 1 + stat_id
+ add_formatting_rule(col_start_id=col_ids+1, row_start_id=row,
+ col_end_id=df.shape[1] + col_ids+1, row_end_id=row,
+ lower_is_better=lower_better[stat_id])
+
+ fs.make_dir_for_file(output_file)
+ wb.save(output_file)
+
+
+def _drop_stats_rows(df: 'pd.DataFrame',
+ stats: typing.Sequence[str] = ('AVG', 'AVERAGE', 'MEAN', 'MEDIAN', 'STDEV.P', 'STDEV'))\
+ -> 'pd.DataFrame':
+ df = df.copy()
+ for stat in stats:
+ df = df.drop(stat, errors='ignore')
+ return df
+
+
+def make_dataset_comparison(results_reports: typing.Sequence[typing.Sequence[str]], output_file: str):
+ import time
+ import pandas as pd
+
+ def _get_header_and_mean(report_file: typing.Union[str, typing.Any]):
+ metrics_type = os.path.basename(report_file)
+ metrics_type = os.path.splitext(metrics_type)[0]
+
+ if not os.path.isfile(report_file):
+ method_name = os.path.basename(os.path.split(os.path.split(report_file)[0])[0])
+ headers = ['Model', 'Mean {}'.format(metrics_type),
+ 'Median {}'.format(metrics_type), 'Stdev {}'.format(metrics_type), ]
+ data = [method_name, np.nan, np.nan, np.nan, ]
+
+ df_missing = pd.DataFrame(data=[data], columns=headers)
+ df_missing = df_missing.set_index('Model')
+ return df_missing
+
+ df = pd.read_excel(io=report_file, header=0, index_col=0)
+ df = _drop_stats_rows(df)
+
+ if len(df.columns) == 1: # CD, IoU, NC -> single columns in multiple files
+ df_name = df.columns[0]
+ df_mean = df.mean(axis=0)[0]
+ df_median = df.median(axis=0)[0]
+ df_stdev = df.std(axis=0)[0]
+ headers = ['Model', 'Mean {}'.format(metrics_type),
+ 'Median {}'.format(metrics_type), 'Stdev {}'.format(metrics_type), ]
+ data = [df_name, df_mean, df_median, df_stdev, ]
+
+ df_means = pd.DataFrame(data=[data], columns=headers)
+ df_means = df_means.set_index('Model')
+ else: # Test, only one file with multiple columns
+ series_means = df.mean(axis=0)
+ df_means: pd.DataFrame = series_means.to_frame().transpose()
+ model_name = os.path.basename(report_file).split('metrics_')[1]
+ model_name = os.path.splitext(model_name)[0]
+ df_means.insert(0, 'Model', [model_name])
+ df_means = df_means.set_index('Model')
+ test_time = np.datetime64(time.strftime('%Y-%m-%dT%H:%M:%S', time.gmtime(os.path.getmtime(report_file))))
+ df_means.insert(0, 'Date', [test_time])
+ df_means.insert(0, 'Count', [float(df.shape[0])]) # float to keep the array dtype
+ return df_means
+
+ def assemble_model_data(reports_model: typing.Sequence[str]):
+ df_model_list = [_get_header_and_mean(f) for f in reports_model]
+ df_model = pd.concat(df_model_list, axis=1)
+ return df_model
+
+ df_mean_list = [assemble_model_data(l) for l in results_reports]
+ df_mean_all = pd.concat(df_mean_list, axis=0)
+
+ # df_mean_all.insert(0, 'Model', [str(f) for f in results_reports])
+ # df_mean_all = df_mean_all.set_index('Model')
+
+ # df_mean_all = df_mean_all.sort_values('f1_score', ascending=False)
+ df_mean_all = df_mean_all.sort_values('Mean chamfer_distance', ascending=False)
+ export_xlsx(df=df_mean_all, low_metrics_better=None, output_file=output_file, add_stats=False,
+ header=True, independent_cols=True)
+
+
+def assemble_quantitative_comparison(
+ comp_output_dir: str,
+ report_path_templates=('results/poco_blensor_prec32_again/{}.xlsx',),
+ metrics=('chamfer_distance', 'iou', 'normal_error', 'f1'),
+ metrics_lower_better=(True, False, True, False)):
+ import pandas as pd
+
+ def assemble_report(report_paths: typing.Sequence[str]):
+ reports_list = []
+ for p in report_paths:
+ if not os.path.isfile(p):
+ print('Missing report: {}'.format(p))
+
+ model_name = os.path.split(os.path.split(os.path.dirname(p))[0])[1]
+ headers = ['Shape', model_name]
+ df_report = pd.DataFrame(columns=headers)
+ df_report = df_report.set_index('Shape')
+ reports_list.append(df_report)
+ else:
+ df_report: pd.DataFrame = pd.read_excel(io=p, header=0, index_col=0)
+ reports_list.append(df_report)
+
+ df = pd.concat(reports_list, axis=1)
+ df = _drop_stats_rows(df)
+ return df
+
+ results_per_shape_dict = {}
+ for mi, m in enumerate(metrics):
+ report_paths_metric = [t.format(m) for t in report_path_templates]
+ df_m = assemble_report(report_paths_metric)
+ results_per_shape_dict[m] = df_m.to_numpy()
+
+ report_file = os.path.join(comp_output_dir, '{}.xlsx'.format(m))
+ export_xlsx(df=df_m, low_metrics_better=metrics_lower_better[mi], output_file=report_file, add_stats=True,
+ header=True, independent_cols=False)
+
+ return results_per_shape_dict
+
+
+def _prettify_df(df: 'pd.DataFrame'):
+ import re
+
+ def _replace_case_insensitive(text, word, replacement):
+ regex_ignore_case = '(?i){}'.format(word)
+ return re.sub(regex_ignore_case, replacement, text)
+
+ # keep index title in row with CD, IoU, NC
+ df.reset_index(inplace=True)
+
+ find_replace = (
+ ('abc', 'ABC'),
+ ('famous', 'Famous'),
+ ('thingi10k_scans', 'Thingi10k'),
+ ('chamfer_distance', 'Chamfer Distance (x100)'),
+ ('iou', 'IoU'),
+ ('normal_error', 'Normal Error'),
+ ('normal_consistency', 'Normal Error'),
+ ('sap', 'SAP'),
+ ('p2s', 'P2S'),
+ ('poco', 'POCO'),
+ ('pts_gen_sub3k_iter10', ''),
+ ('ppsurf', 'PPSurf'),
+ ('_vanilla_zeros_global', ' only local'),
+ ('_vanilla_zeros_local', ' only global'),
+ ('_vanilla_sym_max', ' sym max'),
+ ('_vanilla_qpoints', ' qpoints'),
+ ('_vanilla', ' merge cat'),
+ ('_merge_sum', ' merge sum'),
+ ('optim', 'O'),
+ ('mean ', ''),
+ ('_', ' '),
+ )
+ for fr in find_replace:
+ # replace in values
+ regex_ignore_case_values = '(?i){}'.format(fr[0])
+ df.replace(to_replace=regex_ignore_case_values, value=fr[1], inplace=True, regex=True)
+
+ # rename in multi
+ for multi_col in df.columns:
+ col: str # type annotation not allowed in for
+ for col in multi_col:
+ if col.lower().find(fr[0].lower()) >= 0:
+ df.rename(columns={col: _replace_case_insensitive(col, fr[0], fr[1])}, inplace=True)
+
+ # factor 100 for chamfer distance
+ cols_chamfer = [c for c in df.columns if c[0].find('Chamfer Distance') >= 0]
+ for c in cols_chamfer:
+ df[c] = df[c] * 100
+
+ return df
+
+
+def xslx_to_latex(xlsx_file: str, latex_file: str):
+ import pandas as pd
+
+ df: pd.DataFrame = pd.read_excel(io=xlsx_file, header=0, index_col=0)
+
+ # nicer column names
+ columns = [
+ ('Chamfer Distance (x100)', 'Mean'),
+ ('Chamfer Distance (x100)', 'Median'),
+ ('Chamfer Distance (x100)', 'Stdev'),
+ ('IoU', 'Mean'),
+ ('IoU', 'Median'),
+ ('IoU', 'Stdev'),
+ ('F1', 'Mean'),
+ ('F1', 'Median'),
+ ('F1', 'Stdev'),
+ ('Normal Error', 'Mean'),
+ ('Normal Error', 'Median'),
+ ('Normal Error', 'Stdev'),
+ ]
+ df.columns = pd.MultiIndex.from_tuples(columns)
+
+ df = _prettify_df(df)
+ df.to_latex(buf=latex_file, float_format='%.2f', na_rep='-', index=False, bold_rows=True,
+ column_format='l' + 'c' * (df.shape[1] - 1), escape=False)
+
+ # strange stuff with styler. why can't I give this to the df for export?
+ # styler = df.style.highlight_max(axis=None, props='font-weight:bold;', subset=columns)
+ # styler.format('{:.2f}', na_rep='-', subset=columns)
+ # styler.to_latex(buf=latex_file, column_format='l' + 'c' * (df.shape[1] - 1))
+
+
+def merge_comps(comp_list: typing.Sequence[str], comp_merged_out_file: str,
+ comp_merged_out_latex: str, methods_order: typing.Optional[list], float_format='%.2f'):
+ import pandas as pd
+ comp_list_existing = [f for f in comp_list if os.path.isfile(f)]
+ if len(comp_list_existing) == 0:
+ print('WARNING: No metrics found for: {}'.format(comp_list))
+ return
+
+ dfs = [pd.read_excel(io=f, header=0, index_col=0) for f in comp_list_existing]
+ datasets = [os.path.split(os.path.dirname(f))[1] for f in comp_list_existing]
+ dfs_with_ds = [df.assign(dataset=datasets[dfi]) for dfi, df in enumerate(dfs)]
+ dfs_multiindex = [df.set_index(['dataset', df.index]).T for df in dfs_with_ds]
+
+ def _extract_metric(df_list: typing.Sequence[pd.DataFrame], order: typing.Optional[list], metric: str):
+ df_metric = [df.xs(metric, axis=0) for df in df_list]
+
+ # dataset name as index, metric as column
+ df_metric_id = [df.reset_index(level=0) for df in df_metric]
+ df_metric_for_col = [df.rename(columns={metric: datasets[dfi]}) for dfi, df in enumerate(df_metric_id)]
+ df_metric_for_col = [df.drop(columns=['dataset']).T for df in df_metric_for_col]
+
+ # xs removes the column name, so we need to add it again
+ # df_metric_ds_col = [df.rename(index=datasets[dfi]).T for dfi, df in enumerate(df_metric)]
+
+ df_metric_merged = pd.concat(df_metric_for_col, axis=0)
+
+ if order is not None and len(order) > 0:
+ df_metric_merged = df_metric_merged[order]
+
+ df_metric_merged_with_ds = df_metric_merged.T.assign(metric=metric)
+ df_metric_merged_id = df_metric_merged_with_ds.set_index(['metric', df_metric_merged_with_ds.index]).T
+
+ return df_metric_merged_id
+
+ df_cd = _extract_metric(df_list=dfs_multiindex, order=methods_order, metric='Mean chamfer_distance')
+ df_iou = _extract_metric(df_list=dfs_multiindex, order=methods_order, metric='Mean iou')
+ df_f1 = _extract_metric(df_list=dfs_multiindex, order=methods_order, metric='Mean f1')
+ df_nc = _extract_metric(df_list=dfs_multiindex, order=methods_order, metric='Mean normal_error')
+
+ df_merged: pd.DataFrame = pd.concat((df_cd, df_iou, df_f1, df_nc), axis=1)
+
+ # add mean row
+ df_mean_row = df_merged.mean(axis=0).rename('Mean')
+ df_merged = pd.concat((df_merged, pd.DataFrame(df_mean_row).T), axis=0)
+
+ df_merged = _prettify_df(df_merged)
+ df_merged.rename(columns={'index': 'Dataset'}, inplace=True)
+
+ from source.base.fs import make_dir_for_file
+ make_dir_for_file(comp_merged_out_file)
+ df_merged.to_excel(comp_merged_out_file, float_format=float_format)
+ make_dir_for_file(comp_merged_out_latex)
+ # TODO: to_latex is deprecated, use df.style.to_latex instead
+ df_merged.to_latex(buf=comp_merged_out_latex, float_format=float_format, na_rep='-', index=False, bold_rows=True,
+ column_format='l' + 'c' * (df_merged.shape[1] - 1), escape=False)
+
diff --git a/ppsurf/source/base/fs.py b/ppsurf/source/base/fs.py
new file mode 100644
index 0000000000000000000000000000000000000000..a22ad0e1c8f26195b4eecf1da328363a9b1dc164
--- /dev/null
+++ b/ppsurf/source/base/fs.py
@@ -0,0 +1,145 @@
+import typing
+import os
+
+
+def create_activate_env(env_name: str):
+ import subprocess
+
+ # check if conda is installed
+ def _check_conda_installed(command: str):
+ try:
+ conda_output = subprocess.check_output([command, '-V']).decode('utf-8').strip()
+ except Exception as _:
+ return False
+ conda_output_regex = r'conda (\d+\.\d+\.\d+)' # mamba and conda both output 'conda 23.3.1'
+ import re
+ conda_match = re.match(conda_output_regex, conda_output)
+ return conda_match is not None
+
+ if _check_conda_installed('mamba'):
+ conda_exe = 'mamba'
+ elif _check_conda_installed('conda'):
+ conda_exe = 'conda'
+ else:
+ raise ValueError('Conda not found')
+
+ # check if env already exists. example outputs:
+ # conda env list --json
+ # {
+ # "envs": [
+ # "C:\\miniforge",
+ # "C:\\miniforge\\envs\\pps"
+ # ]
+ # }
+ import json
+ env_list_str = subprocess.check_output([conda_exe, 'env', 'list', '--json']).decode('utf-8')
+ env_list_json = json.loads(env_list_str)
+ envs = env_list_json['envs']
+ envs_dirs = [os.path.split(env)[1] for env in envs]
+ first_run = env_name not in envs_dirs
+ if first_run:
+ import subprocess
+ on_windows = os.name == 'nt'
+ yml_file = '{}{}.yml'.format(env_name, '_win' if on_windows else '')
+ env_install_cmd = [conda_exe, 'env', 'create', '--file', yml_file]
+ print('Creating conda environment from {}\n{}'.format(yml_file, env_install_cmd))
+ subprocess.call(env_install_cmd)
+
+ # conda activate pps
+ subprocess.call([conda_exe, 'activate', env_name])
+
+ if first_run:
+ print('Downloading datasets')
+ subprocess.call(['python', 'source/datasets/download_abc_training.py'])
+ subprocess.call(['python', 'source/datasets/download_testsets.py'])
+
+
+def make_dir_for_file(file):
+ file_dir = os.path.dirname(file)
+ if file_dir != '':
+ if not os.path.exists(file_dir):
+ try:
+ os.makedirs(os.path.dirname(file), exist_ok=True)
+ except FileExistsError as exc:
+ pass
+ except OSError as exc: # Guard against race condition
+ raise
+
+
+def call_necessary(file_in: typing.Union[str, typing.Sequence[str]], file_out: typing.Union[str, typing.Sequence[str]],
+ min_file_size=0, verbose=False):
+ """
+ Check if all input files exist and at least one output file does not exist or is invalid.
+ :param file_in: list of str or str
+ :param file_out: list of str or str
+ :param min_file_size: int
+ :return:
+ """
+
+ def check_parameter_types(param):
+ if isinstance(param, str):
+ return [param]
+ elif isinstance(param, list):
+ return param
+ elif isinstance(param, tuple):
+ return param
+ else:
+ raise ValueError('Wrong input type')
+
+ file_in = check_parameter_types(file_in)
+ file_out = check_parameter_types(file_out)
+
+ def print_result(msg: str):
+ if verbose:
+ print('call_necessary\n {}\n ->\n {}: \n{}'.format(file_in, file_out, msg))
+
+ if len(file_out) == 0:
+ print_result('No output')
+ return True
+
+ inputs_missing = [f for f in file_in if not os.path.isfile(f)]
+ if len(inputs_missing) > 0:
+ print_result('WARNING: Input files are missing: {}'.format(inputs_missing))
+ return False
+
+ outputs_missing = [f for f in file_out if not os.path.isfile(f)]
+ if len(outputs_missing) > 0:
+ print_result('Some output files are missing: {}'.format(outputs_missing))
+ return True
+
+ min_output_file_size = min([os.path.getsize(f) for f in file_out])
+ if min_output_file_size < min_file_size:
+ print_result('Output too small')
+ return True
+
+ oldest_input_file_mtime = max([os.path.getmtime(f) for f in file_in])
+ youngest_output_file_mtime = min([os.path.getmtime(f) for f in file_out])
+ if oldest_input_file_mtime >= youngest_output_file_mtime:
+ if verbose:
+ import time
+ import numpy as np
+ input_file_mtime_arg_max = np.argmax(np.array([os.path.getmtime(f) for f in file_in]))
+ output_file_mtime_arg_min = np.argmin(np.array([os.path.getmtime(f) for f in file_out]))
+ input_file_mtime_max = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(oldest_input_file_mtime))
+ output_file_mtime_min = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(youngest_output_file_mtime))
+ print_result('Input file {} is newer than output file {}: {} >= {}'.format(
+ file_in[input_file_mtime_arg_max], file_out[output_file_mtime_arg_min],
+ input_file_mtime_max, output_file_mtime_min))
+ return True
+
+ return False
+
+
+def text_file_lf_to_crlf(file):
+ """
+ Convert line endings of a text file from LF to CRLF.
+ :param file:
+ :return:
+ """
+
+ with open(file, 'r') as fp:
+ lines = fp.readlines()
+
+ with open(file, 'w') as fp:
+ for line in lines:
+ fp.write(line.rstrip() + '\r\n')
diff --git a/ppsurf/source/base/math.py b/ppsurf/source/base/math.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ecc698d875ca8b2dcfd1f8a10927a4d157562cf
--- /dev/null
+++ b/ppsurf/source/base/math.py
@@ -0,0 +1,179 @@
+import typing
+
+import numpy as np
+import trimesh
+
+
+def cartesian_dist_1_n(vec_x: np.array, vec_y: np.array, axis=1) -> np.ndarray:
+ """
+ L2 distance
+ :param vec_x: array[d]
+ :param vec_y: array[n, d]
+ :param axis: int
+ :return: array[n]
+ """
+ vec_x_bc = np.tile(np.expand_dims(vec_x, 0), (vec_y.shape[0], 1))
+
+ dist = np.linalg.norm(vec_x_bc - vec_y, axis=axis)
+ return dist
+
+
+def cartesian_dist(vec_x: np.array, vec_y: np.array, axis=1) -> np.ndarray:
+ """
+ L2 distance
+ :param vec_x: array[n, d]
+ :param vec_y: array[n, d]
+ :param axis: int
+ :return: array[n]
+ """
+ dist = np.linalg.norm(vec_x - vec_y, axis=axis)
+ return dist
+
+
+def vector_length(vecs: np.array, axis=1) -> np.ndarray:
+ dist = np.linalg.norm(vecs, axis=axis)
+ return dist
+
+
+def normalize_vectors(vecs: np.array):
+ """
+ :param vecs: array[n, dims]
+ :return:
+ """
+ n_dims = vecs.shape[1]
+ vecs_normalized = vecs / np.repeat(vector_length(vecs)[:, np.newaxis], repeats=n_dims, axis=1)
+ return vecs_normalized
+
+
+def get_patch_radii(pts_patch: np.array, query_pts: np.array):
+ if pts_patch.shape[0] == 0:
+ patch_radius = 0.0
+ elif pts_patch.shape == query_pts.shape:
+ patch_radius = np.linalg.norm(pts_patch - query_pts, axis=0)
+ else:
+ dist = cartesian_dist(np.repeat(np.expand_dims(query_pts, axis=0), pts_patch.shape[0], axis=0),
+ pts_patch, axis=1)
+ patch_radius = np.max(dist, axis=-1)
+ return patch_radius
+
+
+def model_space_to_patch_space_single_point(
+ pts_to_convert_ms: np.array, pts_patch_center_ms: np.array, patch_radius_ms: typing.Union[float, np.ndarray]):
+
+ pts_patch_space = pts_to_convert_ms - pts_patch_center_ms
+ pts_patch_space = pts_patch_space / patch_radius_ms
+ return pts_patch_space
+
+
+def model_space_to_patch_space(
+ pts_to_convert_ms: np.array, pts_patch_center_ms: np.array, patch_radius_ms: typing.Union[float, np.ndarray]):
+
+ pts_patch_center_ms_repeated = \
+ np.repeat(np.expand_dims(pts_patch_center_ms, axis=0), pts_to_convert_ms.shape[-2], axis=-2)
+ pts_patch_space = pts_to_convert_ms - pts_patch_center_ms_repeated
+ pts_patch_space = pts_patch_space / patch_radius_ms
+
+ return pts_patch_space
+
+
+def lerp(
+ a: np.ndarray,
+ b: np.ndarray,
+ factor: typing.Union[np.ndarray, float]):
+ interpolated = a + factor * (b - a)
+ return interpolated
+
+
+def normalize_data(arr: np.ndarray, in_max: float, in_min: float, out_max=1.0, out_min=-1.0, clip=False):
+
+ arr = arr.copy()
+ in_range = in_max - in_min
+ out_range = out_max - out_min
+
+ if in_range == 0.0 or out_range == 0.0:
+ print('Warning: normalization would result in NaN, kept raw values')
+ return arr - in_max
+
+ # scale so that in_max=1.0 and in_min=0.0
+ arr -= in_min
+ arr /= in_range
+
+ # scale to out_max..out_min
+ arr *= out_range
+ arr += out_min
+
+ if clip:
+ arr = np.clip(arr, out_min, out_max)
+
+ return arr
+
+
+def get_points_normalization_info(pts: np.ndarray, padding_factor: float = 0.05):
+ pts_bb_min = np.min(pts, axis=0)
+ pts_bb_max = np.max(pts, axis=0)
+
+ bb_center = (pts_bb_min + pts_bb_max) * 0.5
+ scale = np.max(pts_bb_max - pts_bb_min) * (1.0 + padding_factor)
+ return bb_center, scale
+
+
+def normalize_points_with_info(pts: np.ndarray, bb_center: np.ndarray, scale: float):
+ pts_new = pts - np.tile(bb_center, reps=(pts.shape[0], 1))
+ pts_new /= scale
+
+ # pts_new = pts / scale
+ # pts_new -= np.tile(-bb_center, reps=(pts.shape[0], 1))
+ return pts_new
+
+
+def denormalize_points_with_info(pts: np.ndarray, bb_center: np.ndarray, scale: float):
+ pts_new = pts * scale
+ pts_new += np.tile(bb_center, reps=(pts.shape[0], 1))
+ return pts_new
+
+
+def rotate_points_around_pivot(pts: np.ndarray, rotation_mat: np.ndarray, pivot: np.ndarray):
+ """
+ rotate_points_around_pivot
+ :param pts: np.ndarray[n, dims=3]
+ :param rotation_mat: np.ndarray[4, 4]
+ :param pivot: np.ndarray[dims=3]
+ :return:
+ """
+ pivot_bc = np.broadcast_to(pivot[np.newaxis, :], pts.shape)
+
+ pts -= pivot_bc
+ pts = trimesh.transformations.transform_points(pts, rotation_mat)
+ pts += pivot_bc
+
+ return pts
+
+
+def _test_normalize():
+ ms = 0.75
+ vs = 1.0 / 32
+ # padding_factor = 0.0
+ padding_factor = 0.05
+ pts_ms = np.array([[-ms, -ms], [-ms, +ms], [+ms, -ms], [+ms, +ms], [0.0, 0.0],
+ [vs*0.3, -vs*0.3], [vs*0.5, -vs*0.5], [vs*0.6, -vs*0.6]])
+ pts_ms *= 76.0
+ pts_ms += 123.0
+
+ # vertices = np.random.random(size=(25, 2)) * 2.0 - 1.0
+ vertices = pts_ms
+
+ bb_center, scale = get_points_normalization_info(pts=pts_ms, padding_factor=padding_factor)
+ vertices_norm = normalize_points_with_info(pts=vertices, bb_center=bb_center, scale=scale)
+ vertices_denorm = denormalize_points_with_info(pts=vertices_norm, bb_center=bb_center, scale=scale)
+
+ if not np.allclose(vertices_denorm, vertices):
+ raise ValueError()
+
+ if vertices_norm.max() > 0.5 or vertices_norm.min() < -0.5:
+ raise ValueError()
+
+ return 0
+
+
+if __name__ == '__main__':
+ _test_normalize()
diff --git a/ppsurf/source/base/mesh.py b/ppsurf/source/base/mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..1407bf18d99c86f3886b40c61245a873f0912fa9
--- /dev/null
+++ b/ppsurf/source/base/mesh.py
@@ -0,0 +1,38 @@
+import typing
+
+import numpy as np
+import trimesh
+
+
+def clean_simple_inplace(mesh: trimesh.Trimesh):
+ # extra function because this overlaps only partially with mesh.process(validate=True)
+ mesh.remove_unreferenced_vertices()
+ mesh.remove_infinite_values()
+ mesh.merge_vertices()
+ mesh.remove_degenerate_faces()
+ mesh.remove_duplicate_faces()
+
+
+def remove_small_connected_components(mesh: trimesh.Trimesh, num_faces: typing.Optional[int] = 100):
+ from trimesh import graph
+
+ # https://github.com/Wuziyi616/IF-Defense/blob/main/ONet/data_proc/make_watertight.py
+
+ # if num_faces not given, take 1 % of faces
+ total_num_faces = len(mesh.faces)
+ if num_faces is None:
+ num_faces = total_num_faces // 100
+
+ cc = graph.connected_components(mesh.face_adjacency, min_len=3)
+ mask = np.zeros(total_num_faces, dtype=bool)
+ cc_large_enough = [c for c in cc if len(c) > num_faces]
+ if len(cc_large_enough) == 0:
+ cc_large_enough = np.empty()
+ cc = np.concatenate(cc_large_enough, axis=0)
+ mask[cc] = True
+ mesh.update_faces(mask)
+
+ # clean to keep only used data
+ clean_simple_inplace(mesh=mesh)
+
+ return mesh
diff --git a/ppsurf/source/base/metrics.py b/ppsurf/source/base/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4cf056b7a29ae352747f300ecc5e40db9404265
--- /dev/null
+++ b/ppsurf/source/base/metrics.py
@@ -0,0 +1,324 @@
+import os
+import typing
+
+import numpy as np
+
+from source.base.point_cloud import sample_mesh
+from source.base.proximity import kdtree_query_oneshot
+
+
+def calc_accuracy(num_true, num_predictions):
+ if num_predictions == 0:
+ return float('NaN')
+ else:
+ return num_true / num_predictions
+
+
+def calc_precision(num_true_pos, num_false_pos):
+ if isinstance(num_true_pos, (int, float)) and isinstance(num_false_pos, (int, float)) and \
+ num_true_pos + num_false_pos == 0:
+ return float('NaN')
+ else:
+ return num_true_pos / (num_true_pos + num_false_pos)
+
+
+def calc_recall(num_true_pos, num_false_neg):
+ if isinstance(num_true_pos, (int, float)) and isinstance(num_false_neg, (int, float)) and \
+ num_true_pos + num_false_neg == 0:
+ return float('NaN')
+ else:
+ return num_true_pos / (num_true_pos + num_false_neg)
+
+
+def calc_f1(precision, recall):
+ if isinstance(precision, (int, float)) and isinstance(recall, (int, float)) and \
+ precision + recall == 0:
+ return float('NaN')
+ else:
+ return 2.0 * (precision * recall) / (precision + recall)
+
+
+def compare_predictions_binary_tensors(ground_truth, predicted, prediction_name):
+ """
+
+ :param ground_truth:
+ :param predicted:
+ :param prediction_name:
+ :return: res_dict, prec_per_patch
+ """
+
+ import torch
+
+ if ground_truth.shape != predicted.shape:
+ raise ValueError('The ground truth matrix and the predicted matrix have different sizes!')
+
+ if not isinstance(ground_truth, torch.Tensor) or not isinstance(predicted, torch.Tensor):
+ raise ValueError('Both matrices must be dense of type torch.tensor!')
+
+ ground_truth_int = (ground_truth > 0.0).to(dtype=torch.int32)
+ predicted_int = (predicted > 0.0).to(dtype=torch.int32)
+ res_dict = dict()
+ if prediction_name is not None:
+ res_dict['comp_name'] = prediction_name
+
+ res_dict['predictions'] = float(torch.numel(ground_truth_int))
+ res_dict['pred_gt'] = float(torch.numel(ground_truth_int))
+ res_dict['positives'] = float(torch.nonzero(predicted_int).shape[0])
+ res_dict['pos_gt'] = float(torch.nonzero(ground_truth_int).shape[0])
+ res_dict['true_neg'] = res_dict['predictions'] - float(torch.nonzero(predicted_int + ground_truth_int).shape[0])
+ res_dict['negatives'] = res_dict['predictions'] - res_dict['positives']
+ res_dict['neg_gt'] = res_dict['pred_gt'] - res_dict['pos_gt']
+ true_pos = ((predicted_int + ground_truth_int) == 2).sum().to(dtype=torch.float32)
+ res_dict['true_pos'] = float(true_pos.sum())
+ res_dict['true'] = res_dict['true_pos'] + res_dict['true_neg']
+ false_pos = ((predicted_int * 2 + ground_truth_int) == 2).sum().to(dtype=torch.float32)
+ res_dict['false_pos'] = float(false_pos.sum())
+ false_neg = ((predicted_int + 2 * ground_truth_int) == 2).sum().to(dtype=torch.float32)
+ res_dict['false_neg'] = float(false_neg.sum())
+ res_dict['false'] = res_dict['false_pos'] + res_dict['false_neg']
+ res_dict['accuracy'] = calc_accuracy(res_dict['true'], res_dict['predictions'])
+ res_dict['precision'] = calc_precision(res_dict['true_pos'], res_dict['false_pos'])
+ res_dict['recall'] = calc_recall(res_dict['true_pos'], res_dict['false_neg'])
+ res_dict['f1_score'] = calc_f1(res_dict['precision'], res_dict['recall'])
+
+ return res_dict
+
+
+def compare_predictions_binary_arrays(ground_truth: np.ndarray, predicted: np.ndarray, prediction_name):
+
+ if ground_truth.shape != predicted.shape:
+ raise ValueError('The ground truth matrix and the predicted matrix have different sizes!')
+
+ ground_truth_int = (ground_truth > 0.0).astype(dtype=np.int32)
+ predicted_int = (predicted > 0.0).astype(dtype=np.int32)
+ res_dict = dict()
+ res_dict['comp_name'] = prediction_name
+
+ res_dict['predictions'] = float(ground_truth_int.size)
+ res_dict['pred_gt'] = float(ground_truth_int.size)
+ res_dict['positives'] = float(np.nonzero(predicted_int)[0].shape[0])
+ res_dict['pos_gt'] = float(np.nonzero(ground_truth_int)[0].shape[0])
+ res_dict['true_neg'] = res_dict['predictions'] - float(np.nonzero(predicted_int + ground_truth_int)[0].shape[0])
+ res_dict['negatives'] = res_dict['predictions'] - res_dict['positives']
+ res_dict['neg_gt'] = res_dict['pred_gt'] - res_dict['pos_gt']
+ true_pos = ((predicted_int + ground_truth_int) == 2).sum().astype(dtype=np.float32)
+ res_dict['true_pos'] = float(true_pos.sum())
+ res_dict['true'] = res_dict['true_pos'] + res_dict['true_neg']
+ false_pos = ((predicted_int * 2 + ground_truth_int) == 2).sum().astype(dtype=np.float32)
+ res_dict['false_pos'] = float(false_pos.sum())
+ false_neg = ((predicted_int + 2 * ground_truth_int) == 2).sum().astype(dtype=np.float32)
+ res_dict['false_neg'] = float(false_neg.sum())
+ res_dict['false'] = res_dict['false_pos'] + res_dict['false_neg']
+ res_dict['accuracy'] = calc_accuracy(res_dict['true'], res_dict['predictions'])
+ res_dict['precision'] = calc_precision(res_dict['true_pos'], res_dict['false_pos'])
+ res_dict['recall'] = calc_recall(res_dict['true_pos'], res_dict['false_neg'])
+ res_dict['f1_score'] = calc_f1(res_dict['precision'], res_dict['recall'])
+
+ return res_dict
+
+
+def chamfer_distance(file_in, file_ref, samples_per_model, num_processes=1):
+ # http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf
+
+ new_mesh_samples = sample_mesh(file_in, samples_per_model, rejection_radius=0.0)
+ ref_mesh_samples = sample_mesh(file_ref, samples_per_model, rejection_radius=0.0)
+
+ if new_mesh_samples.shape[0] == 0 or ref_mesh_samples.shape[0] == 0:
+ return file_in, file_ref, -1.0
+
+ ref_new_dist, corr_new_ids = kdtree_query_oneshot(pts=new_mesh_samples, pts_query=ref_mesh_samples,
+ k=1, sqr_dists=False)
+ new_ref_dist, corr_ref_ids = kdtree_query_oneshot(pts=ref_mesh_samples, pts_query=new_mesh_samples,
+ k=1, sqr_dists=False)
+
+ ref_new_dist_sum = np.sum(ref_new_dist)
+ new_ref_dist_sum = np.sum(new_ref_dist)
+ chamfer_dist = ref_new_dist_sum + new_ref_dist_sum
+ chamfer_dist_mean = chamfer_dist / (new_mesh_samples.shape[0] + ref_mesh_samples.shape[0])
+
+ return file_in, file_ref, chamfer_dist_mean
+
+
+def hausdorff_distance(file_in, file_ref, samples_per_model):
+ import scipy.spatial as spatial
+
+ new_mesh_samples = sample_mesh(file_in, samples_per_model)
+ ref_mesh_samples = sample_mesh(file_ref, samples_per_model)
+
+ if new_mesh_samples.shape[0] == 0 or ref_mesh_samples.shape[0] == 0:
+ return file_in, file_ref, -1.0, -1.0, -1.0
+
+ dist_new_ref, _, _ = spatial.distance.directed_hausdorff(new_mesh_samples, ref_mesh_samples)
+ dist_ref_new, _, _ = spatial.distance.directed_hausdorff(ref_mesh_samples, new_mesh_samples)
+ dist = max(dist_new_ref, dist_ref_new)
+ return file_in, file_ref, dist_new_ref, dist_ref_new, dist
+
+
+def intersection_over_union(file_in, file_ref, num_samples, num_dims=3):
+ # https://learnopencv.com/intersection-over-union-iou-in-object-detection-and-segmentation/
+
+ import trimesh
+ from source.base.proximity import get_signed_distance_pysdf_inaccurate
+
+ rng = np.random.default_rng(seed=42)
+ samples = rng.random(size=(num_samples, num_dims)) - 0.5
+
+ try:
+ mesh_in = trimesh.load(file_in)
+ mesh_ref = trimesh.load(file_ref)
+ except:
+ return file_in, file_ref, np.nan
+
+ sdf_in = get_signed_distance_pysdf_inaccurate(mesh_in, samples)
+ sdf_ref = get_signed_distance_pysdf_inaccurate(mesh_ref, samples)
+
+ occ_in = sdf_in > 0.0
+ occ_ref = sdf_ref > 0.0
+
+ intersection = np.logical_and(occ_in, occ_ref)
+ union = np.logical_or(occ_in, occ_ref)
+ intersection_sum = np.sum(intersection)
+ union_sum = np.sum(union)
+
+ if union_sum == 0.0:
+ iou = 0.0
+ else:
+ iou = intersection_sum / union_sum
+
+ return file_in, file_ref, iou
+
+
+def f1_approx(file_in, file_ref, num_samples, num_dims=3):
+ # https://learnopencv.com/intersection-over-union-iou-in-object-detection-and-segmentation/
+
+ import trimesh
+ from source.base.proximity import get_signed_distance_pysdf_inaccurate
+
+ rng = np.random.default_rng(seed=42)
+ samples = rng.random(size=(num_samples, num_dims)) - 0.5
+
+ try:
+ mesh_in = trimesh.load(file_in)
+ mesh_ref = trimesh.load(file_ref)
+ except:
+ return file_in, file_ref, np.nan
+
+ sdf_in = get_signed_distance_pysdf_inaccurate(mesh_in, samples)
+ sdf_ref = get_signed_distance_pysdf_inaccurate(mesh_ref, samples)
+
+ occ_in = sdf_in > 0.0
+ occ_ref = sdf_ref > 0.0
+
+ stats = compare_predictions_binary_arrays(occ_ref, occ_in, prediction_name='f1_approx')
+
+ if np.isnan(stats['f1_score']):
+ f1 = 0.0
+ else:
+ f1 = stats['f1_score']
+
+ return file_in, file_ref, f1
+
+
+def normal_error(file_in, file_ref, num_samples):
+
+ import trimesh.sample
+ from source.base import proximity
+
+ try:
+ mesh_in = trimesh.load(file_in)
+ mesh_ref = trimesh.load(file_ref)
+ except:
+ return file_in, file_ref, np.nan
+
+ samples, face_index = trimesh.sample.sample_surface(mesh_ref, num_samples)
+ face_normals_ref = mesh_ref.face_normals[face_index]
+
+ closest_points_in, distance, faces_in = proximity.get_closest_point_on_mesh(mesh_in, samples)
+ face_normals_in = mesh_in.face_normals[faces_in]
+
+ cosine = np.einsum('ij,ij->i', face_normals_ref, face_normals_in)
+ cosine = np.clip(cosine, -1, 1)
+ normal_c = np.nanmean(np.arccos(cosine))
+
+ return file_in, file_ref, normal_c
+
+
+def normal_error_approx(file_in, file_ref, num_samples=100000, num_processes=1):
+ import trimesh.sample
+
+ try:
+ mesh_in = trimesh.load(file_in)
+ mesh_ref = trimesh.load(file_ref)
+ except:
+ return file_in, file_ref, np.nan
+
+ samples_rec, face_index_rec = trimesh.sample.sample_surface(mesh_in, num_samples)
+ face_normals_rec = mesh_in.face_normals[face_index_rec]
+
+ samples_gt, face_index_gt = trimesh.sample.sample_surface(mesh_ref, num_samples)
+ face_normals_gt = mesh_ref.face_normals[face_index_gt]
+
+ _, rec_ids = kdtree_query_oneshot(pts=samples_gt, pts_query=samples_rec, k=1, sqr_dists=True)
+
+ face_normals_gt_nn = face_normals_gt[rec_ids]
+
+ cosine = np.einsum('ij,ij->i', face_normals_rec, face_normals_gt_nn)
+ cosine = np.clip(cosine, -1, 1)
+ normal_c = np.nanmean(np.arccos(cosine))
+
+ return file_in, file_ref, normal_c
+
+
+def rmse(predictions: np.ndarray, targets: np.ndarray):
+ return np.sqrt(((predictions - targets) ** 2).mean())
+
+
+def get_metric_mesh_single_file(gt_mesh_file: str, mesh_file: str, num_samples: int,
+ metric: typing.Literal['chamfer', 'iou', 'normals', 'f1'] = 'chamfer') -> float:
+
+ if os.path.isfile(mesh_file) and os.path.isfile(gt_mesh_file):
+ if metric == 'chamfer':
+ file_in, file_ref, metric_result = chamfer_distance(
+ file_in=mesh_file, file_ref=gt_mesh_file, samples_per_model=num_samples)
+ elif metric == 'iou':
+ file_in, file_ref, metric_result = intersection_over_union(
+ file_in=mesh_file, file_ref=gt_mesh_file, num_samples=num_samples)
+ elif metric == 'normals':
+ file_in, file_ref, metric_result = normal_error_approx(
+ file_in=mesh_file, file_ref=gt_mesh_file, num_samples=num_samples)
+ elif metric == 'f1':
+ file_in, file_ref, metric_result = f1_approx(
+ file_in=mesh_file, file_ref=gt_mesh_file, num_samples=num_samples)
+ else:
+ raise ValueError()
+ elif not os.path.isfile(mesh_file):
+ print('WARNING: mesh missing: {}'.format(mesh_file))
+ metric_result = np.nan
+ # raise FileExistsError()
+ elif not os.path.isfile(gt_mesh_file):
+ raise FileExistsError()
+ else:
+ raise NotImplementedError()
+
+ return metric_result
+
+
+def get_metric_meshes(result_file_template: typing.Sequence[str],
+ shape_list: typing.Sequence[str], gt_mesh_files: typing.Sequence[str],
+ num_samples=10000, metric: typing.Literal['chamfer', 'iou', 'normals', 'f1'] = 'chamfer',
+ num_processes=1) \
+ -> typing.Iterable[np.ndarray]:
+ from source.base.mp import start_process_pool
+
+ metric_results = []
+ for template in result_file_template:
+ cd_params = []
+ for sni, shape_name in enumerate(shape_list):
+ gt_mesh_file = gt_mesh_files[sni]
+ mesh_file = template.format(shape_name)
+ cd_params.append((gt_mesh_file, mesh_file, num_samples, metric))
+
+ metric_results.append(np.array(start_process_pool(
+ worker_function=get_metric_mesh_single_file, parameters=cd_params, num_processes=num_processes)))
+
+ return metric_results
diff --git a/ppsurf/source/base/mp.py b/ppsurf/source/base/mp.py
new file mode 100644
index 0000000000000000000000000000000000000000..afcaaa363e7c71d3fb70ccdbe1e77691c99fba58
--- /dev/null
+++ b/ppsurf/source/base/mp.py
@@ -0,0 +1,94 @@
+import subprocess
+import multiprocessing
+import typing
+import os
+
+
+def mp_worker(call):
+ """
+ Small function that starts a new thread with a system call. Used for thread pooling.
+ :param call:
+ :return:
+ """
+ call = call.split(' ')
+ verbose = call[-1] == '--verbose'
+ if verbose:
+ call = call[:-1]
+ subprocess.run(call)
+ else:
+ # subprocess.run(call, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # suppress outputs
+ subprocess.run(call, stdout=subprocess.DEVNULL)
+
+
+def start_process_pool(worker_function, parameters: typing.Iterable[typing.Iterable], num_processes, timeout=None):
+ from tqdm import tqdm
+
+ if len(parameters) > 0:
+ if num_processes <= 1:
+ print('Running loop for {} with {} calls on {} workers'.format(
+ str(worker_function), len(parameters), num_processes))
+ results = []
+ for c in tqdm(parameters):
+ results.append(worker_function(*c))
+ return results
+ else:
+ print('Running loop for {} with {} calls on {} subprocess workers'.format(
+ str(worker_function), len(parameters), num_processes))
+
+ results = []
+ context = multiprocessing.get_context('spawn') # 2023-10-25 fork got stuck on Linux (Python 3.9.12)
+ pool = context.Pool(processes=num_processes, maxtasksperchild=1)
+ try:
+ # quick and easy TQDM, a bit laggy but works
+ for result in pool.starmap(worker_function, tqdm(parameters, total=len(parameters))):
+ # for result in pool.starmap(worker_function, parameters): # without TQDM
+ results.append(result)
+ except KeyboardInterrupt:
+ # Allow ^C to interrupt from any thread.
+ exit()
+ pool.close()
+ return results
+ else:
+ return None
+
+
+def start_thread(func, args: typing.Sequence, kwargs={}):
+ import threading
+
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs)
+ thread.start()
+ return thread
+
+
+def start_process(func, args: typing.Sequence, start_method=None):
+ import multiprocessing as mp
+ if start_method is None:
+ proc = mp.Process(target=func, args=args)
+ else:
+ ctx = mp.get_context(start_method)
+ proc = ctx.Process(target=func, args=args)
+ proc.start()
+ return proc
+
+
+def get_multi_gpu_params(max_workers: typing.Optional[int] = None) -> typing.List[str]:
+ from torch.cuda import device_count
+
+ num_gpus = device_count()
+ if num_gpus <= 1:
+ return []
+
+ num_workers = os.cpu_count()
+ if max_workers is not None:
+ num_workers = min(num_workers, max_workers)
+
+ res_args = [
+ '--trainer.strategy', 'ddp',
+ # '--trainer.strategy', 'ddp_find_unused_parameters_true', # for debugging
+ '--model.init_args.workers', str(num_workers),
+ '--data.init_args.use_ddp', str(True),
+ '--data.init_args.workers', str(num_workers),
+ '--data.init_args.batch_size', str(50 // num_gpus),
+ ]
+
+ return res_args
diff --git a/ppsurf/source/base/nn.py b/ppsurf/source/base/nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..251f57e1a5d7883c67e9e7150c4e4c7b4c8eb85b
--- /dev/null
+++ b/ppsurf/source/base/nn.py
@@ -0,0 +1,701 @@
+import typing
+
+import pytorch_lightning as pl
+import torch
+from torch.nn import functional as f
+
+
+# https://github.com/numpy/numpy/issues/5228
+def cartesian_to_polar(pts_cart: torch.Tensor):
+ batch_size = pts_cart.shape[0]
+ num_pts = pts_cart.shape[1]
+ num_dim = pts_cart.shape[2]
+ pts_cart_flat = pts_cart.reshape((-1, num_dim))
+
+ def pol_2d():
+ x = pts_cart_flat[:, 0]
+ y = pts_cart_flat[:, 1]
+
+ r = torch.sqrt(x ** 2 + y ** 2)
+ phi = torch.atan2(y, x)
+ return torch.stack((r, phi), dim=1)
+
+ def pol_3d():
+ x = pts_cart_flat[:, 0]
+ y = pts_cart_flat[:, 1]
+ z = pts_cart_flat[:, 2]
+
+ hxy = torch.hypot(x, y)
+ r = torch.hypot(hxy, z)
+ el = torch.atan2(z, hxy)
+ az = torch.atan2(y, x)
+ return torch.stack((az, el, r), dim=1)
+
+ pts_spherical_flat = pol_2d() if num_dim == 2 else pol_3d()
+ pts_spherical = pts_spherical_flat.reshape((batch_size, num_pts, num_dim))
+
+ return pts_spherical
+
+
+def pos_encoding(pts: torch.Tensor, pos_encoding_levels: int, skip_last_dim=False):
+ """
+ use positional encoding on points
+ 3d example: [x, y, z] -> [f(cos, x), f(cos, y), f(cos, z), f(sin, x), f(sin, y), f(sin, z)]
+ :param pts: tensor[b, n, 2 or 3]
+ :param pos_encoding_levels: int
+ :param skip_last_dim: bool, skip last dim of input points (necessary for radius of polar coordinates)
+ :return:
+ """
+
+ if pos_encoding_levels <= 0:
+ return pts
+
+ batch_size = pts.shape[0]
+ num_pts = pts.shape[1]
+ num_dim = pts.shape[2]
+ num_dim_out = num_dim * 2 * pos_encoding_levels
+ pts_enc = torch.zeros((batch_size, num_pts, num_dim_out), device=pts.device)
+
+ for dim in range(num_dim):
+ for lvl in range(pos_encoding_levels):
+ dim_out = dim * lvl * 2
+ if skip_last_dim and dim == num_dim - 1:
+ pts_enc[..., dim_out] = pts[..., dim]
+ pts_enc[..., dim_out + num_dim] = pts[..., dim]
+ else:
+ pts_enc[..., dim_out] = torch.cos(pts[..., dim] * lvl * torch.pi * pow(2.0, lvl))
+ pts_enc[..., dim_out + num_dim] = torch.sin(pts[..., dim] * lvl * torch.pi * pow(2.0, lvl))
+
+ return pts_enc
+
+
+class AttentionPoco(pl.LightningModule):
+ # self-attention for feature vectors
+ # adapted from POCO attention
+ # https://github.com/valeoai/POCO/blob/4e39b5e722c82e91570df5f688e2c6e4870ffe65/networks/decoder/interp_attention.py
+
+ def __init__(self, net_size_max=1024, reduce=True):
+ super(AttentionPoco, self).__init__()
+
+ self.fc_query = torch.nn.Conv2d(net_size_max, 1, 1)
+ self.fc_value = torch.nn.Conv2d(net_size_max, net_size_max, 1)
+ self.reduce = reduce
+
+ def forward(self, feature_vectors: torch.Tensor):
+ # [feat_len, batch, num_feat] expected -> feature dim to dim 0
+ feature_vectors_t = torch.permute(feature_vectors, (1, 0, 2))
+
+ query = self.fc_query(feature_vectors_t).squeeze(0) # fc over feature dim -> [batch, num_feat]
+ value = self.fc_value(feature_vectors_t).permute(1, 2, 0) # -> [batch, num_feat, feat_len]
+
+ weights = torch.nn.functional.softmax(query, dim=-1) # softmax over num_feat -> [batch, num_feat]
+ if self.reduce:
+ feature_vector_out = torch.sum(value * weights.unsqueeze(-1).broadcast_to(value.shape), dim=1)
+ else:
+ feature_vector_out = (weights.unsqueeze(2) * value).permute(0, 2, 1)
+ return feature_vector_out
+
+
+def batch_quat_to_rotmat(q, out=None):
+ """
+ quaternion a + bi + cj + dk should be given in the form [a,b,c,d]
+ :param q:
+ :param out:
+ :return:
+ """
+
+ batchsize = q.size(0)
+
+ if out is None:
+ out = q.new_empty(batchsize, 3, 3)
+
+ # 2 / squared quaternion 2-norm
+ s = 2 / torch.sum(q.pow(2), 1)
+
+ # coefficients of the Hamilton product of the quaternion with itself
+ h = torch.bmm(q.unsqueeze(2), q.unsqueeze(1))
+
+ out[:, 0, 0] = 1 - (h[:, 2, 2] + h[:, 3, 3]).mul(s)
+ out[:, 0, 1] = (h[:, 1, 2] - h[:, 3, 0]).mul(s)
+ out[:, 0, 2] = (h[:, 1, 3] + h[:, 2, 0]).mul(s)
+
+ out[:, 1, 0] = (h[:, 1, 2] + h[:, 3, 0]).mul(s)
+ out[:, 1, 1] = 1 - (h[:, 1, 1] + h[:, 3, 3]).mul(s)
+ out[:, 1, 2] = (h[:, 2, 3] - h[:, 1, 0]).mul(s)
+
+ out[:, 2, 0] = (h[:, 1, 3] - h[:, 2, 0]).mul(s)
+ out[:, 2, 1] = (h[:, 2, 3] + h[:, 1, 0]).mul(s)
+ out[:, 2, 2] = 1 - (h[:, 1, 1] + h[:, 2, 2]).mul(s)
+
+ return out
+
+
+class STN(pl.LightningModule):
+ def __init__(self, net_size_max=1024, num_scales=1, num_points=500, dim=3, sym_op='max'):
+ super(STN, self).__init__()
+
+ self.net_size_max = net_size_max
+ self.dim = dim
+ self.sym_op = sym_op
+ self.num_scales = num_scales
+ self.num_points = num_points
+
+ self.conv1 = torch.nn.Conv1d(self.dim, 64, 1)
+ self.conv2 = torch.nn.Conv1d(64, 128, 1)
+ self.conv3 = torch.nn.Conv1d(128, self.net_size_max, 1)
+ self.mp1 = torch.nn.MaxPool1d(num_points)
+
+ self.fc1 = torch.nn.Linear(self.net_size_max, int(self.net_size_max / 2))
+ self.fc2 = torch.nn.Linear(int(self.net_size_max / 2), int(self.net_size_max / 4))
+ self.fc3 = torch.nn.Linear(int(self.net_size_max / 4), self.dim*self.dim)
+
+ self.bn1 = torch.nn.BatchNorm1d(64)
+ self.bn2 = torch.nn.BatchNorm1d(128)
+ self.bn3 = torch.nn.BatchNorm1d(self.net_size_max)
+ self.bn4 = torch.nn.BatchNorm1d(int(self.net_size_max / 2))
+ self.bn5 = torch.nn.BatchNorm1d(int(self.net_size_max / 4))
+
+ if self.num_scales > 1:
+ self.fc0 = torch.nn.Linear(self.net_size_max * self.num_scales, self.net_size_max)
+ self.bn0 = torch.nn.BatchNorm1d(self.net_size_max)
+
+ def forward(self, x):
+ batch_size = x.size()[0]
+ x = f.relu(self.bn1(self.conv1(x)))
+ x = f.relu(self.bn2(self.conv2(x)))
+ x = f.relu(self.bn3(self.conv3(x)))
+
+ # symmetric operation over all points
+ if self.num_scales == 1:
+ x = self.mp1(x)
+ else:
+ x_scales = x.new_empty(x.size(0), self.net_size_max * self.num_scales, 1)
+ for s in range(self.num_scales):
+ x_scales[:, s*self.net_size_max:(s+1)*self.net_size_max, :] = \
+ self.mp1(x[:, :, s*self.num_points:(s+1)*self.num_points])
+ x = x_scales
+
+ x = x.view(-1, self.net_size_max*self.num_scales)
+
+ if self.num_scales > 1:
+ x = f.relu(self.bn0(self.fc0(x)))
+
+ x = f.relu(self.bn4(self.fc1(x)))
+ x = f.relu(self.bn5(self.fc2(x)))
+ x = self.fc3(x)
+
+ iden = torch.eye(self.dim, dtype=x.dtype, device=x.device).view(1, self.dim*self.dim).repeat(batch_size, 1)
+ x = x + iden
+ x = x.view(-1, self.dim, self.dim)
+ return x
+
+
+class QSTN(pl.LightningModule):
+ def __init__(self, net_size_max=1024, num_scales=1, num_points=500, dim=3, sym_op='max'):
+ super(QSTN, self).__init__()
+
+ self.net_size_max = net_size_max
+ self.dim = dim
+ self.sym_op = sym_op
+ self.num_scales = num_scales
+ self.num_points = num_points
+
+ self.conv1 = torch.nn.Conv1d(self.dim, 64, 1)
+ self.conv2 = torch.nn.Conv1d(64, 128, 1)
+ self.conv3 = torch.nn.Conv1d(128, self.net_size_max, 1)
+ self.mp1 = torch.nn.MaxPool1d(num_points)
+ self.fc1 = torch.nn.Linear(self.net_size_max, int(self.net_size_max / 2))
+ self.fc2 = torch.nn.Linear(int(self.net_size_max / 2), int(self.net_size_max / 4))
+ self.fc3 = torch.nn.Linear(int(self.net_size_max / 4), 4)
+
+ self.bn1 = torch.nn.BatchNorm1d(64)
+ self.bn2 = torch.nn.BatchNorm1d(128)
+ self.bn3 = torch.nn.BatchNorm1d(self.net_size_max)
+ self.bn4 = torch.nn.BatchNorm1d(int(self.net_size_max / 2))
+ self.bn5 = torch.nn.BatchNorm1d(int(self.net_size_max / 4))
+
+ if self.num_scales > 1:
+ self.fc0 = torch.nn.Linear(self.net_size_max*self.num_scales, self.net_size_max)
+ self.bn0 = torch.nn.BatchNorm1d(self.net_size_max)
+
+ def forward(self, x):
+ x = f.relu(self.bn1(self.conv1(x)))
+ x = f.relu(self.bn2(self.conv2(x)))
+ x = f.relu(self.bn3(self.conv3(x)))
+
+ # symmetric operation over all points
+ if self.num_scales == 1:
+ x = self.mp1(x)
+ else:
+ x_scales = x.new_empty(x.size(0), self.net_size_max*self.num_scales, 1)
+ for s in range(self.num_scales):
+ x_scales[:, s*self.net_size_max:(s+1)*self.net_size_max, :] = \
+ self.mp1(x[:, :, s*self.num_points:(s+1)*self.num_points])
+ x = x_scales
+
+ x = x.view(-1, self.net_size_max*self.num_scales)
+
+ if self.num_scales > 1:
+ x = f.relu(self.bn0(self.fc0(x)))
+
+ x = f.relu(self.bn4(self.fc1(x)))
+ x = f.relu(self.bn5(self.fc2(x)))
+ x = self.fc3(x)
+
+ # add identity quaternion (so the network can output 0 to leave the point cloud identical)
+ iden = x.new_tensor([1, 0, 0, 0])
+ x_quat = x + iden
+
+ # convert quaternion to rotation matrix
+ x = batch_quat_to_rotmat(x_quat)
+
+ return x, x_quat
+
+
+class PointNetfeat(pl.LightningModule):
+ def __init__(self, net_size_max=1024, num_scales=1, num_points=500,
+ polar=False, use_point_stn=True, use_feat_stn=True,
+ output_size=100, sym_op='max', dim=3):
+ super(PointNetfeat, self).__init__()
+
+ self.net_size_max = net_size_max
+ self.num_points = num_points
+ self.num_scales = num_scales
+ self.polar = polar
+ self.use_point_stn = use_point_stn
+ self.use_feat_stn = use_feat_stn
+ self.sym_op = sym_op
+ self.output_size = output_size
+ self.dim = dim
+
+ if self.use_point_stn:
+ self.stn1 = QSTN(net_size_max=net_size_max, num_scales=self.num_scales,
+ num_points=num_points, dim=dim, sym_op=self.sym_op)
+
+ if self.use_feat_stn:
+ self.stn2 = STN(net_size_max=net_size_max, num_scales=self.num_scales,
+ num_points=num_points, dim=64, sym_op=self.sym_op)
+
+ self.conv0a = torch.nn.Conv1d(self.dim, 64, 1)
+ self.conv0b = torch.nn.Conv1d(64, 64, 1)
+ self.bn0a = torch.nn.BatchNorm1d(64)
+ self.bn0b = torch.nn.BatchNorm1d(64)
+ self.conv1 = torch.nn.Conv1d(64, 64, 1)
+ self.conv2 = torch.nn.Conv1d(64, 128, 1)
+ self.conv3 = torch.nn.Conv1d(128, output_size, 1)
+ self.bn1 = torch.nn.BatchNorm1d(64)
+ self.bn2 = torch.nn.BatchNorm1d(128)
+ self.bn3 = torch.nn.BatchNorm1d(output_size)
+
+ if self.num_scales > 1:
+ self.conv4 = torch.nn.Conv1d(output_size, output_size*self.num_scales, 1)
+ self.bn4 = torch.nn.BatchNorm1d(output_size*self.num_scales)
+
+ if self.sym_op == 'max':
+ self.mp1 = torch.nn.MaxPool1d(num_points)
+ elif self.sym_op == 'sum':
+ pass
+ elif self.sym_op == 'wsum':
+ pass
+ elif self.sym_op == 'att':
+ self.att = AttentionPoco(output_size)
+ else:
+ raise ValueError('Unsupported symmetric operation: {}'.format(self.sym_op))
+
+ def forward(self, x, pts_weights):
+
+ # input transform
+ if self.use_point_stn:
+ trans, trans_quat = self.stn1(x[:, :3, :]) # transform only point data
+ # an error here can mean that your input size is wrong (e.g. added normals in the point cloud files)
+ x_transformed = torch.bmm(trans, x[:, :3, :]) # transform only point data
+ x = torch.cat((x_transformed, x[:, 3:, :]), dim=1)
+ else:
+ trans = None
+ trans_quat = None
+
+ if bool(self.polar):
+ x = torch.permute(x, (0, 2, 1))
+ x = cartesian_to_polar(pts_cart=x)
+ x = torch.permute(x, (0, 2, 1))
+
+ # mlp (64,64)
+ x = f.relu(self.bn0a(self.conv0a(x)))
+ x = f.relu(self.bn0b(self.conv0b(x)))
+
+ # feature transform
+ if self.use_feat_stn:
+ trans2 = self.stn2(x)
+ x = torch.bmm(trans2, x)
+ else:
+ trans2 = None
+
+ # mlp (64,128,output_size)
+ x = f.relu(self.bn1(self.conv1(x)))
+ x = f.relu(self.bn2(self.conv2(x)))
+ x = self.bn3(self.conv3(x))
+
+ # mlp (output_size,output_size*num_scales)
+ if self.num_scales > 1:
+ x = self.bn4(self.conv4(f.relu(x)))
+
+ # symmetric max operation over all points
+ if self.num_scales == 1:
+ if self.sym_op == 'max':
+ x = self.mp1(x)
+ elif self.sym_op == 'sum':
+ x = torch.sum(x, 2, keepdim=True)
+ elif self.sym_op == 'wsum':
+ pts_weights_bc = torch.broadcast_to(torch.unsqueeze(pts_weights, 1), size=x.shape)
+ x = x * pts_weights_bc
+ x = torch.sum(x, 2, keepdim=True)
+ elif self.sym_op == 'att':
+ x = self.att(x)
+ else:
+ raise ValueError('Unsupported symmetric operation: {}'.format(self.sym_op))
+
+ else:
+ x_scales = x.new_empty(x.size(0), self.output_size*self.num_scales**2, 1)
+ if self.sym_op == 'max':
+ for s in range(self.num_scales):
+ x_scales[:, s*self.num_scales*self.output_size:(s+1)*self.num_scales*self.output_size, :] = \
+ self.mp1(x[:, :, s*self.num_points:(s+1)*self.num_points])
+ elif self.sym_op == 'sum':
+ for s in range(self.num_scales):
+ x_scales[:, s*self.num_scales*self.output_size:(s+1)*self.num_scales*self.output_size, :] = \
+ torch.sum(x[:, :, s*self.num_points:(s+1)*self.num_points], 2, keepdim=True)
+ else:
+ raise ValueError('Unsupported symmetric operation: %s' % self.sym_op)
+ x = x_scales
+
+ x = x.view(-1, self.output_size * self.num_scales ** 2)
+
+ return x, trans, trans_quat, trans2
+
+
+class MLP(pl.LightningModule):
+ def __init__(self, input_size: int, output_size: int, num_layers: int,
+ halving_size=True, final_bn_act=False, final_layer_norm=False,
+ activation: typing.Optional[typing.Callable[..., torch.nn.Module]] = torch.nn.ReLU,
+ norm: typing.Optional[typing.Callable[..., torch.nn.Module]] = torch.nn.BatchNorm1d,
+ fc_layer=torch.nn.Linear, dropout=0.0):
+ super(MLP, self).__init__()
+
+ self.num_layers = num_layers
+
+ if halving_size:
+ layer_sizes = [int(input_size / (2 ** i)) for i in range(num_layers)]
+ else:
+ layer_sizes = [input_size for _ in range(num_layers)]
+
+ fully_connected = [fc_layer(layer_sizes[i], layer_sizes[i+1]) for i in range(num_layers-1)]
+ norms = [norm(layer_sizes[i + 1]) for i in range(num_layers - 1)]
+
+ layers_list = []
+ for i in range(self.num_layers-1):
+ layers_list.append(torch.nn.Sequential(
+ fully_connected[i],
+ norms[i],
+ activation(),
+ torch.nn.Dropout(dropout),
+ ))
+
+ final_modules = [fc_layer(layer_sizes[-1], output_size)]
+ if final_bn_act:
+ if final_layer_norm:
+ final_modules.append(torch.nn.LayerNorm(output_size))
+ else:
+ final_modules.append(norm(output_size))
+ final_modules.append(activation())
+ final_layer = torch.nn.Sequential(*final_modules)
+ layers_list.append(final_layer)
+
+ self.layers = torch.nn.Sequential(*layers_list)
+
+ def forward(self, x):
+ x = self.layers.forward(x)
+ return x
+
+
+class ResidualBlock(pl.LightningModule):
+
+ def __init__(self, in_channels, out_channels, kernel_size, activation=torch.nn.ReLU()):
+ super().__init__()
+ bn = torch.nn.BatchNorm1d
+
+ self.cv0 = torch.nn.Conv1d(in_channels, in_channels // 2, 1)
+ self.bn0 = bn(in_channels // 2)
+ self.cv1 = FKAConvLayer(in_channels // 2, in_channels // 2, kernel_size, activation=activation)
+ self.bn1 = bn(in_channels // 2)
+ self.cv2 = torch.nn.Conv1d(in_channels // 2, out_channels, 1)
+ self.bn2 = bn(out_channels)
+ self.activation = torch.nn.ReLU(inplace=True)
+
+ self.shortcut = torch.nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels \
+ else torch.nn.Identity()
+ self.bn_shortcut = bn(out_channels) if in_channels != out_channels else torch.nn.Identity()
+
+ def forward(self, x, pts, support_points, neighbors_indices):
+ x_short = x
+ x = self.activation(self.bn0(self.cv0(x)))
+ x = self.activation(self.bn1(self.cv1(x, pts, support_points, neighbors_indices)))
+ x = self.bn2(self.cv2(x))
+
+ x_short = self.bn_shortcut(self.shortcut(x_short))
+ if x_short.shape[2] != x.shape[2]:
+ x_short = max_pool(x_short, neighbors_indices)
+
+ x = self.activation(x + x_short)
+
+ return x
+
+
+class FKAConvNetwork(pl.LightningModule):
+
+ def __init__(self, in_channels, out_channels, segmentation=False, hidden=64, dropout=0.5,
+ last_layer_additional_size=None, fix_support_number=False,
+ activation=torch.nn.ReLU(), x4d_bug_fixed=False):
+ super().__init__()
+
+ self.fixed = x4d_bug_fixed
+
+ self.lcp_preprocess = True
+ self.segmentation = segmentation
+ self.fix_support_point_number = fix_support_number
+ self.kernel_size = 16
+
+ self.cv0 = FKAConvLayer(in_channels, hidden, 16, activation=activation)
+
+ bn = torch.nn.BatchNorm1d
+ self.bn0 = bn(hidden)
+
+ def _make_resnet_block(in_channels_resnetb, out_channels_resnetb):
+ return ResidualBlock(in_channels=in_channels_resnetb, out_channels=out_channels_resnetb,
+ kernel_size=self.kernel_size, activation=activation)
+
+ self.resnetb01 = _make_resnet_block(hidden, hidden)
+ self.resnetb10 = _make_resnet_block(hidden, 2 * hidden)
+ self.resnetb11 = _make_resnet_block(2 * hidden, 2 * hidden)
+ self.resnetb20 = _make_resnet_block(2 * hidden, 4 * hidden)
+ self.resnetb21 = _make_resnet_block(4 * hidden, 4 * hidden)
+ self.resnetb30 = _make_resnet_block(4 * hidden, 8 * hidden)
+ self.resnetb31 = _make_resnet_block(8 * hidden, 8 * hidden)
+ self.resnetb40 = _make_resnet_block(8 * hidden, 16 * hidden)
+ self.resnetb41 = _make_resnet_block(16 * hidden, 16 * hidden)
+ if self.segmentation:
+
+ self.cv5 = torch.nn.Conv1d(32 * hidden, 16 * hidden, 1)
+ self.bn5 = bn(16 * hidden)
+ self.cv3d = torch.nn.Conv1d(24 * hidden, 8 * hidden, 1)
+ self.bn3d = bn(8 * hidden)
+ self.cv2d = torch.nn.Conv1d(12 * hidden, 4 * hidden, 1)
+ self.bn2d = bn(4 * hidden)
+ self.cv1d = torch.nn.Conv1d(6 * hidden, 2 * hidden, 1)
+ self.bn1d = bn(2 * hidden)
+ self.cv0d = torch.nn.Conv1d(3 * hidden, hidden, 1)
+ self.bn0d = bn(hidden)
+
+ if last_layer_additional_size is not None:
+ self.fcout = torch.nn.Conv1d(hidden + last_layer_additional_size, out_channels, 1)
+ else:
+ self.fcout = torch.nn.Conv1d(hidden, out_channels, 1)
+ else:
+ self.fcout = torch.nn.Conv1d(16 * hidden, out_channels, 1)
+
+ self.dropout = torch.nn.Dropout(dropout)
+ self.activation = torch.nn.ReLU()
+
+ def forward(self, data, spectral_only=False):
+ if not spectral_only:
+ from source.poco_data_loader import get_fkaconv_ids
+ spatial_data = get_fkaconv_ids(data)
+ for key, value in spatial_data.items():
+ data[key] = value
+
+ # x = data['x']
+ pts = data['pts']
+ x = torch.ones_like(pts)
+
+ x0 = self.activation(self.bn0(self.cv0(x, pts, pts, data['ids00'])))
+ x0 = self.resnetb01(x0, pts, pts, data['ids00'])
+ x1 = self.resnetb10(x0, pts, data['support1'], data['ids01'])
+ x1 = self.resnetb11(x1, data['support1'], data['support1'], data['ids11'])
+ x2 = self.resnetb20(x1, data['support1'], data['support2'], data['ids12'])
+ x2 = self.resnetb21(x2, data['support2'], data['support2'], data['ids22'])
+ x3 = self.resnetb30(x2, data['support2'], data['support3'], data['ids23'])
+ x3 = self.resnetb31(x3, data['support3'], data['support3'], data['ids33'])
+ x4 = self.resnetb40(x3, data['support3'], data['support4'], data['ids34'])
+ x4 = self.resnetb41(x4, data['support4'], data['support4'], data['ids44'])
+
+ if self.segmentation:
+ x5 = x4.max(dim=2, keepdim=True)[0].expand_as(x4)
+ x4d = self.activation(self.bn5(self.cv5(torch.cat([x4, x5], dim=1))))
+ if not self.fixed:
+ x4d = x4
+
+ x3d = interpolate(x4d, data['ids43'])
+ x3d = self.activation(self.bn3d(self.cv3d(torch.cat([x3d, x3], dim=1))))
+
+ x2d = interpolate(x3d, data['ids32'])
+ x2d = self.activation(self.bn2d(self.cv2d(torch.cat([x2d, x2], dim=1))))
+
+ x1d = interpolate(x2d, data['ids21'])
+ x1d = self.activation(self.bn1d(self.cv1d(torch.cat([x1d, x1], dim=1))))
+
+ xout = interpolate(x1d, data['ids10'])
+ xout = self.activation(self.bn0d(self.cv0d(torch.cat([xout, x0], dim=1))))
+ xout = self.dropout(xout)
+ xout = self.fcout(xout)
+ else:
+ xout = x4
+ xout = self.dropout(xout)
+ xout = self.fcout(xout)
+ xout = xout.mean(dim=2)
+ return xout
+
+
+class FKAConvLayer(pl.LightningModule):
+
+ def __init__(self, in_channels, out_channels, kernel_size=16, bias=False, dim=3,
+ activation=torch.nn.ReLU()):
+ super().__init__()
+
+ # parameters
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.bias = bias
+ self.dim = dim
+
+ # convolution kernel
+ self.cv = torch.nn.Conv2d(in_channels, out_channels, (1, kernel_size), bias=bias)
+
+ # normalization radius
+ self.norm_radius_momentum = 0.1
+ self.register_buffer('norm_radius', torch.Tensor(1,))
+ self.alpha = torch.nn.Parameter(torch.Tensor(1,), requires_grad=True)
+ self.beta = torch.nn.Parameter(torch.Tensor(1,), requires_grad=True)
+ torch.nn.init.ones_(self.norm_radius.data)
+ torch.nn.init.ones_(self.alpha.data)
+ torch.nn.init.ones_(self.beta.data)
+
+ # features to kernel weights
+ self.fc1 = torch.nn.Conv2d(self.dim, self.kernel_size, 1, bias=False)
+ self.fc2 = torch.nn.Conv2d(2 * self.kernel_size, self.kernel_size, 1, bias=False)
+ self.fc3 = torch.nn.Conv2d(2 * self.kernel_size, self.kernel_size, 1, bias=False)
+ self.bn1 = torch.nn.InstanceNorm2d(self.kernel_size, affine=True)
+ self.bn2 = torch.nn.InstanceNorm2d(self.kernel_size, affine=True)
+
+ self.activation = activation
+
+ # TODO: try sigmoid again
+ def forward(self, x, pts, support_points, neighbors_indices):
+
+ if x is None:
+ return None
+
+ pts = batch_gather(pts, dim=2, index=neighbors_indices).contiguous()
+ x = batch_gather(x, dim=2, index=neighbors_indices).contiguous()
+
+ # center the neighborhoods (local coordinates)
+ pts = pts - support_points.unsqueeze(3)
+
+ # normalize points
+ # compute distances from points to their support point
+ distances = torch.sqrt((pts.detach() ** 2).sum(1))
+
+ # update the normalization radius
+ if self.training:
+ mean_radius = distances.max(2)[0].mean()
+ self.norm_radius.data = (
+ self.norm_radius.data * (1 - self.norm_radius_momentum)
+ + mean_radius * self.norm_radius_momentum
+ )
+
+ # normalize
+ pts = pts / self.norm_radius
+
+ # estimate distance weights
+ distance_weight = torch.sigmoid(-self.alpha * distances + self.beta)
+ distance_weight_s = distance_weight.sum(2, keepdim=True)
+ distance_weight_s = distance_weight_s + (distance_weight_s == 0) + 1e-6
+ distance_weight = (
+ distance_weight / distance_weight_s * distances.shape[2]
+ ).unsqueeze(1)
+
+ # feature weighting matrix estimation
+ if pts.shape[3] == 1:
+ mat = self.activation(self.fc1(pts))
+ else:
+ mat = self.activation(self.bn1(self.fc1(pts)))
+ mp1 = torch.max(mat * distance_weight, dim=3, keepdim=True)[0].expand(
+ (-1, -1, -1, mat.shape[3])
+ )
+ mat = torch.cat([mat, mp1], dim=1)
+ if pts.shape[3] == 1:
+ mat = self.activation(self.fc2(mat))
+ else:
+ mat = self.activation(self.bn2(self.fc2(mat)))
+ mp2 = torch.max(mat * distance_weight, dim=3, keepdim=True)[0].expand(
+ (-1, -1, -1, mat.shape[3])
+ )
+ mat = torch.cat([mat, mp2], dim=1)
+ mat = self.activation(self.fc3(mat)) * distance_weight
+ # mat = torch.sigmoid(self.fc3(mat)) * distance_weight
+
+ # compute features
+ features = torch.matmul(
+ x.transpose(1, 2), mat.permute(0, 2, 3, 1)
+ ).transpose(1, 2)
+ features = self.cv(features).squeeze(3)
+
+ return features
+
+
+@torch.jit.script
+def batch_gather(data: torch.Tensor, dim: int, index: torch.Tensor):
+
+ index_shape = list(index.shape)
+ input_shape = list(data.shape)
+
+ views = [data.shape[0]] + [
+ 1 if i != dim else -1 for i in range(1, len(data.shape))
+ ]
+ expanse = list(data.shape)
+ expanse[0] = -1
+ expanse[dim] = -1
+ index = index.view(views).expand(expanse)
+
+ output = torch.gather(data, dim, index)
+
+ # compute final shape
+ output_shape = input_shape[0:dim] + index_shape[1:] + input_shape[dim+1:]
+
+ return output.reshape(output_shape)
+
+
+def max_pool(data: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
+ features = batch_gather(data, dim=2, index=indices).contiguous()
+ features = features.max(dim=3)[0]
+ return features
+
+
+# TODO: test sum
+def interpolate(x, neighbors_indices, method='mean'):
+
+ mask = (neighbors_indices > -1)
+ neighbors_indices[~mask] = 0
+
+ x = batch_gather(x, 2, neighbors_indices)
+
+ if neighbors_indices.shape[-1] > 1:
+ if method == 'mean':
+ return x.mean(-1)
+ elif method == 'max':
+ return x.mean(-1)[0]
+ else:
+ return x.squeeze(-1)
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
diff --git a/ppsurf/source/base/point_cloud.py b/ppsurf/source/base/point_cloud.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c10fe83209f6cc0d9475e83e55d3852466c9c7b
--- /dev/null
+++ b/ppsurf/source/base/point_cloud.py
@@ -0,0 +1,233 @@
+import numpy as np
+
+from source.base import fs
+
+
+def load_xyz(file_path):
+ data = np.loadtxt(file_path).astype('float32')
+ nan_lines = np.isnan(data).any(axis=1)
+ num_nan_lines = np.sum(nan_lines)
+ if num_nan_lines > 0:
+ data = data[~nan_lines] # filter rows with nan values
+ print('Ignored {} points containing NaN coordinates in point cloud {}'.format(num_nan_lines, file_path))
+ return data
+
+
+def write_ply(file_path: str, points: np.ndarray, normals=None, colors=None):
+ """
+ Write point cloud file as .ply.
+ :param file_path:
+ :param points:
+ :param normals:
+ :param colors:
+ :return: None
+ """
+
+ import trimesh
+
+ assert(file_path.endswith('.ply'))
+
+ def sanitize_inputs(arr: np.ndarray):
+ if arr is None:
+ return arr
+
+ # should be array
+ arr = np.asarray(arr)
+
+ # should be 2 dims
+ if arr.ndim == 1:
+ arr = np.expand_dims(arr, axis=0)
+
+ # convert 2d points to 3d
+ if arr.shape[1] == 2:
+ arr_2p5d = np.zeros((arr.shape[0], 3))
+ arr_2p5d[:, :2] = arr
+ arr_2p5d[:, 2] = 0.0
+ arr = arr_2p5d
+
+ # should be (n, dims)
+ if arr.shape[0] == 3 and arr.shape[1] != 3:
+ arr = arr.transpose([1, 0])
+
+ return arr
+
+ points = sanitize_inputs(points)
+ colors = sanitize_inputs(colors)
+ normals = sanitize_inputs(normals)
+
+ mesh = trimesh.Trimesh(vertices=points, vertex_colors=colors, vertex_normals=normals)
+ fs.make_dir_for_file(file_path)
+ mesh.export(file_path)
+
+
+def write_xyz(file_path, points: np.ndarray, normals=None, colors=None):
+ """
+ Write point cloud file.
+ :param file_path:
+ :param points:
+ :param normals:
+ :param colors:
+ :return: None
+ """
+
+ fs.make_dir_for_file(file_path)
+
+ if points.shape == (3,):
+ points = np.expand_dims(points, axis=0)
+
+ if points.shape[0] == 3 and points.shape[1] != 3:
+ points = points.transpose([1, 0])
+
+ if colors is not None and colors.shape[0] == 3 and colors.shape[1] != 3:
+ colors = colors.transpose([1, 0])
+
+ if normals is not None and normals.shape[0] == 3 and normals.shape[1] != 3:
+ normals = normals.transpose([1, 0])
+
+ with open(file_path, 'w') as fp:
+
+ # convert 2d points to 3d
+ if points.shape[1] == 2:
+ vertices_2p5d = np.zeros((points.shape[0], 3))
+ vertices_2p5d[:, :2] = points
+ vertices_2p5d[:, 2] = 0.0
+ points = vertices_2p5d
+
+ # write points
+ # meshlab doesn't like colors, only using normals. try cloud compare instead.
+ for vi, v in enumerate(points):
+ line_vertex = str(v[0]) + ' ' + str(v[1]) + ' ' + str(v[2]) + ' '
+ if normals is not None:
+ line_vertex += str(normals[vi][0]) + ' ' + str(normals[vi][1]) + ' ' + str(normals[vi][2]) + ' '
+ if colors is not None:
+ line_vertex += str(colors[vi][0]) + ' ' + str(colors[vi][1]) + ' ' + str(colors[vi][2]) + ' '
+ fp.write(line_vertex + '\n')
+
+
+def load_pcd(file_in):
+ # PCD: https://pointclouds.org/documentation/tutorials/pcd_file_format.html
+ # PCD RGB: http://docs.pointclouds.org/trunk/structpcl_1_1_r_g_b.html#a4ad91ab9726a3580e6dfc734ab77cd18
+
+ def read_header(lines_header):
+ header_info = dict()
+
+ def add_line_to_header_dict(header_dict, line, expected_field):
+ line_parts = line.split(sep=' ')
+ assert (line_parts[0] == expected_field), \
+ ('Warning: "' + expected_field + '" expected but not found in pcd header!')
+ header_dict[expected_field] = (' '.join(line_parts[1:])).replace('\n', '')
+
+ add_line_to_header_dict(header_info, lines_header[0], '#')
+ add_line_to_header_dict(header_info, lines_header[1], 'VERSION')
+ add_line_to_header_dict(header_info, lines_header[2], 'FIELDS')
+ add_line_to_header_dict(header_info, lines_header[3], 'SIZE')
+ add_line_to_header_dict(header_info, lines_header[4], 'TYPE')
+ add_line_to_header_dict(header_info, lines_header[5], 'COUNT')
+ add_line_to_header_dict(header_info, lines_header[6], 'WIDTH')
+ add_line_to_header_dict(header_info, lines_header[7], 'HEIGHT')
+ add_line_to_header_dict(header_info, lines_header[8], 'VIEWPOINT')
+ add_line_to_header_dict(header_info, lines_header[9], 'POINTS')
+ add_line_to_header_dict(header_info, lines_header[10], 'DATA')
+
+ assert header_info['VERSION'] == '0.7'
+ assert header_info['FIELDS'] == 'x y z rgb label'
+ assert header_info['SIZE'] == '4 4 4 4 4'
+ assert header_info['TYPE'] == 'F F F F U'
+ assert header_info['COUNT'] == '1 1 1 1 1'
+ # assert header_info['HEIGHT'] == '1'
+ assert header_info['DATA'] == 'ascii'
+ # assert header_info['WIDTH'] == header_info['POINTS']
+
+ return header_info
+
+ f = open(file_in, 'r')
+ f_lines = f.readlines()
+ f_lines_header = f_lines[:11]
+ f_lines_points = f_lines[11:]
+ header_info = read_header(f_lines_header)
+ header_info['_file_'] = file_in
+
+ num_points = int(header_info['POINTS'])
+ point_data_list_str_ = [l.split(sep=' ')[:3] for l in f_lines_points]
+ point_data_list = [[float(l[0]), float(l[1]), float(l[2])] for l in point_data_list_str_]
+
+ # filter nan points that appear through the blensor kinect sensor
+ point_data_list = [p for p in point_data_list if
+ (not np.isnan(p[0]) and not np.isnan(p[1]) and not np.isnan(p[2]))]
+
+ point_data = np.array(point_data_list)
+
+ f.close()
+
+ return point_data, header_info
+
+
+def numpy_to_ply(npy_file_in: str, ply_file_out: str):
+ pts_in = np.load(npy_file_in)
+
+ if pts_in.shape[1] >= 6:
+ normals = pts_in[:, 3:6]
+ else:
+ normals = None
+
+ if pts_in.shape[1] >= 9:
+ colors = pts_in[:, 6:9]
+ else:
+ colors = None
+
+ write_ply(file_path=ply_file_out, points=pts_in[:, :3], normals=normals, colors=colors)
+
+
+def sample_mesh(mesh_file, num_samples, rejection_radius=None):
+ import trimesh.sample
+
+ try:
+ mesh = trimesh.load(mesh_file)
+ except:
+ return np.zeros((0, 3))
+ samples, face_indices = trimesh.sample.sample_surface_even(mesh, num_samples, rejection_radius)
+ return samples
+
+
+if __name__ == '__main__':
+ # convert all datasets to ply
+ import os
+ from source.base import mp
+
+ datasets = [
+ 'abc_train',
+ # 'abc',
+ # 'abc_extra_noisy',
+ # 'abc_noisefree',
+ # 'famous_noisefree',
+ # 'famous_original',
+ # 'famous_extra_noisy',
+ # 'famous_sparse',
+ # 'famous_dense',
+ # 'thingi10k_scans_original',
+ # 'thingi10k_scans_dense',
+ # 'thingi10k_scans_sparse',
+ # 'thingi10k_scans_extra_noisy',
+ # 'thingi10k_scans_noisefree',
+ ]
+
+ # test on dir, multi-threaded
+ # num_processes = 0
+ # num_processes = 4
+ num_processes = 15
+ # num_processes = 48
+
+ for dataset in datasets:
+ in_dir = r'D:\repos\meshnet\datasets\{}\04_pts'.format(dataset)
+ in_files = os.listdir(in_dir)
+ in_files = [os.path.join(in_dir, f) for f in in_files if
+ os.path.isfile(os.path.join(in_dir, f)) and f.endswith('.npy')]
+ out_dir = in_dir + '_vis'
+ calls = []
+ for fi, f in enumerate(in_files):
+ file_base_name = os.path.basename(f)
+ file_out = os.path.join(out_dir, file_base_name[:-4] + '.ply')
+ if fs.call_necessary(f, file_out):
+ calls.append([f, file_out])
+ mp.start_process_pool(numpy_to_ply, calls, num_processes)
+
diff --git a/ppsurf/source/base/profiling.py b/ppsurf/source/base/profiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4a4a8ee581d5323c8d8dc02cdde817ebf2730af
--- /dev/null
+++ b/ppsurf/source/base/profiling.py
@@ -0,0 +1,77 @@
+import tracemalloc
+import linecache
+import typing
+
+
+def init_profiling():
+ tracemalloc.start()
+
+
+def compare_snaps(snap1, snap2, limit=50):
+ top_stats = snap1.compare_to(snap2, 'lineno')
+
+ for stat in top_stats[:limit]:
+ line = str(stat)
+ # if '~/' in line: # filter only lines from my own code
+ print(line)
+
+
+def display_top(snapshot: typing.Union[tracemalloc.Snapshot, None], key_type='lineno', limit=10):
+ if snapshot is None:
+ snapshot = tracemalloc.take_snapshot()
+
+ snapshot = snapshot.filter_traces((
+ tracemalloc.Filter(False, ''),
+ tracemalloc.Filter(False, ''),
+ ))
+ top_stats = snapshot.statistics(key_type)
+
+ print('Top %s lines' % limit)
+ for index, stat in enumerate(top_stats[:limit], 1):
+ frame = stat.traceback[0]
+ print('#%s: %s:%s: %.1f KiB'
+ % (index, frame.filename, frame.lineno, stat.size / 1024))
+ line = linecache.getline(frame.filename, frame.lineno).strip()
+ if line:
+ print(' %s' % line)
+
+ other = top_stats[limit:]
+ if other:
+ size = sum(stat.size for stat in other)
+ print('%s other: %.1f KiB' % (len(other), size / 1024))
+ total = sum(stat.size for stat in top_stats)
+ print('Total allocated size: %.1f KiB' % (total / 1024))
+
+
+def print_duration(func, params: dict, name: str):
+ import time
+ start = time.time()
+ res = func(**params)
+ end = time.time()
+ print('{} took: {}'.format(name, end - start))
+ return res
+
+
+def print_memory(min_num_bytes=0):
+ import sys
+ import gc
+
+ objects = gc.get_objects()
+
+ objects_sizes = dict()
+ for obj_id, obj in enumerate(objects):
+ num_bytes = sys.getsizeof(obj)
+ if num_bytes >= min_num_bytes:
+ name = str(type(obj)) + str(obj_id)
+ objects_sizes[name] = num_bytes
+
+ objects_sizes_sorted = dict(sorted(objects_sizes.items(), key=lambda item: item[1], reverse=True))
+ print('Objects in scope:')
+ for name, num_bytes in objects_sizes_sorted.items():
+ print('{}: {} kB'.format(name, num_bytes / 1024))
+ print('')
+
+
+def get_now_str():
+ import datetime
+ return str(datetime.datetime.now())
diff --git a/ppsurf/source/base/proximity.py b/ppsurf/source/base/proximity.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd359f346611dc699d4845c34a99d94ee545d510
--- /dev/null
+++ b/ppsurf/source/base/proximity.py
@@ -0,0 +1,89 @@
+import typing
+
+import numpy as np
+import trimesh
+from torch.jit import ignore
+from pysdf import SDF
+from scipy.spatial import KDTree as ScipyKDTree
+from pykdtree.kdtree import KDTree as PyKDTree
+
+
+def get_signed_distance_pysdf_inaccurate(in_mesh: trimesh.Trimesh, query_pts_ms: np.ndarray):
+ # pysdf is inaccurate up to +-0.1 of BB
+ # this is not good enough for P2S at resolution > 64
+ # but the sign is correct
+ sdf = SDF(in_mesh.vertices, in_mesh.faces)
+ dists_ms = sdf(query_pts_ms)
+ return dists_ms
+
+
+def get_closest_point_on_mesh(mesh: trimesh.Trimesh, query_pts, batch_size=1000):
+
+ import trimesh.proximity as prox
+
+ # process batches because trimesh's closest_point very inefficient on memory, similar to signed_distance()
+ closest_pts = np.zeros((query_pts.shape[0], 3), dtype=np.float32)
+ dists = np.zeros(query_pts.shape[0], dtype=np.float32)
+ tri_ids = np.zeros(query_pts.shape[0], dtype=np.int32)
+ pts_ids = np.arange(query_pts.shape[0])
+ pts_ids_split = np.array_split(pts_ids, max(1, int(query_pts.shape[0] / batch_size)))
+ for pts_ids_batch in pts_ids_split:
+ query_pts_batch = query_pts[pts_ids_batch]
+ closest_pts_batch, dist_batch, tri_id_batch = prox.closest_point(mesh=mesh, points=query_pts_batch)
+ closest_pts[pts_ids_batch] = closest_pts_batch
+ dists[pts_ids_batch] = dist_batch
+ tri_ids[pts_ids_batch] = tri_id_batch
+
+ return closest_pts, dists, tri_ids
+
+
+def make_kdtree(pts: np.ndarray):
+
+ # old reliable
+ def _make_kdtree_scipy(pts_np: np.ndarray):
+ # otherwise KDTree construction may run out of recursions
+ import sys
+ leaf_size = 1000
+ sys.setrecursionlimit(int(max(1000, round(pts_np.shape[0] / leaf_size))))
+ _kdtree = ScipyKDTree(pts_np, leaf_size)
+ return _kdtree
+
+ # a lot slower than scipy
+ # def _make_kdtree_sklearn(pts: np.ndarray):
+ # from sklearn.neighbors import NearestNeighbors
+ # nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto', n_jobs=workers).fit(pts)
+ # # indices_batch: np.ndarray = nbrs.kneighbors(pts_query_np, return_distance=False)
+
+ # fastest even without multiprocessing
+ def _make_kdtree_pykdtree(pts_np: np.ndarray):
+ _kdtree = PyKDTree(pts_np, leafsize=10)
+ return _kdtree
+
+ # kdtree = _make_kdtree_scipy(pts)
+ kdtree = _make_kdtree_pykdtree(pts)
+ return kdtree
+
+
+def query_kdtree(kdtree: typing.Union[ScipyKDTree, PyKDTree],
+ pts_query: np.ndarray, k: int, sqr_dists=False, **kwargs):
+ # sqr_dists: some speed-up if True but distorted distances
+
+ if isinstance(kdtree, ScipyKDTree):
+ kdtree = typing.cast(ScipyKDTree, kdtree)
+ nn_dists, nn_ids = kdtree.query(x=pts_query, k=k, workers=kwargs.get('workers', -1))
+ if not sqr_dists:
+ nn_dists = nn_dists ** 2
+ elif isinstance(kdtree, PyKDTree):
+ kdtree = typing.cast(PyKDTree, kdtree)
+ nn_dists, nn_ids = kdtree.query(pts_query, k=k, sqr_dists=sqr_dists)
+ else:
+ raise NotImplementedError('Unknown kdtree type: {}'.format(type(kdtree)))
+ return nn_dists, nn_ids
+
+
+@ignore # can't compile kdtree
+def kdtree_query_oneshot(pts: np.ndarray, pts_query: np.ndarray, k: int, sqr_dists=False, **kwargs):
+ # sqr_dists: some speed-up if True but distorted distances
+ kdtree = make_kdtree(pts)
+ nn_dists, nn_ids = query_kdtree(kdtree=kdtree, pts_query=pts_query, k=k, sqr_dists=sqr_dists, **kwargs)
+ return nn_dists, nn_ids
diff --git a/ppsurf/source/base/visualization.py b/ppsurf/source/base/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..312e2eb76b47dee2ef9d0a16a89f2d97a71794c3
--- /dev/null
+++ b/ppsurf/source/base/visualization.py
@@ -0,0 +1,394 @@
+import typing
+
+import numpy as np
+
+from source.base import fs
+
+
+def plot_pts_scalar_data(pts: np.ndarray, data: np.ndarray, file_path: str, prop_min: float, prop_max: float,
+ color_channel=0):
+ from source.base import point_cloud
+
+ prop = np.nan_to_num(data.astype(np.float))
+ # [prop_min, prop_max] -> [0, 1]
+ prop[prop < prop_min] = prop_min
+ prop[prop > prop_max] = prop_max
+ prop -= prop_min
+ prop /= prop_max - prop_min
+
+ colors = np.zeros(pts.shape, dtype=np.float)
+ colors[:, color_channel] = prop.flatten()
+
+ point_cloud.write_ply(file_path=file_path, points=pts, colors=colors)
+
+
+def render_scene(mesh_file: str, rendering_file: str):
+ import os
+ import typing
+ import trimesh
+ import pyglet
+ pyglet.options["headless"] = True
+
+ if not os.path.isfile(mesh_file):
+ print('Rendering failed, file not found: ' + mesh_file)
+ return
+
+ # trimesh.util.attach_to_log() # print logged messages
+
+ if mesh_file.endswith('.npy'): # assume point cloud
+ vertices = np.load(mesh_file)
+ mesh: typing.Union[trimesh.Trimesh, trimesh.PointCloud] = trimesh.PointCloud(vertices)
+ else:
+ mesh: typing.Union[trimesh.Trimesh, trimesh.PointCloud] = trimesh.load(file_obj=mesh_file)
+
+ scene = mesh.scene()
+
+ img_bytes = None
+ while img_bytes is None:
+ try:
+ scene.set_camera(angles=(np.pi * 0.25, np.pi * 0.25, 0.0), distance=2.2, fov=(45, 45))
+ img_bytes = scene.save_image(resolution=(1024, 1024), visible=True)
+ except pyglet.canvas.xlib.NoSuchDisplayException as _:
+ img_bytes = bytes([0]) # Pyglet can't render without real screen attached -> will always fail on servers by default
+ except BaseException as E:
+ print('ERROR rendering {} to {}: {}'.format(mesh_file, rendering_file, str(E)))
+
+ # try again after waiting for window to have opened fully (hopefully)
+ import time
+ time.sleep(1.0)
+
+ if img_bytes is not None:
+ fs.make_dir_for_file(rendering_file)
+ with open(rendering_file, 'wb') as text_file:
+ text_file.write(img_bytes)
+
+
+def distances_to_vertex_colors(dist_per_vertex: np.ndarray, cut_off=0.3):
+
+ dist_per_vertex[dist_per_vertex > cut_off] = cut_off
+ dist_per_vertex /= cut_off
+
+ # use parula colormap: dist=0 -> blue, dist=0.5 -> green, dist=1.0 -> yellow
+ parulas_indices = (dist_per_vertex * (parula_cm.shape[0] - 1)).astype(np.int32)
+ dist_greater_than_norm_target = parulas_indices >= parula_cm.shape[0]
+ parulas_indices[dist_greater_than_norm_target] = parula_cm.shape[0] - 1
+ dist_colors_rgb = [parula_cm[parula_indices] for parula_indices in parulas_indices]
+
+ return dist_colors_rgb
+
+
+def visualize_chamfer_distance(
+ input_mesh_file: str, reference_mesh_file: str, output_mesh_file: str,
+ min_vertex_count: typing.Union[int, None], dist_cut_off=0.3, distance_batch_size=1000):
+
+ import trimesh
+ from trimesh.base import Trimesh
+ from trimesh.visual.color import VertexColor
+
+ from source.base import proximity
+
+ in_mesh: Trimesh = trimesh.load(input_mesh_file)
+ ref_mesh = trimesh.load(reference_mesh_file)
+
+ if min_vertex_count is not None:
+ while in_mesh.vertices.shape[0] < min_vertex_count:
+ in_mesh = in_mesh.subdivide()
+
+ closest_pts, dist_rec_verts_to_ref, tri_id = proximity.get_closest_point_on_mesh(
+ mesh=ref_mesh, query_pts=in_mesh.vertices, batch_size=int(distance_batch_size))
+ vertex_colors = distances_to_vertex_colors(dist_rec_verts_to_ref, float(dist_cut_off))
+
+ in_mesh.visual = VertexColor(vertex_colors)
+ fs.make_dir_for_file(output_mesh_file)
+ in_mesh.export(output_mesh_file)
+
+
+def visualize_chamfer_distance_pool(
+ rec_meshes: typing.Sequence[str], gt_meshes: typing.Sequence[str], output_mesh_files: typing.Sequence[str],
+ min_vertex_count=10000, dist_cut_off=0.3, distance_batch_size=1000, num_processes=0):
+
+ from source.base.mp import start_process_pool
+
+ assert(len(rec_meshes) == len(gt_meshes))
+
+ cd_vis_params = [(rec_meshes[i], gt_meshes[i], output_mesh_files[i],
+ min_vertex_count, dist_cut_off, distance_batch_size)
+ for i in range(len(rec_meshes))
+ if fs.call_necessary([rec_meshes[i], gt_meshes[i]], output_mesh_files[i])]
+ start_process_pool(worker_function=visualize_chamfer_distance,
+ parameters=cd_vis_params, num_processes=num_processes)
+
+
+def render_meshes(all_meshes_in, all_renders_out, workers=1):
+ from source.base.mp import start_process_pool
+
+ def render_meshes_pool(mesh_paths_in: typing.Sequence[str], image_paths_out: typing.Sequence[str]):
+ assert (len(mesh_paths_in) == len(image_paths_out))
+ render_scene_params = [(mesh_paths_in[i], image_paths_out[i])
+ for i in range(len(mesh_paths_in))
+ if fs.call_necessary(mesh_paths_in[i], image_paths_out[i])]
+ start_process_pool(worker_function=render_scene,
+ parameters=render_scene_params, num_processes=workers)
+
+ # do rendering
+ render_meshes_pool(all_meshes_in, all_renders_out)
+
+
+parula_cm = np.array([
+ [0.2422, 0.1504, 0.6603],
+ [0.2444, 0.1534, 0.6728],
+ [0.2464, 0.1569, 0.6847],
+ [0.2484, 0.1607, 0.6961],
+ [0.2503, 0.1648, 0.7071],
+ [0.2522, 0.1689, 0.7179],
+ [0.254, 0.1732, 0.7286],
+ [0.2558, 0.1773, 0.7393],
+ [0.2576, 0.1814, 0.7501],
+ [0.2594, 0.1854, 0.761],
+ [0.2611, 0.1893, 0.7719],
+ [0.2628, 0.1932, 0.7828],
+ [0.2645, 0.1972, 0.7937],
+ [0.2661, 0.2011, 0.8043],
+ [0.2676, 0.2052, 0.8148],
+ [0.2691, 0.2094, 0.8249],
+ [0.2704, 0.2138, 0.8346],
+ [0.2717, 0.2184, 0.8439],
+ [0.2729, 0.2231, 0.8528],
+ [0.274, 0.228, 0.8612],
+ [0.2749, 0.233, 0.8692],
+ [0.2758, 0.2382, 0.8767],
+ [0.2766, 0.2435, 0.884],
+ [0.2774, 0.2489, 0.8908],
+ [0.2781, 0.2543, 0.8973],
+ [0.2788, 0.2598, 0.9035],
+ [0.2794, 0.2653, 0.9094],
+ [0.2798, 0.2708, 0.915],
+ [0.2802, 0.2764, 0.9204],
+ [0.2806, 0.2819, 0.9255],
+ [0.2809, 0.2875, 0.9305],
+ [0.2811, 0.293, 0.9352],
+ [0.2813, 0.2985, 0.9397],
+ [0.2814, 0.304, 0.9441],
+ [0.2814, 0.3095, 0.9483],
+ [0.2813, 0.315, 0.9524],
+ [0.2811, 0.3204, 0.9563],
+ [0.2809, 0.3259, 0.96],
+ [0.2807, 0.3313, 0.9636],
+ [0.2803, 0.3367, 0.967],
+ [0.2798, 0.3421, 0.9702],
+ [0.2791, 0.3475, 0.9733],
+ [0.2784, 0.3529, 0.9763],
+ [0.2776, 0.3583, 0.9791],
+ [0.2766, 0.3638, 0.9817],
+ [0.2754, 0.3693, 0.984],
+ [0.2741, 0.3748, 0.9862],
+ [0.2726, 0.3804, 0.9881],
+ [0.271, 0.386, 0.9898],
+ [0.2691, 0.3916, 0.9912],
+ [0.267, 0.3973, 0.9924],
+ [0.2647, 0.403, 0.9935],
+ [0.2621, 0.4088, 0.9946],
+ [0.2591, 0.4145, 0.9955],
+ [0.2556, 0.4203, 0.9965],
+ [0.2517, 0.4261, 0.9974],
+ [0.2473, 0.4319, 0.9983],
+ [0.2424, 0.4378, 0.9991],
+ [0.2369, 0.4437, 0.9996],
+ [0.2311, 0.4497, 0.9995],
+ [0.225, 0.4559, 0.9985],
+ [0.2189, 0.462, 0.9968],
+ [0.2128, 0.4682, 0.9948],
+ [0.2066, 0.4743, 0.9926],
+ [0.2006, 0.4803, 0.9906],
+ [0.195, 0.4861, 0.9887],
+ [0.1903, 0.4919, 0.9867],
+ [0.1869, 0.4975, 0.9844],
+ [0.1847, 0.503, 0.9819],
+ [0.1831, 0.5084, 0.9793],
+ [0.1818, 0.5138, 0.9766],
+ [0.1806, 0.5191, 0.9738],
+ [0.1795, 0.5244, 0.9709],
+ [0.1785, 0.5296, 0.9677],
+ [0.1778, 0.5349, 0.9641],
+ [0.1773, 0.5401, 0.9602],
+ [0.1768, 0.5452, 0.956],
+ [0.1764, 0.5504, 0.9516],
+ [0.1755, 0.5554, 0.9473],
+ [0.174, 0.5605, 0.9432],
+ [0.1716, 0.5655, 0.9393],
+ [0.1686, 0.5705, 0.9357],
+ [0.1649, 0.5755, 0.9323],
+ [0.161, 0.5805, 0.9289],
+ [0.1573, 0.5854, 0.9254],
+ [0.154, 0.5902, 0.9218],
+ [0.1513, 0.595, 0.9182],
+ [0.1492, 0.5997, 0.9147],
+ [0.1475, 0.6043, 0.9113],
+ [0.1461, 0.6089, 0.908],
+ [0.1446, 0.6135, 0.905],
+ [0.1429, 0.618, 0.9022],
+ [0.1408, 0.6226, 0.8998],
+ [0.1383, 0.6272, 0.8975],
+ [0.1354, 0.6317, 0.8953],
+ [0.1321, 0.6363, 0.8932],
+ [0.1288, 0.6408, 0.891],
+ [0.1253, 0.6453, 0.8887],
+ [0.1219, 0.6497, 0.8862],
+ [0.1185, 0.6541, 0.8834],
+ [0.1152, 0.6584, 0.8804],
+ [0.1119, 0.6627, 0.877],
+ [0.1085, 0.6669, 0.8734],
+ [0.1048, 0.671, 0.8695],
+ [0.1009, 0.675, 0.8653],
+ [0.0964, 0.6789, 0.8609],
+ [0.0914, 0.6828, 0.8562],
+ [0.0855, 0.6865, 0.8513],
+ [0.0789, 0.6902, 0.8462],
+ [0.0713, 0.6938, 0.8409],
+ [0.0628, 0.6972, 0.8355],
+ [0.0535, 0.7006, 0.8299],
+ [0.0433, 0.7039, 0.8242],
+ [0.0328, 0.7071, 0.8183],
+ [0.0234, 0.7103, 0.8124],
+ [0.0155, 0.7133, 0.8064],
+ [0.0091, 0.7163, 0.8003],
+ [0.0046, 0.7192, 0.7941],
+ [0.0019, 0.722, 0.7878],
+ [0.0009, 0.7248, 0.7815],
+ [0.0018, 0.7275, 0.7752],
+ [0.0046, 0.7301, 0.7688],
+ [0.0094, 0.7327, 0.7623],
+ [0.0162, 0.7352, 0.7558],
+ [0.0253, 0.7376, 0.7492],
+ [0.0369, 0.74, 0.7426],
+ [0.0504, 0.7423, 0.7359],
+ [0.0638, 0.7446, 0.7292],
+ [0.077, 0.7468, 0.7224],
+ [0.0899, 0.7489, 0.7156],
+ [0.1023, 0.751, 0.7088],
+ [0.1141, 0.7531, 0.7019],
+ [0.1252, 0.7552, 0.695],
+ [0.1354, 0.7572, 0.6881],
+ [0.1448, 0.7593, 0.6812],
+ [0.1532, 0.7614, 0.6741],
+ [0.1609, 0.7635, 0.6671],
+ [0.1678, 0.7656, 0.6599],
+ [0.1741, 0.7678, 0.6527],
+ [0.1799, 0.7699, 0.6454],
+ [0.1853, 0.7721, 0.6379],
+ [0.1905, 0.7743, 0.6303],
+ [0.1954, 0.7765, 0.6225],
+ [0.2003, 0.7787, 0.6146],
+ [0.2061, 0.7808, 0.6065],
+ [0.2118, 0.7828, 0.5983],
+ [0.2178, 0.7849, 0.5899],
+ [0.2244, 0.7869, 0.5813],
+ [0.2318, 0.7887, 0.5725],
+ [0.2401, 0.7905, 0.5636],
+ [0.2491, 0.7922, 0.5546],
+ [0.2589, 0.7937, 0.5454],
+ [0.2695, 0.7951, 0.536],
+ [0.2809, 0.7964, 0.5266],
+ [0.2929, 0.7975, 0.517],
+ [0.3052, 0.7985, 0.5074],
+ [0.3176, 0.7994, 0.4975],
+ [0.3301, 0.8002, 0.4876],
+ [0.3424, 0.8009, 0.4774],
+ [0.3548, 0.8016, 0.4669],
+ [0.3671, 0.8021, 0.4563],
+ [0.3795, 0.8026, 0.4454],
+ [0.3921, 0.8029, 0.4344],
+ [0.405, 0.8031, 0.4233],
+ [0.4184, 0.803, 0.4122],
+ [0.4322, 0.8028, 0.4013],
+ [0.4463, 0.8024, 0.3904],
+ [0.4608, 0.8018, 0.3797],
+ [0.4753, 0.8011, 0.3691],
+ [0.4899, 0.8002, 0.3586],
+ [0.5044, 0.7993, 0.348],
+ [0.5187, 0.7982, 0.3374],
+ [0.5329, 0.797, 0.3267],
+ [0.547, 0.7957, 0.3159],
+ [0.5609, 0.7943, 0.305],
+ [0.5748, 0.7929, 0.2941],
+ [0.5886, 0.7913, 0.2833],
+ [0.6024, 0.7896, 0.2726],
+ [0.6161, 0.7878, 0.2622],
+ [0.6297, 0.7859, 0.2521],
+ [0.6433, 0.7839, 0.2423],
+ [0.6567, 0.7818, 0.2329],
+ [0.6701, 0.7796, 0.2239],
+ [0.6833, 0.7773, 0.2155],
+ [0.6963, 0.775, 0.2075],
+ [0.7091, 0.7727, 0.1998],
+ [0.7218, 0.7703, 0.1924],
+ [0.7344, 0.7679, 0.1852],
+ [0.7468, 0.7654, 0.1782],
+ [0.759, 0.7629, 0.1717],
+ [0.771, 0.7604, 0.1658],
+ [0.7829, 0.7579, 0.1608],
+ [0.7945, 0.7554, 0.157],
+ [0.806, 0.7529, 0.1546],
+ [0.8172, 0.7505, 0.1535],
+ [0.8281, 0.7481, 0.1536],
+ [0.8389, 0.7457, 0.1546],
+ [0.8495, 0.7435, 0.1564],
+ [0.86, 0.7413, 0.1587],
+ [0.8703, 0.7392, 0.1615],
+ [0.8804, 0.7372, 0.165],
+ [0.8903, 0.7353, 0.1695],
+ [0.9, 0.7336, 0.1749],
+ [0.9093, 0.7321, 0.1815],
+ [0.9184, 0.7308, 0.189],
+ [0.9272, 0.7298, 0.1973],
+ [0.9357, 0.729, 0.2061],
+ [0.944, 0.7285, 0.2151],
+ [0.9523, 0.7284, 0.2237],
+ [0.9606, 0.7285, 0.2312],
+ [0.9689, 0.7292, 0.2373],
+ [0.977, 0.7304, 0.2418],
+ [0.9842, 0.733, 0.2446],
+ [0.99, 0.7365, 0.2429],
+ [0.9946, 0.7407, 0.2394],
+ [0.9966, 0.7458, 0.2351],
+ [0.9971, 0.7513, 0.2309],
+ [0.9972, 0.7569, 0.2267],
+ [0.9971, 0.7626, 0.2224],
+ [0.9969, 0.7683, 0.2181],
+ [0.9966, 0.774, 0.2138],
+ [0.9962, 0.7798, 0.2095],
+ [0.9957, 0.7856, 0.2053],
+ [0.9949, 0.7915, 0.2012],
+ [0.9938, 0.7974, 0.1974],
+ [0.9923, 0.8034, 0.1939],
+ [0.9906, 0.8095, 0.1906],
+ [0.9885, 0.8156, 0.1875],
+ [0.9861, 0.8218, 0.1846],
+ [0.9835, 0.828, 0.1817],
+ [0.9807, 0.8342, 0.1787],
+ [0.9778, 0.8404, 0.1757],
+ [0.9748, 0.8467, 0.1726],
+ [0.972, 0.8529, 0.1695],
+ [0.9694, 0.8591, 0.1665],
+ [0.9671, 0.8654, 0.1636],
+ [0.9651, 0.8716, 0.1608],
+ [0.9634, 0.8778, 0.1582],
+ [0.9619, 0.884, 0.1557],
+ [0.9608, 0.8902, 0.1532],
+ [0.9601, 0.8963, 0.1507],
+ [0.9596, 0.9023, 0.148],
+ [0.9595, 0.9084, 0.145],
+ [0.9597, 0.9143, 0.1418],
+ [0.9601, 0.9203, 0.1382],
+ [0.9608, 0.9262, 0.1344],
+ [0.9618, 0.932, 0.1304],
+ [0.9629, 0.9379, 0.1261],
+ [0.9642, 0.9437, 0.1216],
+ [0.9657, 0.9494, 0.1168],
+ [0.9674, 0.9552, 0.1116],
+ [0.9692, 0.9609, 0.1061],
+ [0.9711, 0.9667, 0.1001],
+ [0.973, 0.9724, 0.0938],
+ [0.9749, 0.9782, 0.0872],
+ [0.9769, 0.9839, 0.0805]
+])
diff --git a/ppsurf/source/cli.py b/ppsurf/source/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..19bb6b091511a36aab0fabeac5c90aea5934139f
--- /dev/null
+++ b/ppsurf/source/cli.py
@@ -0,0 +1,118 @@
+import os
+import sys
+import typing
+from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING, Union, Set
+from pathlib import Path
+import abc
+
+from tqdm import tqdm
+
+import pytorch_lightning.profilers
+from pytorch_lightning.cli import LightningCLI, Namespace, LightningArgumentParser
+from pytorch_lightning.callbacks import TQDMProgressBar
+
+from source.base.profiling import get_now_str
+
+
+class PPSProgressBar(TQDMProgressBar): # disable validation prog bar
+ def init_validation_tqdm(self):
+ bar_disabled = tqdm(disable=True)
+ return bar_disabled
+
+
+class PPSProfiler(pytorch_lightning.profilers.PyTorchProfiler):
+ def __init__(
+ self,
+ dirpath: Optional[Union[str, Path]] = None,
+ filename: Optional[str] = None,
+ group_by_input_shapes: bool = False,
+ emit_nvtx: bool = False,
+ export_to_chrome: bool = True,
+ row_limit: int = 20,
+ sort_by_key: Optional[str] = None,
+ record_module_names: bool = True,
+ with_stack: bool = False,
+ **profiler_kwargs: Any,
+ ) -> None:
+ super().__init__(dirpath=dirpath, filename=filename, group_by_input_shapes=group_by_input_shapes,
+ emit_nvtx=emit_nvtx, export_to_chrome=export_to_chrome, row_limit=row_limit,
+ sort_by_key=sort_by_key, record_module_names=record_module_names, with_stack=with_stack,
+ **profiler_kwargs)
+
+
+class CLI(LightningCLI):
+ def __init__(self, model_class, subclass_mode_model, datamodule_class, subclass_mode_data):
+ print('{}: Starting {}'.format(get_now_str(), ' '.join(sys.argv)))
+ sys.argv = self.handle_rec_subcommand(sys.argv) # only call this with args from system command line
+ super().__init__(
+ model_class=model_class, subclass_mode_model=subclass_mode_model,
+ datamodule_class=datamodule_class, subclass_mode_data=subclass_mode_data,
+ save_config_kwargs={'overwrite': True})
+ print('{}: Finished {}'.format(get_now_str(), ' '.join(sys.argv)))
+
+ def cur_config(self) -> Namespace:
+ return self.config[self.config.subcommand]
+
+ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
+ # fundamentals
+ parser.add_argument('--debug', type=bool, default=False,
+ help='set to True if you want debug outputs to validate the model')
+
+ @abc.abstractmethod
+ def handle_rec_subcommand(self, args: typing.List[str]) -> typing.List[str]:
+ """
+ Replace rec subcommand with predict and its default parameters before any argparse.
+ Args:
+ args: typing.List[str]
+
+ Returns:
+ new_args: typing.List[str]
+ """
+ pass
+
+ # def before_fit(self):
+ # pass
+ #
+ # def after_fit(self):
+ # pass
+ #
+ # def before_predict(self):
+ # pass
+ #
+ # def after_predict(self):
+ # pass
+
+ def before_instantiate_classes(self):
+ import torch
+ # torch.set_float32_matmul_precision('medium') # PPSurf 50NN: 5.123h, ABC CD 0.012920511
+ torch.set_float32_matmul_precision('high') # PPSurf 50NN: xh, ABC CD y
+ # torch.set_float32_matmul_precision('highest') # PPSurf 50NN: xh, ABC CD y
+
+ if bool(self.cur_config().debug):
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+ os.environ['TORCH_DISTRIBUTED_DEBUG '] = '1'
+
+ self.cur_config().trainer.detect_anomaly = True
+
+ # def instantiate_classes(self):
+ # pass
+
+ # def instantiate_trainer(self):
+ # pass
+
+ # def parse_arguments(self, parser, args):
+ # pass
+
+ # def setup_parser(self, add_subcommands, main_kwargs, subparser_kwargs):
+ # pass
+
+ @staticmethod
+ def subcommands() -> Dict[str, Set[str]]:
+ """Defines the list of available subcommands and the arguments to skip."""
+ return {
+ 'fit': {'model', 'train_dataloaders', 'val_dataloaders', 'datamodule'},
+ # 'validate': {'model', 'dataloaders', 'datamodule'}, # no val for this
+ 'test': {'model', 'dataloaders', 'datamodule'},
+ 'predict': {'model', 'dataloaders', 'datamodule'},
+ # 'tune': {'model', 'train_dataloaders', 'val_dataloaders', 'datamodule'},
+ }
diff --git a/ppsurf/source/figures/__init__.py b/ppsurf/source/figures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ppsurf/source/figures/cmap_YlOrRd.npy b/ppsurf/source/figures/cmap_YlOrRd.npy
new file mode 100644
index 0000000000000000000000000000000000000000..50bfed513350ef86f21bd73fd87ec21f5640c806
--- /dev/null
+++ b/ppsurf/source/figures/cmap_YlOrRd.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b16fd2b27bed54dd85f847a6f6695cd822a1bec4649fa77272376cc9889b0a1a
+size 6272
diff --git a/ppsurf/source/figures/comp_ablation_abc_maxnoise.py b/ppsurf/source/figures/comp_ablation_abc_maxnoise.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bdf2982b53c15a0c189b8200a1fe5d10fc1cf41
--- /dev/null
+++ b/ppsurf/source/figures/comp_ablation_abc_maxnoise.py
@@ -0,0 +1,61 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source import make_comparison
+
+if __name__ == '__main__':
+
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_ablation_abc_extra_noisy'
+ # comp_name = 'comp_ablation_abc_varnoise'
+ comp_dir = 'results/comp'
+
+ dataset = 'abc_extra_noisy'
+
+ params = [
+ '--comp_name', comp_name,
+ '--comp_dir', comp_dir,
+ '--comp_mean_name', 'comp_mean',
+ '--html_name', 'comp_all',
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+ '--result_headers',
+ 'ppsurf_merge_cat',
+ 'ppsurf_vanilla_zeros_global',
+ 'ppsurf_vanilla_zeros_local',
+ 'ppsurf_vanilla_qpoints',
+ 'ppsurf_sym_max',
+ 'ppsurf_10nn',
+ 'p2s2_25nn',
+ 'p2s2_50nn',
+ 'p2s2_merge_sum',
+ 'p2s2_200nn',
+ '--result_paths',
+ r'results/p2s2_vanilla/' + dataset,
+ r'results/p2s2_vanilla_zeros_global/' + dataset,
+ r'results/p2s2_vanilla_zeros_local/' + dataset,
+ r'results/p2s2_vanilla_qpoints/' + dataset,
+ r'results/p2s2_sym_max/' + dataset,
+ r'results/p2s2_10nn/' + dataset,
+ r'results/p2s2_25nn/' + dataset,
+ r'results/p2s2_50nn/' + dataset,
+ r'results/p2s2_merge_sum/' + dataset,
+ r'results/p2s2_200nn/' + dataset,
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+ ]
+ make_comparison.main(argv=params)
+
+ # Convert xlsx to latex
+ from source.base.evaluation import xslx_to_latex
+ ablation_xlsx = os.path.join(comp_dir, comp_name, 'comp_mean.xlsx')
+ xslx_to_latex(ablation_xlsx, ablation_xlsx[:-5] + '.tex')
+
+ print('Points2Surf2 Comparison is finished!')
diff --git a/ppsurf/source/figures/comp_ablation_abc_varnoise.py b/ppsurf/source/figures/comp_ablation_abc_varnoise.py
new file mode 100644
index 0000000000000000000000000000000000000000..36b1850d81417871d5e8970ff326066bf2169fd7
--- /dev/null
+++ b/ppsurf/source/figures/comp_ablation_abc_varnoise.py
@@ -0,0 +1,50 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source import make_comparison
+
+if __name__ == '__main__':
+
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_ablation_abc_varnoise'
+ dataset = 'abc'
+
+ methods = [
+ 'ppsurf_vanilla',
+ 'ppsurf_vanilla_zeros_global',
+ 'ppsurf_vanilla_zeros_local',
+ 'ppsurf_vanilla_qpoints',
+ 'ppsurf_sym_max',
+ 'ppsurf_10nn',
+ 'ppsurf_25nn',
+ 'ppsurf_50nn',
+ 'ppsurf_merge_sum',
+ 'ppsurf_200nn',
+ ]
+
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ make_comparison.main(argv=params)
+
+ # Convert xlsx to latex
+ from source.base.evaluation import xslx_to_latex
+ ablation_xlsx = os.path.join('results', 'comp', dataset, comp_name + '.xlsx')
+ xslx_to_latex(ablation_xlsx, ablation_xlsx[:-5] + '.tex')
diff --git a/ppsurf/source/figures/comp_ablation_all.py b/ppsurf/source/figures/comp_ablation_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..d455181e6c681d09d3d0bfb5af7cbf4d5da66204
--- /dev/null
+++ b/ppsurf/source/figures/comp_ablation_all.py
@@ -0,0 +1,66 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_ablation_all'
+
+ datasets = [
+ 'abc',
+ 'abc_extra_noisy',
+ 'abc_noisefree',
+ 'famous_noisefree',
+ 'famous_original',
+ 'famous_extra_noisy',
+ 'famous_sparse',
+ 'famous_dense',
+ 'thingi10k_scans_original',
+ 'thingi10k_scans_dense',
+ 'thingi10k_scans_sparse',
+ 'thingi10k_scans_extra_noisy',
+ 'thingi10k_scans_noisefree',
+ ]
+
+ methods = [
+ 'ppsurf_25nn',
+ 'ppsurf_50nn',
+ 'ppsurf_vanilla',
+ 'ppsurf_merge_sum',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.3f')
diff --git a/ppsurf/source/figures/comp_ablation_dense.py b/ppsurf/source/figures/comp_ablation_dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b49a884c6aaae69f94d67d63748715f76cf5bf2
--- /dev/null
+++ b/ppsurf/source/figures/comp_ablation_dense.py
@@ -0,0 +1,55 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_ablation_dense'
+
+ datasets = [
+ 'famous_dense',
+ 'thingi10k_scans_dense',
+ ]
+
+ methods = [
+ 'ppsurf_25nn',
+ 'ppsurf_50nn',
+ 'ppsurf_vanilla',
+ 'ppsurf_merge_sum',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.3f')
diff --git a/ppsurf/source/figures/comp_ablation_noisefree.py b/ppsurf/source/figures/comp_ablation_noisefree.py
new file mode 100644
index 0000000000000000000000000000000000000000..98ef6f56607714f568c5d2e94ae8b9b8fc377ad7
--- /dev/null
+++ b/ppsurf/source/figures/comp_ablation_noisefree.py
@@ -0,0 +1,56 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_ablation_noisefree'
+
+ datasets = [
+ 'abc_noisefree',
+ 'famous_noisefree',
+ 'thingi10k_scans_noisefree',
+ ]
+
+ methods = [
+ 'ppsurf_25nn',
+ 'ppsurf_50nn',
+ 'ppsurf_vanilla',
+ 'ppsurf_merge_sum',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.3f')
diff --git a/ppsurf/source/figures/comp_ablation_noisy.py b/ppsurf/source/figures/comp_ablation_noisy.py
new file mode 100644
index 0000000000000000000000000000000000000000..34c6e4fb8d3609466537b26fa59e8cd10e7bd0e5
--- /dev/null
+++ b/ppsurf/source/figures/comp_ablation_noisy.py
@@ -0,0 +1,56 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_ablation_noisy'
+
+ datasets = [
+ 'abc_extra_noisy',
+ 'famous_extra_noisy',
+ 'thingi10k_scans_extra_noisy',
+ ]
+
+ methods = [
+ 'ppsurf_25nn',
+ 'ppsurf_50nn',
+ 'ppsurf_vanilla',
+ 'ppsurf_merge_sum',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.3f')
diff --git a/ppsurf/source/figures/comp_ablation_original.py b/ppsurf/source/figures/comp_ablation_original.py
new file mode 100644
index 0000000000000000000000000000000000000000..77f35d568975d8996ae9fb626254202ba030b370
--- /dev/null
+++ b/ppsurf/source/figures/comp_ablation_original.py
@@ -0,0 +1,55 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_ablation_original'
+
+ datasets = [
+ 'famous_original',
+ 'thingi10k_scans_original',
+ ]
+
+ methods = [
+ 'ppsurf_25nn',
+ 'p2s2_50nn',
+ 'p2s2_vanilla',
+ 'p2s2_merge_sum',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.3f')
diff --git a/ppsurf/source/figures/comp_ablation_sparse.py b/ppsurf/source/figures/comp_ablation_sparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..6405dd29fe461566f046fc06f784f501fee17a2c
--- /dev/null
+++ b/ppsurf/source/figures/comp_ablation_sparse.py
@@ -0,0 +1,55 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_ablation_sparse'
+
+ datasets = [
+ 'famous_sparse',
+ 'thingi10k_scans_sparse',
+ ]
+
+ methods = [
+ 'ppsurf_25nn',
+ 'p2s2_50nn',
+ 'p2s2_vanilla',
+ 'p2s2_merge_sum',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.3f')
diff --git a/ppsurf/source/figures/comp_all.py b/ppsurf/source/figures/comp_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..0262251439ff92efca4d2e54185ad0c935bbdd4a
--- /dev/null
+++ b/ppsurf/source/figures/comp_all.py
@@ -0,0 +1,69 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_all'
+
+ datasets = [
+ 'abc',
+ 'abc_extra_noisy',
+ 'abc_noisefree',
+ 'famous_noisefree',
+ 'famous_original',
+ 'famous_extra_noisy',
+ 'famous_sparse',
+ 'famous_dense',
+ 'thingi10k_scans_original',
+ 'thingi10k_scans_dense',
+ 'thingi10k_scans_sparse',
+ 'thingi10k_scans_extra_noisy',
+ 'thingi10k_scans_noisefree',
+ ]
+
+ methods = [
+ 'neural_imls',
+ 'pgr',
+ 'sap_optim',
+ 'sap',
+ 'p2s',
+ 'poco Pts_gen_sub3k_iter10',
+ 'ppsurf_50nn',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.2f')
diff --git a/ppsurf/source/figures/comp_dense.py b/ppsurf/source/figures/comp_dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fa4c4a45dbec04822ef22d551682026ce469fed
--- /dev/null
+++ b/ppsurf/source/figures/comp_dense.py
@@ -0,0 +1,58 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_dense'
+
+ datasets = [
+ 'famous_dense',
+ 'thingi10k_scans_dense',
+ ]
+
+ methods = [
+ 'neural_imls',
+ 'pgr',
+ 'sap_optim',
+ 'sap',
+ 'p2s',
+ 'poco Pts_gen_sub3k_iter10',
+ 'ppsurf_50nn',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.2f')
diff --git a/ppsurf/source/figures/comp_noisefree.py b/ppsurf/source/figures/comp_noisefree.py
new file mode 100644
index 0000000000000000000000000000000000000000..57b0db82398a9e0d276696df0dcb4ec329711fed
--- /dev/null
+++ b/ppsurf/source/figures/comp_noisefree.py
@@ -0,0 +1,59 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_noisefree'
+
+ datasets = [
+ 'abc_noisefree',
+ 'famous_noisefree',
+ 'thingi10k_scans_noisefree',
+ ]
+
+ methods = [
+ 'neural_imls',
+ 'pgr',
+ 'sap_optim',
+ 'sap',
+ 'p2s',
+ 'poco Pts_gen_sub3k_iter10',
+ 'ppsurf_50nn',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.2f')
diff --git a/ppsurf/source/figures/comp_noisy.py b/ppsurf/source/figures/comp_noisy.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f5edb760510428c0543cd71f593c052f8e88923
--- /dev/null
+++ b/ppsurf/source/figures/comp_noisy.py
@@ -0,0 +1,59 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_noisy'
+
+ datasets = [
+ 'abc_extra_noisy',
+ 'famous_extra_noisy',
+ 'thingi10k_scans_extra_noisy',
+ ]
+
+ methods = [
+ 'neural_imls',
+ 'pgr',
+ 'sap_optim',
+ 'sap',
+ 'p2s',
+ 'poco Pts_gen_sub3k_iter10',
+ 'ppsurf_50nn',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.2f')
diff --git a/ppsurf/source/figures/comp_original.py b/ppsurf/source/figures/comp_original.py
new file mode 100644
index 0000000000000000000000000000000000000000..983f5785e6c29f297673eadefc45b2155598da0b
--- /dev/null
+++ b/ppsurf/source/figures/comp_original.py
@@ -0,0 +1,58 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_original'
+
+ datasets = [
+ 'famous_original',
+ 'thingi10k_scans_original',
+ ]
+
+ methods = [
+ 'neural_imls',
+ 'pgr',
+ 'sap_optim',
+ 'sap',
+ 'p2s',
+ 'poco Pts_gen_sub3k_iter10',
+ 'ppsurf_50nn',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.2f')
diff --git a/ppsurf/source/figures/comp_sparse.py b/ppsurf/source/figures/comp_sparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..594a38e3cb7cc387c006c7cf9b1eb5db2692f226
--- /dev/null
+++ b/ppsurf/source/figures/comp_sparse.py
@@ -0,0 +1,58 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.evaluation import merge_comps
+from source import make_comparison
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ comp_name = 'comp_sparse'
+
+ datasets = [
+ 'famous_sparse',
+ 'thingi10k_scans_sparse',
+ ]
+
+ methods = [
+ 'neural_imls',
+ 'pgr',
+ 'sap_optim',
+ 'sap',
+ 'p2s',
+ 'poco Pts_gen_sub3k_iter10',
+ 'ppsurf_50nn',
+ ]
+
+ # Run all comparisons
+ for dataset in datasets:
+ print('Running comparison for dataset {}'.format(dataset))
+ params = [
+ '--comp_name', dataset,
+ '--comp_dir', 'results/comp',
+ '--comp_mean_name', comp_name,
+ '--html_name', comp_name,
+ '--data_dir', 'datasets/' + dataset,
+ '--testset', 'testset.txt',
+ '--results_dir', 'results',
+
+ '--workers', str(workers),
+ '--dist_cut_off', str(0.01),
+
+ '--result_headers', *methods,
+ '--result_paths', *[r'results/{}/'.format(m) + dataset for m in methods],
+ ]
+ try:
+ make_comparison.main(argv=params)
+ except Exception as e:
+ print('Error in dataset {}: {}'.format(dataset, e))
+
+ # Merge all comparisons
+ comp_files = ['results/comp/{}/{}.xlsx'.format(dataset, comp_name) for dataset in datasets]
+ comp_merged_xlsx = 'results/comp/reports/{}.xlsx'.format(comp_name)
+ comp_merged_latex = 'results/comp/reports/{}.tex'.format(comp_name)
+ merge_comps(comp_files, comp_merged_xlsx, comp_merged_latex, methods_order=methods, float_format='%.2f')
diff --git a/ppsurf/source/figures/fix_imls.py b/ppsurf/source/figures/fix_imls.py
new file mode 100644
index 0000000000000000000000000000000000000000..eba48442887a60e036e1ecb20999a5de056f4ada
--- /dev/null
+++ b/ppsurf/source/figures/fix_imls.py
@@ -0,0 +1,77 @@
+import os
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.fs import call_necessary, make_dir_for_file
+from source.base.mp import start_process_pool
+
+
+def _revert_normalization(src, gt, dst):
+ import trimesh
+
+ if not os.path.isfile(src):
+ print('File not found: {}'.format(src))
+ return
+
+ mesh_gt = trimesh.load(gt)
+ bounds = mesh_gt.extents
+ if bounds.min() == 0.0:
+ return
+
+ # translate to origin
+ translation = (mesh_gt.bounds[0] + mesh_gt.bounds[1]) * 0.5
+ translation_inv = trimesh.transformations.translation_matrix(direction=translation)
+
+ # scale to unit cube
+ scale = 1.0 / bounds.max()
+ scale_trafo_inv = trimesh.transformations.scale_matrix(factor=1.0 / scale)
+
+ mesh_rec = trimesh.load(src)
+
+ mesh_rec.apply_transform(scale_trafo_inv)
+ mesh_rec.apply_transform(translation_inv)
+
+ make_dir_for_file(dst)
+ mesh_rec.export(dst)
+
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ datasets_path = 'datasets'
+ # extra-noisy is not provided
+ datasets = [
+ 'abc',
+ # 'abc_extra_noisy',
+ 'abc_noisefree',
+ 'famous_noisefree',
+ 'famous_original',
+ # 'famous_extra_noisy',
+ 'famous_sparse',
+ 'famous_dense',
+ 'thingi10k_scans_original',
+ 'thingi10k_scans_dense',
+ 'thingi10k_scans_sparse',
+ # 'thingi10k_scans_extra_noisy',
+ 'thingi10k_scans_noisefree',
+ ]
+ results_path = 'results'
+
+ for d in datasets:
+ test_set = os.path.join(datasets_path, d, 'testset.txt')
+ test_shapes = [l.strip() for l in open(test_set, 'r').readlines()]
+ test_files = [os.path.join(datasets_path, d, '03_meshes', s + '.ply') for s in test_shapes]
+
+ rec_meshes_in = [os.path.join(results_path, 'neural_imls misaligned', d, 'meshes', s + '.ply') for s in test_shapes]
+ rec_meshes_out = [os.path.join(results_path, 'neural_imls', d, 'meshes', s + '.ply') for s in test_shapes]
+
+ def _make_params(l1, l2, l3):
+ params = tuple(zip(l1, l2, l3))
+ params_necessary = [p for p in params if call_necessary((p[0], p[1]), p[2], verbose=False)]
+ return params_necessary
+
+ start_process_pool(_revert_normalization, _make_params(rec_meshes_in, test_files, rec_meshes_out),
+ num_processes=workers)
diff --git a/ppsurf/source/figures/prepare_figures.py b/ppsurf/source/figures/prepare_figures.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb0d482d4b4f93d87748075b44bb75131ed3e654
--- /dev/null
+++ b/ppsurf/source/figures/prepare_figures.py
@@ -0,0 +1,218 @@
+import os
+import shutil
+import typing
+import sys
+sys.path.append(os.path.abspath('.'))
+
+from source.base.mp import start_process_pool
+from source.base.fs import call_necessary
+
+
+def _copy_files(src, dst):
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
+ if os.path.isfile(src):
+ shutil.copy(src, dst)
+ else:
+ print('File not found: {}'.format(src))
+
+
+def _get_vertex_distances(
+ input_mesh_file: str, reference_mesh_file: str, output_mesh_file: str,
+ min_vertex_count: typing.Union[int, None], distance_batch_size=10000):
+
+ import numpy as np
+ import trimesh
+ from trimesh.base import Trimesh
+
+ from source.base import proximity, fs
+
+ in_mesh: Trimesh = trimesh.load(input_mesh_file)
+ ref_mesh = trimesh.load(reference_mesh_file)
+
+ if min_vertex_count is not None:
+ while in_mesh.vertices.shape[0] < min_vertex_count:
+ in_mesh = in_mesh.subdivide()
+
+ closest_pts, dist_rec_verts_to_ref, tri_id = proximity.get_closest_point_on_mesh(
+ mesh=ref_mesh, query_pts=in_mesh.vertices, batch_size=int(distance_batch_size))
+
+ fs.make_dir_for_file(output_mesh_file)
+ np.savez(file=output_mesh_file, vertices=in_mesh.vertices, faces=in_mesh.faces, distances=dist_rec_verts_to_ref)
+
+ # debug output
+ from trimesh.visual.color import VertexColor
+ from source.base.visualization import distances_to_vertex_colors
+ dist_cut_off = 0.1
+ vertex_colors = distances_to_vertex_colors(dist_rec_verts_to_ref, dist_cut_off)
+ in_mesh.visual = VertexColor(vertex_colors)
+ in_mesh.export(output_mesh_file[:-4] + '_dist_col.ply')
+ pass
+
+
+def _assemble_figure_data(figure_path, objects, datasets_path, results_path, methods, workers=0):
+ gt_in = [os.path.join(datasets_path, o[0], '03_meshes', o[1] + '.ply') for o in objects]
+ gt_out = [os.path.join(figure_path, o[0], o[1], 'gt.ply') for o in objects]
+
+ pc_in = [os.path.join(datasets_path, o[0], '04_pts_vis', o[1] + '.xyz.plys') for o in objects]
+ pc_out = [os.path.join(figure_path, o[0], o[1], 'pc.ply') for o in objects]
+
+ method_in = [[os.path.join(results_path, m, o[0], 'meshes', o[1] + '.ply') for m in methods] for o in objects]
+ method_out = [[os.path.join(figure_path, o[0], o[1], m + '.ply') for m in methods] for o in objects]
+ method_dist_out = [[os.path.join(figure_path, o[0], o[1], m + '_dist.npz') for m in methods] for o in objects]
+
+ def _flatten(l):
+ return [item for sublist in l for item in sublist]
+
+ def _make_params(l1, l2):
+ params = tuple(zip(l1, l2))
+ params_necessary = [p for p in params if call_necessary(p[0], p[1], verbose=False)]
+ return params_necessary
+
+ start_process_pool(_copy_files, _make_params(gt_in, gt_out), num_processes=workers)
+ start_process_pool(_copy_files, _make_params(_flatten(method_in), _flatten(method_out)), num_processes=workers)
+
+ from source.base.point_cloud import numpy_to_ply
+ start_process_pool(numpy_to_ply, _make_params(pc_in, pc_out), num_processes=workers)
+
+ min_vertex_count = 10000
+ distance_batch_size = 1000
+ params = [tuple(zip(m, [gt_out[mi]] * len(m), method_dist_out[mi],
+ [min_vertex_count] * len(m), [distance_batch_size] * len(m)))
+ for mi, m in enumerate(method_out)]
+ params_flat = _flatten(params)
+ params_flat_necessary = [p for p in params_flat if call_necessary((p[0], p[1]), p[2], verbose=False)]
+ start_process_pool(_get_vertex_distances, params_flat_necessary, num_processes=workers)
+
+
+if __name__ == '__main__':
+ workers = 15 # for training PC
+ # workers = 8 # for Windows missing fork
+ # workers = 4 # for strange window size bug
+ # workers = 0 # debug
+
+ datasets_path = 'datasets'
+
+ results_path = 'results'
+ methods_comp = [
+ 'neural_imls',
+ 'pgr',
+ 'sap_optim',
+ 'sap',
+ 'p2s',
+ 'poco Pts_gen_sub3k_iter10',
+ 'ppsurf_merge_sum',
+ ]
+
+ figure_path_comp = 'results/figures/comp'
+ objects_comp = [
+ ('abc', '00010429_fc56088abf10474bba06f659_trimesh_004'),
+ ('abc', '00011602_c087f04c99464bf7ab2380c4_trimesh_000'),
+ ('abc', '00013052_9084b77631834dd584b2ac93_trimesh_033'),
+ ('abc', '00014452_55263057b8f440a0bb50b260_trimesh_017'),
+ ('abc', '00017014_fbef9df8f24940a0a2df6ccb_trimesh_001'),
+ ('abc', '00990573_d1914c7f68f9a6b58bed9421_trimesh_000'),
+ ('abc_noisefree', '00012754_b17656deace54b61b3130c7e_trimesh_019'),
+ ('abc_noisefree', '00011696_1ca1ad2a09504ff1bf83cf74_trimesh_029'),
+ ('abc_noisefree', '00016680_5a9a2a2a5eb64501863164e9_trimesh_000'),
+ ('abc_noisefree', '00017682_f0ea0b827ae34675a4162390_trimesh_003'),
+ ('abc_noisefree', '00019114_87f2e2e15b2746ffa4a2fd9a_trimesh_003'),
+ ('abc_noisefree', '00011171_db6e2de6f4ae4ec493ebe2aa_trimesh_047'),
+ ('abc_noisefree', '00011171_db6e2de6f4ae4ec493ebe2aa_trimesh_047'),
+ ('abc_extra_noisy', '00013052_9084b77631834dd584b2ac93_trimesh_033'),
+ ('abc_extra_noisy', '00014101_7b2cf2f0fd464e80a5062901_trimesh_000'),
+ ('abc_extra_noisy', '00014155_a04f003ab9b74295bbed8248_trimesh_000'),
+ ('abc_extra_noisy', '00016144_8dadc1c5885e427292f34e71_trimesh_026'),
+ ('abc_extra_noisy', '00018947_b302da1a26764dd0afcd55ff_trimesh_075'),
+ ('abc_extra_noisy', '00019203_1bcd132f82c84761b4e9851d_trimesh_001'),
+ ('abc_extra_noisy', '00992690_ed0f9f06ad21b92e7ffab606_trimesh_002'),
+ ('famous_dense', 'tortuga'),
+ ('famous_dense', 'yoda'),
+ ('famous_dense', 'armadillo'),
+ ('famous_extra_noisy', 'Utah_teapot_(solid)'),
+ ('famous_extra_noisy', 'happy'),
+ ('famous_noisefree', 'galera'),
+ ('famous_original', 'hand'),
+ ('famous_original', 'horse'),
+ ('famous_sparse', 'xyzrgb_statuette'),
+ ('famous_sparse', 'dragon'),
+ ('thingi10k_scans_dense', '58982'),
+ ('thingi10k_scans_dense', '70558'),
+ ('thingi10k_scans_dense', '77245'),
+ ('thingi10k_scans_dense', '88053'),
+ ('thingi10k_scans_extra_noisy', '86848'),
+ ('thingi10k_scans_extra_noisy', '83022'),
+ ('thingi10k_scans_noisefree', '103354'),
+ ('thingi10k_scans_noisefree', '53159'),
+ ('thingi10k_scans_noisefree', '54725'),
+ ('thingi10k_scans_original', '53920'),
+ ('thingi10k_scans_original', '64194'),
+ ('thingi10k_scans_original', '73075'),
+ ('thingi10k_scans_sparse', '80650'),
+ ('thingi10k_scans_sparse', '81368'),
+ ('thingi10k_scans_sparse', '81762'),
+ ('real_world', 'madersperger_cropped'),
+ ('real_world', 'statue_ps_outliers2_cropped'),
+ # ('real_world', 'statue_ps_pointcleannet_cropped'),
+ ('real_world', 'torch_ps_outliers2'),
+ ]
+ # for general comparison
+ _assemble_figure_data(figure_path=figure_path_comp, methods=methods_comp, objects=objects_comp,
+ datasets_path=datasets_path, results_path=results_path, workers=workers)
+
+ figure_path_ablation = 'results/figures/ablation'
+ objects_ablation = [
+ ('abc', '00012451_f54bcfcb352445bf90726b58_trimesh_001'),
+ ('abc', '00014221_57e4213b31844b5b95cc62cd_trimesh_000'),
+ ('abc', '00015159_57353d3381fb481182d9bdc6_trimesh_013'),
+ ('abc', '00990546_db31ddca9d3585c330dcce3a_trimesh_000'),
+ ('abc', '00993692_494894597fe7b39310a44a99_trimesh_000'),
+ ]
+ methods_ablation = [
+ 'ppsurf_vanilla_zeros_local',
+ 'ppsurf_vanilla_zeros_global',
+ 'ppsurf_vanilla_sym_max',
+ 'ppsurf_vanilla_qpoints',
+ 'ppsurf_vanilla',
+ 'ppsurf_merge_sum',
+ ]
+ # for ablation study
+ _assemble_figure_data(figure_path=figure_path_ablation, methods=methods_ablation, objects=objects_ablation,
+ datasets_path=datasets_path, results_path=results_path, workers=workers)
+
+ figure_path_real = 'results/figures/real_world'
+ objects_real = [
+ ('real_world', 'madersperger_cropped'),
+ ('real_world', 'statue_ps_outliers2_cropped'),
+ ('real_world', 'torch_ps_outliers2'),
+ ]
+ # for ablation study
+ _assemble_figure_data(figure_path=figure_path_real, methods=methods_comp, objects=objects_real,
+ datasets_path=datasets_path, results_path=results_path, workers=workers)
+
+ figure_path_dataset = 'results/figures/datasets'
+ objects_dataset = [
+ ('abc', '00013052_9084b77631834dd584b2ac93_trimesh_033'),
+ ('abc_noisefree', '00013052_9084b77631834dd584b2ac93_trimesh_033'),
+ ('abc_extra_noisy', '00013052_9084b77631834dd584b2ac93_trimesh_033'),
+ ('famous_dense', 'hand'),
+ ('famous_extra_noisy', 'hand'),
+ ('famous_noisefree', 'hand'),
+ ('famous_original', 'hand'),
+ ('famous_sparse', 'hand'),
+ ('thingi10k_scans_dense', '54725'),
+ ('thingi10k_scans_extra_noisy', '54725'),
+ ('thingi10k_scans_noisefree', '54725'),
+ ('thingi10k_scans_original', '54725'),
+ ('thingi10k_scans_sparse', '54725'),
+ ]
+ # for datasets figure
+ _assemble_figure_data(figure_path=figure_path_dataset, methods=[], objects=objects_dataset,
+ datasets_path=datasets_path, results_path=results_path, workers=workers)
+
+ figure_path_limitations = 'results/figures/limitations'
+ objects_limitations = [
+ ('thingi10k_scans_sparse', '274379'),
+ ]
+ # for limitations figure
+ _assemble_figure_data(figure_path=figure_path_limitations, methods=['ppsurf_merge_sum'], objects=objects_limitations,
+ datasets_path=datasets_path, results_path=results_path, workers=workers)
diff --git a/ppsurf/source/figures/render_meshes_blender.py b/ppsurf/source/figures/render_meshes_blender.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4b10fbca6939510196ab26acfaddaeb2441c7b1
--- /dev/null
+++ b/ppsurf/source/figures/render_meshes_blender.py
@@ -0,0 +1,548 @@
+import os
+import sys
+import bpy
+import site
+
+# # run blender with elevated privileges and uncomment to following lines to install required packages
+# import ensurepip
+# import subprocess
+# ensurepip.bootstrap()
+# os.environ.pop('PIP_REQ_TRACKER', None)
+# from pathlib import Path
+#
+# python_path = next((Path(sys.prefix)/'bin').glob('python*'))
+#
+# subprocess.check_output([python_path, '-m', 'pip', 'install', 'numpy'])
+# subprocess.check_output([python_path, '-m', 'pip', 'install', 'scipy'])
+# subprocess.check_output([python_path, '-m', 'pip', 'install', 'trimesh'])
+# subprocess.check_output([python_path, '-m', 'pip', 'install', 'networkx'])
+
+usersitepackagespath = site.getusersitepackages()
+if os.path.exists(usersitepackagespath) and usersitepackagespath not in sys.path:
+ sys.path.append(usersitepackagespath)
+
+import numpy as np
+import scipy.spatial
+import trimesh
+import bmesh
+
+
+def eval_cmap(vals, cmap_colors):
+ # print((np.clip(vals, a_min=0.0, a_max=1.0)*(cmap_colors.shape[0]-1)).round().astype('int32').max())
+ # print((np.clip(vals, a_min=0.0, a_max=1.0)*(cmap_colors.shape[0]-1)).round().astype('int32').min())
+ # print(vals.min())
+ # print(np.isnan(vals).sum())
+ colors = cmap_colors[(np.clip(vals, a_min=0.0, a_max=1.0) * (cmap_colors.shape[0] - 1)).round().astype('int32'), :]
+ return np.concatenate([colors, np.ones(shape=[colors.shape[0], 1], dtype=colors.dtype)], axis=1) # add alpha
+
+
+def rotation_between_vectors(vec1, vec2):
+ """ Find the rotation matrix that aligns vec1 to vec2
+ :param vec1: A 3d "source" vector
+ :param vec2: A 3d "destination" vector
+ :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
+ """
+ a, b = (vec1 / np.linalg.norm(vec1, axis=1, keepdims=True)), (vec2 / np.linalg.norm(vec2, axis=1, keepdims=True))
+ cos_angle = (a * b).sum(axis=1, keepdims=True)
+ axis = np.cross(a, b)
+ axis = axis / np.linalg.norm(axis, axis=1, keepdims=True)
+ axis[cos_angle[:, 0] < -0.99999, :] = np.array([[1.0, 0.0, 0.0]])
+
+ rot = scipy.spatial.transform.Rotation.from_rotvec(axis * np.arccos(cos_angle))
+ return rot.as_matrix()
+
+
+def copy_animation_data(src_obj, dst_obj):
+ src_ad = src_obj.animation_data
+
+ if dst_obj.animation_data is None:
+ dst_obj.animation_data_create()
+ dst_ad = dst_obj.animation_data
+
+ for src_prop in src_ad.bl_rna.properties:
+ if not src_prop.is_readonly:
+ setattr(dst_ad, src_prop.identifier, getattr(src_ad, src_prop.identifier))
+
+
+def render_meshes(input_dir, output_dir):
+ clear = False
+ fix_wires = True
+ turning_animation = False
+ render_wireframe = False
+
+ # input_dir = '/home/lizeth/Downloads/for rendering/comp/abc/00013052_9084b77631834dd584b2ac93_trimesh_033/'
+ # output_dir = '/home/lizeth/Downloads/for rendering/rendered/abc/00013052_9084b77631834dd584b2ac93_trimesh_033/'
+ # input_dir = '/home/lizeth/Downloads/for rendering/comp/abc/00014452_55263057b8f440a0bb50b260_trimesh_017/'
+ # output_dir = '/home/lizeth/Downloads/for rendering/rendered/abc/00014452_55263057b8f440a0bb50b260_trimesh_017/'
+ # input_dir = '/home/lizeth/Downloads/for rendering/comp/abc/00017014_fbef9df8f24940a0a2df6ccb_trimesh_001/'
+ # output_dir = '/home/lizeth/Downloads/for rendering/rendered/abc/00017014_fbef9df8f24940a0a2df6ccb_trimesh_001/'
+ # input_dir = '/home/lizeth/Downloads/for rendering/comp/abc/00990573_d1914c7f68f9a6b58bed9421_trimesh_000/'
+ # output_dir = '/home/lizeth/Downloads/for rendering/rendered/abc/00990573_d1914c7f68f9a6b58bed9421_trimesh_000/'
+
+
+ scale_vecfield = True
+ vecfield_prefix = None
+ use_boundaries = True
+ boundary_edges_prefix = 'boundary_edges_'
+ boundary_verts_prefix = 'boundary_coordinates_'
+ vcolor_max = 1.5560769862496584
+ vcolor_min = -1.60086702642704
+ vcolor_prefix = 'trig_size_'
+ vcolor_suffix = ''
+ render_wireframe = False
+ # boundary_exclude_prefix = ['gt', 'baseline', 'ours']
+ boundary_exclude_prefix = []
+ wireframe_exclude_prefix = ['gt']
+ vert_colors_exclude_prefix = ['ours']
+ scale_vecfield_exclude_prefix = ['gt']
+ vecfield_exclude_prefix = []
+ recompute_vcolor_range = True # only for steps figure and video
+ recompute_vcolor_range_each_mesh = False
+ cmap = np.load('/home/lizeth/Downloads/ppsurf/ppsurf/figures/blender_script/cmap_YlOrRd.npy')
+
+ # shared (models in both size and curv applications)
+ model_list = [
+ # '75106',
+ # '75660',
+ '75667',
+ # '78251',
+ # '100349',
+ # '100478',
+ # '101865',
+ # '103141',
+ # '116066',
+ # '762604',
+ ]
+
+ method_list = [
+ # 'init',
+ 'ours',
+ # 'gt',
+ # 'baseline'
+ ]
+
+ # # analytic surfaces
+ # model_list = [
+ # # 'catenoid_curvature_trig_size',
+ # # 'catenoid_equal_area_triangles',
+ # 'enneper_equal_area_triangles',
+ # ]
+
+ # method_list = [
+ # 'init3D',
+ # 'opt3D',
+ # ]
+
+ # method_list = [
+ # 'init2D',
+ # 'opt2D',
+ # ]
+
+ steps = [None]
+ # steps = [2, 6, 18, 66, 254, 998] # for triangle size steps figure
+ # steps = [1, 3, 9, 33, 127, 499] # for cuvature steps figure (every two steps)
+ # steps = list(range(1295))
+ # steps = list(range(500))
+ # mesh_color = np.array([255.0, 255, 255, 255])
+ mesh_color = np.array([231.0, 166, 130, 255]) # clay # clay
+ vec_size = 0.025
+
+ # test
+ # axes = 'x'
+ # rot = [90]
+ # y_offset = 0.0
+ # scale = 1.0
+
+ # default
+ # axes = ['x', 'y', 'z']
+ # rot = [90, 0, 0]
+ # y_offset = 0.0
+ # scale = 1.0
+
+ # happy model
+ # axes = ['x', 'y', 'z']
+ # rot = [10, 80, -40]
+ # y_offset = 0.0
+ # scale = 1.0
+
+ import json
+ camera_filename = input_dir + 'camera_params.json'
+
+ def distances_to_vertex_colors(dist_per_vertex: np.ndarray, cut_off=0.3):
+
+ dist_per_vertex[dist_per_vertex > cut_off] = cut_off
+ dist_per_vertex /= cut_off
+
+ # use parula colormap: dist=0 -> blue, dist=0.5 -> green, dist=1.0 -> yellow
+ parulas_indices = (dist_per_vertex * (cmap.shape[0] - 1)).astype(np.int32)
+ dist_greater_than_norm_target = parulas_indices >= cmap.shape[0]
+ parulas_indices[dist_greater_than_norm_target] = cmap.shape[0] - 1
+ dist_colors_rgb = [cmap[parula_indices] for parula_indices in parulas_indices]
+
+ return dist_colors_rgb
+
+ if os.path.exists(camera_filename):
+ with open(camera_filename, 'r') as file:
+ camera_settings = json.load(file)
+ axes = camera_settings['axes']
+ rot = camera_settings['rot']
+ y_offset = camera_settings['y_offset']
+ scale = camera_settings['scale']
+
+ else: # default
+ axes = ['x', 'y', 'z']
+ rot = [0, 0, 0]
+ y_offset = 0.0
+ scale = 1.0
+
+ camera_settings = {}
+ camera_settings['axes'] = axes
+ camera_settings['rot'] = rot
+ camera_settings['y_offset'] = y_offset
+ camera_settings['scale'] = scale
+
+ write_camera_params = False
+ test = True
+ use_vert_colors = True
+ use_vecfield = False
+
+ if write_camera_params:
+ # Save the dictionary as a JSON file
+ with open(camera_filename, 'w') as file:
+ json.dump(camera_settings, file)
+
+ # get mesh list from list of models and list of methods
+ # mesh_names = []
+ # for model_name in model_list:
+ # for method_name in method_list:
+ # mesh_names.append(f'{method_name}_{model_name}')
+ #
+ # # pre-process meshes to get vertex colors
+ # print(f'getting vertex colors ...')
+
+ # if steps[0] is None:
+ # raise RuntimeError('vcolor_max should only be determined once and then set consistently throughout all experiments')
+
+ def get_ply_files(directory, output_path):
+
+ ply_files = []
+ output_files = []
+ for root, dirs, files in os.walk(directory):
+ for filename in files:
+ if filename.endswith('.ply'):
+ file_path = os.path.join(root, filename)
+ ply_files.append(os.path.join(root, filename))
+ output_file = file_path.replace(directory, output_path, 1)
+ output_file = os.path.splitext(output_file)[0] + '.png'
+ output_files.append(output_file)
+ return ply_files, output_files
+
+ mesh_names, output_files = get_ply_files(input_dir, output_dir)
+ # mesh_names = ['/home/lizeth/Downloads/ppsurf/ppsurf/figures/comp mathods/abc/ppsurf.ply']
+ # vcolor_filename = '/home/lizeth/Downloads/for rendering/comp/abc/00010429_fc56088abf10474bba06f659_trimesh_004/ppsurf_merge_sum_dist.npz'
+ # vert_colors = np.load(vcolor_filename)
+
+ # save camera parameters
+ # camera_config = np.array([axes, rot, y_offset, scale])
+
+ if use_vert_colors and recompute_vcolor_range:
+ all_vcolors = []
+ for mesh_name in mesh_names:
+ if not os.path.basename(mesh_name) == 'gt.ply' and not os.path.basename(mesh_name) == 'pc.ply':
+ vcolor_filename = os.path.splitext(mesh_name)[0] + '_dist.npz'
+ vert_color_vals = np.load(vcolor_filename)['distances']
+ all_vcolors.append(vert_color_vals)
+ vcolor_max = np.percentile(np.concatenate(all_vcolors), 95)
+ vcolor_min = np.percentile(np.concatenate(all_vcolors), 5)
+ print(f'vcolor_max: {vcolor_max}')
+ print(f'vcolor_min: {vcolor_min}')
+ # save color map values
+ colormap_min_max = np.array([vcolor_min, vcolor_max])
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ np.savetxt(output_dir + 'vcolor_min_max.txt', colormap_min_max, delimiter=' ')
+
+ scene = bpy.context.scene
+
+ scene.render.engine = 'CYCLES'
+
+ # Set the device_type
+ bpy.context.preferences.addons[
+ 'cycles'
+ ].preferences.compute_device_type = 'CUDA'
+
+ # Set the device and feature set
+ scene.cycles.device = 'GPU'
+
+ # scene.render.tile_x = 256
+ # scene.render.tile_y = 256
+
+ # get_devices() to let Blender detects GPU device
+ bpy.context.preferences.addons['cycles'].preferences.get_devices()
+ print(bpy.context.preferences.addons['cycles'].preferences.compute_device_type)
+ for d in bpy.context.preferences.addons['cycles'].preferences.devices:
+ if d.type == 'CPU':
+ d.use = False
+ else:
+ d.use = True
+ print(d.name, d.use)
+
+ for mesh_ind, mesh_name in enumerate(mesh_names):
+ # pre-process meshes to get bounding box for each step
+ print(f'preprocessing mesh {mesh_name} ...')
+ mesh_bbmin = np.array([np.inf, np.inf, np.inf])
+ mesh_bbmax = np.array([-np.inf, -np.inf, -np.inf])
+ # for step in steps:
+ # sample_name = f'{mesh_name}_{step}' if step is not None else mesh_name
+ # mesh_filename = os.path.join(input_dir, f'{sample_name}.ply')
+ mesh_filename = mesh_name
+ mesh = trimesh.load(mesh_filename, process=False) # disable processing to preserve vertex order
+
+ # align abc var-noise to abc extra noisy
+ if '/abc/' in mesh_name or '\\abc\\' in mesh_name:
+ import trimesh
+ import trimesh.registration
+ ref_mesh_name = mesh_name.replace('/abc/', '/abc_extra_noisy/').replace('\\abc\\', '\\abc_extra_noisy\\')
+ ref_mesh: trimesh.Trimesh = trimesh.load(ref_mesh_name, process=False)
+
+ # get the transformation matrix
+ align_matrix, cost = trimesh.registration.mesh_other(mesh, ref_mesh.vertices, samples=10000)
+ mesh.apply_transform(align_matrix)
+
+ # automatic view parameters for missing
+ if rot == [0, 0, 0]:
+ import trimesh
+ import trimesh.geometry
+ import trimesh.transformations as trafo
+ up = [0, 0, 1]
+ points_pit = mesh.bounding_box_oriented.principal_inertia_transform
+ up_rotated = trimesh.transform_points([up], points_pit)[0]
+ rotate_to_up = trimesh.geometry.align_vectors(up_rotated, up)
+ mesh.apply_transform(rotate_to_up)
+
+ # a little bit of rotation
+ mesh.apply_transform(trafo.rotation_matrix(np.pi/4, [0, 1, 0]))
+
+ if not os.path.basename(mesh_filename) == 'pc.ply':
+ faces = np.array(mesh.faces).astype('int32')
+ verts = np.array(mesh.vertices).astype('float32')
+ mesh_bbmin = np.minimum(mesh_bbmin, verts.min(axis=0))
+ mesh_bbmax = np.maximum(mesh_bbmax, verts.max(axis=0))
+ mesh_bbcenter = 0.5 * (mesh_bbmin + mesh_bbmax)
+ mesh_bbsize = (mesh_bbmax - mesh_bbmin).max()
+
+ for step_ind, step in enumerate(steps):
+ # print(f'[{step_ind + mesh_ind*len(steps) + 1} / {len(mesh_names)*len(steps)}] rendering {sample_name}')
+
+ sample_name = f'{mesh_name}_{step}' if step is not None else mesh_name
+
+ # mesh_filename = os.path.join(input_dir, f'{sample_name}.ply')
+ # mesh_filename = '/home/lizeth/Downloads/ppsurf/ppsurf/figures/comp mathods/abc/ppsurf.ply'
+ # output_filename = os.path.join(output_dir, f'{mesh_name}_{step:05d}{output_suffix}.png' if step is not None else f'{mesh_name}{output_suffix}.png')
+ output_filename = output_files[mesh_ind]
+
+ # remove objects from previous iteration
+ if 'object' in bpy.data.objects:
+ bpy.data.objects.remove(bpy.data.objects['object'], do_unlink=True)
+
+ if 'wireframe' in bpy.data.objects:
+ bpy.data.objects.remove(bpy.data.objects['wireframe'], do_unlink=True)
+
+ if 'attachments' in bpy.data.objects:
+ bpy.data.objects.remove(bpy.data.objects['attachments'], do_unlink=True)
+
+ if 'vecfield' in bpy.data.objects:
+ bpy.data.objects.remove(bpy.data.objects['vecfield'], do_unlink=True)
+
+ if 'boundary' in bpy.data.objects:
+ bpy.data.objects.remove(bpy.data.objects['boundary'], do_unlink=True)
+
+ if 'object' in bpy.data.meshes:
+ bpy.data.meshes.remove(bpy.data.meshes['object'], do_unlink=True)
+
+ if 'wireframe' in bpy.data.meshes:
+ bpy.data.meshes.remove(bpy.data.meshes['wireframe'], do_unlink=True)
+
+ if 'attachments' in bpy.data.meshes:
+ bpy.data.meshes.remove(bpy.data.meshes['attachments'], do_unlink=True)
+
+ if 'vecfield' in bpy.data.meshes:
+ bpy.data.meshes.remove(bpy.data.meshes['vecfield'], do_unlink=True)
+
+ if 'boundary' in bpy.data.meshes:
+ bpy.data.meshes.remove(bpy.data.meshes['boundary'], do_unlink=True)
+
+ if 'vec' in bpy.data.meshes:
+ bpy.data.meshes.remove(bpy.data.meshes['vec'], do_unlink=True)
+
+ if clear:
+ break
+
+ # create vectorfield 'arrow' mesh
+ if use_vecfield and os.path.basename(mesh_filename) == 'pc.ply':
+ vec_mesh = bpy.data.meshes.new('vec')
+ vec_bmesh = bmesh.new()
+ # bmesh.ops.create_cone(vec_bmesh, cap_ends=True, cap_tris=True, segments=5, diameter1=vec_size*0.3, diameter2=vec_size*0.05, depth=vec_size)
+ # bmesh.ops.create_cone(vec_bmesh, cap_ends=True, cap_tris=True, segments=12, diameter1=vec_size * 0.08,
+ # diameter2=vec_size * 0.01, depth=vec_size)
+ bmesh.ops.create_icosphere(vec_bmesh, subdivisions=2, radius=0.005)
+ bmesh.ops.triangulate(vec_bmesh, faces=vec_bmesh.faces[:])
+ # bmesh.ops.create_cone(vec_bmesh)
+ vec_bmesh.to_mesh(vec_mesh)
+ vec_bmesh.free()
+
+ vec_verts = np.array([[v.co.x, v.co.y, v.co.z] for v in vec_mesh.vertices])
+ vec_faces = np.array([[p.vertices[0], p.vertices[1], p.vertices[2]] for p in
+ vec_mesh.polygons]) # vec_mesh.loop_triangles
+ vec_verts[:, 2] -= vec_verts.min(axis=0)[2]
+
+ # load mesh of main object
+ mesh = trimesh.load(mesh_filename, process=False) # disable processing to preserve vertex order
+ if not os.path.basename(mesh_filename) == 'pc.ply':
+ faces = np.array(mesh.faces).astype('int32')
+ verts = np.array(mesh.vertices).astype('float32')
+
+ # move bounding box center to origin and normalize max. bounding box side length to 1
+ # mesh_bbcenter = (verts.max(axis=0)+verts.min(axis=0))/2
+ # mesh_bbsize = (verts.max(axis=0)-verts.min(axis=0)).max()
+ verts = verts - mesh_bbcenter
+ verts = verts / mesh_bbsize
+
+ # to blender coordinates and apply any custom rotation, scaling, and translation
+ coord_rot = np.array([[-1, 0, 0], [0, 0, 1], [0, 1, 0]])
+ for i in range(len(rot)):
+ coord_rot = np.matmul(
+ scipy.spatial.transform.Rotation.from_euler(axes[i], rot[i], degrees=True).as_matrix(), coord_rot)
+ coord_rot = np.matmul(np.array([[scale, 0, 0], [0, scale, 0], [0, 0, scale]]), coord_rot)
+ verts = np.transpose(np.matmul(coord_rot, np.transpose(verts)))
+ y_min = verts.min(axis=0)[2]
+ verts[:, 2] -= y_min # make objects 'stand' on the xz coordinate plane (y_min = 0)
+ verts[:, 2] += y_offset # apply custom translation in y direction
+
+ mesh_scale_vecfield = scale_vecfield and not any(
+ os.path.basename(sample_name).startswith(x) for x in scale_vecfield_exclude_prefix)
+
+ if (use_vert_colors and not os.path.basename(mesh_filename) == 'gt.ply' and not os.path.basename(
+ mesh_filename) == 'pc.ply'):
+ # vcolor_filename = os.path.join(input_dir, os.path.dirname(sample_name), f'{vcolor_prefix}{os.path.basename(sample_name)}{vcolor_suffix}.npy')
+ vcolor_filename = os.path.splitext(mesh_filename)[0] + '_dist.npz'
+ vert_color_vals = np.load(vcolor_filename)['distances']
+ if np.isnan(vert_color_vals).any():
+ print('WARNING: some vertex color values are NaN! Setting them to zero.')
+ vert_color_vals[np.isnan(vert_color_vals)] = 0
+ if recompute_vcolor_range_each_mesh:
+ vcolor_max = np.percentile(vert_color_vals, 95)
+ vcolor_min = np.percentile(vert_color_vals, 5)
+ print(vert_color_vals.min())
+ print(vert_color_vals.max())
+ vert_color_vals = (vert_color_vals - vcolor_min) / (vcolor_max - vcolor_min)
+ vert_colors = eval_cmap(vert_color_vals, cmap_colors=cmap)
+ mix = 1.0
+ vert_colors = vert_colors * mix + (mesh_color / 255.0) * (1 - mix)
+ else:
+ vert_colors = (np.repeat([mesh_color], verts.shape[0], axis=0) / 255).astype('float32').clip(min=0.0,
+ max=1.0)
+
+ if use_vecfield and os.path.basename(mesh_filename) == 'pc.ply':
+ # # get rotated instances of the arrow mesh
+ # rotmats = np.concatenate([
+ # rotation_between_vectors(np.array([[0, 0, 1.0]]), vecfield_dirs),
+ # rotation_between_vectors(np.array([[0, 0, 1.0]]), -vecfield_dirs)], axis=0)
+ # vecfield_verts = np.matmul(np.expand_dims(rotmats, 1), np.expand_dims(vecfield_verts, 3)).squeeze(
+ # -1)
+ # rotation
+ vecfield_verts = (
+ np.expand_dims(vec_verts, axis=0) + np.expand_dims(verts, axis=1)).reshape(-1, 3)
+
+ # translation
+ vecfield_faces = (np.expand_dims(vec_faces, axis=0) + (
+ np.arange(verts.shape[0] * 2) * vec_verts.shape[0]).reshape(-1, 1, 1)).reshape(-1, 3)
+
+ vecfield_verts = vecfield_verts.tolist()
+ vecfield_faces = vecfield_faces.tolist()
+
+ verts = verts.tolist()
+
+ if not os.path.basename(mesh_filename) == 'pc.ply':
+ faces = faces.tolist()
+
+ vert_colors = vert_colors.tolist()
+
+ # create blender mesh for cuboids
+ mesh = bpy.data.meshes.new('object')
+
+ if not os.path.basename(mesh_filename) == 'pc.ply':
+ mesh.from_pydata(verts, [], faces)
+ mesh.validate()
+
+ mesh.vertex_colors.new(name='Col') # named 'Col' by default
+ mesh_vert_colors = mesh.vertex_colors['Col']
+
+ # wireframe_mesh = mesh.copy()
+ # wireframe_mesh.name = 'wireframe'
+
+ for poly in mesh.polygons:
+ for loop_index in poly.loop_indices:
+ loop_vert_index = mesh.loops[loop_index].vertex_index
+ if loop_vert_index < len(vert_colors):
+ mesh.vertex_colors['Col'].data[loop_index].color = vert_colors[loop_vert_index]
+
+ # create blender object for cuboids
+ obj = bpy.data.objects.new('object', mesh)
+ if not os.path.basename(mesh_filename) == 'pc.ply':
+ obj.data.materials.append(bpy.data.materials['sphere_material'])
+ scene.collection.objects.link(obj)
+ if turning_animation:
+ copy_animation_data(src_obj=scene.objects['turntable'], dst_obj=obj)
+
+ if use_vecfield and os.path.basename(mesh_filename) == 'pc.ply':
+ vecfield_mesh = bpy.data.meshes.new('vecfield')
+ vecfield_mesh.from_pydata(vecfield_verts, [], vecfield_faces)
+ vecfield_mesh.validate()
+
+ vecfield_mesh.vertex_colors.new(name='Col') # named 'Col' by default
+ mesh_vert_color = np.array((mesh_color / 255.0).tolist(), dtype=np.float32).clip(min=0.0, max=1.0)
+
+ for poly in vecfield_mesh.polygons:
+ for loop_index in poly.loop_indices:
+ loop_vert_index = vecfield_mesh.loops[loop_index].vertex_index
+ vecfield_mesh.vertex_colors['Col'].data[loop_index].color = mesh_vert_color
+
+ # create blender object for vector field
+ vecfield_obj = bpy.data.objects.new('vecfield', vecfield_mesh)
+ vecfield_obj.data.materials.append(bpy.data.materials['sphere_material'])
+ scene.collection.objects.link(vecfield_obj)
+ if turning_animation:
+ copy_animation_data(src_obj=scene.objects['turntable'], dst_obj=vecfield_obj)
+
+ # render scene
+ scene.render.image_settings.file_format = 'PNG'
+ # print(f'rendering to {scene.render.filepath}')
+
+ if use_vecfield == False and use_vert_colors == True and os.path.basename(mesh_filename) == 'pc.ply':
+ break
+
+ else:
+ scene.render.filepath = output_files[mesh_ind]
+ bpy.ops.render.render(write_still=True)
+
+ if clear or test:
+ break
+
+
+if __name__ == '__main__':
+ input_dir_par = '/home/lizeth/Downloads/for rendering/comp/'
+ output_dir_par = '/home/lizeth/Downloads/for rendering/rendered/'
+ input_dirs_datasets = [os.path.join(input_dir_par, d) for d in os.listdir(input_dir_par) if os.path.isdir(os.path.join(input_dir_par, d))]
+ output_dirs_datasets = [os.path.join(output_dir_par, d) for d in os.listdir(input_dir_par) if os.path.isdir(os.path.join(input_dir_par, d))]
+
+ # input_dir = '/home/lizeth/Downloads/for rendering/comp/abc/00014452_55263057b8f440a0bb50b260_trimesh_017/'
+ # output_dir = '/home/lizeth/Downloads/for rendering/rendered/abc/00014452_55263057b8f440a0bb50b260_trimesh_017/'
+ for input_dir_dataset, output_dir_dataset in zip(input_dirs_datasets, output_dirs_datasets):
+ input_dirs_meshes = [os.path.join(input_dir_dataset, d) for d in os.listdir(input_dir_dataset)
+ if os.path.isdir(os.path.join(input_dir_dataset, d))]
+ output_dirs_meshes = [os.path.join(output_dir_dataset, d) for d in os.listdir(input_dir_dataset)
+ if os.path.isdir(os.path.join(input_dir_dataset, d))]
+ for input_dir, output_dir in zip(input_dirs_meshes, output_dirs_meshes):
+ print('Rendering meshes in {} to {}'.format(input_dir, output_dir))
+ render_meshes(input_dir+'/', output_dir+'/')
diff --git a/ppsurf/source/make_comparison.py b/ppsurf/source/make_comparison.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5ffbb1f3bb59c05b2c7753aad0b05a77a8e7321
--- /dev/null
+++ b/ppsurf/source/make_comparison.py
@@ -0,0 +1,118 @@
+import os
+import argparse
+import sys
+
+import source.base.visualization
+import source.occupancy_data_module
+
+sys.path.append(os.path.abspath('.'))
+
+debug = False
+
+
+def parse_arguments(args=None):
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--comp_name', type=str, default='abc_minimal', help='comp name')
+
+ parser.add_argument('--comp_dir', type=str, default='results/comp', help='folder for comparisons')
+ parser.add_argument('--data_dir', type=str, default='datasets/abc_minimal/03_meshes',
+ help='input folder (meshes)')
+ parser.add_argument('--testset', type=str, default='datasets/abc_minimal/testset.txt',
+ help='test set file name')
+ parser.add_argument('--results_dir', type=str, default='results',
+ help='output folder (reconstructions)')
+ parser.add_argument('--result_headers', type=str, nargs='+', default=[],
+ help='list of strings for comparison (human readable table headers)')
+ parser.add_argument('--result_paths', type=str, nargs='+', default=[],
+ help='list of strings for comparison (result path templates)')
+ parser.add_argument('--comp_mean_name', type=str, default='comp_mean',
+ help='file name for dataset means')
+ parser.add_argument('--html_name', type=str, default='comp_html',
+ help='file name for dataset means')
+
+ parser.add_argument('--workers', type=int, default=8,
+ help='number of data loading workers - 0 means same thread as main execution')
+
+ parser.add_argument('--dist_cut_off', type=float, default=0.05,
+ help='cutoff for color-coded distance visualization')
+
+ return parser.parse_args(args=args)
+
+
+def comparison_rec_mesh_template(args):
+ from source.base import evaluation
+ from itertools import chain
+ from source.occupancy_data_module import read_shape_list
+
+ comp_dir = os.path.join(args.comp_dir, args.comp_name)
+ os.makedirs(comp_dir, exist_ok=True)
+
+ shape_names = read_shape_list(os.path.join(args.data_dir, args.testset))
+ gt_meshes = [os.path.join(args.data_dir, '03_meshes', '{}.ply'.format(vs)) for vs in shape_names]
+
+ # quantitative comparison
+ report_path_templates = [os.path.join(r, '{}.xlsx') for r in args.result_paths]
+ results_per_shape_dict = evaluation.assemble_quantitative_comparison(
+ comp_output_dir=comp_dir, report_path_templates=report_path_templates)
+ cd_results = results_per_shape_dict['chamfer_distance'].transpose().tolist()
+ iou_results = results_per_shape_dict['iou'].transpose().tolist()
+ nc_results = results_per_shape_dict['normal_error'].transpose().tolist()
+
+ # assemble dataset means
+ # def _get_all_reports(results_dir, results_report_template):
+ # from pathlib import Path
+ # report_files = list(Path(results_dir).rglob(results_report_template))
+ # return report_files
+ test_report_comp = os.path.join(comp_dir, '{}.xlsx'.format(args.comp_mean_name))
+ cd_report_path = [os.path.join(r, 'chamfer_distance.xlsx') for r in args.result_paths]
+ f1_report_path = [os.path.join(r, 'f1.xlsx') for r in args.result_paths]
+ iou_report_path = [os.path.join(r, 'iou.xlsx') for r in args.result_paths]
+ nc_report_path = [os.path.join(r, 'normal_error.xlsx') for r in args.result_paths]
+ report_path_templates = [(cd_report_path[i], iou_report_path[i], f1_report_path[i], nc_report_path[i])
+ for i in range(len(args.result_paths))]
+ evaluation.make_dataset_comparison(results_reports=report_path_templates, output_file=test_report_comp)
+
+ # visualize chamfer distance as vertex colors
+ gt_meshes_bc = [gt_meshes] * len(args.result_paths)
+ gt_meshes_bc_flat = list(chain.from_iterable(gt_meshes_bc))
+ cd_meshes_out = [[os.path.join(comp_dir, res, 'mesh_cd_vis', '{}.ply'.format(s))
+ for s in shape_names] for res in args.result_headers]
+ cd_meshes_out_flat = list(chain.from_iterable(cd_meshes_out))
+ rec_paths = [os.path.join(res, 'meshes/{}.xyz.ply') for res in args.result_paths]
+ rec_meshes = [[res.format(s) for s in shape_names] for res in rec_paths]
+ rec_meshes = [[s if os.path.isfile(s) else s[:-4] + '.obj' for s in res] for res in rec_meshes] # if no ply, try obj
+ rec_meshes_flat = list(chain.from_iterable(rec_meshes))
+ source.base.visualization.visualize_chamfer_distance_pool(
+ rec_meshes=rec_meshes_flat, gt_meshes=gt_meshes_bc_flat, output_mesh_files=cd_meshes_out_flat,
+ min_vertex_count=10000, dist_cut_off=args.dist_cut_off, distance_batch_size=1000, num_processes=args.workers)
+
+ # render meshes
+ gt_renders_out = [os.path.join(comp_dir, 'mesh_gt_rend', '{}.png'.format(vs)) for vs in shape_names]
+ rec_renders_out = [[os.path.join(comp_dir, res, 'mesh_rend', '{}.png'.format(s))
+ for s in shape_names] for res in args.result_headers]
+ cd_vis_renders_out = [[os.path.join(comp_dir, res, 'cd_vis_rend', '{}.png'.format(s))
+ for s in shape_names] for res in args.result_headers]
+ cd_vis_renders_out_flat = list(chain.from_iterable(cd_vis_renders_out))
+ rec_renders_flat = list(chain.from_iterable(rec_renders_out))
+ pc = [os.path.join(args.data_dir, '04_pts_vis', '{}.xyz.ply'.format(vs)) for vs in shape_names]
+ pc_renders_out = [os.path.join(comp_dir, 'pc_rend', '{}.png'.format(vs)) for vs in shape_names]
+ all_meshes_in = rec_meshes_flat + gt_meshes + cd_meshes_out_flat + pc
+ all_renders_out = rec_renders_flat + gt_renders_out + cd_vis_renders_out_flat + pc_renders_out
+ source.base.visualization.render_meshes(all_meshes_in, all_renders_out, workers=args.workers)
+
+ # qualitative comparison as a HTML table
+ report_file_out = os.path.join(comp_dir, args.html_name + '.html')
+ evaluation.make_html_report(report_file_out=report_file_out, comp_name=args.comp_name,
+ pc_renders=pc_renders_out, gt_renders=gt_renders_out,
+ cd_vis_renders=cd_vis_renders_out, dist_cut_off=args.dist_cut_off,
+ metrics_cd=cd_results, metrics_iou=iou_results, metrics_nc=nc_results)
+
+
+def main(argv=None):
+ args = parse_arguments(argv)
+ comparison_rec_mesh_template(args=args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/ppsurf/source/make_evaluation.py b/ppsurf/source/make_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..05b3ae4ab736911e153ebef9210fe07e73dfa447
--- /dev/null
+++ b/ppsurf/source/make_evaluation.py
@@ -0,0 +1,103 @@
+import os
+import argparse
+import sys
+
+sys.path.append(os.path.abspath('.'))
+
+
+def parse_arguments(args=None):
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--name', type=str, default='ppsurf', help='name')
+
+ parser.add_argument('--workers', type=int, default=8,
+ help='number of data loading workers - 0 means same thread as main execution')
+
+ parser.add_argument('--results_dir', type=str, default='results',
+ help='output folder (reconstructions)')
+ parser.add_argument('--data_dir', type=str, default='datasets/abc_minimal/03_meshes',
+ help='input folder (meshes)')
+ parser.add_argument('--testset', type=str, default='datasets/abc_minimal/testset.txt',
+ help='test set file name')
+
+ parser.add_argument('--num_samples', type=int, default=10000,
+ help='number of samples for metrics')
+
+ return parser.parse_args(args=args)
+
+
+def make_evaluation(args):
+ from source.base import evaluation
+ from source.occupancy_data_module import read_shape_list
+
+ model_results_rec_dir = os.path.join(args.results_dir, args.name, os.path.basename(args.data_dir))
+ shape_names = read_shape_list(os.path.join(args.data_dir, args.testset))
+ gt_meshes_dir = os.path.join(args.data_dir, '03_meshes')
+ if not os.path.exists(gt_meshes_dir):
+ print('Warning: {} not found. Skipping evaluation.'.format(gt_meshes_dir))
+ else:
+ gt_meshes = [os.path.join(gt_meshes_dir, '{}.ply'.format(vs)) for vs in shape_names]
+ os.makedirs(model_results_rec_dir, exist_ok=True)
+ result_headers = [args.name]
+ result_file_templates = [os.path.join(model_results_rec_dir, 'meshes/{}.xyz.ply')]
+ _ = evaluation.make_quantitative_comparison(
+ shape_names=shape_names, gt_mesh_files=gt_meshes,
+ result_headers=result_headers, result_file_templates=result_file_templates,
+ comp_output_dir=model_results_rec_dir, num_processes=args.workers, num_samples=args.num_samples)
+
+
+def main(argv=None):
+ args = parse_arguments(argv)
+ make_evaluation(args=args)
+
+
+if __name__ == '__main__':
+ # main()
+
+ # test
+ model_names = [
+ 'pgr',
+ 'neural_imls',
+ 'sap_optim',
+ 'sap',
+ 'p2s',
+ 'poco Pts_gen_sub3k_iter10',
+ 'ppsurf_qpoints',
+ 'ppsurf_merge_sum',
+ 'ppsurf_vanilla_zeros_local',
+ 'ppsurf_vanilla_zeros_global',
+ 'ppsurf_10nn',
+ 'ppsurf_25nn',
+ 'ppsurf_50nn',
+ 'ppsurf_100nn',
+ 'ppsurf_200nn',
+ ]
+ dataset_names = [
+ 'abc',
+ 'abc_extra_noisy',
+ 'abc_noisefree',
+ # 'real_world',
+ 'famous_original',
+ 'famous_noisefree',
+ 'famous_sparse',
+ 'famous_dense',
+ 'famous_extra_noisy',
+ 'thingi10k_scans_original',
+ 'thingi10k_scans_noisefree',
+ 'thingi10k_scans_sparse',
+ 'thingi10k_scans_dense',
+ 'thingi10k_scans_extra_noisy'
+ ]
+ for dataset_name in dataset_names:
+ for model_name in model_names:
+ print('Evaluating {} on {}'.format(model_name, dataset_name))
+ params = [
+ '--name', model_name,
+ '--workers', '15',
+ # '--workers', '0',
+ '--results_dir', 'results',
+ '--data_dir', 'datasets/{}'.format(dataset_name),
+ '--testset', 'testset.txt',
+ '--num_samples', '100000',
+ ]
+ main(argv=params)
diff --git a/ppsurf/source/occupancy_data_module.py b/ppsurf/source/occupancy_data_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..244eb5905f94c5a731cd8b1f33d9abdbf2141194
--- /dev/null
+++ b/ppsurf/source/occupancy_data_module.py
@@ -0,0 +1,253 @@
+import os
+import typing
+from abc import ABC, abstractmethod
+
+import torch.utils.data as data
+import numpy as np
+
+from pytorch_lightning import LightningDataModule
+
+import source.base
+import source.base.math
+from source.base.proximity import make_kdtree
+
+
+# Adapted from POCO: https://github.com/valeoai/POCO
+# which is published under Apache 2.0: https://github.com/valeoai/POCO/blob/main/LICENSE
+
+
+def in_file_is_dataset(in_file: str):
+ return os.path.splitext(in_file)[1].lower() == '.txt'
+
+
+def get_dataset_dir(in_file: str):
+ dataset_dir = os.path.dirname(in_file)
+ return dataset_dir
+
+
+def get_dataset_name(in_file: str):
+ dataset_dir = get_dataset_dir(in_file)
+ dataset_name = os.path.basename(dataset_dir)
+ return dataset_name
+
+
+def get_meshes_dir(in_file: str):
+ dataset_dir = get_dataset_dir(in_file)
+ meshes_dir = os.path.join(dataset_dir, '03_meshes')
+ return meshes_dir
+
+
+def get_pc_dir(in_file: str):
+ dataset_dir = get_dataset_dir(in_file)
+ pc_dir = os.path.join(dataset_dir, '04_pts_vis')
+ return pc_dir
+
+
+def get_pc_file(in_file, shape_name):
+ if in_file_is_dataset(in_file):
+ dataset_dir = get_dataset_dir(in_file)
+ pc_file = os.path.join(dataset_dir, '04_pts_vis', shape_name + '.xyz.ply')
+ return pc_file
+ else:
+ return in_file
+
+
+def get_training_data_dir(in_file: str):
+ dataset_dir = get_dataset_dir(in_file)
+ query_pts_dir = os.path.join(dataset_dir, '05_query_pts')
+ query_dist_dir = os.path.join(dataset_dir, '05_query_dist')
+ return query_pts_dir, query_dist_dir
+
+
+def get_set_files(in_file: str):
+ if in_file_is_dataset(in_file):
+ train_set = os.path.join(os.path.dirname(in_file), 'trainset.txt')
+ val_set = os.path.join(os.path.dirname(in_file), 'valset.txt')
+ test_set = os.path.join(os.path.dirname(in_file), 'testset.txt')
+ else:
+ train_set = in_file
+ val_set = in_file
+ test_set = in_file
+ return train_set, val_set, test_set
+
+
+def get_results_dir(out_dir: str, name: str, in_file: str):
+ dataset_name = get_dataset_name(in_file)
+ model_results_rec_dir = os.path.join(out_dir, name, dataset_name)
+ return model_results_rec_dir
+
+
+def read_shape_list(shape_list_file: str):
+ with open(shape_list_file) as f:
+ shape_names = f.readlines()
+ shape_names = [x.strip() for x in shape_names]
+ shape_names = list(filter(None, shape_names))
+ return shape_names
+
+
+class OccupancyDataModule(LightningDataModule, ABC):
+
+ def __init__(self, use_ddp, workers, in_file, patches_per_shape: typing.Optional[int],
+ do_data_augmentation: bool, batch_size: int):
+ super(OccupancyDataModule, self).__init__()
+ self.use_ddp = use_ddp
+ self.workers = workers
+ self.in_file = in_file
+ self.trainset, self.valset, self.testset = get_set_files(in_file)
+ self.patches_per_shape = patches_per_shape
+ self.do_data_augmentation = do_data_augmentation
+ self.batch_size = batch_size
+
+ @staticmethod
+ def seed_train_worker(worker_id):
+ import random
+ import torch
+ worker_seed = torch.initial_seed() % 2 ** 32 + worker_id
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+
+ @abstractmethod
+ def make_dataset(
+ self, in_file: typing.Union[str, list], reconstruction: bool, patches_per_shape: typing.Optional[int],
+ do_data_augmentation: bool):
+ pass
+
+ def make_datasampler(self, dataset, shuffle=False):
+ from torch.cuda import device_count
+ if bool(self.use_ddp) and device_count() > 1:
+ from torch.utils.data.distributed import DistributedSampler
+ datasampler = DistributedSampler(
+ dataset, num_replicas=None, rank=None,
+ shuffle=shuffle, seed=0, drop_last=False)
+ else:
+ datasampler = None
+ return datasampler
+
+ def make_dataloader(self, dataset, data_sampler, batch_size: int, shuffle: bool = False):
+
+ dataloader = data.DataLoader(
+ dataset,
+ sampler=data_sampler,
+ batch_size=batch_size,
+ num_workers=int(self.workers),
+ persistent_workers=True if int(self.workers) > 0 else False,
+ pin_memory=True,
+ worker_init_fn=OccupancyDataModule.seed_train_worker,
+ shuffle=shuffle)
+ return dataloader
+
+ def train_dataloader(self):
+ dataset = self.make_dataset(in_file=self.trainset, reconstruction=False,
+ patches_per_shape=self.patches_per_shape,
+ do_data_augmentation=self.do_data_augmentation)
+ data_sampler = self.make_datasampler(dataset=dataset, shuffle=True)
+ dataloader = self.make_dataloader(dataset=dataset, data_sampler=data_sampler,
+ batch_size=self.batch_size, shuffle=data_sampler is None)
+ return dataloader
+
+ def val_dataloader(self):
+ dataset = self.make_dataset(in_file=self.valset, reconstruction=False,
+ patches_per_shape=self.patches_per_shape, do_data_augmentation=False)
+ data_sampler = self.make_datasampler(dataset=dataset, shuffle=False)
+ dataloader = self.make_dataloader(dataset=dataset, data_sampler=data_sampler,
+ batch_size=self.batch_size)
+ return dataloader
+
+ def test_dataloader(self):
+ batch_size = 1
+ dataset = self.make_dataset(in_file=self.testset, reconstruction=False,
+ patches_per_shape=None, do_data_augmentation=False)
+ data_sampler = None
+ dataloader = self.make_dataloader(dataset=dataset, data_sampler=data_sampler,
+ batch_size=batch_size)
+ return dataloader
+
+ def predict_dataloader(self):
+ batch_size = 1
+ dataset = self.make_dataset(in_file=self.testset, reconstruction=True,
+ patches_per_shape=None, do_data_augmentation=False)
+ data_sampler = None
+ dataloader = self.make_dataloader(dataset=dataset, data_sampler=data_sampler,
+ batch_size=batch_size)
+ return dataloader
+
+ @staticmethod
+ def load_pts(pts_file: str):
+ # Supported file formats are:
+ # - PLY, STL, OBJ and other mesh files loaded by [trimesh](https://github.com/mikedh/trimesh).
+ # - XYZ as whitespace-separated text file, read by [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.loadtxt.html).
+ # Load first 3 columns as XYZ coordinates. All other columns will be ignored.
+ # - NPY and NPZ, read by [NumPy](https://numpy.org/doc/stable/reference/generated/numpy.load.html).
+ # NPZ assumes default key='arr_0'. All columns after the first 3 columns will be ignored.
+ # - LAS and LAZ (version 1.0-1.4), COPC and CRS loaded by [Laspy](https://github.com/laspy/laspy).
+ # You may want to sub-sample large point clouds to ~250k points to avoid speed and memory issues.
+ # For detailed reconstruction, you'll need to extract parts of large point clouds.
+
+ import os
+
+ file_name, file_ext = os.path.splitext(pts_file)
+ file_ext = file_ext.lower()
+ if file_ext == '.npy':
+ pts = np.load(pts_file)
+ elif file_ext == '.npy':
+ arrs = np.load(pts_file)
+ pts = arrs['arr_0']
+ elif file_ext == '.xyz':
+ from source.base.point_cloud import load_xyz
+ pts = load_xyz(pts_file)
+ elif file_ext in ['.stl', '.ply', '.obj', 'gltf', '.glb', '.dae', '.off', '.ctm', '.3dxml']:
+ import trimesh
+ trimesh_obj: typing.Union[trimesh.Scene, trimesh.Trimesh] = trimesh.load_mesh(file_obj=pts_file)
+ if isinstance(trimesh_obj, trimesh.Scene):
+ mesh: trimesh.Trimesh = trimesh_obj.geometry.items()[0]
+ elif isinstance(trimesh_obj, trimesh.Trimesh):
+ mesh: trimesh.Trimesh = trimesh_obj
+ elif isinstance(trimesh_obj, trimesh.PointCloud):
+ mesh: trimesh.Trimesh = trimesh_obj
+ else:
+ raise ValueError('Unknown trimesh object type: {}'.format(type(trimesh_obj)))
+ pts = np.array(mesh.vertices)
+ elif file_ext in ['.las', '.laz', '.copc', '.crs']:
+ import laspy
+ las = laspy.read(pts_file)
+ pts = las.xyz
+ else:
+ raise ValueError('Unknown point cloud type: {}'.format(pts_file))
+ return pts
+
+ @staticmethod
+ def pre_process_pts(pts: np.ndarray):
+ if pts.shape[1] > 3:
+ normals = source.base.math.normalize_vectors(pts[:, 3:6])
+ pts = pts[:, 0:3]
+ else:
+ normals = np.zeros_like(pts)
+ return pts, normals
+
+ @staticmethod
+ def load_shape_data_pc(in_file, padding_factor, shape_name: str, normalize=False, return_kdtree=True):
+ from source.base import container
+
+ pts_file = get_pc_file(in_file, shape_name)
+ pts_np = OccupancyDataModule.load_pts(pts_file=pts_file)
+ pts_np, normals_np = OccupancyDataModule.pre_process_pts(pts=pts_np)
+
+ if normalize:
+ bb_center, scale = source.base.math.get_points_normalization_info(
+ pts=pts_np, padding_factor=padding_factor)
+ pts_np = source.base.math.normalize_points_with_info(pts=pts_np, bb_center=bb_center, scale=scale)
+
+ # convert only after normalization
+ if pts_np.dtype != np.float32:
+ pts_np = pts_np.astype(np.float32)
+
+ # debug output
+ from source.base.point_cloud import write_ply
+ write_ply('debug/pts_ms.ply', pts_np, normals_np)
+
+ shape_data = {'pts_ms': pts_np, 'normals_ms': normals_np, 'pc_file_in': pts_file}
+ if return_kdtree:
+ kdtree = make_kdtree(pts_np)
+ shape_data['kdtree'] = kdtree
+
+ return shape_data
diff --git a/ppsurf/source/poco_data_loader.py b/ppsurf/source/poco_data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..01503bc32ef7e80777f93fed4c0aaf961b49ede0
--- /dev/null
+++ b/ppsurf/source/poco_data_loader.py
@@ -0,0 +1,412 @@
+import os
+import os.path
+import typing
+
+# avoid heavy imports at the top level, so they don't get imported in the workers
+if typing.TYPE_CHECKING:
+ import torch
+
+import torch.utils.data as torch_data
+import numpy as np
+import trimesh
+from overrides import EnforceOverrides
+
+from source.occupancy_data_module import OccupancyDataModule, get_training_data_dir
+from source.base.container import dict_np_to_torch
+
+
+# Adapted from POCO: https://github.com/valeoai/POCO
+# which is published under Apache 2.0: https://github.com/valeoai/POCO/blob/main/LICENSE
+
+
+class PocoDataModule(OccupancyDataModule):
+
+ def __init__(self, in_file, workers, use_ddp, padding_factor, seed, manifold_points,
+ patches_per_shape: typing.Optional[int], do_data_augmentation: bool, batch_size: int):
+ super(PocoDataModule, self).__init__(
+ use_ddp=use_ddp, workers=workers, in_file=in_file, patches_per_shape=patches_per_shape,
+ do_data_augmentation=do_data_augmentation, batch_size=batch_size)
+ self.in_file = in_file
+ self.padding_factor = padding_factor
+ self.seed = seed
+ self.manifold_points = manifold_points
+ self.patches_per_shape = patches_per_shape
+ self.do_data_augmentation = do_data_augmentation
+
+ def make_dataset(
+ self, in_file: typing.Union[str, list], reconstruction: bool, patches_per_shape: typing.Optional[int],
+ do_data_augmentation: bool):
+ if reconstruction:
+ dataset = PocoReconstructionDataset(
+ in_file=in_file,
+ padding_factor=self.padding_factor,
+ seed=self.seed,
+ use_ddp=self.use_ddp,
+ )
+ else:
+ dataset = PocoDataset(
+ in_file=in_file,
+ padding_factor=self.padding_factor,
+ seed=self.seed,
+ patches_per_shape=self.patches_per_shape,
+ do_data_augmentation=do_data_augmentation,
+ use_ddp=self.use_ddp,
+ manifold_points=self.manifold_points,
+ )
+ return dataset
+
+
+def sampling_quantized(pts_batch, ratio=None, n_support=None, support_points=None, support_points_ids=None):
+ # TODO: try without importing torch
+ import math
+ import torch
+ from torch_geometric.transforms import RandomRotate
+ from torch_geometric.data import Data
+ from torch_geometric.nn import voxel_grid
+ from torch_geometric.nn.pool.consecutive import consecutive_cluster
+ from source.base.nn import batch_gather
+
+ if support_points is not None:
+ return support_points, support_points_ids
+
+ assert ((ratio is None) != (n_support is None))
+
+ if ratio is not None:
+ support_point_number = max(1, int(pts_batch.shape[2] * ratio))
+ else:
+ support_point_number = n_support
+
+ if support_point_number == pts_batch.shape[2]:
+ support_points_ids = torch.arange(pts_batch.shape[2], dtype=torch.long, device=pts_batch.device)
+ support_points_ids = support_points_ids.unsqueeze(0).expand(pts_batch.shape[0], pts_batch.shape[2])
+ return pts_batch, support_points_ids
+ elif 0 < support_point_number < pts_batch.shape[2]:
+
+ # voxel_size
+ maxi, _ = torch.max(pts_batch, dim=2)
+ mini, _ = torch.min(pts_batch, dim=2)
+ vox_size = (maxi - mini).norm(2, dim=1) / math.sqrt(support_point_number)
+
+ rot_x = RandomRotate(180, axis=0)
+ rot_y = RandomRotate(180, axis=1)
+ rot_z = RandomRotate(180, axis=2)
+
+ support_points_ids = []
+ for i in range(pts_batch.shape[0]):
+ pts = pts_batch[i].clone().transpose(0, 1)
+ ids = torch.arange(pts.shape[0], device=pts.device)
+ sampled_count = 0
+ sampled = []
+ vox = vox_size[i]
+ while True:
+ # TODO: optimize to one call to one linear transformation
+ pts_rot = rot_z(rot_y(rot_x(Data(pos=pts)))).pos.to(pts.dtype)
+
+ batch = torch.zeros(pts_rot.shape[0], device=pts.device, dtype=pts.dtype)
+ c = voxel_grid(pts_rot, batch=batch, size=vox)
+ _, perm = consecutive_cluster(c)
+
+ if sampled_count + perm.shape[0] < support_point_number:
+ sampled.append(ids[perm])
+ sampled_count += perm.shape[0]
+
+ tmp = torch.ones_like(ids)
+ tmp[perm] = 0
+ tmp = (tmp > 0)
+ pts = pts[tmp]
+ ids = ids[tmp]
+ vox = vox / 2
+ # pts = pts[perm]
+ # ids = ids[perm]
+ else:
+ n_to_select = support_point_number - sampled_count
+ perm = perm[torch.randperm(perm.shape[0])[:n_to_select]]
+ sampled.append(ids[perm])
+ break
+ sampled = torch.cat(sampled)
+ support_points_ids.append(sampled)
+
+ support_points_ids = torch.stack(support_points_ids, dim=0)
+ support_points_ids = support_points_ids.to(pts_batch.device)
+ support_points = batch_gather(pts_batch, dim=2, index=support_points_ids)
+ return support_points, support_points_ids
+ else:
+ raise ValueError(f'Search Quantized - ratio value error {ratio} should be in ]0,1]')
+
+
+def get_fkaconv_ids(data: typing.Dict[str, 'torch.Tensor'], segmentation: bool = True) \
+ -> typing.Dict[str, 'torch.Tensor']:
+ from source.poco_utils import knn
+
+ pts = data['pts'].clone()
+
+ add_batch_dimension = False
+ if len(pts.shape) == 2:
+ pts = pts.unsqueeze(0)
+ add_batch_dimension = True
+
+ support1, _ = sampling_quantized(pts, 0.25)
+ support2, _ = sampling_quantized(support1, 0.25)
+ support3, _ = sampling_quantized(support2, 0.25)
+ support4, _ = sampling_quantized(support3, 0.25)
+
+ # compute the ids
+ ret_data = {}
+ ids00 = knn(pts, pts, 16)
+ ids01 = knn(pts, support1, 16)
+ ids11 = knn(support1, support1, 16)
+ ids12 = knn(support1, support2, 16)
+ ids22 = knn(support2, support2, 16)
+ ids23 = knn(support2, support3, 16)
+ ids33 = knn(support3, support3, 16)
+ ids34 = knn(support3, support4, 16)
+ ids44 = knn(support4, support4, 16)
+ if segmentation:
+ ids43 = knn(support4, support3, 1)
+ ids32 = knn(support3, support2, 1)
+ ids21 = knn(support2, support1, 1)
+ ids10 = knn(support1, pts, 1)
+ if add_batch_dimension:
+ ids43 = ids43.squeeze(0)
+ ids32 = ids32.squeeze(0)
+ ids21 = ids21.squeeze(0)
+ ids10 = ids10.squeeze(0)
+
+ ret_data['ids43'] = ids43
+ ret_data['ids32'] = ids32
+ ret_data['ids21'] = ids21
+ ret_data['ids10'] = ids10
+
+ if add_batch_dimension:
+ support1 = support1.squeeze(0)
+ support2 = support2.squeeze(0)
+ support3 = support3.squeeze(0)
+ support4 = support4.squeeze(0)
+ ids00 = ids00.squeeze(0)
+ ids01 = ids01.squeeze(0)
+ ids11 = ids11.squeeze(0)
+ ids12 = ids12.squeeze(0)
+ ids22 = ids22.squeeze(0)
+ ids23 = ids23.squeeze(0)
+ ids33 = ids33.squeeze(0)
+ ids34 = ids34.squeeze(0)
+ ids44 = ids44.squeeze(0)
+
+ ret_data['support1'] = support1
+ ret_data['support2'] = support2
+ ret_data['support3'] = support3
+ ret_data['support4'] = support4
+
+ ret_data['ids00'] = ids00
+ ret_data['ids01'] = ids01
+ ret_data['ids11'] = ids11
+ ret_data['ids12'] = ids12
+ ret_data['ids22'] = ids22
+ ret_data['ids23'] = ids23
+ ret_data['ids33'] = ids33
+ ret_data['ids34'] = ids34
+ ret_data['ids44'] = ids44
+ return ret_data
+
+
+def get_proj_ids(data: typing.Dict[str, 'torch.Tensor'], k: int) -> typing.Dict[str, 'torch.Tensor']:
+ from source.poco_utils import knn
+
+ pts = data['pts']
+ pts_query = data['pts_query']
+
+ add_batch_dimension_pos = False
+ if len(pts.shape) == 2:
+ pts = pts.unsqueeze(0)
+ add_batch_dimension_pos = True
+
+ add_batch_dimension_non_manifold = False
+ if len(pts_query.shape) == 2:
+ pts_query = pts_query.unsqueeze(0)
+ add_batch_dimension_non_manifold = True
+
+ if pts.shape[1] != 3:
+ pts = pts.transpose(1, 2)
+
+ if pts_query.shape[1] != 3:
+ pts_query = pts_query.transpose(1, 2)
+
+ indices = knn(pts, pts_query, k, -1)
+
+ if add_batch_dimension_non_manifold or add_batch_dimension_pos:
+ indices = indices.squeeze(0)
+
+ ret_data = {'proj_ids': indices}
+ return ret_data
+
+
+def get_data_poco(batch_data: dict):
+ import torch
+
+ fkaconv_data = {
+ 'pts': torch.transpose(batch_data['pts_ms'], -1, -2),
+ 'pts_query': torch.transpose(batch_data['pts_query_ms'], -1, -2),
+ }
+
+ if 'imp_surf_dist_ms' in batch_data.keys():
+ occ_sign = torch.sign(batch_data['imp_surf_dist_ms'])
+ occ = torch.zeros_like(occ_sign, dtype=torch.int64)
+ occ[occ_sign > 0.0] = 1
+ fkaconv_data['occ'] = occ
+ else:
+ fkaconv_data['occ'] = torch.zeros(fkaconv_data['pts_query'].shape[:1])
+
+ with torch.no_grad():
+ net_data = get_fkaconv_ids(fkaconv_data)
+ proj_data = get_proj_ids(fkaconv_data, k=64) # TODO: put k in param
+ net_data['proj_ids'] = proj_data['proj_ids']
+
+ # need points also for poco ids
+ for k in fkaconv_data.keys():
+ batch_data[k] = fkaconv_data[k]
+ for k in net_data.keys():
+ batch_data[k] = net_data[k]
+
+ return batch_data
+
+
+class PocoDataset(torch_data.Dataset, EnforceOverrides):
+
+ def __init__(self, in_file: str, padding_factor: float, seed, use_ddp: bool, manifold_points: typing.Optional[int],
+ patches_per_shape: typing.Optional[int], do_data_augmentation=True):
+
+ super(PocoDataset, self).__init__()
+
+ self.in_file = in_file
+ self.seed = seed
+ self.patches_per_shape = patches_per_shape
+ self.do_data_augmentation = do_data_augmentation
+ self.padding_factor = padding_factor
+ self.use_ddp = use_ddp
+ self.manifold_points = manifold_points
+
+ # initialize rng for picking points in the local subsample of a patch
+ if self.seed is None:
+ self.seed = np.random.random_integers(0, 2 ** 32 - 1, 1)[0]
+
+ from torch.cuda import device_count
+ if bool(self.use_ddp) and device_count() > 1:
+ import torch.distributed as dist
+ if not dist.is_available():
+ raise RuntimeError('Requires distributed package to be available')
+ rank = dist.get_rank()
+ self.seed += rank
+ self.rng = np.random.RandomState(self.seed)
+
+ # get all shape names in the dataset
+ if isinstance(self.in_file, str):
+ # assume .txt files contain a list of shapes
+ if os.path.splitext(self.in_file)[1].lower() == '.txt':
+ self.shape_names = []
+ with open(os.path.join(in_file)) as f:
+ self.shape_names = f.readlines()
+ self.shape_names = [x.strip() for x in self.shape_names]
+ self.shape_names = list(filter(None, self.shape_names))
+ else: # all other single files are just one shape to be reconstructed
+ self.shape_names = [self.in_file]
+ else:
+ raise NotImplementedError()
+
+ def __len__(self):
+ return len(self.shape_names)
+
+ def augment_shape(self, shape_data: dict, rand_rot: np.ndarray) -> dict:
+ import trimesh.transformations as trafo
+
+ def rot_arr(arr, rot):
+ return trafo.transform_points(arr, rot).astype(np.float32)
+
+ shape_data['pts_ms'] = rot_arr(shape_data['pts_ms'], rand_rot)
+ shape_data['normals_ms'] = rot_arr(shape_data['normals_ms'], rand_rot)
+ shape_data['pts_query_ms'] = rot_arr(shape_data['pts_query_ms'], rand_rot)
+ return shape_data
+
+ # returns a patch centered at the point with the given global index
+ # and the ground truth normal the patch center
+ def __getitem__(self, shape_id):
+ shape_data, pts_ms_raw = self.load_shape_by_index(shape_id, return_kdtree=False)
+
+ if self.do_data_augmentation:
+ # self.rng.seed(42) # always pick the same points for debugging
+ rand_rot = trimesh.transformations.random_rotation_matrix(self.rng.rand(3))
+ shape_data = self.augment_shape(shape_data, rand_rot)
+
+ shape_data = dict_np_to_torch(shape_data)
+ shape_data = get_data_poco(shape_data)
+ return shape_data
+
+ # load shape from a given shape index
+ def load_shape_by_index(self, shape_ind, return_kdtree=True):
+ # assume that datasets are already normalized
+ from source.occupancy_data_module import in_file_is_dataset
+ normalize = not in_file_is_dataset(self.in_file)
+
+ shape_data = OccupancyDataModule.load_shape_data_pc(
+ in_file=self.in_file, padding_factor=self.padding_factor,
+ shape_name=self.shape_names[shape_ind], normalize=normalize, return_kdtree=return_kdtree)
+ pts_ms_raw = shape_data['pts_ms']
+
+ def sub_sample_point_cloud(pts: np.ndarray, normals: np.ndarray, num_target_pts: int):
+ if num_target_pts is None:
+ return pts, normals
+ replace = True if pts.shape[0] < num_target_pts else False
+ choice_ids = self.rng.choice(np.arange(pts.shape[0]), size=num_target_pts, replace=replace)
+ return pts[choice_ids], normals[choice_ids]
+
+ pts_sub_sample, normals_sub_sample = sub_sample_point_cloud(
+ pts=shape_data['pts_ms'], normals=shape_data['normals_ms'], num_target_pts=self.manifold_points)
+ shape_data['pts_ms'] = pts_sub_sample
+ shape_data['normals_ms'] = normals_sub_sample
+
+ query_pts_dir, query_dist_dir = get_training_data_dir(self.in_file)
+ imp_surf_query_filename = os.path.join(query_pts_dir, self.shape_names[shape_ind] + '.ply.npy')
+ imp_surf_dist_filename = os.path.join(query_dist_dir, self.shape_names[shape_ind] + '.ply.npy')
+
+ if os.path.isfile(imp_surf_query_filename): # if GT data exists
+ pts_query_ms = np.load(imp_surf_query_filename)
+ if pts_query_ms.dtype != np.float32:
+ pts_query_ms = pts_query_ms.astype(np.float32)
+
+ imp_surf_dist_ms = np.load(imp_surf_dist_filename)
+ if imp_surf_dist_ms.dtype != np.float32:
+ imp_surf_dist_ms = imp_surf_dist_ms.astype(np.float32)
+ else: # if no GT data
+ pts_query_ms = np.empty((0, 3), dtype=np.float32)
+ imp_surf_dist_ms = np.empty((0, 3), dtype=np.float32)
+
+ # DDP sampler can't handle patches_per_shape, so we do it here
+ from torch.cuda import device_count
+ if bool(self.use_ddp) and device_count() > 1 and \
+ self.patches_per_shape is not None and self.patches_per_shape > 0:
+ query_pts_ids = self.rng.choice(np.arange(pts_query_ms.shape[0]), self.patches_per_shape)
+ pts_query_ms = pts_query_ms[query_pts_ids]
+ imp_surf_dist_ms = imp_surf_dist_ms[query_pts_ids]
+
+ shape_data['pts_query_ms'] = pts_query_ms
+ shape_data['imp_surf_dist_ms'] = imp_surf_dist_ms
+ shape_data['shape_id'] = shape_ind
+
+ # print('PID={}: loaded shape {}'.format(os.getpid(), shape_id)) # debug multi-processing cache
+
+ return shape_data, pts_ms_raw
+
+
+class PocoReconstructionDataset(PocoDataset):
+
+ def __init__(self, in_file, padding_factor, seed, use_ddp):
+ super(PocoReconstructionDataset, self).__init__(
+ in_file=in_file, padding_factor=padding_factor, seed=seed,
+ use_ddp=use_ddp, manifold_points=None,
+ patches_per_shape=None, do_data_augmentation=False)
+
+ # returns a patch centered at the point with the given global index
+ # and the ground truth normal the patch center
+ def __getitem__(self, shape_id):
+ shape_data, pts_ms_raw = self.load_shape_by_index(shape_id, return_kdtree=False)
+ shape_data = dict_np_to_torch(shape_data)
+ return shape_data
diff --git a/ppsurf/source/poco_model.py b/ppsurf/source/poco_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e639d7113685475ddc78231a0c3d9283adae5b3
--- /dev/null
+++ b/ppsurf/source/poco_model.py
@@ -0,0 +1,419 @@
+import typing
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+import pytorch_lightning as pl
+
+from source.base.nn import FKAConvNetwork, batch_gather, count_parameters
+from source.base import fs
+from source.base.metrics import compare_predictions_binary_tensors
+
+from source.poco_data_loader import get_proj_ids, get_data_poco
+
+
+# Adapted from POCO: https://github.com/valeoai/POCO
+# which is published under Apache 2.0: https://github.com/valeoai/POCO/blob/main/LICENSE
+
+class PocoModel(pl.LightningModule):
+
+ def __init__(self, output_names, in_channels, out_channels, k,
+ lambda_l1, debug, in_file, results_dir, padding_factor, name, network_latent_size,
+ gen_subsample_manifold_iter, gen_subsample_manifold, gen_resolution_global,
+ rec_batch_size, gen_refine_iter, workers):
+ super().__init__()
+
+ self.output_names = output_names
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.k = k
+
+ self.lambda_l1 = lambda_l1
+ self.network_latent_size = network_latent_size
+ self.gen_subsample_manifold_iter = gen_subsample_manifold_iter
+ self.gen_subsample_manifold = gen_subsample_manifold
+ self.gen_resolution_global = gen_resolution_global
+ self.gen_resolution_metric = None
+ self.num_pts_local = None
+ self.rec_batch_size = rec_batch_size
+ self.gen_refine_iter = gen_refine_iter
+ self.workers = workers
+
+ self.in_file = in_file
+ self.results_dir = results_dir
+ self.padding_factor = padding_factor
+
+ self.debug = debug
+ self.show_unused_params = debug
+ self.name = name
+
+ self.network = PocoNetwork(in_channels=self.in_channels, latent_size=self.network_latent_size,
+ out_channels=self.out_channels, k=self.k)
+
+ self.test_step_outputs = []
+
+ def on_after_backward(self):
+ # for finding disconnected parts
+ # DDP won't run by default if such parameters exist
+ # find_unused_parameters makes it run but is slower
+ if self.show_unused_params:
+ for name, param in self.named_parameters():
+ if param.grad is None:
+ print('Unused param {}'.format(name))
+ self.show_unused_params = False # print once is enough
+
+ def get_prog_bar(self):
+ from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar
+ prog_bar = self.trainer.progress_bar_callback
+ if prog_bar is not None and not isinstance(prog_bar, TQDMProgressBar):
+ print('Warning: invalid progress bar type: {}'.format(type(prog_bar)))
+ else:
+ prog_bar = typing.cast(typing.Optional[TQDMProgressBar], prog_bar)
+ return prog_bar
+
+ def compute_loss(self, pred, batch_data):
+ loss_components = []
+
+ occ_target = batch_data['occ']
+ occ_loss = nn.functional.cross_entropy(input=pred, target=occ_target, reduction='none')
+ loss_components.append(occ_loss)
+
+ loss_components_mean = [torch.mean(l) for l in loss_components]
+
+ loss_components = torch.stack(loss_components)
+ loss_components_mean = torch.stack(loss_components_mean)
+ loss_tensor = loss_components_mean.mean()
+
+ return loss_tensor, loss_components_mean, loss_components
+
+ def calc_metrics(self, pred, gt_data):
+
+ def compare_classification(pred, gt):
+ pred_labels = torch.argmax(pred, dim=1).to(torch.float32)
+
+ eval_dict = compare_predictions_binary_tensors(
+ ground_truth=gt.squeeze(), predicted=pred_labels.squeeze(), prediction_name=None)
+ return eval_dict
+
+ eval_dict = compare_classification(pred=pred, gt=gt_data['occ'])
+ eval_dict['abs_dist_rms'] = np.nan
+ return eval_dict
+
+ def get_loss_and_metrics(self, pred, batch):
+ loss, loss_components_mean, loss_components = self.compute_loss(pred=pred, batch_data=batch)
+ metrics_dict = self.calc_metrics(pred=pred, gt_data=batch)
+ return loss, loss_components_mean, loss_components, metrics_dict
+
+ def default_step_dict(self, batch):
+ pred = self.network.forward(batch)
+ loss, loss_components_mean, loss_components, metrics_dict = self.get_loss_and_metrics(pred, batch)
+
+ if self.lambda_l1 != 0.0:
+ loss = self.regularize(loss)
+
+ if bool(self.debug):
+ self.visualize_step_results(batch_data=batch, predictions=pred,
+ losses=loss_components, metrics=metrics_dict)
+ return loss, loss_components_mean, loss_components, metrics_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_components_mean, loss_components, metrics_dict = self.default_step_dict(batch=batch)
+ self.do_logging(loss, loss_components_mean, log_type='train',
+ output_names=self.output_names, metrics_dict=metrics_dict, f1_in_prog_bar=False,
+ keys_to_log=frozenset({'accuracy', 'precision', 'recall', 'f1_score'}))
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ loss, loss_components_mean, loss_components, metrics_dict = self.default_step_dict(batch=batch)
+ self.do_logging(loss, loss_components_mean, log_type='val',
+ output_names=self.output_names, metrics_dict=metrics_dict, f1_in_prog_bar=True,
+ keys_to_log=frozenset({'accuracy', 'precision', 'recall', 'f1_score'}))
+ return loss
+
+ def test_step(self, batch, batch_idx):
+ pred = self.network.forward(batch)
+
+ # assume batch size is 1
+ if batch['shape_id'].shape[0] != 1:
+ raise NotImplementedError('batch size > 1 not supported')
+
+ shape_id = batch['shape_id']
+
+ loss, loss_components_mean, loss_components = self.compute_loss(pred=pred, batch_data=batch)
+ metrics_dict = self.calc_metrics(pred=pred, gt_data=batch)
+
+ if bool(self.debug):
+ self.visualize_step_results(batch_data=batch, predictions=pred,
+ losses=loss_components, metrics=metrics_dict)
+
+ shape_id = shape_id.squeeze(0)
+ loss_components_mean = loss_components_mean.squeeze(0)
+ loss_components = loss_components.squeeze(0)
+ pc_file_in = batch['pc_file_in'][0]
+
+ results = {'shape_id': shape_id, 'pc_file_in': pc_file_in, 'loss': loss,
+ 'loss_components_mean': loss_components_mean,
+ 'loss_components': loss_components, 'metrics_dict': metrics_dict}
+ self.test_step_outputs.append(results)
+
+ prog_bar = self.get_prog_bar()
+ prog_bar.test_progress_bar.set_postfix_str('pc_file: {}'.format(os.path.basename(pc_file_in)), refresh=True)
+ return results
+
+ def on_test_epoch_end(self):
+
+ from source.base.evaluation import make_test_report
+ from source.base.container import aggregate_dicts, flatten_dicts
+ from source.occupancy_data_module import read_shape_list, get_results_dir
+
+ shape_names = read_shape_list(self.in_file)
+ results_dir = get_results_dir(out_dir=self.results_dir, name=self.name, in_file=self.in_file)
+
+ outputs_flat = flatten_dicts(self.test_step_outputs)
+ metrics_dicts_stacked = aggregate_dicts(outputs_flat, method='stack')
+
+ output_file = os.path.join(results_dir, 'metrics_{}.xlsx'.format(self.name))
+ loss_total_mean, abs_dist_rms_mean, f1_mean = make_test_report(
+ shape_names=shape_names, results=metrics_dicts_stacked,
+ output_file=output_file, output_names=self.output_names, is_dict=True)
+
+ print('Test results (mean): Loss={}, RMSE={}, F1={}'.format(loss_total_mean, abs_dist_rms_mean, f1_mean))
+
+ def predict_step(self, batch: dict, batch_idx, dataloader_idx=0):
+ from source.occupancy_data_module import get_results_dir, in_file_is_dataset
+
+ shape_data_poco = get_data_poco(batch_data=batch)
+ prog_bar = self.get_prog_bar()
+
+ if batch['pts_ms'].shape[0] > 1:
+ raise NotImplementedError('batch size > 1 not supported')
+
+ pc_file_in = batch['pc_file_in'][0]
+ if in_file_is_dataset(self.in_file):
+ results_dir = get_results_dir(out_dir=self.results_dir, name=self.name, in_file=self.in_file)
+ out_file_rec = os.path.join(results_dir, 'meshes', os.path.basename(pc_file_in))
+ else:
+ # simple folder structure for single reconstruction
+ out_file_basename = os.path.basename(pc_file_in) + '.ply'
+ out_file_rec = os.path.join(self.results_dir, os.path.basename(pc_file_in), out_file_basename)
+ pts = shape_data_poco['pts'][0].transpose(0, 1)
+
+ # create the latent storage
+ latent = torch.zeros((pts.shape[0], self.network_latent_size),
+ dtype=torch.float, device=pts.device)
+ counts = torch.zeros((pts.shape[0],), dtype=torch.float, device=pts.device)
+
+ iteration = 0
+ for current_value in range(self.gen_subsample_manifold_iter):
+ while counts.min() < current_value + 1:
+ valid_ids = torch.argwhere(counts == current_value)[:, 0].clone().detach().long()
+
+ if pts.shape[0] >= self.gen_subsample_manifold:
+
+ ids = torch.randperm(valid_ids.shape[0])[:self.gen_subsample_manifold]
+ ids = valid_ids[ids]
+
+ if ids.shape[0] < self.gen_subsample_manifold:
+ ids = torch.cat(
+ [ids, torch.randperm(pts.shape[0], device=pts.device)[
+ :self.gen_subsample_manifold - ids.shape[0]]],
+ dim=0)
+ assert (ids.shape[0] == self.gen_subsample_manifold)
+ else:
+ ids = torch.arange(pts.shape[0])
+
+ data_partial = {'pts': shape_data_poco['pts'][0].transpose(1, 0)[ids].transpose(1, 0).unsqueeze(0)}
+ partial_latent = self.network.get_latent(data_partial)['latents']
+ latent[ids] += partial_latent[0].transpose(1, 0)
+ counts[ids] += 1
+
+ iteration += 1
+ prog_bar.predict_progress_bar.set_postfix_str('get_latent iter: {}'.format(iteration), refresh=True)
+
+ latent = latent / counts.unsqueeze(1)
+ latent = latent.transpose(1, 0).unsqueeze(0)
+ shape_data_poco['latents'] = latent
+ latent = shape_data_poco
+
+ from source.poco_utils import export_mesh_and_refine_vertices_region_growing_v3
+ mesh = export_mesh_and_refine_vertices_region_growing_v3(
+ network=self.network, latent=latent,
+ pts_raw_ms=batch['pts_raw_ms'] if 'pts_raw_ms' in batch.keys() else None,
+ resolution=self.gen_resolution_global,
+ padding=1,
+ mc_value=0,
+ num_pts=self.rec_batch_size,
+ num_pts_local=self.num_pts_local,
+ input_points=shape_data_poco['pts'][0].cpu().numpy().transpose(1, 0),
+ refine_iter=self.gen_refine_iter,
+ out_value=1,
+ prog_bar=prog_bar,
+ pc_file_in=pc_file_in,
+ # workers=self.workers,
+ )
+
+ if mesh is not None:
+ # de-normalize if not part of a dataset
+ from source.occupancy_data_module import in_file_is_dataset
+ if not in_file_is_dataset(self.in_file):
+ from source.base.math import get_points_normalization_info, denormalize_points_with_info
+ from source.occupancy_data_module import OccupancyDataModule
+ pts_np = OccupancyDataModule.load_pts(pts_file=pc_file_in)
+ pts_np, _ = OccupancyDataModule.pre_process_pts(pts=pts_np)
+ bb_center, scale = get_points_normalization_info(pts=pts_np, padding_factor=self.padding_factor)
+ mesh.vertices = denormalize_points_with_info(pts=mesh.vertices, bb_center=bb_center, scale=scale)
+
+ # print(out_file_rec)
+ fs.make_dir_for_file(out_file_rec)
+ mesh.export(file_obj=out_file_rec)
+ else:
+ print('No reconstruction for {}'.format(pc_file_in))
+
+ return 0 # return something to suppress warning
+
+ def on_predict_epoch_end(self):
+ from source.base.profiling import get_now_str
+ from source.occupancy_data_module import get_results_dir, read_shape_list, get_meshes_dir, in_file_is_dataset
+
+ if not in_file_is_dataset(self.in_file):
+ return # no dataset -> nothing to evaluate
+
+ print('{}: Evaluating {}'.format(get_now_str(), self.name))
+ from source.base import evaluation, fs
+
+ results_dir = get_results_dir(out_dir=self.results_dir, name=self.name, in_file=self.in_file)
+ shape_names = read_shape_list(self.in_file)
+ gt_meshes_dir = get_meshes_dir(in_file=self.in_file)
+ if not os.path.exists(gt_meshes_dir):
+ print('Warning: {} not found. Skipping evaluation.'.format(gt_meshes_dir))
+ else:
+ gt_meshes = [os.path.join(gt_meshes_dir, '{}.ply'.format(vs)) for vs in shape_names]
+ os.makedirs(results_dir, exist_ok=True)
+ result_headers = [self.name]
+ result_file_templates = [os.path.join(results_dir, 'meshes/{}.xyz.ply')]
+ _ = evaluation.make_quantitative_comparison(
+ shape_names=shape_names, gt_mesh_files=gt_meshes,
+ result_headers=result_headers, result_file_templates=result_file_templates,
+ comp_output_dir=results_dir, num_processes=self.workers, num_samples=100000)
+
+ print('{}: Evaluating {} finished'.format(get_now_str(), self.name))
+
+ def do_logging(self, loss_total, loss_components, log_type: str, output_names: list, metrics_dict: dict,
+ keys_to_log=frozenset({'abs_dist_rms', 'accuracy', 'precision', 'recall', 'f1_score'}),
+ f1_in_prog_bar=True, on_step=True, on_epoch=False):
+
+ import math
+ import numbers
+
+ self.log('loss/{}/00_all'.format(log_type), loss_total, on_step=on_step, on_epoch=on_epoch)
+ if len(loss_components) > 1:
+ for li, l in enumerate(loss_components):
+ self.log('loss/{}/{}_{}'.format(log_type, li, output_names[li]), l, on_step=on_step, on_epoch=on_epoch)
+
+ for key in metrics_dict.keys():
+ if key in keys_to_log and isinstance(metrics_dict[key], numbers.Number):
+ value = metrics_dict[key]
+ if math.isnan(value):
+ value = 0.0
+ self.log('metrics/{}/{}'.format(log_type, key), value, on_step=on_step, on_epoch=on_epoch)
+
+ self.log('metrics/{}/{}'.format(log_type, 'F1'), metrics_dict['f1_score'],
+ on_step=on_step, on_epoch=on_epoch, logger=False, prog_bar=f1_in_prog_bar)
+
+ def visualize_step_results(self, batch_data: dict, predictions, losses, metrics):
+ from source.base import visualization
+ query_pts_ms = batch_data['pts_query_ms'].detach().cpu().numpy()
+ occ_loss = losses[0].detach().cpu().numpy()
+ vis_to_eval_file = os.path.join('debug', 'occ_loss_vis', 'test' + '.ply')
+ visualization.plot_pts_scalar_data(query_pts_ms, occ_loss, vis_to_eval_file, prop_min=0.0, prop_max=1.0)
+
+
+class PocoNetwork(pl.LightningModule):
+
+ def __init__(self, in_channels, latent_size, out_channels, k):
+ super().__init__()
+
+ self.encoder = FKAConvNetwork(in_channels, latent_size, segmentation=True, dropout=0, x4d_bug_fixed=False)
+ self.projection = InterpAttentionKHeadsNet(latent_size, out_channels, k)
+
+ self.lcp_preprocess = True
+
+ print(f'Network -- backbone -- {count_parameters(self.encoder)} parameters')
+ print(f'Network -- projection -- {count_parameters(self.projection)} parameters')
+
+ def forward(self, data):
+ latents = self.encoder.forward(data, spectral_only=True)
+ data['latents'] = latents
+ ret_data = self.projection.forward(data, has_proj_ids=True)
+ return ret_data
+
+ def get_latent(self, data):
+ latents = self.encoder.forward(data, spectral_only=False)
+ data['latents'] = latents
+ data['proj_correction'] = None
+ return data
+
+ def from_latent(self, data: typing.Dict[str, torch.Tensor]):
+ data_proj = self.projection.forward(data)
+ return data_proj
+
+
+class InterpAttentionKHeadsNet(torch.nn.Module):
+
+ def __init__(self, latent_size, out_channels, k=16):
+ super().__init__()
+
+ print(f'InterpNet - Simple - K={k}')
+ self.fc1 = torch.nn.Conv2d(latent_size + 3, latent_size, 1)
+ self.fc2 = torch.nn.Conv2d(latent_size, latent_size, 1)
+ self.fc3 = torch.nn.Conv2d(latent_size, latent_size, 1)
+
+ self.fc8 = torch.nn.Conv1d(latent_size, out_channels, 1)
+
+ self.fc_query = torch.nn.Conv2d(latent_size, 64, 1)
+ self.fc_value = torch.nn.Conv2d(latent_size, latent_size, 1)
+
+ self.k = k
+
+ self.activation = torch.nn.ReLU()
+
+ def forward(self, data: typing.Dict[str, torch.Tensor], has_proj_ids: bool = False, last_layer: bool = True)\
+ -> torch.Tensor:
+
+ if not has_proj_ids:
+ spatial_data = get_proj_ids(data, self.k)
+ for key, value in spatial_data.items():
+ data[key] = value
+
+ x = data['latents']
+ indices = data['proj_ids']
+ pts = data['pts']
+ pts_query = data['pts_query'].to(pts.device)
+
+ if pts.shape[1] != 3:
+ pts = pts.transpose(1, 2)
+
+ if pts_query.shape[1] != 3:
+ pts_query = pts_query.transpose(1, 2)
+
+ x = batch_gather(x, 2, indices)
+ pts = batch_gather(pts, 2, indices)
+ pts = pts_query.unsqueeze(3) - pts
+
+ x = torch.cat([x, pts], dim=1)
+ x = self.activation(self.fc1(x))
+ x = self.activation(self.fc2(x))
+ x = self.activation(self.fc3(x))
+
+ query = self.fc_query(x)
+ value = self.fc_value(x)
+
+ attention = torch.nn.functional.softmax(query, dim=-1).mean(dim=1)
+ x = torch.matmul(attention.unsqueeze(-2), value.permute(0, 2, 3, 1)).squeeze(-2)
+ x = x.transpose(1, 2)
+
+ if last_layer:
+ x = self.fc8(x)
+
+ return x
diff --git a/ppsurf/source/poco_utils.py b/ppsurf/source/poco_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb8852c30a0ffd2e9206f15aa369108a875b1c1b
--- /dev/null
+++ b/ppsurf/source/poco_utils.py
@@ -0,0 +1,273 @@
+import typing
+import os
+
+import trimesh
+import numpy as np
+import torch
+import torch.nn.functional as func
+import pytorch_lightning as pl
+from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar
+
+import source.base.mesh
+
+# Adapted from POCO: https://github.com/valeoai/POCO
+# which is published under Apache 2.0: https://github.com/valeoai/POCO/blob/main/LICENSE
+
+
+def profile_from_latent(func, data: dict):
+ import time
+ start = time.time()
+ res: torch.Tensor = func(data)
+ end = time.time()
+ print('{} took: {}, shape: {}'.format('from_latent', end - start, res.shape))
+ return res
+
+
+def export_mesh_and_refine_vertices_region_growing_v3(
+ network: pl.LightningModule,
+ latent: dict,
+ pts_raw_ms: torch.Tensor,
+ resolution: int,
+ padding=0,
+ mc_value=0,
+ num_pts=50000,
+ num_pts_local=None,
+ refine_iter=10,
+ input_points=None,
+ out_value=np.nan,
+ dilation_size=2,
+ prog_bar: typing.Optional[TQDMProgressBar] = None,
+ pc_file_in: str = 'unknown',
+ # workers=1,
+) -> typing.Optional[trimesh.Trimesh]:
+ from tqdm import tqdm
+ from skimage import measure
+ from source.base.fs import make_dir_for_file
+ from source.base.proximity import make_kdtree, query_kdtree
+ from source.ppsurf_data_loader import PPSurfDataset
+
+ if latent['pts_ms'].shape[0] != 1:
+ raise ValueError('Reconstruction must be done with batch size = 0!')
+
+ bmin = input_points.min()
+ bmax = input_points.max()
+
+ step = (bmax - bmin) / (resolution - 1)
+
+ bmin_pad = bmin - padding * step
+ bmax_pad = bmax + padding * step
+
+ pts_ids = (input_points - bmin) / step + padding
+ pts_ids = pts_ids.astype(np.int32)
+
+ if num_pts_local is not None:
+ pts_raw_ms = pts_raw_ms[0].detach().cpu().numpy() # expect batch size = 1
+ kdtree = make_kdtree(pts=pts_raw_ms)
+
+ def _get_pts_local_ps(pts_query: np.ndarray):
+ _, patch_pts_ids = query_kdtree(kdtree=kdtree, pts_query=pts_query, k=num_pts_local, sqr_dists=True)
+ pts_local_ms = pts_raw_ms[patch_pts_ids.astype(np.int64)]
+ pts_local_ps_np = PPSurfDataset.normalize_patches(pts_local_ms=pts_local_ms, pts_query_ms=pts_query)
+ pts_local_ps = torch.from_numpy(pts_local_ps_np).to(latent['pts_ms'].device).unsqueeze(0)
+ return pts_local_ps
+
+ def _predict_from_latent(_latent: dict):
+ occ_hat = network.from_latent(_latent)
+ # occ_hat = profile_from_latent(network.from_latent, _latent)
+
+ # get class and non-class
+ occ_hat = func.softmax(occ_hat, dim=1)
+ occ_hat = occ_hat[:, 0] - occ_hat[:, 1]
+ occ_hat = occ_hat.squeeze(0).detach().cpu().numpy()
+ return occ_hat
+
+ volume = _create_volume(_get_pts_local_ps, _predict_from_latent, dilation_size, bmin_pad, latent, num_pts,
+ num_pts_local, out_value, padding, pc_file_in, prog_bar, pts_ids, resolution, step)
+
+ # volume[np.isnan(volume)] = out_value
+ maxi = volume[~np.isnan(volume)].max()
+ mini = volume[~np.isnan(volume)].min()
+
+ # occ doesn't cross zero-level set
+ if not (maxi > mc_value > mini):
+ return None
+
+ # compute the marching cubes
+ verts, faces, _, _ = measure.marching_cubes(volume=volume.copy(), level=mc_value)
+
+ # remove the nan values in the vertices
+ # values = verts.sum(axis=1)
+ # invalid_vertices_mask = np.isnan(values)
+ # verts = np.asarray(verts[invalid_vertices_mask])
+ # faces = np.asarray(faces)
+
+ # clean mesh
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces)
+ source.base.mesh.clean_simple_inplace(mesh=mesh)
+ mesh = source.base.mesh.remove_small_connected_components(mesh=mesh, num_faces=6)
+
+ verts = np.asarray(mesh.vertices)
+ faces = np.asarray(mesh.faces)
+ if refine_iter > 0:
+ dirs = verts - np.floor(verts)
+ dirs = (dirs > 0).astype(dirs.dtype)
+
+ mask = np.logical_and(dirs.sum(axis=1) > 0, dirs.sum(axis=1) < 2)
+ v = verts[mask]
+ dirs = dirs[mask]
+
+ # initialize the two values (the two vertices for mc grid)
+ v1 = np.floor(v)
+ v2 = v1 + dirs
+
+ # get the predicted values for both set of points
+ v1 = v1.astype(int)
+ v2 = v2.astype(int)
+ preds1 = volume[v1[:, 0], v1[:, 1], v1[:, 2]]
+ preds2 = volume[v2[:, 0], v2[:, 1], v2[:, 2]]
+
+ # get the coordinates in the real coordinate system
+ v1 = v1.astype(np.float32) * step + bmin_pad
+ v2 = v2.astype(np.float32) * step + bmin_pad
+
+ # tmp mask
+ mask_tmp = np.logical_and(np.logical_not(np.isnan(preds1)), np.logical_not(np.isnan(preds2)))
+ v = v[mask_tmp]
+ # dirs = dirs[mask_tmp]
+ v1 = v1[mask_tmp]
+ v2 = v2[mask_tmp]
+ mask[mask] = mask_tmp
+
+ # initialize the vertices
+ verts = verts * step + bmin_pad
+ v = v * step + bmin_pad
+
+ # iterate for the refinement step
+ for iter_id in range(refine_iter):
+ preds = []
+ pnts_all = torch.tensor(v, dtype=torch.float)
+ for pnts in tqdm(torch.split(pnts_all, num_pts, dim=0), ncols=100, disable=True):
+ latent['pts_query'] = pnts.unsqueeze(0)
+ if num_pts_local is not None:
+ latent['pts_local_ps'] = _get_pts_local_ps(pts_query=pnts.detach().cpu().numpy())
+ preds.append(_predict_from_latent(latent))
+ preds = np.concatenate(preds, axis=0)
+
+ mask1 = (preds * preds1) > 0
+ v1[mask1] = v[mask1]
+ preds1[mask1] = preds[mask1]
+
+ mask2 = (preds * preds2) > 0
+ v2[mask2] = v[mask2]
+ preds2[mask2] = preds[mask2]
+
+ v = (v2 + v1) / 2
+ verts[mask] = v
+
+ prog_bar.predict_progress_bar.set_postfix_str('{}, refine iter {}'.format(
+ os.path.basename(pc_file_in)[:16], iter_id), refresh=True)
+ else:
+ verts = verts * step + bmin_pad
+
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces)
+ source.base.mesh.clean_simple_inplace(mesh=mesh)
+ mesh = source.base.mesh.remove_small_connected_components(mesh=mesh, num_faces=6)
+ return mesh
+
+
+def _create_volume(_get_pts_local_ps, _predict_from_latent, dilation_size, bmin_pad, latent, num_pts, num_pts_local,
+ out_value, padding, pc_file_in, prog_bar, pts_ids, resolution, step):
+
+ def _dilate_binary(arr: np.ndarray, pts_int: np.ndarray):
+ # old POCO version actually dilates with a 4^3 kernel, 2 to lower, 1 to upper
+ # -> no out-of upper bounds with 2 dilation_size by default
+ # we make it symmetric (+1 to max)
+ pts_min = np.maximum(0, pts_int - dilation_size)
+ pts_max = np.minimum(arr.shape[0], pts_int + dilation_size + 1)
+
+ def _dilate_point(pt_min, pt_max):
+ arr[pt_min[0]:pt_max[0],
+ pt_min[1]:pt_max[1],
+ pt_min[2]:pt_max[2]] = True
+
+ # vectorizing slices is not possible? so we iterate over the points
+ # skimage.morphology and scipy.ndimage take longer, probably because of overhead
+ _ = [_dilate_point(pt_min=pts_min[i], pt_max=pts_max[i]) for i in range(pts_int.shape[0])]
+ return arr
+
+ res_x = resolution
+ res_y = resolution
+ res_z = resolution
+
+ volume_shape = (res_x + 2 * padding, res_y + 2 * padding, res_z + 2 * padding)
+ volume = np.full(volume_shape, np.nan, dtype=np.float64)
+ mask_to_see = np.full(volume_shape, True, dtype=bool)
+ while pts_ids.shape[0] > 0:
+ # create the mask
+ mask = np.full(volume_shape, False, dtype=bool)
+ mask[pts_ids[:, 0], pts_ids[:, 1], pts_ids[:, 2]] = True
+ mask = _dilate_binary(arr=mask, pts_int=pts_ids)
+
+ # get the valid points
+ valid_points_coord = np.argwhere(mask).astype(np.float32)
+ valid_points = valid_points_coord * step + bmin_pad
+
+ # get the prediction for each valid points
+ z = []
+ near_surface_samples_torch = torch.tensor(valid_points, dtype=torch.float)
+ for pnts in torch.split(near_surface_samples_torch, num_pts, dim=0):
+
+ latent['pts_query'] = pnts.unsqueeze(0)
+ if num_pts_local is not None:
+ latent['pts_local_ps'] = _get_pts_local_ps(pts_query=pnts.detach().cpu().numpy())
+ z.append(_predict_from_latent(latent))
+
+ prog_bar.predict_progress_bar.set_postfix_str(
+ '{}, occ_batch iter {}'.format(os.path.basename(pc_file_in), len(z)), refresh=True)
+
+ z = np.concatenate(z, axis=0)
+ z = z.astype(np.float64)
+
+ # update the volume
+ volume[mask] = z
+
+ # create the masks
+ mask_pos = np.full(volume_shape, False, dtype=bool)
+ mask_neg = np.full(volume_shape, False, dtype=bool)
+ mask_to_see[pts_ids[:, 0], pts_ids[:, 1], pts_ids[:, 2]] = False
+
+ # dilate
+ pts_ids_pos = pts_ids[volume[pts_ids[:, 0], pts_ids[:, 1], pts_ids[:, 2]] <= 0]
+ pts_ids_neg = pts_ids[volume[pts_ids[:, 0], pts_ids[:, 1], pts_ids[:, 2]] >= 0]
+ mask_neg = _dilate_binary(arr=mask_neg, pts_int=pts_ids_pos)
+ mask_pos = _dilate_binary(arr=mask_pos, pts_int=pts_ids_neg)
+
+ # get the new points
+ new_mask = (mask_neg & (volume >= 0) & mask_to_see) | (mask_pos & (volume <= 0) & mask_to_see)
+ pts_ids = np.argwhere(new_mask).astype(np.int64)
+ volume[0:padding, :, :] = out_value
+ volume[-padding:, :, :] = out_value
+ volume[:, 0:padding, :] = out_value
+ volume[:, -padding:, :] = out_value
+ volume[:, :, 0:padding] = out_value
+ volume[:, :, -padding:] = out_value
+ return volume
+
+
+@torch.jit.ignore
+def knn(points: torch.Tensor, support_points: torch.Tensor, k: int, workers: int = 1) -> torch.Tensor:
+ if k > points.shape[2]:
+ k = points.shape[2]
+ pts = points.cpu().detach().transpose(1, 2).numpy().copy()
+ s_pts = support_points.cpu().detach().transpose(1, 2).numpy().copy()
+
+ from source.base.proximity import kdtree_query_oneshot
+
+ indices: list = []
+ for i in range(pts.shape[0]):
+ _, ids = kdtree_query_oneshot(pts=pts[i], pts_query=s_pts[i], k=k, workers=workers)
+ indices.append(torch.from_numpy(ids.astype(np.int64)))
+ indices: torch.Tensor = torch.stack(indices, dim=0)
+ if k == 1:
+ indices = indices.unsqueeze(2)
+ return indices.to(points.device)
diff --git a/ppsurf/source/ppsurf_data_loader.py b/ppsurf/source/ppsurf_data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..9359f5d3dd14673be502c17d96c305ef48e92ddb
--- /dev/null
+++ b/ppsurf/source/ppsurf_data_loader.py
@@ -0,0 +1,141 @@
+import typing
+
+import numpy as np
+from overrides import overrides
+
+from source.poco_data_loader import PocoDataModule, PocoDataset, get_data_poco
+from source.base.container import dict_np_to_torch
+from source.base.proximity import query_kdtree
+
+
+class PPSurfDataModule(PocoDataModule):
+
+ def __init__(self, num_pts_local: int,
+ in_file, workers, use_ddp, padding_factor, seed, manifold_points,
+ patches_per_shape: typing.Optional[int], do_data_augmentation: bool, batch_size: int):
+ super(PPSurfDataModule, self).__init__(
+ use_ddp=use_ddp, workers=workers, in_file=in_file, patches_per_shape=patches_per_shape,
+ do_data_augmentation=do_data_augmentation, batch_size=batch_size,
+ padding_factor=padding_factor, seed=seed, manifold_points=manifold_points)
+ self.num_pts_local = num_pts_local
+
+ def make_dataset(
+ self, in_file: typing.Union[str, list], reconstruction: bool, patches_per_shape: typing.Optional[int],
+ do_data_augmentation: bool):
+
+ if reconstruction:
+ dataset = PPSurfReconstructionDataset(
+ in_file=in_file,
+ num_pts_local=self.num_pts_local,
+ padding_factor=self.padding_factor,
+ seed=self.seed,
+ use_ddp=self.use_ddp,
+ )
+ else:
+ dataset = PPSurfDataset(
+ in_file=in_file,
+ num_pts_local=self.num_pts_local,
+ padding_factor=self.padding_factor,
+ seed=self.seed,
+ patches_per_shape=self.patches_per_shape,
+ do_data_augmentation=do_data_augmentation,
+ use_ddp=self.use_ddp,
+ manifold_points=self.manifold_points,
+ )
+ return dataset
+
+
+class PPSurfDataset(PocoDataset):
+
+ def __init__(self, in_file, num_pts_local, padding_factor, seed, use_ddp,
+ manifold_points, patches_per_shape: typing.Optional[int], do_data_augmentation=True):
+ super(PPSurfDataset, self).__init__(
+ in_file=in_file, padding_factor=padding_factor, seed=seed,
+ use_ddp=use_ddp, manifold_points=manifold_points,
+ patches_per_shape=patches_per_shape, do_data_augmentation=do_data_augmentation)
+
+ self.num_pts_local = num_pts_local
+
+ # returns a patch centered at the point with the given global index
+ # and the ground truth normal the patch center
+ def __getitem__(self, shape_id):
+ shape_data, pts_ms_raw = self.load_shape_by_index(shape_id, return_kdtree=True)
+ kdtree = shape_data.pop('kdtree')
+
+ if self.do_data_augmentation:
+ import trimesh
+ # optionally always pick the same points for a given patch index (mainly for debugging)
+ # self.rng.seed(42)
+ rand_rot = trimesh.transformations.random_rotation_matrix(self.rng.rand(3))
+ shape_data = self.augment_shape(shape_data, rand_rot)
+
+ # must be after augmentation
+ shape_data = PPSurfDataset.get_local_subsamples(shape_data, kdtree, pts_ms_raw, self.num_pts_local)
+
+ pts_local_ps = self.normalize_patches(
+ pts_local_ms=shape_data['pts_local_ms'], pts_query_ms=shape_data['pts_query_ms'])
+
+ shape_data['pts_local_ps'] = pts_local_ps
+ shape_data = dict_np_to_torch(shape_data) # must be before poco part
+ shape_data = get_data_poco(shape_data)
+ return shape_data
+
+ @staticmethod
+ def get_local_subsamples(shape_data, kdtree, pts_raw_ms, num_pts_local):
+ _, patch_pts_ids = query_kdtree(kdtree=kdtree, pts_query=shape_data['pts_query_ms'],
+ k=num_pts_local, sqr_dists=True)
+ patch_pts_ids = patch_pts_ids.astype(np.int64)
+ shape_data['pts_local_ms'] = pts_raw_ms[patch_pts_ids]
+ return shape_data
+
+ @staticmethod
+ def normalize_patches(pts_local_ms, pts_query_ms):
+ patch_radius_ms = PPSurfDataset.get_patch_radii(pts_local_ms, pts_query_ms)
+ pts_local_ps = PPSurfDataset.model_space_to_patch_space(
+ pts_to_convert_ms=pts_local_ms, pts_patch_center_ms=pts_query_ms,
+ patch_radius_ms=patch_radius_ms)
+ return pts_local_ps
+
+ @staticmethod
+ def get_patch_radii(pts_patch: np.array, query_pts: np.array):
+ if pts_patch.shape[1] == 0:
+ patch_radius = 0.0
+ elif pts_patch.shape == query_pts.shape:
+ patch_radius = np.linalg.norm(pts_patch - query_pts, axis=0)
+ else:
+ from source.base.math import cartesian_dist
+ dist = cartesian_dist(np.repeat(
+ np.expand_dims(query_pts, axis=1), pts_patch.shape[1], axis=1), pts_patch, axis=2)
+ patch_radius = np.max(dist, axis=-1)
+ return patch_radius
+
+ @staticmethod
+ def model_space_to_patch_space(pts_to_convert_ms: np.array, pts_patch_center_ms: np.array,
+ patch_radius_ms: typing.Union[float, np.ndarray]):
+
+ pts_patch_center_ms_repeated = \
+ np.repeat(np.expand_dims(pts_patch_center_ms, axis=1), pts_to_convert_ms.shape[-2], axis=-2)
+ pts_patch_space = pts_to_convert_ms - pts_patch_center_ms_repeated
+ patch_radius_ms_expanded = np.expand_dims(np.expand_dims(patch_radius_ms, axis=1), axis=2)
+ patch_radius_ms_repeated = np.repeat(patch_radius_ms_expanded, pts_to_convert_ms.shape[-2], axis=-2)
+ patch_radius_ms_repeated = np.repeat(patch_radius_ms_repeated, pts_to_convert_ms.shape[-1], axis=-1)
+ pts_patch_space = pts_patch_space / patch_radius_ms_repeated
+ return pts_patch_space
+
+
+class PPSurfReconstructionDataset(PPSurfDataset):
+
+ def __init__(self, in_file, num_pts_local, padding_factor, seed, use_ddp):
+
+ super(PPSurfReconstructionDataset, self).__init__(
+ in_file=in_file, num_pts_local=num_pts_local, padding_factor=padding_factor, seed=seed,
+ use_ddp=use_ddp, manifold_points=None, patches_per_shape=None, do_data_augmentation=False)
+
+ # returns a patch centered at the point with the given global index
+ # and the ground truth normal the patch center
+ @overrides
+ def __getitem__(self, shape_id):
+ shape_data, pts_ms_raw = self.load_shape_by_index(shape_id, return_kdtree=False)
+ shape_data['pts_raw_ms'] = pts_ms_raw # collate issue for batch size > 1
+ shape_data = dict_np_to_torch(shape_data)
+ return shape_data
diff --git a/ppsurf/source/ppsurf_model.py b/ppsurf/source/ppsurf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4acbcc2715596a86f368c4389b49c8a92e450acc
--- /dev/null
+++ b/ppsurf/source/ppsurf_model.py
@@ -0,0 +1,118 @@
+import typing
+
+import pytorch_lightning as pl
+import torch
+
+from source.poco_model import PocoModel
+from source.base.nn import count_parameters
+
+
+class PPSurfModel(PocoModel):
+
+ def __init__(self,
+ pointnet_latent_size,
+ output_names, in_channels, out_channels, k,
+ lambda_l1, debug, in_file, results_dir, padding_factor, name, network_latent_size,
+ gen_subsample_manifold_iter, gen_subsample_manifold, gen_resolution_global, num_pts_local,
+ rec_batch_size, gen_refine_iter, workers
+ ):
+ super(PPSurfModel, self).__init__(
+ output_names=output_names, in_channels=in_channels,
+ out_channels=out_channels, k=k,
+ lambda_l1=lambda_l1, debug=debug, in_file=in_file, results_dir=results_dir,
+ padding_factor=padding_factor, name=name, workers=workers, rec_batch_size=rec_batch_size,
+ gen_refine_iter=gen_refine_iter, gen_subsample_manifold=gen_subsample_manifold,
+ gen_resolution_global=gen_resolution_global,
+ gen_subsample_manifold_iter=gen_subsample_manifold_iter,
+ network_latent_size=network_latent_size
+ )
+
+ self.num_pts_local = num_pts_local
+ self.pointnet_latent_size = pointnet_latent_size
+
+ self.network = PPSurfNetwork(in_channels=self.in_channels, latent_size=self.network_latent_size,
+ out_channels=self.out_channels, k=self.k,
+ num_pts_local=self.num_pts_local,
+ pointnet_latent_size=self.pointnet_latent_size)
+
+
+class PPSurfNetwork(pl.LightningModule):
+
+ def __init__(self, in_channels, latent_size, out_channels, k, num_pts_local, pointnet_latent_size):
+ super().__init__()
+
+ from source.poco_model import InterpAttentionKHeadsNet
+ from source.base.nn import FKAConvNetwork
+ from source.base.nn import PointNetfeat, MLP
+
+ self.latent_size = latent_size
+ self.encoder = FKAConvNetwork(in_channels, latent_size, segmentation=True, dropout=0,
+ activation=torch.nn.SiLU(), x4d_bug_fixed=True)
+ self.projection = InterpAttentionKHeadsNet(latent_size, latent_size, k)
+ self.point_net = PointNetfeat(net_size_max=pointnet_latent_size, num_points=num_pts_local, use_point_stn=False,
+ use_feat_stn=True, output_size=latent_size, sym_op='att', dim=3)
+
+ # self.branch_att = AttentionPoco(latent_size, reduce=True) # attention ablation
+
+ # self.mlp = MLP(input_size=latent_size*2, output_size=out_channels, num_layers=3, # cat ablation
+ self.mlp = MLP(input_size=latent_size, output_size=out_channels, num_layers=3, # att and sum ablation
+ halving_size=False, dropout=0.3)
+
+ self.lcp_preprocess = True
+
+ self.activation = torch.nn.ReLU()
+
+ print(f'Network -- backbone -- {count_parameters(self.encoder)} parameters')
+ print(f'Network -- projection -- {count_parameters(self.projection)} parameters')
+ print(f'Network -- point_net -- {count_parameters(self.point_net)} parameters')
+ print(f'Network -- mlp -- {count_parameters(self.mlp)} parameters')
+
+ def forward(self, data):
+ latents = self.encoder.forward(data, spectral_only=True)
+ data['latents'] = latents
+ ret_data = self.from_latent(data)
+ return ret_data
+
+ def get_latent(self, data):
+ latents = self.encoder.forward(data, spectral_only=False)
+ data['latents'] = latents
+ data['proj_correction'] = None
+ return data
+
+ def from_latent(self, data: typing.Dict[str, torch.Tensor]):
+ feat_proj = self.projection.forward(data, has_proj_ids=False)
+
+ # zero tensor for debug
+ # feat_pn_shape = (data['proj_ids'].shape[0], data['proj_ids'].shape[2], data['proj_ids'].shape[1])
+ # feat_pointnet = torch.zeros(feat_pn_shape, dtype=torch.float32, device=self.device)
+
+ # PointNetFeat uses query points for batch dim -> need to flatten shape * query points dim
+ pts_local_shape = data['pts_local_ps'].shape
+ pts_local_flat_shape = (pts_local_shape[0] * pts_local_shape[1], pts_local_shape[2], pts_local_shape[3])
+ pts_local_ps_flat = data['pts_local_ps'].view(pts_local_flat_shape)
+ feat_pointnet_flat = self.point_net.forward(pts_local_ps_flat.transpose(1, 2), pts_weights=None)[0]
+ feat_pointnet = feat_pointnet_flat.view((pts_local_shape[0], pts_local_shape[1], feat_pointnet_flat.shape[1]))
+
+ # cat ablation
+ # feat_all = torch.cat((feat_proj.transpose(1, 2), feat_pointnet), dim=2)
+
+ # sum ablation -> vanilla
+ feat_all = torch.sum(torch.stack((feat_proj.transpose(1, 2), feat_pointnet), dim=0), dim=0)
+ # feat_all = feat_proj.transpose(1, 2) + feat_pointnet # result is non-contiguous
+
+ # # att: [batch, feat_len, num_feat] -> [batch, feat_len]
+ # feat_all = torch.stack((feat_proj.transpose(1, 2), feat_pointnet), dim=3)
+ # feat_all_shape = feat_all.shape
+ # feat_all = feat_all.view(feat_all_shape[0] * feat_all_shape[1], feat_all_shape[2], feat_all_shape[3])
+ # feat_all = self.branch_att.forward(feat_all)
+ # feat_all = self.activation(feat_all)
+ # feat_all = feat_all.view(feat_all_shape[0], feat_all_shape[1], feat_all_shape[2])
+
+ # PointNetFeat uses query points for batch dim -> need to flatten shape * query points dim
+ feat_all_flat = feat_all.view((feat_all.shape[0] * feat_all.shape[1], feat_all.shape[2]))
+ ret_data_flat = self.mlp(feat_all_flat)
+ ret_data = ret_data_flat.view((feat_all.shape[0], feat_all.shape[1], ret_data_flat.shape[1]))
+ ret_data = ret_data.transpose(1, 2)
+
+ return ret_data
+
diff --git a/ppsurf/start_tensorboard.sh b/ppsurf/start_tensorboard.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f1a781e7a629e262f86d3865a97dd5876828b266
--- /dev/null
+++ b/ppsurf/start_tensorboard.sh
@@ -0,0 +1 @@
+tensorboard --logdir="logs"
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 07ae9ece2150ec9e9db269cef24a147ecfe10ef2..40e165be0c5ff9ba8e9f1378db8fd48bb77ca957 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,26 +1,30 @@
-pytorch-lightning>=2.1
-torch_geometric>=2 -f https://data.pyg.org/whl/torch-2.1.0+cu124.html
-torch-cluster>=1.6.0
-scikit-learn-intelex>=2024.1.0
-numpy>=1.26.4
-scikit-image>=0.22.0
-scipy>=1.12.0
-pandas>=1.5.3
-openpyxl>=3.1.2
-overrides>=7.7.0
-pykdtree>=1.3.11
+lightning==2.4.0
+intel-openmp>=2024.0.2
+jsonargparse[signatures]==4.27.7
laspy[laszip,lazrs]>=2.5.3
+matplotlib==3.10.1
+numpy==2.2.3
+openpyxl==3.0.9
+overrides==7.7.0
+pandas==2.2.3
+pyg_lib==0.4.0+pt24cu124
pillow>=10.2.0
-intel-openmp>=2024.0.2
-tqdm>=4.66.2
pyglet>=1.5.28
-rtree>=1.2.0
-tensorboard>=2.16.2
-trimesh>=3.23.5
-pysdf>=0.1.9
-jsonargparse[signatures]>=4.27.5
+pykdtree==1.4.1
+pysdf==0.1.9
+pytorch-lightning==2.4.0
+rtree==1.4.0
+scikit-image==0.25.2
+scikit-learn==1.6.1
+scipy==1.15.2
+tensorboard==2.19.0
+torch==2.4.1+cu124
+torch-geometric==2.6.1
+torch_cluster==1.6.3+pt24cu124
+tqdm==4.67.1
+trimesh==4.6.4
spaces>=0.23
-gradio>=4.19
---index-url https://download.pytorch.org/whl/cu124
---extra-index-url=https://pypi.org/simple
\ No newline at end of file
+gradio==4.44.1
+-f https://data.pyg.org/whl/torch-2.4.0+cu124.html
+--extra-index-url https://download.pytorch.org/whl/cu124
\ No newline at end of file