Thomas Male commited on
Commit
a5407e7
·
1 Parent(s): 4f9b047

Upload 98 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +6 -0
  3. LICENSE +22 -0
  4. README.md +28 -3
  5. model-card.md +62 -0
  6. point_e.egg-info/PKG-INFO +5 -0
  7. point_e.egg-info/SOURCES.txt +35 -0
  8. point_e.egg-info/dependency_links.txt +1 -0
  9. point_e.egg-info/requires.txt +12 -0
  10. point_e.egg-info/top_level.txt +1 -0
  11. point_e/__init__.py +0 -0
  12. point_e/__pycache__/__init__.cpython-311.pyc +0 -0
  13. point_e/__pycache__/__init__.cpython-39.pyc +0 -0
  14. point_e/diffusion/__init__.py +0 -0
  15. point_e/diffusion/__pycache__/__init__.cpython-311.pyc +0 -0
  16. point_e/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  17. point_e/diffusion/__pycache__/configs.cpython-311.pyc +0 -0
  18. point_e/diffusion/__pycache__/configs.cpython-39.pyc +0 -0
  19. point_e/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc +0 -0
  20. point_e/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc +0 -0
  21. point_e/diffusion/__pycache__/k_diffusion.cpython-311.pyc +0 -0
  22. point_e/diffusion/__pycache__/k_diffusion.cpython-39.pyc +0 -0
  23. point_e/diffusion/__pycache__/sampler.cpython-311.pyc +0 -0
  24. point_e/diffusion/__pycache__/sampler.cpython-39.pyc +0 -0
  25. point_e/diffusion/configs.py +64 -0
  26. point_e/diffusion/gaussian_diffusion.py +1091 -0
  27. point_e/diffusion/k_diffusion.py +332 -0
  28. point_e/diffusion/sampler.py +263 -0
  29. point_e/evals/__init__.py +0 -0
  30. point_e/evals/feature_extractor.py +119 -0
  31. point_e/evals/fid_is.py +81 -0
  32. point_e/evals/npz_stream.py +270 -0
  33. point_e/evals/pointnet2_cls_ssg.py +101 -0
  34. point_e/evals/pointnet2_utils.py +356 -0
  35. point_e/evals/scripts/blender_script.py +533 -0
  36. point_e/evals/scripts/evaluate_pfid.py +40 -0
  37. point_e/evals/scripts/evaluate_pis.py +31 -0
  38. point_e/examples/.ipynb_checkpoints/Test-checkpoint.py +7 -0
  39. point_e/examples/.ipynb_checkpoints/pointcloud2mesh-checkpoint.ipynb +106 -0
  40. point_e/examples/.ipynb_checkpoints/text2pointcloud-checkpoint.ipynb +150 -0
  41. point_e/examples/GPUtest.py +56 -0
  42. point_e/examples/Saving Model Code.txt +6 -0
  43. point_e/examples/Test.py +7 -0
  44. point_e/examples/example_data/blue_bird.npz +3 -0
  45. point_e/examples/example_data/blue_bird.ply +0 -0
  46. point_e/examples/example_data/corgi.jpg +0 -0
  47. point_e/examples/example_data/corgi.ply +0 -0
  48. point_e/examples/example_data/cube_stack.jpg +0 -0
  49. point_e/examples/example_data/pc_corgi.npz +3 -0
  50. point_e/examples/example_data/pc_cube_stack.npz +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ point_e/examples/paper_banner.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.egg-info/
2
+ __pycache__/
3
+ point_e_model_cache/
4
+ .ipynb_checkpoints/
5
+ .DS_Store
6
+
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 OpenAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
README.md CHANGED
@@ -1,3 +1,28 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Point·E
2
+
3
+ ![Animation of four 3D point clouds rotating](point_e/examples/paper_banner.gif)
4
+
5
+ This is the official code and model release for [Point-E: A System for Generating 3D Point Clouds from Complex Prompts](https://arxiv.org/abs/2212.08751).
6
+
7
+ # Usage
8
+
9
+ Install with `pip install -e .`.
10
+
11
+ To get started with examples, see the following notebooks:
12
+
13
+ * [image2pointcloud.ipynb](point_e/examples/image2pointcloud.ipynb) - sample a point cloud, conditioned on some example synthetic view images.
14
+ * [text2pointcloud.ipynb](point_e/examples/text2pointcloud.ipynb) - use our small, worse quality pure text-to-3D model to produce 3D point clouds directly from text descriptions. This model's capabilities are limited, but it does understand some simple categories and colors.
15
+ * [pointcloud2mesh.ipynb](point_e/examples/pointcloud2mesh.ipynb) - try our SDF regression model for producing meshes from point clouds.
16
+
17
+ For our P-FID and P-IS evaluation scripts, see:
18
+
19
+ * [evaluate_pfid.py](point_e/evals/scripts/evaluate_pfid.py)
20
+ * [evaluate_pis.py](point_e/evals/scripts/evaluate_pis.py)
21
+
22
+ For our Blender rendering code, see [blender_script.py](point_e/evals/scripts/blender_script.py)
23
+
24
+ # Samples
25
+
26
+ You can download the seed images and point clouds corresponding to the paper banner images [here](https://openaipublic.azureedge.net/main/point-e/banner_pcs.zip).
27
+
28
+ You can download the seed images used for COCO CLIP R-Precision evaluations [here](https://openaipublic.azureedge.net/main/point-e/coco_images.zip).
model-card.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card: Point-E
2
+
3
+ This is the official codebase for running the point cloud diffusion models and SDF regression models described in [Point-E: A System for Generating 3D Point Clouds from Complex Prompts](https://arxiv.org/abs/2212.08751). These models were trained and released by OpenAI.
4
+ Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we're providing some information about how the models were trained and evaluated.
5
+
6
+ # Model Details
7
+
8
+ The Point-E models are trained for use as point cloud diffusion models and SDF regression models.
9
+ Our image-conditional models are often capable of producing coherent 3D point clouds, given a single rendering of a 3D object. However, the models sometimes fail to do so, either producing incorrect geometry where the rendering is occluded, or producing geometry that is inconsistent with visible parts of the rendering. The resulting point clouds are relatively low-resolution, and are often noisy and contain defects such as outliers or cracks.
10
+ Our text-conditional model is sometimes capable of producing 3D point clouds which can be recognized as the provided text description, especially when the text description is simple. However, we find that this model fails to generalize to complex prompts or unusual objects.
11
+
12
+ ## Model Date
13
+
14
+ December 2022
15
+
16
+ ## Model Versions
17
+
18
+ * `base40M-imagevec` - a 40 million parameter image to point cloud model that conditions on a single CLIP ViT-L/14 image vector. This model can be used to generate point clouds from rendered images, but does not perform as well as our other models for this task.
19
+ * `base40M-textvec` - a 40 million parameter text to point cloud model that conditions on a single CLIP ViT-L/14 text vector. This model can be used to directly generate point clouds from text descriptions, but only works for simple prompts.
20
+ * `base40M-uncond` - a 40 million parameter point cloud diffusion model that generates unconditional samples. This is included only as a baseline.
21
+ * `base40M` - a 40 million parameter image to point cloud diffusion model that conditions on the latent grid from a CLIP ViT-L/14 model. This model can be used to generate point clouds from rendered images, but is not as good as the larger models trained on the same task.
22
+ * `base300M` - a 300 million parameter image to point cloud diffusion model that conditions on the latent grid from a CLIP ViT-L/14 model. This model can be used to generate point clouds from rendered images, but it is slightly worse than base1B
23
+ * `base1B` - a 1 billion parameter image to point cloud diffusion model that conditions on the latent grid from a CLIP ViT-L/14 model.
24
+ * `upsample` - a 40 million parameter point cloud upsampling model that can optionally condition on an image as well. This takes a point cloud of 1024 points and upsamples it to 4096 points.
25
+ * `sdf` - a small model for predicting signed distance functions from 3D point clouds. This can be used to predict meshes from point clouds.
26
+ * `pointnet` - a small point cloud classification model used for our P-FID and P-IS evaluation metrics.
27
+
28
+ ## Paper & samples
29
+
30
+ [Paper](https://arxiv.org/abs/2212.08751) / [Sample point clouds](point_e/examples/paper_banner.gif)
31
+
32
+ # Training data
33
+
34
+ These models were trained on a dataset of several million 3D models. We filtered the dataset to avoid flat objects, and used [CLIP](https://github.com/openai/CLIP/blob/main/model-card.md) to cluster the dataset and downweight clusters of 3D models which appeared to contain mostly unrecognizable objects. We additionally down-weighted clusters which appeared to consist of many similar-looking objects. We processed the resulting dataset into renders (RGB point clouds of 4K points each) and text captions from the associated metadata.
35
+ Our SDF regression model was trained on a subset of the above dataset. In particular, we only retained 3D meshes which were manifold (i.e. watertight and free of singularities).
36
+
37
+ # Evaluated Use
38
+
39
+ We release these models to help advance research in generative modeling. Due to the limitations and biases of our models, we do not currently recommend it for commercial use. We understand that our models may be used in ways we haven't anticipated, and that it is difficult to define clear boundaries around what constitutes appropriate "research" use. In particular, we caution against using these models in applications where precision is critical, as subtle flaws in the outputs could lead to errors or inaccuracies.
40
+ Functionally, these models are trained to be able to perform the following tasks for research purposes, and are evaluated on these tasks:
41
+
42
+ * Generate 3D point clouds conditioned on single rendered images
43
+ * Generate 3D point clouds conditioned on text
44
+ * Create 3D meshes from noisy 3D point clouds
45
+
46
+ Our image-conditional models are intended to produce coherent point clouds, given a representative rendering of a 3D object. However, at their current level of capabilities, the models sometimes fail to generate coherent output, either producing incorrect geometry where the rendering is occluded, or producing geometry that is inconsistent with visible parts of the rendering. The resulting point clouds are relatively low-resolution, and are often noisy and contain defects such as outliers or cracks.
47
+
48
+ Our text-conditional model is sometimes capable of producing 3D point clouds which can be recognized as the provided text description, especially when the text description is simple. However, we find that this model fails to generalize to complex prompts or unusual objects.
49
+
50
+ # Performance and Limitations
51
+
52
+ Our image-conditional models are limited by the text-to-image model that is used to produce synthetic views. If the text-to-image model contains a bias or fails to understand a particular concept, these limitations will be passed down to the image-conditional point cloud model through conditioning images.
53
+ While our main focus is on image-conditional models, we also experimented with a text-conditional model. We find that this model can sometimes produce 3D models of people that exhibit gender biases (for example, samples for "a man" tend to be wider and less narrow than samples for "a woman"). We additionally find that this model is sometimes capable of producing violent objects such as guns or tanks, although these generations are always low-quality and unrealistic.
54
+
55
+ Since our dataset contains many simplistic, cartoonish 3D objects, our models are prone to mimicking this style.
56
+
57
+ While these models were developed for research purposes, they have potential implications if used more broadly. For example, the ability to generate 3D point clouds from single images could help advance research in computer graphics, virtual reality, and robotics. The text-conditional model could allow for users to easily create 3D models from simple descriptions, which could be useful for rapid prototyping or 3D printing.
58
+
59
+ The combination of these models with 3D printing could potentially be harmful, for example if used to prototype dangerous objects or when parts created by the model are trusted without external validation.
60
+
61
+ Finally, point cloud models inherit many of the same risks and limitations as image-generation models, including the propensity to produce biased or otherwise harmful content or to carry dual-use risk. More research is needed on how these risks manifest themselves as capabilities improve.
62
+
point_e.egg-info/PKG-INFO ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: point-e
3
+ Version: 0.0.0
4
+ Author: OpenAI
5
+ License-File: LICENSE
point_e.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ point_e/__init__.py
5
+ point_e.egg-info/PKG-INFO
6
+ point_e.egg-info/SOURCES.txt
7
+ point_e.egg-info/dependency_links.txt
8
+ point_e.egg-info/requires.txt
9
+ point_e.egg-info/top_level.txt
10
+ point_e/diffusion/__init__.py
11
+ point_e/diffusion/configs.py
12
+ point_e/diffusion/gaussian_diffusion.py
13
+ point_e/diffusion/k_diffusion.py
14
+ point_e/diffusion/sampler.py
15
+ point_e/evals/__init__.py
16
+ point_e/evals/feature_extractor.py
17
+ point_e/evals/fid_is.py
18
+ point_e/evals/npz_stream.py
19
+ point_e/evals/pointnet2_cls_ssg.py
20
+ point_e/evals/pointnet2_utils.py
21
+ point_e/models/__init__.py
22
+ point_e/models/checkpoint.py
23
+ point_e/models/configs.py
24
+ point_e/models/download.py
25
+ point_e/models/perceiver.py
26
+ point_e/models/pretrained_clip.py
27
+ point_e/models/sdf.py
28
+ point_e/models/transformer.py
29
+ point_e/models/util.py
30
+ point_e/util/__init__.py
31
+ point_e/util/mesh.py
32
+ point_e/util/pc_to_mesh.py
33
+ point_e/util/plotting.py
34
+ point_e/util/ply_util.py
35
+ point_e/util/point_cloud.py
point_e.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
point_e.egg-info/requires.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ filelock
2
+ Pillow
3
+ torch
4
+ fire
5
+ humanize
6
+ requests
7
+ tqdm
8
+ matplotlib
9
+ scikit-image
10
+ scipy
11
+ numpy
12
+ clip@ git+https://github.com/openai/CLIP.git
point_e.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ point_e
point_e/__init__.py ADDED
File without changes
point_e/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (173 Bytes). View file
 
point_e/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (155 Bytes). View file
 
point_e/diffusion/__init__.py ADDED
File without changes
point_e/diffusion/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (183 Bytes). View file
 
point_e/diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (165 Bytes). View file
 
point_e/diffusion/__pycache__/configs.cpython-311.pyc ADDED
Binary file (2.3 kB). View file
 
point_e/diffusion/__pycache__/configs.cpython-39.pyc ADDED
Binary file (1.55 kB). View file
 
point_e/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc ADDED
Binary file (52.9 kB). View file
 
point_e/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc ADDED
Binary file (32.5 kB). View file
 
point_e/diffusion/__pycache__/k_diffusion.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
point_e/diffusion/__pycache__/k_diffusion.cpython-39.pyc ADDED
Binary file (10.3 kB). View file
 
point_e/diffusion/__pycache__/sampler.cpython-311.pyc ADDED
Binary file (15.6 kB). View file
 
point_e/diffusion/__pycache__/sampler.cpython-39.pyc ADDED
Binary file (8.59 kB). View file
 
point_e/diffusion/configs.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py
3
+ """
4
+
5
+ from typing import Any, Dict
6
+
7
+ import numpy as np
8
+
9
+ from .gaussian_diffusion import (
10
+ GaussianDiffusion,
11
+ SpacedDiffusion,
12
+ get_named_beta_schedule,
13
+ space_timesteps,
14
+ )
15
+
16
+ BASE_DIFFUSION_CONFIG = {
17
+ "channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0],
18
+ "channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255],
19
+ "mean_type": "epsilon",
20
+ "schedule": "cosine",
21
+ "timesteps": 1024,
22
+ }
23
+
24
+ DIFFUSION_CONFIGS = {
25
+ "base40M-imagevec": BASE_DIFFUSION_CONFIG,
26
+ "base40M-textvec": BASE_DIFFUSION_CONFIG,
27
+ "base40M-uncond": BASE_DIFFUSION_CONFIG,
28
+ "base40M": BASE_DIFFUSION_CONFIG,
29
+ "base300M": BASE_DIFFUSION_CONFIG,
30
+ "base1B": BASE_DIFFUSION_CONFIG,
31
+ "upsample": {
32
+ "channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0],
33
+ "channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255],
34
+ "mean_type": "epsilon",
35
+ "schedule": "linear",
36
+ "timesteps": 1024,
37
+ },
38
+ }
39
+
40
+
41
+ def diffusion_from_config(config: Dict[str, Any]) -> GaussianDiffusion:
42
+ schedule = config["schedule"]
43
+ steps = config["timesteps"]
44
+ respace = config.get("respacing", None)
45
+ mean_type = config.get("mean_type", "epsilon")
46
+ betas = get_named_beta_schedule(schedule, steps)
47
+ channel_scales = config.get("channel_scales", None)
48
+ channel_biases = config.get("channel_biases", None)
49
+ if channel_scales is not None:
50
+ channel_scales = np.array(channel_scales)
51
+ if channel_biases is not None:
52
+ channel_biases = np.array(channel_biases)
53
+ kwargs = dict(
54
+ betas=betas,
55
+ model_mean_type=mean_type,
56
+ model_var_type="learned_range",
57
+ loss_type="mse",
58
+ channel_scales=channel_scales,
59
+ channel_biases=channel_biases,
60
+ )
61
+ if respace is None:
62
+ return GaussianDiffusion(**kwargs)
63
+ else:
64
+ return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs)
point_e/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,1091 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py
3
+ """
4
+
5
+ import math
6
+ from typing import Any, Dict, Iterable, Optional, Sequence, Union
7
+
8
+ import numpy as np
9
+ import torch as th
10
+
11
+
12
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
13
+ """
14
+ This is the deprecated API for creating beta schedules.
15
+
16
+ See get_named_beta_schedule() for the new library of schedules.
17
+ """
18
+ if beta_schedule == "linear":
19
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
20
+ else:
21
+ raise NotImplementedError(beta_schedule)
22
+ assert betas.shape == (num_diffusion_timesteps,)
23
+ return betas
24
+
25
+
26
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
27
+ """
28
+ Get a pre-defined beta schedule for the given name.
29
+
30
+ The beta schedule library consists of beta schedules which remain similar
31
+ in the limit of num_diffusion_timesteps.
32
+ Beta schedules may be added, but should not be removed or changed once
33
+ they are committed to maintain backwards compatibility.
34
+ """
35
+ if schedule_name == "linear":
36
+ # Linear schedule from Ho et al, extended to work for any number of
37
+ # diffusion steps.
38
+ scale = 1000 / num_diffusion_timesteps
39
+ return get_beta_schedule(
40
+ "linear",
41
+ beta_start=scale * 0.0001,
42
+ beta_end=scale * 0.02,
43
+ num_diffusion_timesteps=num_diffusion_timesteps,
44
+ )
45
+ elif schedule_name == "cosine":
46
+ return betas_for_alpha_bar(
47
+ num_diffusion_timesteps,
48
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
49
+ )
50
+ else:
51
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
52
+
53
+
54
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
55
+ """
56
+ Create a beta schedule that discretizes the given alpha_t_bar function,
57
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
58
+
59
+ :param num_diffusion_timesteps: the number of betas to produce.
60
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
61
+ produces the cumulative product of (1-beta) up to that
62
+ part of the diffusion process.
63
+ :param max_beta: the maximum beta to use; use values lower than 1 to
64
+ prevent singularities.
65
+ """
66
+ betas = []
67
+ for i in range(num_diffusion_timesteps):
68
+ t1 = i / num_diffusion_timesteps
69
+ t2 = (i + 1) / num_diffusion_timesteps
70
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
71
+ return np.array(betas)
72
+
73
+
74
+ def space_timesteps(num_timesteps, section_counts):
75
+ """
76
+ Create a list of timesteps to use from an original diffusion process,
77
+ given the number of timesteps we want to take from equally-sized portions
78
+ of the original process.
79
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
80
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
81
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
82
+ :param num_timesteps: the number of diffusion steps in the original
83
+ process to divide up.
84
+ :param section_counts: either a list of numbers, or a string containing
85
+ comma-separated numbers, indicating the step count
86
+ per section. As a special case, use "ddimN" where N
87
+ is a number of steps to use the striding from the
88
+ DDIM paper.
89
+ :return: a set of diffusion steps from the original process to use.
90
+ """
91
+ if isinstance(section_counts, str):
92
+ if section_counts.startswith("ddim"):
93
+ desired_count = int(section_counts[len("ddim") :])
94
+ for i in range(1, num_timesteps):
95
+ if len(range(0, num_timesteps, i)) == desired_count:
96
+ return set(range(0, num_timesteps, i))
97
+ raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
98
+ elif section_counts.startswith("exact"):
99
+ res = set(int(x) for x in section_counts[len("exact") :].split(","))
100
+ for x in res:
101
+ if x < 0 or x >= num_timesteps:
102
+ raise ValueError(f"timestep out of bounds: {x}")
103
+ return res
104
+ section_counts = [int(x) for x in section_counts.split(",")]
105
+ size_per = num_timesteps // len(section_counts)
106
+ extra = num_timesteps % len(section_counts)
107
+ start_idx = 0
108
+ all_steps = []
109
+ for i, section_count in enumerate(section_counts):
110
+ size = size_per + (1 if i < extra else 0)
111
+ if size < section_count:
112
+ raise ValueError(f"cannot divide section of {size} steps into {section_count}")
113
+ if section_count <= 1:
114
+ frac_stride = 1
115
+ else:
116
+ frac_stride = (size - 1) / (section_count - 1)
117
+ cur_idx = 0.0
118
+ taken_steps = []
119
+ for _ in range(section_count):
120
+ taken_steps.append(start_idx + round(cur_idx))
121
+ cur_idx += frac_stride
122
+ all_steps += taken_steps
123
+ start_idx += size
124
+ return set(all_steps)
125
+
126
+
127
+ class GaussianDiffusion:
128
+ """
129
+ Utilities for training and sampling diffusion models.
130
+
131
+ Ported directly from here:
132
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
133
+
134
+ :param betas: a 1-D array of betas for each diffusion timestep from T to 1.
135
+ :param model_mean_type: a string determining what the model outputs.
136
+ :param model_var_type: a string determining how variance is output.
137
+ :param loss_type: a string determining the loss function to use.
138
+ :param discretized_t0: if True, use discrete gaussian loss for t=0. Only
139
+ makes sense for images.
140
+ :param channel_scales: a multiplier to apply to x_start in training_losses
141
+ and sampling functions.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ *,
147
+ betas: Sequence[float],
148
+ model_mean_type: str,
149
+ model_var_type: str,
150
+ loss_type: str,
151
+ discretized_t0: bool = False,
152
+ channel_scales: Optional[np.ndarray] = None,
153
+ channel_biases: Optional[np.ndarray] = None,
154
+ ):
155
+ self.model_mean_type = model_mean_type
156
+ self.model_var_type = model_var_type
157
+ self.loss_type = loss_type
158
+ self.discretized_t0 = discretized_t0
159
+ self.channel_scales = channel_scales
160
+ self.channel_biases = channel_biases
161
+
162
+ # Use float64 for accuracy.
163
+ betas = np.array(betas, dtype=np.float64)
164
+ self.betas = betas
165
+ assert len(betas.shape) == 1, "betas must be 1-D"
166
+ assert (betas > 0).all() and (betas <= 1).all()
167
+
168
+ self.num_timesteps = int(betas.shape[0])
169
+
170
+ alphas = 1.0 - betas
171
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
172
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
173
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
174
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
175
+
176
+ # calculations for diffusion q(x_t | x_{t-1}) and others
177
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
178
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
179
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
180
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
181
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
182
+
183
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
184
+ self.posterior_variance = (
185
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
186
+ )
187
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
188
+ self.posterior_log_variance_clipped = np.log(
189
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
190
+ )
191
+ self.posterior_mean_coef1 = (
192
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
193
+ )
194
+ self.posterior_mean_coef2 = (
195
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
196
+ )
197
+
198
+ def get_sigmas(self, t):
199
+ return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape)
200
+
201
+ def q_mean_variance(self, x_start, t):
202
+ """
203
+ Get the distribution q(x_t | x_0).
204
+
205
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
206
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
207
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
208
+ """
209
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
210
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
211
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
212
+ return mean, variance, log_variance
213
+
214
+ def q_sample(self, x_start, t, noise=None):
215
+ """
216
+ Diffuse the data for a given number of diffusion steps.
217
+
218
+ In other words, sample from q(x_t | x_0).
219
+
220
+ :param x_start: the initial data batch.
221
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
222
+ :param noise: if specified, the split-out normal noise.
223
+ :return: A noisy version of x_start.
224
+ """
225
+ if noise is None:
226
+ noise = th.randn_like(x_start)
227
+ assert noise.shape == x_start.shape
228
+ return (
229
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
230
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
231
+ )
232
+
233
+ def q_posterior_mean_variance(self, x_start, x_t, t):
234
+ """
235
+ Compute the mean and variance of the diffusion posterior:
236
+
237
+ q(x_{t-1} | x_t, x_0)
238
+
239
+ """
240
+ assert x_start.shape == x_t.shape
241
+ posterior_mean = (
242
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
243
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
244
+ )
245
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
246
+ posterior_log_variance_clipped = _extract_into_tensor(
247
+ self.posterior_log_variance_clipped, t, x_t.shape
248
+ )
249
+ assert (
250
+ posterior_mean.shape[0]
251
+ == posterior_variance.shape[0]
252
+ == posterior_log_variance_clipped.shape[0]
253
+ == x_start.shape[0]
254
+ )
255
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
256
+
257
+ def p_mean_variance(
258
+ self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None
259
+ ):
260
+ """
261
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
262
+ the initial x, x_0.
263
+
264
+ :param model: the model, which takes a signal and a batch of timesteps
265
+ as input.
266
+ :param x: the [N x C x ...] tensor at time t.
267
+ :param t: a 1-D Tensor of timesteps.
268
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
269
+ :param denoised_fn: if not None, a function which applies to the
270
+ x_start prediction before it is used to sample. Applies before
271
+ clip_denoised.
272
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
273
+ pass to the model. This can be used for conditioning.
274
+ :return: a dict with the following keys:
275
+ - 'mean': the model mean output.
276
+ - 'variance': the model variance output.
277
+ - 'log_variance': the log of 'variance'.
278
+ - 'pred_xstart': the prediction for x_0.
279
+ """
280
+ if model_kwargs is None:
281
+ model_kwargs = {}
282
+
283
+ B, C = x.shape[:2]
284
+ assert t.shape == (B,)
285
+ model_output = model(x, t, **model_kwargs)
286
+ if isinstance(model_output, tuple):
287
+ model_output, extra = model_output
288
+ else:
289
+ extra = None
290
+
291
+ if self.model_var_type in ["learned", "learned_range"]:
292
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
293
+ model_output, model_var_values = th.split(model_output, C, dim=1)
294
+ if self.model_var_type == "learned":
295
+ model_log_variance = model_var_values
296
+ model_variance = th.exp(model_log_variance)
297
+ else:
298
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
299
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
300
+ # The model_var_values is [-1, 1] for [min_var, max_var].
301
+ frac = (model_var_values + 1) / 2
302
+ model_log_variance = frac * max_log + (1 - frac) * min_log
303
+ model_variance = th.exp(model_log_variance)
304
+ else:
305
+ model_variance, model_log_variance = {
306
+ # for fixedlarge, we set the initial (log-)variance like so
307
+ # to get a better decoder log likelihood.
308
+ "fixed_large": (
309
+ np.append(self.posterior_variance[1], self.betas[1:]),
310
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
311
+ ),
312
+ "fixed_small": (
313
+ self.posterior_variance,
314
+ self.posterior_log_variance_clipped,
315
+ ),
316
+ }[self.model_var_type]
317
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
318
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
319
+
320
+ def process_xstart(x):
321
+ if denoised_fn is not None:
322
+ x = denoised_fn(x)
323
+ if clip_denoised:
324
+ return x.clamp(-1, 1)
325
+ return x
326
+
327
+ if self.model_mean_type == "x_prev":
328
+ pred_xstart = process_xstart(
329
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
330
+ )
331
+ model_mean = model_output
332
+ elif self.model_mean_type in ["x_start", "epsilon"]:
333
+ if self.model_mean_type == "x_start":
334
+ pred_xstart = process_xstart(model_output)
335
+ else:
336
+ pred_xstart = process_xstart(
337
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
338
+ )
339
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
340
+ else:
341
+ raise NotImplementedError(self.model_mean_type)
342
+
343
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
344
+ return {
345
+ "mean": model_mean,
346
+ "variance": model_variance,
347
+ "log_variance": model_log_variance,
348
+ "pred_xstart": pred_xstart,
349
+ "extra": extra,
350
+ }
351
+
352
+ def _predict_xstart_from_eps(self, x_t, t, eps):
353
+ assert x_t.shape == eps.shape
354
+ return (
355
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
356
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
357
+ )
358
+
359
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
360
+ assert x_t.shape == xprev.shape
361
+ return ( # (xprev - coef2*x_t) / coef1
362
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
363
+ - _extract_into_tensor(
364
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
365
+ )
366
+ * x_t
367
+ )
368
+
369
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
370
+ return (
371
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
372
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
373
+
374
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
375
+ """
376
+ Compute the mean for the previous step, given a function cond_fn that
377
+ computes the gradient of a conditional log probability with respect to
378
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
379
+ condition on y.
380
+
381
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
382
+ """
383
+ gradient = cond_fn(x, t, **model_kwargs)
384
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
385
+ return new_mean
386
+
387
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
388
+ """
389
+ Compute what the p_mean_variance output would have been, should the
390
+ model's score function be conditioned by cond_fn.
391
+
392
+ See condition_mean() for details on cond_fn.
393
+
394
+ Unlike condition_mean(), this instead uses the conditioning strategy
395
+ from Song et al (2020).
396
+ """
397
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
398
+
399
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
400
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
401
+
402
+ out = p_mean_var.copy()
403
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
404
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
405
+ return out
406
+
407
+ def p_sample(
408
+ self,
409
+ model,
410
+ x,
411
+ t,
412
+ clip_denoised=False,
413
+ denoised_fn=None,
414
+ cond_fn=None,
415
+ model_kwargs=None,
416
+ ):
417
+ """
418
+ Sample x_{t-1} from the model at the given timestep.
419
+
420
+ :param model: the model to sample from.
421
+ :param x: the current tensor at x_{t-1}.
422
+ :param t: the value of t, starting at 0 for the first diffusion step.
423
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
424
+ :param denoised_fn: if not None, a function which applies to the
425
+ x_start prediction before it is used to sample.
426
+ :param cond_fn: if not None, this is a gradient function that acts
427
+ similarly to the model.
428
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
429
+ pass to the model. This can be used for conditioning.
430
+ :return: a dict containing the following keys:
431
+ - 'sample': a random sample from the model.
432
+ - 'pred_xstart': a prediction of x_0.
433
+ """
434
+ out = self.p_mean_variance(
435
+ model,
436
+ x,
437
+ t,
438
+ clip_denoised=clip_denoised,
439
+ denoised_fn=denoised_fn,
440
+ model_kwargs=model_kwargs,
441
+ )
442
+ noise = th.randn_like(x)
443
+ nonzero_mask = (
444
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
445
+ ) # no noise when t == 0
446
+ if cond_fn is not None:
447
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
448
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
449
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
450
+
451
+ def p_sample_loop(
452
+ self,
453
+ model,
454
+ shape,
455
+ noise=None,
456
+ clip_denoised=False,
457
+ denoised_fn=None,
458
+ cond_fn=None,
459
+ model_kwargs=None,
460
+ device=None,
461
+ progress=False,
462
+ temp=1.0,
463
+ ):
464
+ """
465
+ Generate samples from the model.
466
+
467
+ :param model: the model module.
468
+ :param shape: the shape of the samples, (N, C, H, W).
469
+ :param noise: if specified, the noise from the encoder to sample.
470
+ Should be of the same shape as `shape`.
471
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
472
+ :param denoised_fn: if not None, a function which applies to the
473
+ x_start prediction before it is used to sample.
474
+ :param cond_fn: if not None, this is a gradient function that acts
475
+ similarly to the model.
476
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
477
+ pass to the model. This can be used for conditioning.
478
+ :param device: if specified, the device to create the samples on.
479
+ If not specified, use a model parameter's device.
480
+ :param progress: if True, show a tqdm progress bar.
481
+ :return: a non-differentiable batch of samples.
482
+ """
483
+ final = None
484
+ for sample in self.p_sample_loop_progressive(
485
+ model,
486
+ shape,
487
+ noise=noise,
488
+ clip_denoised=clip_denoised,
489
+ denoised_fn=denoised_fn,
490
+ cond_fn=cond_fn,
491
+ model_kwargs=model_kwargs,
492
+ device=device,
493
+ progress=progress,
494
+ temp=temp,
495
+ ):
496
+ final = sample
497
+ return final["sample"]
498
+
499
+ def p_sample_loop_progressive(
500
+ self,
501
+ model,
502
+ shape,
503
+ noise=None,
504
+ clip_denoised=False,
505
+ denoised_fn=None,
506
+ cond_fn=None,
507
+ model_kwargs=None,
508
+ device=None,
509
+ progress=False,
510
+ temp=1.0,
511
+ ):
512
+ """
513
+ Generate samples from the model and yield intermediate samples from
514
+ each timestep of diffusion.
515
+
516
+ Arguments are the same as p_sample_loop().
517
+ Returns a generator over dicts, where each dict is the return value of
518
+ p_sample().
519
+ """
520
+ if device is None:
521
+ device = next(model.parameters()).device
522
+ assert isinstance(shape, (tuple, list))
523
+ if noise is not None:
524
+ img = noise
525
+ else:
526
+ img = th.randn(*shape, device=device) * temp
527
+ indices = list(range(self.num_timesteps))[::-1]
528
+
529
+ if progress:
530
+ # Lazy import so that we don't depend on tqdm.
531
+ from tqdm.auto import tqdm
532
+
533
+ indices = tqdm(indices)
534
+
535
+ for i in indices:
536
+ t = th.tensor([i] * shape[0], device=device)
537
+ with th.no_grad():
538
+ out = self.p_sample(
539
+ model,
540
+ img,
541
+ t,
542
+ clip_denoised=clip_denoised,
543
+ denoised_fn=denoised_fn,
544
+ cond_fn=cond_fn,
545
+ model_kwargs=model_kwargs,
546
+ )
547
+ yield self.unscale_out_dict(out)
548
+ img = out["sample"]
549
+
550
+ def ddim_sample(
551
+ self,
552
+ model,
553
+ x,
554
+ t,
555
+ clip_denoised=False,
556
+ denoised_fn=None,
557
+ cond_fn=None,
558
+ model_kwargs=None,
559
+ eta=0.0,
560
+ ):
561
+ """
562
+ Sample x_{t-1} from the model using DDIM.
563
+
564
+ Same usage as p_sample().
565
+ """
566
+ out = self.p_mean_variance(
567
+ model,
568
+ x,
569
+ t,
570
+ clip_denoised=clip_denoised,
571
+ denoised_fn=denoised_fn,
572
+ model_kwargs=model_kwargs,
573
+ )
574
+ if cond_fn is not None:
575
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
576
+
577
+ # Usually our model outputs epsilon, but we re-derive it
578
+ # in case we used x_start or x_prev prediction.
579
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
580
+
581
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
582
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
583
+ sigma = (
584
+ eta
585
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
586
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
587
+ )
588
+ # Equation 12.
589
+ noise = th.randn_like(x)
590
+ mean_pred = (
591
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
592
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
593
+ )
594
+ nonzero_mask = (
595
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
596
+ ) # no noise when t == 0
597
+ sample = mean_pred + nonzero_mask * sigma * noise
598
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
599
+
600
+ def ddim_reverse_sample(
601
+ self,
602
+ model,
603
+ x,
604
+ t,
605
+ clip_denoised=False,
606
+ denoised_fn=None,
607
+ cond_fn=None,
608
+ model_kwargs=None,
609
+ eta=0.0,
610
+ ):
611
+ """
612
+ Sample x_{t+1} from the model using DDIM reverse ODE.
613
+ """
614
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
615
+ out = self.p_mean_variance(
616
+ model,
617
+ x,
618
+ t,
619
+ clip_denoised=clip_denoised,
620
+ denoised_fn=denoised_fn,
621
+ model_kwargs=model_kwargs,
622
+ )
623
+ if cond_fn is not None:
624
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
625
+ # Usually our model outputs epsilon, but we re-derive it
626
+ # in case we used x_start or x_prev prediction.
627
+ eps = (
628
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
629
+ - out["pred_xstart"]
630
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
631
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
632
+
633
+ # Equation 12. reversed
634
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
635
+
636
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
637
+
638
+ def ddim_sample_loop(
639
+ self,
640
+ model,
641
+ shape,
642
+ noise=None,
643
+ clip_denoised=False,
644
+ denoised_fn=None,
645
+ cond_fn=None,
646
+ model_kwargs=None,
647
+ device=None,
648
+ progress=False,
649
+ eta=0.0,
650
+ temp=1.0,
651
+ ):
652
+ """
653
+ Generate samples from the model using DDIM.
654
+
655
+ Same usage as p_sample_loop().
656
+ """
657
+ final = None
658
+ for sample in self.ddim_sample_loop_progressive(
659
+ model,
660
+ shape,
661
+ noise=noise,
662
+ clip_denoised=clip_denoised,
663
+ denoised_fn=denoised_fn,
664
+ cond_fn=cond_fn,
665
+ model_kwargs=model_kwargs,
666
+ device=device,
667
+ progress=progress,
668
+ eta=eta,
669
+ temp=temp,
670
+ ):
671
+ final = sample
672
+ return final["sample"]
673
+
674
+ def ddim_sample_loop_progressive(
675
+ self,
676
+ model,
677
+ shape,
678
+ noise=None,
679
+ clip_denoised=False,
680
+ denoised_fn=None,
681
+ cond_fn=None,
682
+ model_kwargs=None,
683
+ device=None,
684
+ progress=False,
685
+ eta=0.0,
686
+ temp=1.0,
687
+ ):
688
+ """
689
+ Use DDIM to sample from the model and yield intermediate samples from
690
+ each timestep of DDIM.
691
+
692
+ Same usage as p_sample_loop_progressive().
693
+ """
694
+ if device is None:
695
+ device = next(model.parameters()).device
696
+ assert isinstance(shape, (tuple, list))
697
+ if noise is not None:
698
+ img = noise
699
+ else:
700
+ img = th.randn(*shape, device=device) * temp
701
+ indices = list(range(self.num_timesteps))[::-1]
702
+
703
+ if progress:
704
+ # Lazy import so that we don't depend on tqdm.
705
+ from tqdm.auto import tqdm
706
+
707
+ indices = tqdm(indices)
708
+
709
+ for i in indices:
710
+ t = th.tensor([i] * shape[0], device=device)
711
+ with th.no_grad():
712
+ out = self.ddim_sample(
713
+ model,
714
+ img,
715
+ t,
716
+ clip_denoised=clip_denoised,
717
+ denoised_fn=denoised_fn,
718
+ cond_fn=cond_fn,
719
+ model_kwargs=model_kwargs,
720
+ eta=eta,
721
+ )
722
+ yield self.unscale_out_dict(out)
723
+ img = out["sample"]
724
+
725
+ def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None):
726
+ """
727
+ Get a term for the variational lower-bound.
728
+
729
+ The resulting units are bits (rather than nats, as one might expect).
730
+ This allows for comparison to other papers.
731
+
732
+ :return: a dict with the following keys:
733
+ - 'output': a shape [N] tensor of NLLs or KLs.
734
+ - 'pred_xstart': the x_0 predictions.
735
+ """
736
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
737
+ x_start=x_start, x_t=x_t, t=t
738
+ )
739
+ out = self.p_mean_variance(
740
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
741
+ )
742
+ kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
743
+ kl = mean_flat(kl) / np.log(2.0)
744
+
745
+ decoder_nll = -discretized_gaussian_log_likelihood(
746
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
747
+ )
748
+ if not self.discretized_t0:
749
+ decoder_nll = th.zeros_like(decoder_nll)
750
+ assert decoder_nll.shape == x_start.shape
751
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
752
+
753
+ # At the first timestep return the decoder NLL,
754
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
755
+ output = th.where((t == 0), decoder_nll, kl)
756
+ return {
757
+ "output": output,
758
+ "pred_xstart": out["pred_xstart"],
759
+ "extra": out["extra"],
760
+ }
761
+
762
+ def training_losses(
763
+ self, model, x_start, t, model_kwargs=None, noise=None
764
+ ) -> Dict[str, th.Tensor]:
765
+ """
766
+ Compute training losses for a single timestep.
767
+
768
+ :param model: the model to evaluate loss on.
769
+ :param x_start: the [N x C x ...] tensor of inputs.
770
+ :param t: a batch of timestep indices.
771
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
772
+ pass to the model. This can be used for conditioning.
773
+ :param noise: if specified, the specific Gaussian noise to try to remove.
774
+ :return: a dict with the key "loss" containing a tensor of shape [N].
775
+ Some mean or variance settings may also have other keys.
776
+ """
777
+ x_start = self.scale_channels(x_start)
778
+ if model_kwargs is None:
779
+ model_kwargs = {}
780
+ if noise is None:
781
+ noise = th.randn_like(x_start)
782
+ x_t = self.q_sample(x_start, t, noise=noise)
783
+
784
+ terms = {}
785
+
786
+ if self.loss_type == "kl" or self.loss_type == "rescaled_kl":
787
+ vb_terms = self._vb_terms_bpd(
788
+ model=model,
789
+ x_start=x_start,
790
+ x_t=x_t,
791
+ t=t,
792
+ clip_denoised=False,
793
+ model_kwargs=model_kwargs,
794
+ )
795
+ terms["loss"] = vb_terms["output"]
796
+ if self.loss_type == "rescaled_kl":
797
+ terms["loss"] *= self.num_timesteps
798
+ extra = vb_terms["extra"]
799
+ elif self.loss_type == "mse" or self.loss_type == "rescaled_mse":
800
+ model_output = model(x_t, t, **model_kwargs)
801
+ if isinstance(model_output, tuple):
802
+ model_output, extra = model_output
803
+ else:
804
+ extra = {}
805
+
806
+ if self.model_var_type in [
807
+ "learned",
808
+ "learned_range",
809
+ ]:
810
+ B, C = x_t.shape[:2]
811
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
812
+ model_output, model_var_values = th.split(model_output, C, dim=1)
813
+ # Learn the variance using the variational bound, but don't let
814
+ # it affect our mean prediction.
815
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
816
+ terms["vb"] = self._vb_terms_bpd(
817
+ model=lambda *args, r=frozen_out: r,
818
+ x_start=x_start,
819
+ x_t=x_t,
820
+ t=t,
821
+ clip_denoised=False,
822
+ )["output"]
823
+ if self.loss_type == "rescaled_mse":
824
+ # Divide by 1000 for equivalence with initial implementation.
825
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
826
+ terms["vb"] *= self.num_timesteps / 1000.0
827
+
828
+ target = {
829
+ "x_prev": self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
830
+ "x_start": x_start,
831
+ "epsilon": noise,
832
+ }[self.model_mean_type]
833
+ assert model_output.shape == target.shape == x_start.shape
834
+ terms["mse"] = mean_flat((target - model_output) ** 2)
835
+ if "vb" in terms:
836
+ terms["loss"] = terms["mse"] + terms["vb"]
837
+ else:
838
+ terms["loss"] = terms["mse"]
839
+ else:
840
+ raise NotImplementedError(self.loss_type)
841
+
842
+ if "losses" in extra:
843
+ terms.update({k: loss for k, (loss, _scale) in extra["losses"].items()})
844
+ for loss, scale in extra["losses"].values():
845
+ terms["loss"] = terms["loss"] + loss * scale
846
+
847
+ return terms
848
+
849
+ def _prior_bpd(self, x_start):
850
+ """
851
+ Get the prior KL term for the variational lower-bound, measured in
852
+ bits-per-dim.
853
+
854
+ This term can't be optimized, as it only depends on the encoder.
855
+
856
+ :param x_start: the [N x C x ...] tensor of inputs.
857
+ :return: a batch of [N] KL values (in bits), one per batch element.
858
+ """
859
+ batch_size = x_start.shape[0]
860
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
861
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
862
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
863
+ return mean_flat(kl_prior) / np.log(2.0)
864
+
865
+ def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None):
866
+ """
867
+ Compute the entire variational lower-bound, measured in bits-per-dim,
868
+ as well as other related quantities.
869
+
870
+ :param model: the model to evaluate loss on.
871
+ :param x_start: the [N x C x ...] tensor of inputs.
872
+ :param clip_denoised: if True, clip denoised samples.
873
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
874
+ pass to the model. This can be used for conditioning.
875
+
876
+ :return: a dict containing the following keys:
877
+ - total_bpd: the total variational lower-bound, per batch element.
878
+ - prior_bpd: the prior term in the lower-bound.
879
+ - vb: an [N x T] tensor of terms in the lower-bound.
880
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
881
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
882
+ """
883
+ device = x_start.device
884
+ batch_size = x_start.shape[0]
885
+
886
+ vb = []
887
+ xstart_mse = []
888
+ mse = []
889
+ for t in list(range(self.num_timesteps))[::-1]:
890
+ t_batch = th.tensor([t] * batch_size, device=device)
891
+ noise = th.randn_like(x_start)
892
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
893
+ # Calculate VLB term at the current timestep
894
+ with th.no_grad():
895
+ out = self._vb_terms_bpd(
896
+ model,
897
+ x_start=x_start,
898
+ x_t=x_t,
899
+ t=t_batch,
900
+ clip_denoised=clip_denoised,
901
+ model_kwargs=model_kwargs,
902
+ )
903
+ vb.append(out["output"])
904
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
905
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
906
+ mse.append(mean_flat((eps - noise) ** 2))
907
+
908
+ vb = th.stack(vb, dim=1)
909
+ xstart_mse = th.stack(xstart_mse, dim=1)
910
+ mse = th.stack(mse, dim=1)
911
+
912
+ prior_bpd = self._prior_bpd(x_start)
913
+ total_bpd = vb.sum(dim=1) + prior_bpd
914
+ return {
915
+ "total_bpd": total_bpd,
916
+ "prior_bpd": prior_bpd,
917
+ "vb": vb,
918
+ "xstart_mse": xstart_mse,
919
+ "mse": mse,
920
+ }
921
+
922
+ def scale_channels(self, x: th.Tensor) -> th.Tensor:
923
+ if self.channel_scales is not None:
924
+ x = x * th.from_numpy(self.channel_scales).to(x).reshape(
925
+ [1, -1, *([1] * (len(x.shape) - 2))]
926
+ )
927
+ if self.channel_biases is not None:
928
+ x = x + th.from_numpy(self.channel_biases).to(x).reshape(
929
+ [1, -1, *([1] * (len(x.shape) - 2))]
930
+ )
931
+ return x
932
+
933
+ def unscale_channels(self, x: th.Tensor) -> th.Tensor:
934
+ if self.channel_biases is not None:
935
+ x = x - th.from_numpy(self.channel_biases).to(x).reshape(
936
+ [1, -1, *([1] * (len(x.shape) - 2))]
937
+ )
938
+ if self.channel_scales is not None:
939
+ x = x / th.from_numpy(self.channel_scales).to(x).reshape(
940
+ [1, -1, *([1] * (len(x.shape) - 2))]
941
+ )
942
+ return x
943
+
944
+ def unscale_out_dict(
945
+ self, out: Dict[str, Union[th.Tensor, Any]]
946
+ ) -> Dict[str, Union[th.Tensor, Any]]:
947
+ return {
948
+ k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) for k, v in out.items()
949
+ }
950
+
951
+
952
+ class SpacedDiffusion(GaussianDiffusion):
953
+ """
954
+ A diffusion process which can skip steps in a base diffusion process.
955
+ :param use_timesteps: (unordered) timesteps from the original diffusion
956
+ process to retain.
957
+ :param kwargs: the kwargs to create the base diffusion process.
958
+ """
959
+
960
+ def __init__(self, use_timesteps: Iterable[int], **kwargs):
961
+ self.use_timesteps = set(use_timesteps)
962
+ self.timestep_map = []
963
+ self.original_num_steps = len(kwargs["betas"])
964
+
965
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
966
+ last_alpha_cumprod = 1.0
967
+ new_betas = []
968
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
969
+ if i in self.use_timesteps:
970
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
971
+ last_alpha_cumprod = alpha_cumprod
972
+ self.timestep_map.append(i)
973
+ kwargs["betas"] = np.array(new_betas)
974
+ super().__init__(**kwargs)
975
+
976
+ def p_mean_variance(self, model, *args, **kwargs):
977
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
978
+
979
+ def training_losses(self, model, *args, **kwargs):
980
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
981
+
982
+ def condition_mean(self, cond_fn, *args, **kwargs):
983
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
984
+
985
+ def condition_score(self, cond_fn, *args, **kwargs):
986
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
987
+
988
+ def _wrap_model(self, model):
989
+ if isinstance(model, _WrappedModel):
990
+ return model
991
+ return _WrappedModel(model, self.timestep_map, self.original_num_steps)
992
+
993
+
994
+ class _WrappedModel:
995
+ def __init__(self, model, timestep_map, original_num_steps):
996
+ self.model = model
997
+ self.timestep_map = timestep_map
998
+ self.original_num_steps = original_num_steps
999
+
1000
+ def __call__(self, x, ts, **kwargs):
1001
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1002
+ new_ts = map_tensor[ts]
1003
+ return self.model(x, new_ts, **kwargs)
1004
+
1005
+
1006
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1007
+ """
1008
+ Extract values from a 1-D numpy array for a batch of indices.
1009
+
1010
+ :param arr: the 1-D numpy array.
1011
+ :param timesteps: a tensor of indices into the array to extract.
1012
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1013
+ dimension equal to the length of timesteps.
1014
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1015
+ """
1016
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1017
+ while len(res.shape) < len(broadcast_shape):
1018
+ res = res[..., None]
1019
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
1020
+
1021
+
1022
+ def normal_kl(mean1, logvar1, mean2, logvar2):
1023
+ """
1024
+ Compute the KL divergence between two gaussians.
1025
+ Shapes are automatically broadcasted, so batches can be compared to
1026
+ scalars, among other use cases.
1027
+ """
1028
+ tensor = None
1029
+ for obj in (mean1, logvar1, mean2, logvar2):
1030
+ if isinstance(obj, th.Tensor):
1031
+ tensor = obj
1032
+ break
1033
+ assert tensor is not None, "at least one argument must be a Tensor"
1034
+
1035
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
1036
+ # Tensors, but it does not work for th.exp().
1037
+ logvar1, logvar2 = [
1038
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)
1039
+ ]
1040
+
1041
+ return 0.5 * (
1042
+ -1.0
1043
+ + logvar2
1044
+ - logvar1
1045
+ + th.exp(logvar1 - logvar2)
1046
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
1047
+ )
1048
+
1049
+
1050
+ def approx_standard_normal_cdf(x):
1051
+ """
1052
+ A fast approximation of the cumulative distribution function of the
1053
+ standard normal.
1054
+ """
1055
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
1056
+
1057
+
1058
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
1059
+ """
1060
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
1061
+ given image.
1062
+ :param x: the target images. It is assumed that this was uint8 values,
1063
+ rescaled to the range [-1, 1].
1064
+ :param means: the Gaussian mean Tensor.
1065
+ :param log_scales: the Gaussian log stddev Tensor.
1066
+ :return: a tensor like x of log probabilities (in nats).
1067
+ """
1068
+ assert x.shape == means.shape == log_scales.shape
1069
+ centered_x = x - means
1070
+ inv_stdv = th.exp(-log_scales)
1071
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
1072
+ cdf_plus = approx_standard_normal_cdf(plus_in)
1073
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
1074
+ cdf_min = approx_standard_normal_cdf(min_in)
1075
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
1076
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
1077
+ cdf_delta = cdf_plus - cdf_min
1078
+ log_probs = th.where(
1079
+ x < -0.999,
1080
+ log_cdf_plus,
1081
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
1082
+ )
1083
+ assert log_probs.shape == x.shape
1084
+ return log_probs
1085
+
1086
+
1087
+ def mean_flat(tensor):
1088
+ """
1089
+ Take the mean over all non-batch dimensions.
1090
+ """
1091
+ return tensor.flatten(1).mean(1)
point_e/diffusion/k_diffusion.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on: https://github.com/crowsonkb/k-diffusion
3
+
4
+ Copyright (c) 2022 Katherine Crowson
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in
14
+ all copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22
+ THE SOFTWARE.
23
+ """
24
+
25
+ import numpy as np
26
+ import torch as th
27
+
28
+ from .gaussian_diffusion import GaussianDiffusion, mean_flat
29
+
30
+
31
+ class KarrasDenoiser:
32
+ def __init__(self, sigma_data: float = 0.5):
33
+ self.sigma_data = sigma_data
34
+
35
+ def get_snr(self, sigmas):
36
+ return sigmas**-2
37
+
38
+ def get_sigmas(self, sigmas):
39
+ return sigmas
40
+
41
+ def get_scalings(self, sigma):
42
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
43
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
44
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
45
+ return c_skip, c_out, c_in
46
+
47
+ def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None):
48
+ if model_kwargs is None:
49
+ model_kwargs = {}
50
+ if noise is None:
51
+ noise = th.randn_like(x_start)
52
+
53
+ terms = {}
54
+
55
+ dims = x_start.ndim
56
+ x_t = x_start + noise * append_dims(sigmas, dims)
57
+ c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)]
58
+ model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs)
59
+ target = (x_start - c_skip * x_t) / c_out
60
+
61
+ terms["mse"] = mean_flat((model_output - target) ** 2)
62
+ terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
63
+
64
+ if "vb" in terms:
65
+ terms["loss"] = terms["mse"] + terms["vb"]
66
+ else:
67
+ terms["loss"] = terms["mse"]
68
+
69
+ return terms
70
+
71
+ def denoise(self, model, x_t, sigmas, **model_kwargs):
72
+ c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)]
73
+ rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
74
+ model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
75
+ denoised = c_out * model_output + c_skip * x_t
76
+ return model_output, denoised
77
+
78
+
79
+ class GaussianToKarrasDenoiser:
80
+ def __init__(self, model, diffusion):
81
+ from scipy import interpolate
82
+
83
+ self.model = model
84
+ self.diffusion = diffusion
85
+ self.alpha_cumprod_to_t = interpolate.interp1d(
86
+ diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps)
87
+ )
88
+
89
+ def sigma_to_t(self, sigma):
90
+ alpha_cumprod = 1.0 / (sigma**2 + 1)
91
+ if alpha_cumprod > self.diffusion.alphas_cumprod[0]:
92
+ return 0
93
+ elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]:
94
+ return self.diffusion.num_timesteps - 1
95
+ else:
96
+ return float(self.alpha_cumprod_to_t(alpha_cumprod))
97
+
98
+ def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None):
99
+ t = th.tensor(
100
+ [self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()],
101
+ dtype=th.long,
102
+ device=sigmas.device,
103
+ )
104
+ c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim)
105
+ out = self.diffusion.p_mean_variance(
106
+ self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
107
+ )
108
+ return None, out["pred_xstart"]
109
+
110
+
111
+ def karras_sample(*args, **kwargs):
112
+ last = None
113
+ for x in karras_sample_progressive(*args, **kwargs):
114
+ last = x["x"]
115
+ return last
116
+
117
+
118
+ def karras_sample_progressive(
119
+ diffusion,
120
+ model,
121
+ shape,
122
+ steps,
123
+ clip_denoised=True,
124
+ progress=False,
125
+ model_kwargs=None,
126
+ device=None,
127
+ sigma_min=0.002,
128
+ sigma_max=80, # higher for highres?
129
+ rho=7.0,
130
+ sampler="heun",
131
+ s_churn=0.0,
132
+ s_tmin=0.0,
133
+ s_tmax=float("inf"),
134
+ s_noise=1.0,
135
+ guidance_scale=0.0,
136
+ ):
137
+ sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
138
+ x_T = th.randn(*shape, device=device) * sigma_max
139
+ sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[
140
+ sampler
141
+ ]
142
+
143
+ if sampler != "ancestral":
144
+ sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)
145
+ else:
146
+ sampler_args = {}
147
+
148
+ if isinstance(diffusion, KarrasDenoiser):
149
+
150
+ def denoiser(x_t, sigma):
151
+ _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
152
+ if clip_denoised:
153
+ denoised = denoised.clamp(-1, 1)
154
+ return denoised
155
+
156
+ elif isinstance(diffusion, GaussianDiffusion):
157
+ model = GaussianToKarrasDenoiser(model, diffusion)
158
+
159
+ def denoiser(x_t, sigma):
160
+ _, denoised = model.denoise(
161
+ x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs
162
+ )
163
+ return denoised
164
+
165
+ else:
166
+ raise NotImplementedError
167
+
168
+ if guidance_scale != 0 and guidance_scale != 1:
169
+
170
+ def guided_denoiser(x_t, sigma):
171
+ x_t = th.cat([x_t, x_t], dim=0)
172
+ sigma = th.cat([sigma, sigma], dim=0)
173
+ x_0 = denoiser(x_t, sigma)
174
+ cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0)
175
+ x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0)
176
+ return x_0
177
+
178
+ else:
179
+ guided_denoiser = denoiser
180
+
181
+ for obj in sample_fn(
182
+ guided_denoiser,
183
+ x_T,
184
+ sigmas,
185
+ progress=progress,
186
+ **sampler_args,
187
+ ):
188
+ if isinstance(diffusion, GaussianDiffusion):
189
+ yield diffusion.unscale_out_dict(obj)
190
+ else:
191
+ yield obj
192
+
193
+
194
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
195
+ """Constructs the noise schedule of Karras et al. (2022)."""
196
+ ramp = th.linspace(0, 1, n)
197
+ min_inv_rho = sigma_min ** (1 / rho)
198
+ max_inv_rho = sigma_max ** (1 / rho)
199
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
200
+ return append_zero(sigmas).to(device)
201
+
202
+
203
+ def to_d(x, sigma, denoised):
204
+ """Converts a denoiser output to a Karras ODE derivative."""
205
+ return (x - denoised) / append_dims(sigma, x.ndim)
206
+
207
+
208
+ def get_ancestral_step(sigma_from, sigma_to):
209
+ """Calculates the noise level (sigma_down) to step down to and the amount
210
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
211
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
212
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
213
+ return sigma_down, sigma_up
214
+
215
+
216
+ @th.no_grad()
217
+ def sample_euler_ancestral(model, x, sigmas, progress=False):
218
+ """Ancestral sampling with Euler method steps."""
219
+ s_in = x.new_ones([x.shape[0]])
220
+ indices = range(len(sigmas) - 1)
221
+ if progress:
222
+ from tqdm.auto import tqdm
223
+
224
+ indices = tqdm(indices)
225
+
226
+ for i in indices:
227
+ denoised = model(x, sigmas[i] * s_in)
228
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
229
+ yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised}
230
+ d = to_d(x, sigmas[i], denoised)
231
+ # Euler method
232
+ dt = sigma_down - sigmas[i]
233
+ x = x + d * dt
234
+ x = x + th.randn_like(x) * sigma_up
235
+ yield {"x": x, "pred_xstart": x}
236
+
237
+
238
+ @th.no_grad()
239
+ def sample_heun(
240
+ denoiser,
241
+ x,
242
+ sigmas,
243
+ progress=False,
244
+ s_churn=0.0,
245
+ s_tmin=0.0,
246
+ s_tmax=float("inf"),
247
+ s_noise=1.0,
248
+ ):
249
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
250
+ s_in = x.new_ones([x.shape[0]])
251
+ indices = range(len(sigmas) - 1)
252
+ if progress:
253
+ from tqdm.auto import tqdm
254
+
255
+ indices = tqdm(indices)
256
+
257
+ for i in indices:
258
+ gamma = (
259
+ min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
260
+ )
261
+ eps = th.randn_like(x) * s_noise
262
+ sigma_hat = sigmas[i] * (gamma + 1)
263
+ if gamma > 0:
264
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
265
+ denoised = denoiser(x, sigma_hat * s_in)
266
+ d = to_d(x, sigma_hat, denoised)
267
+ yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised}
268
+ dt = sigmas[i + 1] - sigma_hat
269
+ if sigmas[i + 1] == 0:
270
+ # Euler method
271
+ x = x + d * dt
272
+ else:
273
+ # Heun's method
274
+ x_2 = x + d * dt
275
+ denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
276
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
277
+ d_prime = (d + d_2) / 2
278
+ x = x + d_prime * dt
279
+ yield {"x": x, "pred_xstart": denoised}
280
+
281
+
282
+ @th.no_grad()
283
+ def sample_dpm(
284
+ denoiser,
285
+ x,
286
+ sigmas,
287
+ progress=False,
288
+ s_churn=0.0,
289
+ s_tmin=0.0,
290
+ s_tmax=float("inf"),
291
+ s_noise=1.0,
292
+ ):
293
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
294
+ s_in = x.new_ones([x.shape[0]])
295
+ indices = range(len(sigmas) - 1)
296
+ if progress:
297
+ from tqdm.auto import tqdm
298
+
299
+ indices = tqdm(indices)
300
+
301
+ for i in indices:
302
+ gamma = (
303
+ min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
304
+ )
305
+ eps = th.randn_like(x) * s_noise
306
+ sigma_hat = sigmas[i] * (gamma + 1)
307
+ if gamma > 0:
308
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
309
+ denoised = denoiser(x, sigma_hat * s_in)
310
+ d = to_d(x, sigma_hat, denoised)
311
+ yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}
312
+ # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
313
+ sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
314
+ dt_1 = sigma_mid - sigma_hat
315
+ dt_2 = sigmas[i + 1] - sigma_hat
316
+ x_2 = x + d * dt_1
317
+ denoised_2 = denoiser(x_2, sigma_mid * s_in)
318
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
319
+ x = x + d_2 * dt_2
320
+ yield {"x": x, "pred_xstart": denoised}
321
+
322
+
323
+ def append_dims(x, target_dims):
324
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
325
+ dims_to_append = target_dims - x.ndim
326
+ if dims_to_append < 0:
327
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
328
+ return x[(...,) + (None,) * dims_to_append]
329
+
330
+
331
+ def append_zero(x):
332
+ return th.cat([x, x.new_zeros([1])])
point_e/diffusion/sampler.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for sampling from a single- or multi-stage point cloud diffusion model.
3
+ """
4
+
5
+ from typing import Any, Callable, Dict, Iterator, List, Sequence, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from point_e.util.point_cloud import PointCloud
11
+
12
+ from .gaussian_diffusion import GaussianDiffusion
13
+ from .k_diffusion import karras_sample_progressive
14
+
15
+
16
+ class PointCloudSampler:
17
+ """
18
+ A wrapper around a model or stack of models that produces conditional or
19
+ unconditional sample tensors.
20
+
21
+ By default, this will load models and configs from files.
22
+ If you want to modify the sampler arguments of an existing sampler, call
23
+ with_options() or with_args().
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ device: torch.device,
29
+ models: Sequence[nn.Module],
30
+ diffusions: Sequence[GaussianDiffusion],
31
+ num_points: Sequence[int],
32
+ aux_channels: Sequence[str],
33
+ model_kwargs_key_filter: Sequence[str] = ("*",),
34
+ guidance_scale: Sequence[float] = (3.0, 3.0),
35
+ clip_denoised: bool = True,
36
+ use_karras: Sequence[bool] = (True, True),
37
+ karras_steps: Sequence[int] = (64, 64),
38
+ sigma_min: Sequence[float] = (1e-3, 1e-3),
39
+ sigma_max: Sequence[float] = (120, 160),
40
+ s_churn: Sequence[float] = (3, 0),
41
+ ):
42
+ n = len(models)
43
+ assert n > 0
44
+
45
+ if n > 1:
46
+ if len(guidance_scale) == 1:
47
+ # Don't guide the upsamplers by default.
48
+ guidance_scale = list(guidance_scale) + [1.0] * (n - 1)
49
+ if len(use_karras) == 1:
50
+ use_karras = use_karras * n
51
+ if len(karras_steps) == 1:
52
+ karras_steps = karras_steps * n
53
+ if len(sigma_min) == 1:
54
+ sigma_min = sigma_min * n
55
+ if len(sigma_max) == 1:
56
+ sigma_max = sigma_max * n
57
+ if len(s_churn) == 1:
58
+ s_churn = s_churn * n
59
+ if len(model_kwargs_key_filter) == 1:
60
+ model_kwargs_key_filter = model_kwargs_key_filter * n
61
+ if len(model_kwargs_key_filter) == 0:
62
+ model_kwargs_key_filter = ["*"] * n
63
+ assert len(guidance_scale) == n
64
+ assert len(use_karras) == n
65
+ assert len(karras_steps) == n
66
+ assert len(sigma_min) == n
67
+ assert len(sigma_max) == n
68
+ assert len(s_churn) == n
69
+ assert len(model_kwargs_key_filter) == n
70
+
71
+ self.device = device
72
+ self.num_points = num_points
73
+ self.aux_channels = aux_channels
74
+ self.model_kwargs_key_filter = model_kwargs_key_filter
75
+ self.guidance_scale = guidance_scale
76
+ self.clip_denoised = clip_denoised
77
+ self.use_karras = use_karras
78
+ self.karras_steps = karras_steps
79
+ self.sigma_min = sigma_min
80
+ self.sigma_max = sigma_max
81
+ self.s_churn = s_churn
82
+
83
+ self.models = models
84
+ self.diffusions = diffusions
85
+
86
+ @property
87
+ def num_stages(self) -> int:
88
+ return len(self.models)
89
+
90
+ def sample_batch(self, batch_size: int, model_kwargs: Dict[str, Any]) -> torch.Tensor:
91
+ samples = None
92
+ for x in self.sample_batch_progressive(batch_size, model_kwargs):
93
+ samples = x
94
+ return samples
95
+
96
+ def sample_batch_progressive(
97
+ self, batch_size: int, model_kwargs: Dict[str, Any]
98
+ ) -> Iterator[torch.Tensor]:
99
+ samples = None
100
+ for (
101
+ model,
102
+ diffusion,
103
+ stage_num_points,
104
+ stage_guidance_scale,
105
+ stage_use_karras,
106
+ stage_karras_steps,
107
+ stage_sigma_min,
108
+ stage_sigma_max,
109
+ stage_s_churn,
110
+ stage_key_filter,
111
+ ) in zip(
112
+ self.models,
113
+ self.diffusions,
114
+ self.num_points,
115
+ self.guidance_scale,
116
+ self.use_karras,
117
+ self.karras_steps,
118
+ self.sigma_min,
119
+ self.sigma_max,
120
+ self.s_churn,
121
+ self.model_kwargs_key_filter,
122
+ ):
123
+ stage_model_kwargs = model_kwargs.copy()
124
+ if stage_key_filter != "*":
125
+ use_keys = set(stage_key_filter.split(","))
126
+ stage_model_kwargs = {k: v for k, v in stage_model_kwargs.items() if k in use_keys}
127
+ if samples is not None:
128
+ stage_model_kwargs["low_res"] = samples
129
+ if hasattr(model, "cached_model_kwargs"):
130
+ stage_model_kwargs = model.cached_model_kwargs(batch_size, stage_model_kwargs)
131
+ sample_shape = (batch_size, 3 + len(self.aux_channels), stage_num_points)
132
+
133
+ if stage_guidance_scale != 1 and stage_guidance_scale != 0:
134
+ for k, v in stage_model_kwargs.copy().items():
135
+ stage_model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
136
+
137
+ if stage_use_karras:
138
+ samples_it = karras_sample_progressive(
139
+ diffusion=diffusion,
140
+ model=model,
141
+ shape=sample_shape,
142
+ steps=stage_karras_steps,
143
+ clip_denoised=self.clip_denoised,
144
+ model_kwargs=stage_model_kwargs,
145
+ device=self.device,
146
+ sigma_min=stage_sigma_min,
147
+ sigma_max=stage_sigma_max,
148
+ s_churn=stage_s_churn,
149
+ guidance_scale=stage_guidance_scale,
150
+ )
151
+ else:
152
+ internal_batch_size = batch_size
153
+ if stage_guidance_scale:
154
+ model = self._uncond_guide_model(model, stage_guidance_scale)
155
+ internal_batch_size *= 2
156
+ samples_it = diffusion.p_sample_loop_progressive(
157
+ model,
158
+ shape=(internal_batch_size, *sample_shape[1:]),
159
+ model_kwargs=stage_model_kwargs,
160
+ device=self.device,
161
+ clip_denoised=self.clip_denoised,
162
+ )
163
+ for x in samples_it:
164
+ samples = x["pred_xstart"][:batch_size]
165
+ if "low_res" in stage_model_kwargs:
166
+ samples = torch.cat(
167
+ [stage_model_kwargs["low_res"][: len(samples)], samples], dim=-1
168
+ )
169
+ yield samples
170
+
171
+ @classmethod
172
+ def combine(cls, *samplers: "PointCloudSampler") -> "PointCloudSampler":
173
+ assert all(x.device == samplers[0].device for x in samplers[1:])
174
+ assert all(x.aux_channels == samplers[0].aux_channels for x in samplers[1:])
175
+ assert all(x.clip_denoised == samplers[0].clip_denoised for x in samplers[1:])
176
+ return cls(
177
+ device=samplers[0].device,
178
+ models=[x for y in samplers for x in y.models],
179
+ diffusions=[x for y in samplers for x in y.diffusions],
180
+ num_points=[x for y in samplers for x in y.num_points],
181
+ aux_channels=samplers[0].aux_channels,
182
+ model_kwargs_key_filter=[x for y in samplers for x in y.model_kwargs_key_filter],
183
+ guidance_scale=[x for y in samplers for x in y.guidance_scale],
184
+ clip_denoised=samplers[0].clip_denoised,
185
+ use_karras=[x for y in samplers for x in y.use_karras],
186
+ karras_steps=[x for y in samplers for x in y.karras_steps],
187
+ sigma_min=[x for y in samplers for x in y.sigma_min],
188
+ sigma_max=[x for y in samplers for x in y.sigma_max],
189
+ s_churn=[x for y in samplers for x in y.s_churn],
190
+ )
191
+
192
+ def _uncond_guide_model(
193
+ self, model: Callable[..., torch.Tensor], scale: float
194
+ ) -> Callable[..., torch.Tensor]:
195
+ def model_fn(x_t, ts, **kwargs):
196
+ half = x_t[: len(x_t) // 2]
197
+ combined = torch.cat([half, half], dim=0)
198
+ model_out = model(combined, ts, **kwargs)
199
+ eps, rest = model_out[:, :3], model_out[:, 3:]
200
+ cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
201
+ half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
202
+ eps = torch.cat([half_eps, half_eps], dim=0)
203
+ return torch.cat([eps, rest], dim=1)
204
+
205
+ return model_fn
206
+
207
+ def split_model_output(
208
+ self,
209
+ output: torch.Tensor,
210
+ rescale_colors: bool = False,
211
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
212
+ assert (
213
+ len(self.aux_channels) + 3 == output.shape[1]
214
+ ), "there must be three spatial channels before aux"
215
+ pos, joined_aux = output[:, :3], output[:, 3:]
216
+
217
+ aux = {}
218
+ for i, name in enumerate(self.aux_channels):
219
+ v = joined_aux[:, i]
220
+ if name in {"R", "G", "B", "A"}:
221
+ v = v.clamp(0, 255).round()
222
+ if rescale_colors:
223
+ v = v / 255.0
224
+ aux[name] = v
225
+ return pos, aux
226
+
227
+ def output_to_point_clouds(self, output: torch.Tensor) -> List[PointCloud]:
228
+ res = []
229
+ for sample in output:
230
+ xyz, aux = self.split_model_output(sample[None], rescale_colors=True)
231
+ res.append(
232
+ PointCloud(
233
+ coords=xyz[0].t().cpu().numpy(),
234
+ channels={k: v[0].cpu().numpy() for k, v in aux.items()},
235
+ )
236
+ )
237
+ return res
238
+
239
+ def with_options(
240
+ self,
241
+ guidance_scale: float,
242
+ clip_denoised: bool,
243
+ use_karras: Sequence[bool] = (True, True),
244
+ karras_steps: Sequence[int] = (64, 64),
245
+ sigma_min: Sequence[float] = (1e-3, 1e-3),
246
+ sigma_max: Sequence[float] = (120, 160),
247
+ s_churn: Sequence[float] = (3, 0),
248
+ ) -> "PointCloudSampler":
249
+ return PointCloudSampler(
250
+ device=self.device,
251
+ models=self.models,
252
+ diffusions=self.diffusions,
253
+ num_points=self.num_points,
254
+ aux_channels=self.aux_channels,
255
+ model_kwargs_key_filter=self.model_kwargs_key_filter,
256
+ guidance_scale=guidance_scale,
257
+ clip_denoised=clip_denoised,
258
+ use_karras=use_karras,
259
+ karras_steps=karras_steps,
260
+ sigma_min=sigma_min,
261
+ sigma_max=sigma_max,
262
+ s_churn=s_churn,
263
+ )
point_e/evals/__init__.py ADDED
File without changes
point_e/evals/feature_extractor.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from multiprocessing.pool import ThreadPool
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from point_e.models.download import load_checkpoint
9
+
10
+ from .npz_stream import NpzStreamer
11
+ from .pointnet2_cls_ssg import get_model
12
+
13
+
14
+ def get_torch_devices() -> List[Union[str, torch.device]]:
15
+ if torch.cuda.is_available():
16
+ return [torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())]
17
+ else:
18
+ return ["cpu"]
19
+
20
+
21
+ class FeatureExtractor(ABC):
22
+ @property
23
+ @abstractmethod
24
+ def supports_predictions(self) -> bool:
25
+ pass
26
+
27
+ @property
28
+ @abstractmethod
29
+ def feature_dim(self) -> int:
30
+ pass
31
+
32
+ @property
33
+ @abstractmethod
34
+ def num_classes(self) -> int:
35
+ pass
36
+
37
+ @abstractmethod
38
+ def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]:
39
+ """
40
+ For a stream of point cloud batches, compute feature vectors and class
41
+ predictions.
42
+
43
+ :param point_clouds: a streamer for a sample batch. Typically, arr_0
44
+ will contain the XYZ coordinates.
45
+ :return: a tuple (features, predictions)
46
+ - features: a [B x feature_dim] array of feature vectors.
47
+ - predictions: a [B x num_classes] array of probabilities.
48
+ """
49
+
50
+
51
+ class PointNetClassifier(FeatureExtractor):
52
+ def __init__(
53
+ self,
54
+ devices: List[Union[str, torch.device]],
55
+ device_batch_size: int = 64,
56
+ cache_dir: Optional[str] = None,
57
+ ):
58
+ state_dict = load_checkpoint("pointnet", device=torch.device("cpu"), cache_dir=cache_dir)[
59
+ "model_state_dict"
60
+ ]
61
+
62
+ self.device_batch_size = device_batch_size
63
+ self.devices = devices
64
+ self.models = []
65
+ for device in devices:
66
+ model = get_model(num_class=40, normal_channel=False, width_mult=2)
67
+ model.load_state_dict(state_dict)
68
+ model.to(device)
69
+ model.eval()
70
+ self.models.append(model)
71
+
72
+ @property
73
+ def supports_predictions(self) -> bool:
74
+ return True
75
+
76
+ @property
77
+ def feature_dim(self) -> int:
78
+ return 256
79
+
80
+ @property
81
+ def num_classes(self) -> int:
82
+ return 40
83
+
84
+ def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]:
85
+ batch_size = self.device_batch_size * len(self.devices)
86
+ point_clouds = (x["arr_0"] for x in streamer.stream(batch_size, ["arr_0"]))
87
+
88
+ output_features = []
89
+ output_predictions = []
90
+
91
+ with ThreadPool(len(self.devices)) as pool:
92
+ for batch in point_clouds:
93
+ batch = normalize_point_clouds(batch)
94
+ batches = []
95
+ for i, device in zip(range(0, len(batch), self.device_batch_size), self.devices):
96
+ batches.append(
97
+ torch.from_numpy(batch[i : i + self.device_batch_size])
98
+ .permute(0, 2, 1)
99
+ .to(dtype=torch.float32, device=device)
100
+ )
101
+
102
+ def compute_features(i_batch):
103
+ i, batch = i_batch
104
+ with torch.no_grad():
105
+ return self.models[i](batch, features=True)
106
+
107
+ for logits, _, features in pool.imap(compute_features, enumerate(batches)):
108
+ output_features.append(features.cpu().numpy())
109
+ output_predictions.append(logits.exp().cpu().numpy())
110
+
111
+ return np.concatenate(output_features, axis=0), np.concatenate(output_predictions, axis=0)
112
+
113
+
114
+ def normalize_point_clouds(pc: np.ndarray) -> np.ndarray:
115
+ centroids = np.mean(pc, axis=1, keepdims=True)
116
+ pc = pc - centroids
117
+ m = np.max(np.sqrt(np.sum(pc**2, axis=-1, keepdims=True)), axis=1, keepdims=True)
118
+ pc = pc / m
119
+ return pc
point_e/evals/fid_is.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py
3
+ """
4
+
5
+
6
+ import warnings
7
+
8
+ import numpy as np
9
+ from scipy import linalg
10
+
11
+
12
+ class InvalidFIDException(Exception):
13
+ pass
14
+
15
+
16
+ class FIDStatistics:
17
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
18
+ self.mu = mu
19
+ self.sigma = sigma
20
+
21
+ def frechet_distance(self, other, eps=1e-6):
22
+ """
23
+ Compute the Frechet distance between two sets of statistics.
24
+ """
25
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
26
+ mu1, sigma1 = self.mu, self.sigma
27
+ mu2, sigma2 = other.mu, other.sigma
28
+
29
+ mu1 = np.atleast_1d(mu1)
30
+ mu2 = np.atleast_1d(mu2)
31
+
32
+ sigma1 = np.atleast_2d(sigma1)
33
+ sigma2 = np.atleast_2d(sigma2)
34
+
35
+ assert (
36
+ mu1.shape == mu2.shape
37
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
38
+ assert (
39
+ sigma1.shape == sigma2.shape
40
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
41
+
42
+ diff = mu1 - mu2
43
+
44
+ # product might be almost singular
45
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
46
+ if not np.isfinite(covmean).all():
47
+ msg = (
48
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
49
+ % eps
50
+ )
51
+ warnings.warn(msg)
52
+ offset = np.eye(sigma1.shape[0]) * eps
53
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
54
+
55
+ # numerical error might give slight imaginary component
56
+ if np.iscomplexobj(covmean):
57
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
58
+ m = np.max(np.abs(covmean.imag))
59
+ raise ValueError("Imaginary component {}".format(m))
60
+ covmean = covmean.real
61
+
62
+ tr_covmean = np.trace(covmean)
63
+
64
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
65
+
66
+
67
+ def compute_statistics(feats: np.ndarray) -> FIDStatistics:
68
+ mu = np.mean(feats, axis=0)
69
+ sigma = np.cov(feats, rowvar=False)
70
+ return FIDStatistics(mu, sigma)
71
+
72
+
73
+ def compute_inception_score(preds: np.ndarray, split_size: int = 5000) -> float:
74
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
75
+ scores = []
76
+ for i in range(0, len(preds), split_size):
77
+ part = preds[i : i + split_size]
78
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
79
+ kl = np.mean(np.sum(kl, 1))
80
+ scores.append(np.exp(kl))
81
+ return float(np.mean(scores))
point_e/evals/npz_stream.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import io
3
+ import os
4
+ import re
5
+ import zipfile
6
+ from abc import ABC, abstractmethod
7
+ from contextlib import contextmanager
8
+ from dataclasses import dataclass
9
+ from typing import Dict, Iterator, List, Optional, Sequence, Tuple
10
+
11
+ import numpy as np
12
+
13
+
14
+ @dataclass
15
+ class NumpyArrayInfo:
16
+ """
17
+ Information about an array in an npz file.
18
+ """
19
+
20
+ name: str
21
+ dtype: np.dtype
22
+ shape: Tuple[int]
23
+
24
+ @classmethod
25
+ def infos_from_first_file(cls, glob_path: str) -> Dict[str, "NumpyArrayInfo"]:
26
+ paths, _ = _npz_paths_and_length(glob_path)
27
+ return cls.infos_from_file(paths[0])
28
+
29
+ @classmethod
30
+ def infos_from_file(cls, npz_path: str) -> Dict[str, "NumpyArrayInfo"]:
31
+ """
32
+ Extract the info of every array in an npz file.
33
+ """
34
+ if not os.path.exists(npz_path):
35
+ raise FileNotFoundError(f"batch of samples was not found: {npz_path}")
36
+ results = {}
37
+ with open(npz_path, "rb") as f:
38
+ with zipfile.ZipFile(f, "r") as zip_f:
39
+ for name in zip_f.namelist():
40
+ if not name.endswith(".npy"):
41
+ continue
42
+ key_name = name[: -len(".npy")]
43
+ with zip_f.open(name, "r") as arr_f:
44
+ version = np.lib.format.read_magic(arr_f)
45
+ if version == (1, 0):
46
+ header = np.lib.format.read_array_header_1_0(arr_f)
47
+ elif version == (2, 0):
48
+ header = np.lib.format.read_array_header_2_0(arr_f)
49
+ else:
50
+ raise ValueError(f"unknown numpy array version: {version}")
51
+ shape, _, dtype = header
52
+ results[key_name] = cls(name=key_name, dtype=dtype, shape=shape)
53
+ return results
54
+
55
+ @property
56
+ def elem_shape(self) -> Tuple[int]:
57
+ return self.shape[1:]
58
+
59
+ def validate(self):
60
+ if self.name in {"R", "G", "B"}:
61
+ if len(self.shape) != 2:
62
+ raise ValueError(
63
+ f"expecting exactly 2-D shape for '{self.name}' but got: {self.shape}"
64
+ )
65
+ elif self.name == "arr_0":
66
+ if len(self.shape) < 2:
67
+ raise ValueError(f"expecting at least 2-D shape but got: {self.shape}")
68
+ elif len(self.shape) == 3:
69
+ # For audio, we require continuous samples.
70
+ if not np.issubdtype(self.dtype, np.floating):
71
+ raise ValueError(
72
+ f"invalid dtype for audio batch: {self.dtype} (expected float)"
73
+ )
74
+ elif self.dtype != np.uint8:
75
+ raise ValueError(f"invalid dtype for image batch: {self.dtype} (expected uint8)")
76
+
77
+
78
+ class NpzStreamer:
79
+ def __init__(self, glob_path: str):
80
+ self.paths, self.trunc_length = _npz_paths_and_length(glob_path)
81
+ self.infos = NumpyArrayInfo.infos_from_file(self.paths[0])
82
+
83
+ def keys(self) -> List[str]:
84
+ return list(self.infos.keys())
85
+
86
+ def stream(self, batch_size: int, keys: Sequence[str]) -> Iterator[Dict[str, np.ndarray]]:
87
+ cur_batch = None
88
+ num_remaining = self.trunc_length
89
+ for path in self.paths:
90
+ if num_remaining is not None and num_remaining <= 0:
91
+ break
92
+ with open_npz_arrays(path, keys) as readers:
93
+ combined_reader = CombinedReader(keys, readers)
94
+ while num_remaining is None or num_remaining > 0:
95
+ read_bs = batch_size
96
+ if cur_batch is not None:
97
+ read_bs -= _dict_batch_size(cur_batch)
98
+ if num_remaining is not None:
99
+ read_bs = min(read_bs, num_remaining)
100
+
101
+ batch = combined_reader.read_batch(read_bs)
102
+ if batch is None:
103
+ break
104
+ if num_remaining is not None:
105
+ num_remaining -= _dict_batch_size(batch)
106
+ if cur_batch is None:
107
+ cur_batch = batch
108
+ else:
109
+ cur_batch = {
110
+ # pylint: disable=unsubscriptable-object
111
+ k: np.concatenate([cur_batch[k], v], axis=0)
112
+ for k, v in batch.items()
113
+ }
114
+ if _dict_batch_size(cur_batch) == batch_size:
115
+ yield cur_batch
116
+ cur_batch = None
117
+ if cur_batch is not None:
118
+ yield cur_batch
119
+
120
+
121
+ def _npz_paths_and_length(glob_path: str) -> Tuple[List[str], Optional[int]]:
122
+ # Match slice syntax like path[:100].
123
+ count_match = re.match("^(.*)\\[:([0-9]*)\\]$", glob_path)
124
+ if count_match:
125
+ raw_path = count_match[1]
126
+ max_count = int(count_match[2])
127
+ else:
128
+ raw_path = glob_path
129
+ max_count = None
130
+ paths = sorted(glob.glob(raw_path))
131
+ if not len(paths):
132
+ raise ValueError(f"no paths found matching: {glob_path}")
133
+ return paths, max_count
134
+
135
+
136
+ class NpzArrayReader(ABC):
137
+ @abstractmethod
138
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
139
+ pass
140
+
141
+
142
+ class StreamingNpzArrayReader(NpzArrayReader):
143
+ def __init__(self, arr_f, shape, dtype):
144
+ self.arr_f = arr_f
145
+ self.shape = shape
146
+ self.dtype = dtype
147
+ self.idx = 0
148
+
149
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
150
+ if self.idx >= self.shape[0]:
151
+ return None
152
+
153
+ bs = min(batch_size, self.shape[0] - self.idx)
154
+ self.idx += bs
155
+
156
+ if self.dtype.itemsize == 0:
157
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
158
+
159
+ read_count = bs * np.prod(self.shape[1:])
160
+ read_size = int(read_count * self.dtype.itemsize)
161
+ data = _read_bytes(self.arr_f, read_size, "array data")
162
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
163
+
164
+
165
+ class MemoryNpzArrayReader(NpzArrayReader):
166
+ def __init__(self, arr):
167
+ self.arr = arr
168
+ self.idx = 0
169
+
170
+ @classmethod
171
+ def load(cls, path: str, arr_name: str):
172
+ with open(path, "rb") as f:
173
+ arr = np.load(f)[arr_name]
174
+ return cls(arr)
175
+
176
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
177
+ if self.idx >= self.arr.shape[0]:
178
+ return None
179
+
180
+ res = self.arr[self.idx : self.idx + batch_size]
181
+ self.idx += batch_size
182
+ return res
183
+
184
+
185
+ @contextmanager
186
+ def open_npz_arrays(path: str, arr_names: Sequence[str]) -> List[NpzArrayReader]:
187
+ if not len(arr_names):
188
+ yield []
189
+ return
190
+ arr_name = arr_names[0]
191
+ with open_array(path, arr_name) as arr_f:
192
+ version = np.lib.format.read_magic(arr_f)
193
+ header = None
194
+ if version == (1, 0):
195
+ header = np.lib.format.read_array_header_1_0(arr_f)
196
+ elif version == (2, 0):
197
+ header = np.lib.format.read_array_header_2_0(arr_f)
198
+
199
+ if header is None:
200
+ reader = MemoryNpzArrayReader.load(path, arr_name)
201
+ else:
202
+ shape, fortran, dtype = header
203
+ if fortran or dtype.hasobject:
204
+ reader = MemoryNpzArrayReader.load(path, arr_name)
205
+ else:
206
+ reader = StreamingNpzArrayReader(arr_f, shape, dtype)
207
+
208
+ with open_npz_arrays(path, arr_names[1:]) as next_readers:
209
+ yield [reader] + next_readers
210
+
211
+
212
+ class CombinedReader:
213
+ def __init__(self, keys: List[str], readers: List[NpzArrayReader]):
214
+ self.keys = keys
215
+ self.readers = readers
216
+
217
+ def read_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]]:
218
+ batches = [r.read_batch(batch_size) for r in self.readers]
219
+ any_none = any(x is None for x in batches)
220
+ all_none = all(x is None for x in batches)
221
+ if any_none != all_none:
222
+ raise RuntimeError("different keys had different numbers of elements")
223
+ if any_none:
224
+ return None
225
+ if any(len(x) != len(batches[0]) for x in batches):
226
+ raise RuntimeError("different keys had different numbers of elements")
227
+ return dict(zip(self.keys, batches))
228
+
229
+
230
+ def _read_bytes(fp, size, error_template="ran out of data"):
231
+ """
232
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
233
+
234
+ Read from file-like object until size bytes are read.
235
+ Raises ValueError if not EOF is encountered before size bytes are read.
236
+ Non-blocking objects only supported if they derive from io objects.
237
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
238
+ requested.
239
+ """
240
+ data = bytes()
241
+ while True:
242
+ # io files (default in python3) return None or raise on
243
+ # would-block, python2 file will truncate, probably nothing can be
244
+ # done about that. note that regular files can't be non-blocking
245
+ try:
246
+ r = fp.read(size - len(data))
247
+ data += r
248
+ if len(r) == 0 or len(data) == size:
249
+ break
250
+ except io.BlockingIOError:
251
+ pass
252
+ if len(data) != size:
253
+ msg = "EOF: reading %s, expected %d bytes got %d"
254
+ raise ValueError(msg % (error_template, size, len(data)))
255
+ else:
256
+ return data
257
+
258
+
259
+ @contextmanager
260
+ def open_array(path: str, arr_name: str):
261
+ with open(path, "rb") as f:
262
+ with zipfile.ZipFile(f, "r") as zip_f:
263
+ if f"{arr_name}.npy" not in zip_f.namelist():
264
+ raise ValueError(f"missing {arr_name} in npz file")
265
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
266
+ yield arr_f
267
+
268
+
269
+ def _dict_batch_size(objs: Dict[str, np.ndarray]) -> int:
270
+ return len(next(iter(objs.values())))
point_e/evals/pointnet2_cls_ssg.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/eb64fe0b4c24055559cea26299cb485dcb43d8dd/models/pointnet2_cls_ssg.py
3
+
4
+ MIT License
5
+
6
+ Copyright (c) 2019 benny
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
25
+ """
26
+
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ from .pointnet2_utils import PointNetSetAbstraction
31
+
32
+
33
+ class get_model(nn.Module):
34
+ def __init__(self, num_class, normal_channel=True, width_mult=1):
35
+ super(get_model, self).__init__()
36
+ self.width_mult = width_mult
37
+ in_channel = 6 if normal_channel else 3
38
+ self.normal_channel = normal_channel
39
+ self.sa1 = PointNetSetAbstraction(
40
+ npoint=512,
41
+ radius=0.2,
42
+ nsample=32,
43
+ in_channel=in_channel,
44
+ mlp=[64 * width_mult, 64 * width_mult, 128 * width_mult],
45
+ group_all=False,
46
+ )
47
+ self.sa2 = PointNetSetAbstraction(
48
+ npoint=128,
49
+ radius=0.4,
50
+ nsample=64,
51
+ in_channel=128 * width_mult + 3,
52
+ mlp=[128 * width_mult, 128 * width_mult, 256 * width_mult],
53
+ group_all=False,
54
+ )
55
+ self.sa3 = PointNetSetAbstraction(
56
+ npoint=None,
57
+ radius=None,
58
+ nsample=None,
59
+ in_channel=256 * width_mult + 3,
60
+ mlp=[256 * width_mult, 512 * width_mult, 1024 * width_mult],
61
+ group_all=True,
62
+ )
63
+ self.fc1 = nn.Linear(1024 * width_mult, 512 * width_mult)
64
+ self.bn1 = nn.BatchNorm1d(512 * width_mult)
65
+ self.drop1 = nn.Dropout(0.4)
66
+ self.fc2 = nn.Linear(512 * width_mult, 256 * width_mult)
67
+ self.bn2 = nn.BatchNorm1d(256 * width_mult)
68
+ self.drop2 = nn.Dropout(0.4)
69
+ self.fc3 = nn.Linear(256 * width_mult, num_class)
70
+
71
+ def forward(self, xyz, features=False):
72
+ B, _, _ = xyz.shape
73
+ if self.normal_channel:
74
+ norm = xyz[:, 3:, :]
75
+ xyz = xyz[:, :3, :]
76
+ else:
77
+ norm = None
78
+ l1_xyz, l1_points = self.sa1(xyz, norm)
79
+ l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
80
+ l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
81
+ x = l3_points.view(B, 1024 * self.width_mult)
82
+ x = self.drop1(F.relu(self.bn1(self.fc1(x))))
83
+ result_features = self.bn2(self.fc2(x))
84
+ x = self.drop2(F.relu(result_features))
85
+ x = self.fc3(x)
86
+ x = F.log_softmax(x, -1)
87
+
88
+ if features:
89
+ return x, l3_points, result_features
90
+ else:
91
+ return x, l3_points
92
+
93
+
94
+ class get_loss(nn.Module):
95
+ def __init__(self):
96
+ super(get_loss, self).__init__()
97
+
98
+ def forward(self, pred, target, trans_feat):
99
+ total_loss = F.nll_loss(pred, target)
100
+
101
+ return total_loss
point_e/evals/pointnet2_utils.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/eb64fe0b4c24055559cea26299cb485dcb43d8dd/models/pointnet_utils.py
3
+
4
+ MIT License
5
+
6
+ Copyright (c) 2019 benny
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
25
+ """
26
+
27
+ from time import time
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+
35
+ def timeit(tag, t):
36
+ print("{}: {}s".format(tag, time() - t))
37
+ return time()
38
+
39
+
40
+ def pc_normalize(pc):
41
+ l = pc.shape[0]
42
+ centroid = np.mean(pc, axis=0)
43
+ pc = pc - centroid
44
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
45
+ pc = pc / m
46
+ return pc
47
+
48
+
49
+ def square_distance(src, dst):
50
+ """
51
+ Calculate Euclid distance between each two points.
52
+
53
+ src^T * dst = xn * xm + yn * ym + zn * zm;
54
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
55
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
56
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
57
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
58
+
59
+ Input:
60
+ src: source points, [B, N, C]
61
+ dst: target points, [B, M, C]
62
+ Output:
63
+ dist: per-point square distance, [B, N, M]
64
+ """
65
+ B, N, _ = src.shape
66
+ _, M, _ = dst.shape
67
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
68
+ dist += torch.sum(src**2, -1).view(B, N, 1)
69
+ dist += torch.sum(dst**2, -1).view(B, 1, M)
70
+ return dist
71
+
72
+
73
+ def index_points(points, idx):
74
+ """
75
+
76
+ Input:
77
+ points: input points data, [B, N, C]
78
+ idx: sample index data, [B, S]
79
+ Return:
80
+ new_points:, indexed points data, [B, S, C]
81
+ """
82
+ device = points.device
83
+ B = points.shape[0]
84
+ view_shape = list(idx.shape)
85
+ view_shape[1:] = [1] * (len(view_shape) - 1)
86
+ repeat_shape = list(idx.shape)
87
+ repeat_shape[0] = 1
88
+ batch_indices = (
89
+ torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
90
+ )
91
+ new_points = points[batch_indices, idx, :]
92
+ return new_points
93
+
94
+
95
+ def farthest_point_sample(xyz, npoint, deterministic=False):
96
+ """
97
+ Input:
98
+ xyz: pointcloud data, [B, N, 3]
99
+ npoint: number of samples
100
+ Return:
101
+ centroids: sampled pointcloud index, [B, npoint]
102
+ """
103
+ device = xyz.device
104
+ B, N, C = xyz.shape
105
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
106
+ distance = torch.ones(B, N).to(device) * 1e10
107
+ if deterministic:
108
+ farthest = torch.arange(0, B, dtype=torch.long).to(device)
109
+ else:
110
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
111
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
112
+ for i in range(npoint):
113
+ centroids[:, i] = farthest
114
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
115
+ dist = torch.sum((xyz - centroid) ** 2, -1)
116
+ mask = dist < distance
117
+ distance[mask] = dist[mask]
118
+ farthest = torch.max(distance, -1)[1]
119
+ return centroids
120
+
121
+
122
+ def query_ball_point(radius, nsample, xyz, new_xyz):
123
+ """
124
+ Input:
125
+ radius: local region radius
126
+ nsample: max sample number in local region
127
+ xyz: all points, [B, N, 3]
128
+ new_xyz: query points, [B, S, 3]
129
+ Return:
130
+ group_idx: grouped points index, [B, S, nsample]
131
+ """
132
+ device = xyz.device
133
+ B, N, C = xyz.shape
134
+ _, S, _ = new_xyz.shape
135
+ group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
136
+ sqrdists = square_distance(new_xyz, xyz)
137
+ group_idx[sqrdists > radius**2] = N
138
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
139
+ group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
140
+ mask = group_idx == N
141
+ group_idx[mask] = group_first[mask]
142
+ return group_idx
143
+
144
+
145
+ def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, deterministic=False):
146
+ """
147
+ Input:
148
+ npoint:
149
+ radius:
150
+ nsample:
151
+ xyz: input points position data, [B, N, 3]
152
+ points: input points data, [B, N, D]
153
+ Return:
154
+ new_xyz: sampled points position data, [B, npoint, nsample, 3]
155
+ new_points: sampled points data, [B, npoint, nsample, 3+D]
156
+ """
157
+ B, N, C = xyz.shape
158
+ S = npoint
159
+ fps_idx = farthest_point_sample(xyz, npoint, deterministic=deterministic) # [B, npoint, C]
160
+ new_xyz = index_points(xyz, fps_idx)
161
+ idx = query_ball_point(radius, nsample, xyz, new_xyz)
162
+ grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
163
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
164
+
165
+ if points is not None:
166
+ grouped_points = index_points(points, idx)
167
+ new_points = torch.cat(
168
+ [grouped_xyz_norm, grouped_points], dim=-1
169
+ ) # [B, npoint, nsample, C+D]
170
+ else:
171
+ new_points = grouped_xyz_norm
172
+ if returnfps:
173
+ return new_xyz, new_points, grouped_xyz, fps_idx
174
+ else:
175
+ return new_xyz, new_points
176
+
177
+
178
+ def sample_and_group_all(xyz, points):
179
+ """
180
+ Input:
181
+ xyz: input points position data, [B, N, 3]
182
+ points: input points data, [B, N, D]
183
+ Return:
184
+ new_xyz: sampled points position data, [B, 1, 3]
185
+ new_points: sampled points data, [B, 1, N, 3+D]
186
+ """
187
+ device = xyz.device
188
+ B, N, C = xyz.shape
189
+ new_xyz = torch.zeros(B, 1, C).to(device)
190
+ grouped_xyz = xyz.view(B, 1, N, C)
191
+ if points is not None:
192
+ new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
193
+ else:
194
+ new_points = grouped_xyz
195
+ return new_xyz, new_points
196
+
197
+
198
+ class PointNetSetAbstraction(nn.Module):
199
+ def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
200
+ super(PointNetSetAbstraction, self).__init__()
201
+ self.npoint = npoint
202
+ self.radius = radius
203
+ self.nsample = nsample
204
+ self.mlp_convs = nn.ModuleList()
205
+ self.mlp_bns = nn.ModuleList()
206
+ last_channel = in_channel
207
+ for out_channel in mlp:
208
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
209
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
210
+ last_channel = out_channel
211
+ self.group_all = group_all
212
+
213
+ def forward(self, xyz, points):
214
+ """
215
+ Input:
216
+ xyz: input points position data, [B, C, N]
217
+ points: input points data, [B, D, N]
218
+ Return:
219
+ new_xyz: sampled points position data, [B, C, S]
220
+ new_points_concat: sample points feature data, [B, D', S]
221
+ """
222
+ xyz = xyz.permute(0, 2, 1)
223
+ if points is not None:
224
+ points = points.permute(0, 2, 1)
225
+
226
+ if self.group_all:
227
+ new_xyz, new_points = sample_and_group_all(xyz, points)
228
+ else:
229
+ new_xyz, new_points = sample_and_group(
230
+ self.npoint, self.radius, self.nsample, xyz, points, deterministic=not self.training
231
+ )
232
+ # new_xyz: sampled points position data, [B, npoint, C]
233
+ # new_points: sampled points data, [B, npoint, nsample, C+D]
234
+ new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
235
+ for i, conv in enumerate(self.mlp_convs):
236
+ bn = self.mlp_bns[i]
237
+ new_points = F.relu(bn(conv(new_points)))
238
+
239
+ new_points = torch.max(new_points, 2)[0]
240
+ new_xyz = new_xyz.permute(0, 2, 1)
241
+ return new_xyz, new_points
242
+
243
+
244
+ class PointNetSetAbstractionMsg(nn.Module):
245
+ def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
246
+ super(PointNetSetAbstractionMsg, self).__init__()
247
+ self.npoint = npoint
248
+ self.radius_list = radius_list
249
+ self.nsample_list = nsample_list
250
+ self.conv_blocks = nn.ModuleList()
251
+ self.bn_blocks = nn.ModuleList()
252
+ for i in range(len(mlp_list)):
253
+ convs = nn.ModuleList()
254
+ bns = nn.ModuleList()
255
+ last_channel = in_channel + 3
256
+ for out_channel in mlp_list[i]:
257
+ convs.append(nn.Conv2d(last_channel, out_channel, 1))
258
+ bns.append(nn.BatchNorm2d(out_channel))
259
+ last_channel = out_channel
260
+ self.conv_blocks.append(convs)
261
+ self.bn_blocks.append(bns)
262
+
263
+ def forward(self, xyz, points):
264
+ """
265
+ Input:
266
+ xyz: input points position data, [B, C, N]
267
+ points: input points data, [B, D, N]
268
+ Return:
269
+ new_xyz: sampled points position data, [B, C, S]
270
+ new_points_concat: sample points feature data, [B, D', S]
271
+ """
272
+ xyz = xyz.permute(0, 2, 1)
273
+ if points is not None:
274
+ points = points.permute(0, 2, 1)
275
+
276
+ B, N, C = xyz.shape
277
+ S = self.npoint
278
+ new_xyz = index_points(xyz, farthest_point_sample(xyz, S, deterministic=not self.training))
279
+ new_points_list = []
280
+ for i, radius in enumerate(self.radius_list):
281
+ K = self.nsample_list[i]
282
+ group_idx = query_ball_point(radius, K, xyz, new_xyz)
283
+ grouped_xyz = index_points(xyz, group_idx)
284
+ grouped_xyz -= new_xyz.view(B, S, 1, C)
285
+ if points is not None:
286
+ grouped_points = index_points(points, group_idx)
287
+ grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
288
+ else:
289
+ grouped_points = grouped_xyz
290
+
291
+ grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
292
+ for j in range(len(self.conv_blocks[i])):
293
+ conv = self.conv_blocks[i][j]
294
+ bn = self.bn_blocks[i][j]
295
+ grouped_points = F.relu(bn(conv(grouped_points)))
296
+ new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
297
+ new_points_list.append(new_points)
298
+
299
+ new_xyz = new_xyz.permute(0, 2, 1)
300
+ new_points_concat = torch.cat(new_points_list, dim=1)
301
+ return new_xyz, new_points_concat
302
+
303
+
304
+ class PointNetFeaturePropagation(nn.Module):
305
+ def __init__(self, in_channel, mlp):
306
+ super(PointNetFeaturePropagation, self).__init__()
307
+ self.mlp_convs = nn.ModuleList()
308
+ self.mlp_bns = nn.ModuleList()
309
+ last_channel = in_channel
310
+ for out_channel in mlp:
311
+ self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
312
+ self.mlp_bns.append(nn.BatchNorm1d(out_channel))
313
+ last_channel = out_channel
314
+
315
+ def forward(self, xyz1, xyz2, points1, points2):
316
+ """
317
+ Input:
318
+ xyz1: input points position data, [B, C, N]
319
+ xyz2: sampled input points position data, [B, C, S]
320
+ points1: input points data, [B, D, N]
321
+ points2: input points data, [B, D, S]
322
+ Return:
323
+ new_points: upsampled points data, [B, D', N]
324
+ """
325
+ xyz1 = xyz1.permute(0, 2, 1)
326
+ xyz2 = xyz2.permute(0, 2, 1)
327
+
328
+ points2 = points2.permute(0, 2, 1)
329
+ B, N, C = xyz1.shape
330
+ _, S, _ = xyz2.shape
331
+
332
+ if S == 1:
333
+ interpolated_points = points2.repeat(1, N, 1)
334
+ else:
335
+ dists = square_distance(xyz1, xyz2)
336
+ dists, idx = dists.sort(dim=-1)
337
+ dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
338
+
339
+ dist_recip = 1.0 / (dists + 1e-8)
340
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
341
+ weight = dist_recip / norm
342
+ interpolated_points = torch.sum(
343
+ index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2
344
+ )
345
+
346
+ if points1 is not None:
347
+ points1 = points1.permute(0, 2, 1)
348
+ new_points = torch.cat([points1, interpolated_points], dim=-1)
349
+ else:
350
+ new_points = interpolated_points
351
+
352
+ new_points = new_points.permute(0, 2, 1)
353
+ for i, conv in enumerate(self.mlp_convs):
354
+ bn = self.mlp_bns[i]
355
+ new_points = F.relu(bn(conv(new_points)))
356
+ return new_points
point_e/evals/scripts/blender_script.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to run within Blender to render a 3D model as RGBAD images.
3
+
4
+ Example usage
5
+
6
+ blender -b -P blender_script.py -- \
7
+ --input_path ../../examples/example_data/corgi.ply \
8
+ --output_path render_out
9
+
10
+ Pass `--camera_pose z-circular-elevated` for the rendering used to compute
11
+ CLIP R-Precision results.
12
+
13
+ The output directory will include metadata json files for each rendered view,
14
+ as well as a global metadata file for the render. Each image will be saved as
15
+ a collection of 16-bit PNG files for each channel (rgbad), as well as a full
16
+ grayscale render of the view.
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import math
22
+ import os
23
+ import random
24
+ import sys
25
+
26
+ import bpy
27
+ from mathutils import Vector
28
+ from mathutils.noise import random_unit_vector
29
+
30
+ MAX_DEPTH = 5.0
31
+ FORMAT_VERSION = 6
32
+ UNIFORM_LIGHT_DIRECTION = [0.09387503, -0.63953443, -0.7630093]
33
+
34
+
35
+ def clear_scene():
36
+ bpy.ops.object.select_all(action="SELECT")
37
+ bpy.ops.object.delete()
38
+
39
+
40
+ def clear_lights():
41
+ bpy.ops.object.select_all(action="DESELECT")
42
+ for obj in bpy.context.scene.objects.values():
43
+ if isinstance(obj.data, bpy.types.Light):
44
+ obj.select_set(True)
45
+ bpy.ops.object.delete()
46
+
47
+
48
+ def import_model(path):
49
+ clear_scene()
50
+ _, ext = os.path.splitext(path)
51
+ ext = ext.lower()
52
+ if ext == ".obj":
53
+ bpy.ops.import_scene.obj(filepath=path)
54
+ elif ext in [".glb", ".gltf"]:
55
+ bpy.ops.import_scene.gltf(filepath=path)
56
+ elif ext == ".stl":
57
+ bpy.ops.import_mesh.stl(filepath=path)
58
+ elif ext == ".fbx":
59
+ bpy.ops.import_scene.fbx(filepath=path)
60
+ elif ext == ".dae":
61
+ bpy.ops.wm.collada_import(filepath=path)
62
+ elif ext == ".ply":
63
+ bpy.ops.import_mesh.ply(filepath=path)
64
+ else:
65
+ raise RuntimeError(f"unexpected extension: {ext}")
66
+
67
+
68
+ def scene_root_objects():
69
+ for obj in bpy.context.scene.objects.values():
70
+ if not obj.parent:
71
+ yield obj
72
+
73
+
74
+ def scene_bbox(single_obj=None, ignore_matrix=False):
75
+ bbox_min = (math.inf,) * 3
76
+ bbox_max = (-math.inf,) * 3
77
+ found = False
78
+ for obj in scene_meshes() if single_obj is None else [single_obj]:
79
+ found = True
80
+ for coord in obj.bound_box:
81
+ coord = Vector(coord)
82
+ if not ignore_matrix:
83
+ coord = obj.matrix_world @ coord
84
+ bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
85
+ bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
86
+ if not found:
87
+ raise RuntimeError("no objects in scene to compute bounding box for")
88
+ return Vector(bbox_min), Vector(bbox_max)
89
+
90
+
91
+ def scene_meshes():
92
+ for obj in bpy.context.scene.objects.values():
93
+ if isinstance(obj.data, (bpy.types.Mesh)):
94
+ yield obj
95
+
96
+
97
+ def normalize_scene():
98
+ bbox_min, bbox_max = scene_bbox()
99
+ scale = 1 / max(bbox_max - bbox_min)
100
+
101
+ for obj in scene_root_objects():
102
+ obj.scale = obj.scale * scale
103
+
104
+ # Apply scale to matrix_world.
105
+ bpy.context.view_layer.update()
106
+
107
+ bbox_min, bbox_max = scene_bbox()
108
+ offset = -(bbox_min + bbox_max) / 2
109
+ for obj in scene_root_objects():
110
+ obj.matrix_world.translation += offset
111
+
112
+ bpy.ops.object.select_all(action="DESELECT")
113
+
114
+
115
+ def create_camera():
116
+ # https://b3d.interplanety.org/en/how-to-create-camera-through-the-blender-python-api/
117
+ camera_data = bpy.data.cameras.new(name="Camera")
118
+ camera_object = bpy.data.objects.new("Camera", camera_data)
119
+ bpy.context.scene.collection.objects.link(camera_object)
120
+ bpy.context.scene.camera = camera_object
121
+
122
+
123
+ def set_camera(direction, camera_dist=2.0):
124
+ camera_pos = -camera_dist * direction
125
+ bpy.context.scene.camera.location = camera_pos
126
+
127
+ # https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically
128
+ rot_quat = direction.to_track_quat("-Z", "Y")
129
+ bpy.context.scene.camera.rotation_euler = rot_quat.to_euler()
130
+
131
+ bpy.context.view_layer.update()
132
+
133
+
134
+ def randomize_camera(camera_dist=2.0):
135
+ direction = random_unit_vector()
136
+ set_camera(direction, camera_dist=camera_dist)
137
+
138
+
139
+ def pan_camera(time, axis="Z", camera_dist=2.0, elevation=-0.1):
140
+ angle = time * math.pi * 2
141
+ direction = [-math.cos(angle), -math.sin(angle), -elevation]
142
+ assert axis in ["X", "Y", "Z"]
143
+ if axis == "X":
144
+ direction = [direction[2], *direction[:2]]
145
+ elif axis == "Y":
146
+ direction = [direction[0], -elevation, direction[1]]
147
+ direction = Vector(direction).normalized()
148
+ set_camera(direction, camera_dist=camera_dist)
149
+
150
+
151
+ def place_camera(time, camera_pose_mode="random", camera_dist_min=2.0, camera_dist_max=2.0):
152
+ camera_dist = random.uniform(camera_dist_min, camera_dist_max)
153
+ if camera_pose_mode == "random":
154
+ randomize_camera(camera_dist=camera_dist)
155
+ elif camera_pose_mode == "z-circular":
156
+ pan_camera(time, axis="Z", camera_dist=camera_dist)
157
+ elif camera_pose_mode == "z-circular-elevated":
158
+ pan_camera(time, axis="Z", camera_dist=camera_dist, elevation=0.2617993878)
159
+ else:
160
+ raise ValueError(f"Unknown camera pose mode: {camera_pose_mode}")
161
+
162
+
163
+ def create_light(location, energy=1.0, angle=0.5 * math.pi / 180):
164
+ # https://blender.stackexchange.com/questions/215624/how-to-create-a-light-with-the-python-api-in-blender-2-92
165
+ light_data = bpy.data.lights.new(name="Light", type="SUN")
166
+ light_data.energy = energy
167
+ light_data.angle = angle
168
+ light_object = bpy.data.objects.new(name="Light", object_data=light_data)
169
+
170
+ direction = -location
171
+ rot_quat = direction.to_track_quat("-Z", "Y")
172
+ light_object.rotation_euler = rot_quat.to_euler()
173
+ bpy.context.view_layer.update()
174
+
175
+ bpy.context.collection.objects.link(light_object)
176
+ light_object.location = location
177
+
178
+
179
+ def create_random_lights(count=4, distance=2.0, energy=1.5):
180
+ clear_lights()
181
+ for _ in range(count):
182
+ create_light(random_unit_vector() * distance, energy=energy)
183
+
184
+
185
+ def create_camera_light():
186
+ clear_lights()
187
+ create_light(bpy.context.scene.camera.location, energy=5.0)
188
+
189
+
190
+ def create_uniform_light(backend):
191
+ clear_lights()
192
+ # Random direction to decorrelate axis-aligned sides.
193
+ pos = Vector(UNIFORM_LIGHT_DIRECTION)
194
+ angle = 0.0092 if backend == "CYCLES" else math.pi
195
+ create_light(pos, energy=5.0, angle=angle)
196
+ create_light(-pos, energy=5.0, angle=angle)
197
+
198
+
199
+ def create_vertex_color_shaders():
200
+ # By default, Blender will ignore vertex colors in both the
201
+ # Eevee and Cycles backends, since these colors aren't
202
+ # associated with a material.
203
+ #
204
+ # What we do here is create a simple material shader and link
205
+ # the vertex color to the material color.
206
+ for obj in bpy.context.scene.objects.values():
207
+ if not isinstance(obj.data, (bpy.types.Mesh)):
208
+ continue
209
+
210
+ if len(obj.data.materials):
211
+ # We don't want to override any existing materials.
212
+ continue
213
+
214
+ color_keys = (obj.data.vertex_colors or {}).keys()
215
+ if not len(color_keys):
216
+ # Many objects will have no materials *or* vertex colors.
217
+ continue
218
+
219
+ mat = bpy.data.materials.new(name="VertexColored")
220
+ mat.use_nodes = True
221
+
222
+ # There should be a Principled BSDF by default.
223
+ bsdf_node = None
224
+ for node in mat.node_tree.nodes:
225
+ if node.type == "BSDF_PRINCIPLED":
226
+ bsdf_node = node
227
+ assert bsdf_node is not None, "material has no Principled BSDF node to modify"
228
+
229
+ socket_map = {}
230
+ for input in bsdf_node.inputs:
231
+ socket_map[input.name] = input
232
+
233
+ # Make sure nothing lights the object except for the diffuse color.
234
+ socket_map["Specular"].default_value = 0.0
235
+ socket_map["Roughness"].default_value = 1.0
236
+
237
+ v_color = mat.node_tree.nodes.new("ShaderNodeVertexColor")
238
+ v_color.layer_name = color_keys[0]
239
+
240
+ mat.node_tree.links.new(v_color.outputs[0], socket_map["Base Color"])
241
+
242
+ obj.data.materials.append(mat)
243
+
244
+
245
+ def create_default_materials():
246
+ for obj in bpy.context.scene.objects.values():
247
+ if isinstance(obj.data, (bpy.types.Mesh)):
248
+ if not len(obj.data.materials):
249
+ mat = bpy.data.materials.new(name="DefaultMaterial")
250
+ mat.use_nodes = True
251
+ obj.data.materials.append(mat)
252
+
253
+
254
+ def find_materials():
255
+ all_materials = set()
256
+ for obj in bpy.context.scene.objects.values():
257
+ if not isinstance(obj.data, (bpy.types.Mesh)):
258
+ continue
259
+ for mat in obj.data.materials:
260
+ all_materials.add(mat)
261
+ return all_materials
262
+
263
+
264
+ def get_socket_value(tree, socket):
265
+ default = socket.default_value
266
+ if not isinstance(default, float):
267
+ default = list(default)
268
+ for link in tree.links:
269
+ if link.to_socket == socket:
270
+ return (link.from_socket, default)
271
+ return (None, default)
272
+
273
+
274
+ def clear_socket_input(tree, socket):
275
+ for link in list(tree.links):
276
+ if link.to_socket == socket:
277
+ tree.links.remove(link)
278
+
279
+
280
+ def set_socket_value(tree, socket, socket_and_default):
281
+ clear_socket_input(tree, socket)
282
+ old_source_socket, default = socket_and_default
283
+ if isinstance(default, float) and not isinstance(socket.default_value, float):
284
+ # Codepath for setting Emission to a previous alpha value.
285
+ socket.default_value = [default] * 3 + [1.0]
286
+ else:
287
+ socket.default_value = default
288
+ if old_source_socket is not None:
289
+ tree.links.new(old_source_socket, socket)
290
+
291
+
292
+ def setup_nodes(output_path, capturing_material_alpha: bool = False):
293
+ tree = bpy.context.scene.node_tree
294
+ links = tree.links
295
+
296
+ for node in tree.nodes:
297
+ tree.nodes.remove(node)
298
+
299
+ # Helpers to perform math on links and constants.
300
+ def node_op(op: str, *args, clamp=False):
301
+ node = tree.nodes.new(type="CompositorNodeMath")
302
+ node.operation = op
303
+ if clamp:
304
+ node.use_clamp = True
305
+ for i, arg in enumerate(args):
306
+ if isinstance(arg, (int, float)):
307
+ node.inputs[i].default_value = arg
308
+ else:
309
+ links.new(arg, node.inputs[i])
310
+ return node.outputs[0]
311
+
312
+ def node_clamp(x, maximum=1.0):
313
+ return node_op("MINIMUM", x, maximum)
314
+
315
+ def node_mul(x, y, **kwargs):
316
+ return node_op("MULTIPLY", x, y, **kwargs)
317
+
318
+ input_node = tree.nodes.new(type="CompositorNodeRLayers")
319
+ input_node.scene = bpy.context.scene
320
+
321
+ input_sockets = {}
322
+ for output in input_node.outputs:
323
+ input_sockets[output.name] = output
324
+
325
+ if capturing_material_alpha:
326
+ color_socket = input_sockets["Image"]
327
+ else:
328
+ raw_color_socket = input_sockets["Image"]
329
+
330
+ # We apply sRGB here so that our fixed-point depth map and material
331
+ # alpha values are not sRGB, and so that we perform ambient+diffuse
332
+ # lighting in linear RGB space.
333
+ color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace")
334
+ color_node.from_color_space = "Linear"
335
+ color_node.to_color_space = "sRGB"
336
+ tree.links.new(raw_color_socket, color_node.inputs[0])
337
+ color_socket = color_node.outputs[0]
338
+ split_node = tree.nodes.new(type="CompositorNodeSepRGBA")
339
+ tree.links.new(color_socket, split_node.inputs[0])
340
+ # Create separate file output nodes for every channel we care about.
341
+ # The process calling this script must decide how to recombine these
342
+ # channels, possibly into a single image.
343
+ for i, channel in enumerate("rgba") if not capturing_material_alpha else [(0, "MatAlpha")]:
344
+ output_node = tree.nodes.new(type="CompositorNodeOutputFile")
345
+ output_node.base_path = f"{output_path}_{channel}"
346
+ links.new(split_node.outputs[i], output_node.inputs[0])
347
+
348
+ if capturing_material_alpha:
349
+ # No need to re-write depth here.
350
+ return
351
+
352
+ depth_out = node_clamp(node_mul(input_sockets["Depth"], 1 / MAX_DEPTH))
353
+ output_node = tree.nodes.new(type="CompositorNodeOutputFile")
354
+ output_node.base_path = f"{output_path}_depth"
355
+ links.new(depth_out, output_node.inputs[0])
356
+
357
+
358
+ def render_scene(output_path, fast_mode: bool):
359
+ use_workbench = bpy.context.scene.render.engine == "BLENDER_WORKBENCH"
360
+ if use_workbench:
361
+ # We must use a different engine to compute depth maps.
362
+ bpy.context.scene.render.engine = "BLENDER_EEVEE"
363
+ bpy.context.scene.eevee.taa_render_samples = 1 # faster, since we discard image.
364
+ if fast_mode:
365
+ if bpy.context.scene.render.engine == "BLENDER_EEVEE":
366
+ bpy.context.scene.eevee.taa_render_samples = 1
367
+ elif bpy.context.scene.render.engine == "CYCLES":
368
+ bpy.context.scene.cycles.samples = 256
369
+ else:
370
+ if bpy.context.scene.render.engine == "CYCLES":
371
+ # We should still impose a per-frame time limit
372
+ # so that we don't timeout completely.
373
+ bpy.context.scene.cycles.time_limit = 40
374
+ bpy.context.view_layer.update()
375
+ bpy.context.scene.use_nodes = True
376
+ bpy.context.scene.view_layers["ViewLayer"].use_pass_z = True
377
+ bpy.context.scene.view_settings.view_transform = "Raw" # sRGB done in graph nodes
378
+ bpy.context.scene.render.film_transparent = True
379
+ bpy.context.scene.render.resolution_x = 512
380
+ bpy.context.scene.render.resolution_y = 512
381
+ bpy.context.scene.render.image_settings.file_format = "PNG"
382
+ bpy.context.scene.render.image_settings.color_mode = "BW"
383
+ bpy.context.scene.render.image_settings.color_depth = "16"
384
+ bpy.context.scene.render.filepath = output_path
385
+ setup_nodes(output_path)
386
+ bpy.ops.render.render(write_still=True)
387
+
388
+ # The output images must be moved from their own sub-directories, or
389
+ # discarded if we are using workbench for the color.
390
+ for channel_name in ["r", "g", "b", "a", "depth"]:
391
+ sub_dir = f"{output_path}_{channel_name}"
392
+ image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0])
393
+ name, ext = os.path.splitext(output_path)
394
+ if channel_name == "depth" or not use_workbench:
395
+ os.rename(image_path, f"{name}_{channel_name}{ext}")
396
+ else:
397
+ os.remove(image_path)
398
+ os.removedirs(sub_dir)
399
+
400
+ if use_workbench:
401
+ # Re-render RGBA using workbench with texture mode, since this seems
402
+ # to show the most reasonable colors when lighting is broken.
403
+ bpy.context.scene.use_nodes = False
404
+ bpy.context.scene.render.engine = "BLENDER_WORKBENCH"
405
+ bpy.context.scene.render.image_settings.color_mode = "RGBA"
406
+ bpy.context.scene.render.image_settings.color_depth = "8"
407
+ bpy.context.scene.display.shading.color_type = "TEXTURE"
408
+ bpy.context.scene.display.shading.light = "FLAT"
409
+ if fast_mode:
410
+ # Single pass anti-aliasing.
411
+ bpy.context.scene.display.render_aa = "FXAA"
412
+ os.remove(output_path)
413
+ bpy.ops.render.render(write_still=True)
414
+ bpy.context.scene.render.image_settings.color_mode = "BW"
415
+ bpy.context.scene.render.image_settings.color_depth = "16"
416
+
417
+
418
+ def scene_fov():
419
+ x_fov = bpy.context.scene.camera.data.angle_x
420
+ y_fov = bpy.context.scene.camera.data.angle_y
421
+ width = bpy.context.scene.render.resolution_x
422
+ height = bpy.context.scene.render.resolution_y
423
+ if bpy.context.scene.camera.data.angle == x_fov:
424
+ y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width)
425
+ else:
426
+ x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height)
427
+ return x_fov, y_fov
428
+
429
+
430
+ def write_camera_metadata(path):
431
+ x_fov, y_fov = scene_fov()
432
+ bbox_min, bbox_max = scene_bbox()
433
+ matrix = bpy.context.scene.camera.matrix_world
434
+ with open(path, "w") as f:
435
+ json.dump(
436
+ dict(
437
+ format_version=FORMAT_VERSION,
438
+ max_depth=MAX_DEPTH,
439
+ bbox=[list(bbox_min), list(bbox_max)],
440
+ origin=list(matrix.col[3])[:3],
441
+ x_fov=x_fov,
442
+ y_fov=y_fov,
443
+ x=list(matrix.col[0])[:3],
444
+ y=list(-matrix.col[1])[:3],
445
+ z=list(-matrix.col[2])[:3],
446
+ ),
447
+ f,
448
+ )
449
+
450
+
451
+ def save_rendering_dataset(
452
+ input_path: str,
453
+ output_path: str,
454
+ num_images: int,
455
+ backend: str,
456
+ light_mode: str,
457
+ camera_pose: str,
458
+ camera_dist_min: float,
459
+ camera_dist_max: float,
460
+ fast_mode: bool,
461
+ ):
462
+ assert light_mode in ["random", "uniform", "camera"]
463
+ assert camera_pose in ["random", "z-circular", "z-circular-elevated"]
464
+
465
+ import_model(input_path)
466
+ bpy.context.scene.render.engine = backend
467
+ normalize_scene()
468
+ if light_mode == "random":
469
+ create_random_lights()
470
+ elif light_mode == "uniform":
471
+ create_uniform_light(backend)
472
+ create_camera()
473
+ create_vertex_color_shaders()
474
+ for i in range(num_images):
475
+ t = i / max(num_images - 1, 1) # same as np.linspace(0, 1, num_images)
476
+ place_camera(
477
+ t,
478
+ camera_pose_mode=camera_pose,
479
+ camera_dist_min=camera_dist_min,
480
+ camera_dist_max=camera_dist_max,
481
+ )
482
+ if light_mode == "camera":
483
+ create_camera_light()
484
+ render_scene(
485
+ os.path.join(output_path, f"{i:05}.png"),
486
+ fast_mode=fast_mode,
487
+ )
488
+ write_camera_metadata(os.path.join(output_path, f"{i:05}.json"))
489
+ with open(os.path.join(output_path, "info.json"), "w") as f:
490
+ info = dict(
491
+ backend=backend,
492
+ light_mode=light_mode,
493
+ fast_mode=fast_mode,
494
+ format_version=FORMAT_VERSION,
495
+ channels=["R", "G", "B", "A", "D"],
496
+ scale=0.5, # The scene is bounded by [-scale, scale].
497
+ )
498
+ json.dump(info, f)
499
+
500
+
501
+ def main():
502
+ try:
503
+ dash_index = sys.argv.index("--")
504
+ except ValueError as exc:
505
+ raise ValueError("arguments must be preceded by '--'") from exc
506
+
507
+ raw_args = sys.argv[dash_index + 1 :]
508
+ parser = argparse.ArgumentParser()
509
+ parser.add_argument("--input_path", required=True, type=str)
510
+ parser.add_argument("--output_path", required=True, type=str)
511
+ parser.add_argument("--num_images", type=int, default=20)
512
+ parser.add_argument("--backend", type=str, default="BLENDER_EEVEE")
513
+ parser.add_argument("--light_mode", type=str, default="uniform")
514
+ parser.add_argument("--camera_pose", type=str, default="random")
515
+ parser.add_argument("--camera_dist_min", type=float, default=2.0)
516
+ parser.add_argument("--camera_dist_max", type=float, default=2.0)
517
+ parser.add_argument("--fast_mode", action="store_true")
518
+ args = parser.parse_args(raw_args)
519
+
520
+ save_rendering_dataset(
521
+ input_path=args.input_path,
522
+ output_path=args.output_path,
523
+ num_images=args.num_images,
524
+ backend=args.backend,
525
+ light_mode=args.light_mode,
526
+ camera_pose=args.camera_pose,
527
+ camera_dist_min=args.camera_dist_min,
528
+ camera_dist_max=args.camera_dist_max,
529
+ fast_mode=args.fast_mode,
530
+ )
531
+
532
+
533
+ main()
point_e/evals/scripts/evaluate_pfid.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate P-FID between two batches of point clouds.
3
+
4
+ The point cloud batches should be saved to two npz files, where there
5
+ is an arr_0 key of shape [N x K x 3], where K is the dimensionality of
6
+ each point cloud and N is the number of clouds.
7
+ """
8
+
9
+ import argparse
10
+
11
+ from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices
12
+ from point_e.evals.fid_is import compute_statistics
13
+ from point_e.evals.npz_stream import NpzStreamer
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--cache_dir", type=str, default=None)
19
+ parser.add_argument("batch_1", type=str)
20
+ parser.add_argument("batch_2", type=str)
21
+ args = parser.parse_args()
22
+
23
+ print("creating classifier...")
24
+ clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir)
25
+
26
+ print("computing first batch activations")
27
+
28
+ features_1, _ = clf.features_and_preds(NpzStreamer(args.batch_1))
29
+ stats_1 = compute_statistics(features_1)
30
+ del features_1
31
+
32
+ features_2, _ = clf.features_and_preds(NpzStreamer(args.batch_2))
33
+ stats_2 = compute_statistics(features_2)
34
+ del features_2
35
+
36
+ print(f"P-FID: {stats_1.frechet_distance(stats_2)}")
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
point_e/evals/scripts/evaluate_pis.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate P-IS of a batch of point clouds.
3
+
4
+ The point cloud batch should be saved to an npz file, where there is an
5
+ arr_0 key of shape [N x K x 3], where K is the dimensionality of each
6
+ point cloud and N is the number of clouds.
7
+ """
8
+
9
+ import argparse
10
+
11
+ from point_e.evals.feature_extractor import PointNetClassifier, get_torch_devices
12
+ from point_e.evals.fid_is import compute_inception_score
13
+ from point_e.evals.npz_stream import NpzStreamer
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--cache_dir", type=str, default=None)
19
+ parser.add_argument("batch", type=str)
20
+ args = parser.parse_args()
21
+
22
+ print("creating classifier...")
23
+ clf = PointNetClassifier(devices=get_torch_devices(), cache_dir=args.cache_dir)
24
+
25
+ print("computing batch predictions")
26
+ _, preds = clf.features_and_preds(NpzStreamer(args.batch))
27
+ print(f"P-IS: {compute_inception_score(preds)}")
28
+
29
+
30
+ if __name__ == "__main__":
31
+ main()
point_e/examples/.ipynb_checkpoints/Test-checkpoint.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow import keras
3
+ print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
4
+ tf.test.is_built_with_cuda()
5
+ print(tf.version.VERSION)
6
+ import sys
7
+ sys.version
point_e/examples/.ipynb_checkpoints/pointcloud2mesh-checkpoint.ipynb ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from PIL import Image\n",
10
+ "import torch\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "from tqdm.auto import tqdm\n",
13
+ "\n",
14
+ "from point_e.models.download import load_checkpoint\n",
15
+ "from point_e.models.configs import MODEL_CONFIGS, model_from_config\n",
16
+ "from point_e.util.pc_to_mesh import marching_cubes_mesh\n",
17
+ "from point_e.util.plotting import plot_point_cloud\n",
18
+ "from point_e.util.point_cloud import PointCloud"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
28
+ "\n",
29
+ "print('creating SDF model...')\n",
30
+ "name = 'sdf'\n",
31
+ "model = model_from_config(MODEL_CONFIGS[name], device)\n",
32
+ "model.eval()\n",
33
+ "\n",
34
+ "print('loading SDF model...')\n",
35
+ "model.load_state_dict(load_checkpoint(name, device))"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "# Load a point cloud we want to convert into a mesh.\n",
45
+ "pc = PointCloud.load('example_data/pc_corgi.npz')\n",
46
+ "\n",
47
+ "# Plot the point cloud as a sanity check.\n",
48
+ "fig = plot_point_cloud(pc, grid_size=2)"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "# Produce a mesh (with vertex colors)\n",
58
+ "mesh = marching_cubes_mesh(\n",
59
+ " pc=pc,\n",
60
+ " model=model,\n",
61
+ " batch_size=4096,\n",
62
+ " grid_size=32, # increase to 128 for resolution used in evals\n",
63
+ " progress=True,\n",
64
+ ")"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "# Write the mesh to a PLY file to import into some other program.\n",
74
+ "with open('mesh.ply', 'wb') as f:\n",
75
+ " mesh.write_ply(f)"
76
+ ]
77
+ }
78
+ ],
79
+ "metadata": {
80
+ "kernelspec": {
81
+ "display_name": "Python 3.9.9 64-bit ('3.9.9')",
82
+ "language": "python",
83
+ "name": "python3"
84
+ },
85
+ "language_info": {
86
+ "codemirror_mode": {
87
+ "name": "ipython",
88
+ "version": 3
89
+ },
90
+ "file_extension": ".py",
91
+ "mimetype": "text/x-python",
92
+ "name": "python",
93
+ "nbconvert_exporter": "python",
94
+ "pygments_lexer": "ipython3",
95
+ "version": "3.9.9"
96
+ },
97
+ "orig_nbformat": 4,
98
+ "vscode": {
99
+ "interpreter": {
100
+ "hash": "b270b0f43bc427bcab7703c037711644cc480aac7c1cc8d2940cfaf0b447ee2e"
101
+ }
102
+ }
103
+ },
104
+ "nbformat": 4,
105
+ "nbformat_minor": 2
106
+ }
point_e/examples/.ipynb_checkpoints/text2pointcloud-checkpoint.ipynb ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "from tqdm.auto import tqdm\n",
11
+ "\n",
12
+ "from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config\n",
13
+ "from point_e.diffusion.sampler import PointCloudSampler\n",
14
+ "from point_e.models.download import load_checkpoint\n",
15
+ "from point_e.models.configs import MODEL_CONFIGS, model_from_config\n",
16
+ "from point_e.util.plotting import plot_point_cloud"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "metadata": {},
23
+ "outputs": [
24
+ {
25
+ "name": "stdout",
26
+ "output_type": "stream",
27
+ "text": [
28
+ "creating base model...\n",
29
+ "creating upsample model...\n",
30
+ "downloading base checkpoint...\n",
31
+ "downloading upsampler checkpoint...\n"
32
+ ]
33
+ },
34
+ {
35
+ "data": {
36
+ "text/plain": [
37
+ "<All keys matched successfully>"
38
+ ]
39
+ },
40
+ "execution_count": 2,
41
+ "metadata": {},
42
+ "output_type": "execute_result"
43
+ }
44
+ ],
45
+ "source": [
46
+ "device = torch.device('cuda')\n",
47
+ "\n",
48
+ "print('creating base model...')\n",
49
+ "base_name = 'base40M-textvec'\n",
50
+ "base_model = model_from_config(MODEL_CONFIGS[base_name], device)\n",
51
+ "base_model.eval()\n",
52
+ "base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])\n",
53
+ "\n",
54
+ "print('creating upsample model...')\n",
55
+ "upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)\n",
56
+ "upsampler_model.eval()\n",
57
+ "upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])\n",
58
+ "\n",
59
+ "print('downloading base checkpoint...')\n",
60
+ "base_model.load_state_dict(load_checkpoint(base_name, device))\n",
61
+ "\n",
62
+ "print('downloading upsampler checkpoint...')\n",
63
+ "upsampler_model.load_state_dict(load_checkpoint('upsample', device))"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 3,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "sampler = PointCloudSampler(\n",
73
+ " device=device,\n",
74
+ " models=[base_model, upsampler_model],\n",
75
+ " diffusions=[base_diffusion, upsampler_diffusion],\n",
76
+ " num_points=[1024, 4096 - 1024],\n",
77
+ " aux_channels=['R', 'G', 'B'],\n",
78
+ " guidance_scale=[3.0, 0.0],\n",
79
+ " model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all\n",
80
+ ")"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "metadata": {},
87
+ "outputs": [
88
+ {
89
+ "data": {
90
+ "application/vnd.jupyter.widget-view+json": {
91
+ "model_id": "2777bd89bbef428aaae750480cbdf123",
92
+ "version_major": 2,
93
+ "version_minor": 0
94
+ },
95
+ "text/plain": [
96
+ "0it [00:00, ?it/s]"
97
+ ]
98
+ },
99
+ "metadata": {},
100
+ "output_type": "display_data"
101
+ }
102
+ ],
103
+ "source": [
104
+ "# Set a prompt to condition on.\n",
105
+ "prompt = 'a yellow dinosaur'\n",
106
+ "\n",
107
+ "# Produce a sample from the model.\n",
108
+ "samples = None\n",
109
+ "for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[prompt]))):\n",
110
+ " samples = x"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "pc = sampler.output_to_point_clouds(samples)[0]\n",
120
+ "fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))"
121
+ ]
122
+ }
123
+ ],
124
+ "metadata": {
125
+ "kernelspec": {
126
+ "display_name": "Python (GPU)",
127
+ "language": "python",
128
+ "name": "gpu_env"
129
+ },
130
+ "language_info": {
131
+ "codemirror_mode": {
132
+ "name": "ipython",
133
+ "version": 3
134
+ },
135
+ "file_extension": ".py",
136
+ "mimetype": "text/x-python",
137
+ "name": "python",
138
+ "nbconvert_exporter": "python",
139
+ "pygments_lexer": "ipython3",
140
+ "version": "3.11.5"
141
+ },
142
+ "vscode": {
143
+ "interpreter": {
144
+ "hash": "b270b0f43bc427bcab7703c037711644cc480aac7c1cc8d2940cfaf0b447ee2e"
145
+ }
146
+ }
147
+ },
148
+ "nbformat": 4,
149
+ "nbformat_minor": 4
150
+ }
point_e/examples/GPUtest.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow.compat.v1 as tf
3
+ tf.disable_v2_behavior()
4
+ from datetime import datetime
5
+
6
+ # Choose which device you want to test on: either 'cpu' or 'gpu'
7
+ devices = ['cpu', 'gpu']
8
+
9
+ # Choose size of the matrix to be used.
10
+ # Make it bigger to see bigger benefits of parallel computation
11
+ shapes = [(50, 50), (100, 100), (500, 500), (1000, 1000)]
12
+
13
+
14
+ def compute_operations(device, shape):
15
+ """Run a simple set of operations on a matrix of given shape on given device
16
+
17
+ Parameters
18
+ ----------
19
+ device : the type of device to use, either 'cpu' or 'gpu'
20
+ shape : a tuple for the shape of a 2d tensor, e.g. (10, 10)
21
+
22
+ Returns
23
+ -------
24
+ out : results of the operations as the time taken
25
+ """
26
+
27
+ # Define operations to be computed on selected device
28
+ with tf.device(device):
29
+ random_matrix = tf.random_uniform(shape=shape, minval=0, maxval=1)
30
+ dot_operation = tf.matmul(random_matrix, tf.transpose(random_matrix))
31
+ sum_operation = tf.reduce_sum(dot_operation)
32
+
33
+ # Time the actual runtime of the operations
34
+ start_time = datetime.now()
35
+ with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as session:
36
+ result = session.run(sum_operation)
37
+ elapsed_time = datetime.now() - start_time
38
+
39
+ return result, elapsed_time
40
+
41
+
42
+
43
+ if __name__ == '__main__':
44
+
45
+ # Run the computations and print summary of each run
46
+ for device in devices:
47
+ print("--" * 20)
48
+
49
+ for shape in shapes:
50
+ _, time_taken = compute_operations(device, shape)
51
+
52
+ # Print the result and also the time taken on the selected device
53
+ print("Input shape:", shape, "using Device:", device, "took: {:.2f}".format(time_taken.seconds + time_taken.microseconds/1e6))
54
+ #print("Computation on shape:", shape, "using Device:", device, "took:")
55
+
56
+ print("--" * 20)
point_e/examples/Saving Model Code.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #Attempts to save the point cloud (pc) to a ply file
2
+
3
+ pc = sampler.output_to_point_clouds(samples)[0]
4
+ with open('example_data/blue_bird.ply','wb') as f:
5
+ pc.write_ply(f)
6
+ #pc.save('example_data/blue_bird.npz')
point_e/examples/Test.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow import keras
3
+ print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
4
+ tf.test.is_built_with_cuda()
5
+ print(tf.version.VERSION)
6
+ import sys
7
+ sys.version
point_e/examples/example_data/blue_bird.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bb37522cb2781a44b6272f7c826485f5e218f6b8a16c3445ae1573e4118ca8e
3
+ size 99272
point_e/examples/example_data/blue_bird.ply ADDED
Binary file (61.6 kB). View file
 
point_e/examples/example_data/corgi.jpg ADDED
point_e/examples/example_data/corgi.ply ADDED
Binary file (65.2 kB). View file
 
point_e/examples/example_data/cube_stack.jpg ADDED
point_e/examples/example_data/pc_corgi.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8adfe52905e03d6e090e52d500f25afb59722153325e05decf4c80002ef76180
3
+ size 99272
point_e/examples/example_data/pc_cube_stack.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41126e73e3a6eaf7cb68d35f6594a6ee221c074ee51b921ce2e8ccce52009c22
3
+ size 99272