FrozenBurning commited on
Commit
81ecb2b
·
1 Parent(s): 06ea84f

single view to 3D init release

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. README.md +2 -1
  3. app.py +209 -0
  4. assets/examples/blue_cat.png +0 -0
  5. assets/examples/bubble_mart_blue.png +0 -0
  6. assets/examples/bulldog.png +0 -0
  7. assets/examples/ceramic.png +0 -0
  8. assets/examples/chair_watermelon.png +0 -0
  9. assets/examples/cup_rgba.png +0 -0
  10. assets/examples/cute_horse.jpg +0 -0
  11. assets/examples/earphone.jpg +0 -0
  12. assets/examples/firedragon.png +0 -0
  13. assets/examples/fox.jpg +0 -0
  14. assets/examples/fruit_elephant.jpg +0 -0
  15. assets/examples/hatsune_miku.png +0 -0
  16. assets/examples/ikun_rgba.png +0 -0
  17. assets/examples/mailbox.png +0 -0
  18. assets/examples/mario.png +0 -0
  19. assets/examples/mei_ling_panda.png +0 -0
  20. assets/examples/mushroom_teapot.jpg +0 -0
  21. assets/examples/pikachu.png +0 -0
  22. assets/examples/potplant_rgba.png +0 -0
  23. assets/examples/seed_frog.png +0 -0
  24. assets/examples/shuai_panda_notail.png +0 -0
  25. assets/examples/yellow_duck.png +0 -0
  26. configs/inference_dit.yml +97 -0
  27. dva/__init__.py +5 -0
  28. dva/attr_dict.py +66 -0
  29. dva/geom.py +653 -0
  30. dva/io.py +56 -0
  31. dva/layers.py +157 -0
  32. dva/losses.py +239 -0
  33. dva/mvp/extensions/mvpraymarch/bvh.cu +292 -0
  34. dva/mvp/extensions/mvpraymarch/cudadispatch.h +104 -0
  35. dva/mvp/extensions/mvpraymarch/helper_math.h +1453 -0
  36. dva/mvp/extensions/mvpraymarch/makefile +2 -0
  37. dva/mvp/extensions/mvpraymarch/mvpraymarch.cpp +405 -0
  38. dva/mvp/extensions/mvpraymarch/mvpraymarch.py +559 -0
  39. dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu +208 -0
  40. dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h +218 -0
  41. dva/mvp/extensions/mvpraymarch/primaccum.h +101 -0
  42. dva/mvp/extensions/mvpraymarch/primsampler.h +94 -0
  43. dva/mvp/extensions/mvpraymarch/primtransf.h +182 -0
  44. dva/mvp/extensions/mvpraymarch/setup.py +30 -0
  45. dva/mvp/extensions/mvpraymarch/utils.h +847 -0
  46. dva/mvp/extensions/utils/helper_math.h +1453 -0
  47. dva/mvp/extensions/utils/makefile +2 -0
  48. dva/mvp/extensions/utils/setup.py +29 -0
  49. dva/mvp/extensions/utils/utils.cpp +137 -0
  50. dva/mvp/extensions/utils/utils.py +211 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__
2
+ build
3
+ *.so
4
+ runs
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: 3DTopia XL
3
  emoji: 🌖
4
  colorFrom: green
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.41.0
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: 3DTopia-XL
3
  emoji: 🌖
4
  colorFrom: green
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.41.0
8
+ python_version: 3.9
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+
5
+ os.system("bash install.sh")
6
+
7
+ from omegaconf import OmegaConf
8
+ import tqdm
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torchvision.transforms.functional as TF
13
+ import rembg
14
+ import gradio as gr
15
+ from dva.io import load_from_config
16
+ from dva.ray_marcher import RayMarcher
17
+ from dva.visualize import visualize_primvolume, visualize_video_primvolume
18
+ from inference import remove_background, resize_foreground, extract_texmesh
19
+ from models.diffusion import create_diffusion
20
+ from huggingface_hub import hf_hub_download
21
+ ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_sview_dit_fp16.pt")
22
+ vae_ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_vae_fp16.pt")
23
+
24
+ GRADIO_PRIM_VIDEO_PATH = 'prim.mp4'
25
+ GRADIO_RGB_VIDEO_PATH = 'rgb.mp4'
26
+ GRADIO_MAT_VIDEO_PATH = 'mat.mp4'
27
+ GRADIO_GLB_PATH = 'pbr_mesh.glb'
28
+ CONFIG_PATH = "./configs/inference_dit.yml"
29
+
30
+ config = OmegaConf.load(CONFIG_PATH)
31
+ config.checkpoint_path = ckpt_path
32
+ config.model.vae_checkpoint_path = vae_ckpt_path
33
+ # model
34
+ model = load_from_config(config.model.generator)
35
+ state_dict = torch.load(config.checkpoint_path, map_location='cpu')
36
+ model.load_state_dict(state_dict['ema'])
37
+ vae = load_from_config(config.model.vae)
38
+ vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu')
39
+ vae.load_state_dict(vae_state_dict['model_state_dict'])
40
+ conditioner = load_from_config(config.model.conditioner)
41
+
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+ vae = vae.to(device)
44
+ conditioner = conditioner.to(device)
45
+ model = model.to(device)
46
+ model.eval()
47
+
48
+ amp = True
49
+ precision_dtype = torch.float16
50
+
51
+ rm = RayMarcher(
52
+ config.image_height,
53
+ config.image_width,
54
+ **config.rm,
55
+ ).to(device)
56
+
57
+ perchannel_norm = False
58
+ if "latent_mean" in config.model:
59
+ latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device)
60
+ latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
61
+ assert latent_mean.shape[-1] == config.model.generator.in_channels
62
+ perchannel_norm = True
63
+
64
+ config.diffusion.pop("timestep_respacing")
65
+ config.model.pop("vae")
66
+ config.model.pop("vae_checkpoint_path")
67
+ config.model.pop("conditioner")
68
+ config.model.pop("generator")
69
+ config.model.pop("latent_nf")
70
+ config.model.pop("latent_mean")
71
+ config.model.pop("latent_std")
72
+ model_primx = load_from_config(config.model)
73
+ # load rembg
74
+ rembg_session = rembg.new_session()
75
+
76
+ # process function
77
+ def process(input_image, input_num_steps=25, input_seed=42, input_cfg=6.0):
78
+ # seed
79
+ torch.manual_seed(input_seed)
80
+
81
+ os.makedirs(config.output_dir, exist_ok=True)
82
+ output_rgb_video_path = os.path.join(config.output_dir, GRADIO_RGB_VIDEO_PATH)
83
+ output_prim_video_path = os.path.join(config.output_dir, GRADIO_PRIM_VIDEO_PATH)
84
+ output_mat_video_path = os.path.join(config.output_dir, GRADIO_MAT_VIDEO_PATH)
85
+ output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH)
86
+
87
+ diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion)
88
+ sample_fn = diffusion.ddim_sample_loop_progressive
89
+ fwd_fn = model.forward_with_cfg
90
+
91
+ # text-conditioned
92
+ if input_image is None:
93
+ raise NotImplementedError
94
+ # image-conditioned (may also input text, but no text usually works too)
95
+ else:
96
+ input_image = remove_background(input_image, rembg_session)
97
+ input_image = resize_foreground(input_image, 0.85)
98
+ raw_image = np.array(input_image)
99
+ mask = (raw_image[..., -1][..., None] > 0) * 1
100
+ raw_image = raw_image[..., :3] * mask
101
+ input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
102
+
103
+ with torch.no_grad():
104
+ latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
105
+ batch = {}
106
+ inf_bs = 1
107
+ inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device)
108
+ y = conditioner.encoder(input_cond)
109
+ model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp)
110
+ if input_cfg >= 0:
111
+ model_kwargs['cfg_scale'] = input_cfg
112
+ for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device):
113
+ final_samples = samples
114
+ recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
115
+ if perchannel_norm:
116
+ recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean
117
+ recon_srt_param = recon_param[:, :, 0:4]
118
+ recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
119
+ recon_feat_param_list = []
120
+ # one-by-one to avoid oom
121
+ for inf_bidx in range(inf_bs):
122
+ if not perchannel_norm:
123
+ decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf)
124
+ else:
125
+ decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
126
+ recon_feat_param_list.append(decoded.detach())
127
+ recon_feat_param = torch.concat(recon_feat_param_list, dim=0)
128
+ # invert normalization
129
+ if not perchannel_norm:
130
+ recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05
131
+ recon_feat_param[:, 0:1, ...] /= 5.
132
+ recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2.
133
+ recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1)
134
+ recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
135
+ visualize_video_primvolume(config.output_dir, batch, recon_param, 60, rm, device)
136
+ prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
137
+ torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(config.output_dir))
138
+
139
+ # exporting GLB mesh
140
+ denoise_param_path = os.path.join(config.output_dir, 'denoised.pt')
141
+ primx_ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict']
142
+ model_primx.load_state_dict(ckpt_weight)
143
+ model_primx.to(device)
144
+ model_primx.eval()
145
+ with torch.no_grad():
146
+ model_primx.srt_param[:, 1:4] *= 0.85
147
+ extract_texmesh(config.inference, model_primx, output_glb_path, device)
148
+
149
+ return output_rgb_video_path, output_prim_video_path, output_mat_video_path, output_glb_path
150
+
151
+ # gradio UI
152
+ _TITLE = '''3DTopia-XL'''
153
+
154
+ _DESCRIPTION = '''
155
+ <div>
156
+ <a style="display:inline-block" href="https://frozenburning.github.io/projects/3DTopia-XL/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
157
+ <a style="display:inline-block; margin-left: .5em" href="https://github.com/3DTopia/3DTopia-XL"><img src='https://img.shields.io/github/stars/3DTopia/3DTopia-XL?style=social'/></a>
158
+ </div>
159
+
160
+ * Now we offer 1) single image conditioned model, we will release 2) multiview images conditioned model and 3) pure text conditioned model in the future!
161
+ * If you find the output unsatisfying, try using different seeds!
162
+ '''
163
+
164
+ block = gr.Blocks(title=_TITLE).queue()
165
+ with block:
166
+ with gr.Row():
167
+ with gr.Column(scale=1):
168
+ gr.Markdown('# ' + _TITLE)
169
+ gr.Markdown(_DESCRIPTION)
170
+
171
+ with gr.Row(variant='panel'):
172
+ with gr.Column(scale=1):
173
+ # input image
174
+ input_image = gr.Image(label="image", type='pil')
175
+ # inference steps
176
+ input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=25)
177
+ # random seed
178
+ input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=1, value=6)
179
+ # random seed
180
+ input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=42)
181
+ # gen button
182
+ button_gen = gr.Button("Generate")
183
+
184
+ with gr.Column(scale=1):
185
+ with gr.Tab("Video"):
186
+ # final video results
187
+ output_rgb_video = gr.Video(label="video")
188
+ output_prim_video = gr.Video(label="video")
189
+ output_mat_video = gr.Video(label="video")
190
+ with gr.Tab("GLB"):
191
+ # glb file
192
+ output_glb = gr.File(label="glb")
193
+
194
+ button_gen.click(process, inputs=[input_image, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb])
195
+
196
+ gr.Examples(
197
+ examples=[
198
+ "assets/examples/fruit_elephant.jpg",
199
+ "assets/examples/mei_ling_panda.png",
200
+ "assets/examples/shuai_panda_notail.png",
201
+ ],
202
+ inputs=[input_image],
203
+ outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb],
204
+ fn=lambda x: process(input_image=x),
205
+ cache_examples=False,
206
+ label='Single Image to 3D PBR Asset'
207
+ )
208
+
209
+ block.launch(server_name="0.0.0.0", share=True)
assets/examples/blue_cat.png ADDED
assets/examples/bubble_mart_blue.png ADDED
assets/examples/bulldog.png ADDED
assets/examples/ceramic.png ADDED
assets/examples/chair_watermelon.png ADDED
assets/examples/cup_rgba.png ADDED
assets/examples/cute_horse.jpg ADDED
assets/examples/earphone.jpg ADDED
assets/examples/firedragon.png ADDED
assets/examples/fox.jpg ADDED
assets/examples/fruit_elephant.jpg ADDED
assets/examples/hatsune_miku.png ADDED
assets/examples/ikun_rgba.png ADDED
assets/examples/mailbox.png ADDED
assets/examples/mario.png ADDED
assets/examples/mei_ling_panda.png ADDED
assets/examples/mushroom_teapot.jpg ADDED
assets/examples/pikachu.png ADDED
assets/examples/potplant_rgba.png ADDED
assets/examples/seed_frog.png ADDED
assets/examples/shuai_panda_notail.png ADDED
assets/examples/yellow_duck.png ADDED
configs/inference_dit.yml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ debug: False
2
+ root_data_dir: ./runs
3
+ checkpoint_path:
4
+ global_seed: 42
5
+
6
+ inference:
7
+ input_dir:
8
+ ddim: 25
9
+ cfg: 6
10
+ seed: ${global_seed}
11
+ precision: fp16
12
+ export_glb: True
13
+ decimate: 100000
14
+ mc_resolution: 256
15
+ batch_size: 4096
16
+ remesh: False
17
+
18
+ image_height: 518
19
+ image_width: 518
20
+
21
+ model:
22
+ class_name: models.primsdf.PrimSDF
23
+ num_prims: 2048
24
+ dim_feat: 6
25
+ prim_shape: 8
26
+ init_scale: 0.05 # useless if auto_scale_init == True
27
+ sdf2alpha_var: 0.005
28
+ auto_scale_init: True
29
+ init_sampling: uniform
30
+ vae:
31
+ class_name: models.vae3d_dib.VAE
32
+ in_channels: ${model.dim_feat}
33
+ latent_channels: 1
34
+ out_channels: ${model.vae.in_channels}
35
+ down_channels: [32, 256]
36
+ mid_attention: True
37
+ up_channels: [256, 32]
38
+ layers_per_block: 2
39
+ gradient_checkpointing: False
40
+ vae_checkpoint_path:
41
+ conditioner:
42
+ class_name: models.conditioner.image.ImageConditioner
43
+ num_prims: ${model.num_prims}
44
+ dim_feat: ${model.dim_feat}
45
+ prim_shape: ${model.prim_shape}
46
+ sample_view: False
47
+ encoder_config:
48
+ class_name: models.conditioner.image_dinov2.Dinov2Wrapper
49
+ model_name: dinov2_vitb14_reg
50
+ freeze: True
51
+ generator:
52
+ class_name: models.dit_crossattn.DiT
53
+ seq_length: ${model.num_prims}
54
+ in_channels: 68 # equals to model.vae.latent_channels * latent_dim^3
55
+ condition_channels: 768
56
+ hidden_size: 1152
57
+ depth: 28
58
+ num_heads: 16
59
+ attn_proj_bias: True
60
+ cond_drop_prob: 0.1
61
+ gradient_checkpointing: False
62
+ latent_nf: 1.0
63
+ latent_mean: [ 0.0442, -0.0029, -0.0425, -0.0043, -0.4086, -0.2906, -0.7002, -0.0852, -0.4446, -0.6896, -0.7344, -0.3524, -0.5488, -0.4313, -1.1715, -0.0875, -0.6131, -0.3924, -0.7335, -0.3749, 0.4658, -0.0236, 0.8362, 0.3388, 0.0188, 0.5988, -0.1853, 1.1579, 0.6240, 0.0758, 0.9641, 0.6586, 0.6260, 0.2384, 0.7798, 0.8297, -0.6543, -0.4441, -1.3887, -0.0393, -0.9008, -0.8616, -1.7434, -0.1328, -0.8119, -0.8225, -1.8533, -0.0444, -1.0510, -0.5158, -1.1907, -0.5265, 0.2832, 0.6037, 0.5981, 0.5461, 0.4366, 0.4144, 0.7219, 0.5722, 0.5937, 0.5598, 0.9414, 0.7419, 0.2102, 0.3388, 0.4501, 0.5166]
64
+ latent_std: [0.0219, 0.3707, 0.3911, 0.3610, 0.7549, 0.7909, 0.9691, 0.9193, 0.8218, 0.9389, 1.1785, 1.0254, 0.6376, 0.6568, 0.7892, 0.8468, 0.8775, 0.7920, 0.9037, 0.9329, 0.9196, 1.1123, 1.3041, 1.0955, 1.2727, 1.6565, 1.8502, 1.7006, 0.8973, 1.0408, 1.2034, 1.2703, 1.0373, 1.0486, 1.0716, 0.9746, 0.7088, 0.8685, 1.0030, 0.9504, 1.0410, 1.3033, 1.5368, 1.4386, 0.6142, 0.6887, 0.9085, 0.9903, 1.0190, 0.9302, 1.0121, 0.9964, 1.1474, 1.2729, 1.4627, 1.1404, 1.3713, 1.6692, 1.8424, 1.5047, 1.1356, 1.2369, 1.3554, 1.1848, 1.1319, 1.0822, 1.1972, 0.9916]
65
+
66
+ diffusion:
67
+ timestep_respacing:
68
+ noise_schedule: squaredcos_cap_v2
69
+ diffusion_steps: 1000
70
+ parameterization: v
71
+
72
+ rm:
73
+ volradius: 10000.0
74
+ dt: 1.0
75
+
76
+ optimizer:
77
+ class_name: torch.optim.AdamW
78
+ lr: 0.0001
79
+ weight_decay: 0
80
+
81
+ scheduler:
82
+ class_name: dva.scheduler.CosineWarmupScheduler
83
+ warmup_iters: 3000
84
+ max_iters: 200000
85
+
86
+ train:
87
+ batch_size: 8
88
+ n_workers: 4
89
+ n_epochs: 1000
90
+ log_every_n_steps: 50
91
+ summary_every_n_steps: 10000
92
+ ckpt_every_n_steps: 10000
93
+ amp: False
94
+ precision: tf32
95
+
96
+ tag: 3dtopia-xl-sview
97
+ output_dir: ${root_data_dir}/inference/${tag}
dva/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
dva/attr_dict.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+
9
+
10
+ class AttrDict:
11
+ def __init__(self, entries):
12
+ self.add_entries_(entries)
13
+
14
+ def keys(self):
15
+ return self.__dict__.keys()
16
+
17
+ def values(self):
18
+ return self.__dict__.values()
19
+
20
+ def __getitem__(self, key):
21
+ return self.__dict__[key]
22
+
23
+ def __setitem__(self, key, value):
24
+ self.__dict__[key] = value
25
+
26
+ def __delitem__(self, key):
27
+ return self.__dict__.__delitem__(key)
28
+
29
+ def __contains__(self, key):
30
+ return key in self.__dict__
31
+
32
+ def __repr__(self):
33
+ return self.__dict__.__repr__()
34
+
35
+ def __getattr__(self, attr):
36
+ if attr.startswith("__"):
37
+ return self.__getattribute__(attr)
38
+ return self.__dict__[attr]
39
+
40
+ def items(self):
41
+ return self.__dict__.items()
42
+
43
+ def __iter__(self):
44
+ return iter(self.items())
45
+
46
+ def add_entries_(self, entries, overwrite=True):
47
+ for key, value in entries.items():
48
+ if key not in self.__dict__:
49
+ if isinstance(value, dict):
50
+ self.__dict__[key] = AttrDict(value)
51
+ else:
52
+ self.__dict__[key] = value
53
+ else:
54
+ if isinstance(value, dict):
55
+ self.__dict__[key].add_entries_(entries=value, overwrite=overwrite)
56
+ elif overwrite or self.__dict__[key] is None:
57
+ self.__dict__[key] = value
58
+
59
+ def serialize(self):
60
+ return json.dumps(self, default=self.obj_to_dict, indent=4)
61
+
62
+ def obj_to_dict(self, obj):
63
+ return obj.__dict__
64
+
65
+ def get(self, key, default=None):
66
+ return self.__dict__.get(key, default)
dva/geom.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import numpy as np
3
+ import torch as th
4
+ import torch.nn.functional as F
5
+ import torch.nn as nn
6
+
7
+ from sklearn.neighbors import KDTree
8
+
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # NOTE: we need pytorch3d primarily for UV rasterization things
14
+ from pytorch3d.renderer.mesh.rasterize_meshes import rasterize_meshes
15
+ from pytorch3d.structures import Meshes
16
+ from typing import Union, Optional, Tuple
17
+ import trimesh
18
+ from trimesh import Trimesh
19
+ from trimesh.triangles import points_to_barycentric
20
+
21
+ try:
22
+ # pyre-fixme[21]: Could not find module `igl`.
23
+ from igl import point_mesh_squared_distance # @manual
24
+
25
+ # pyre-fixme[3]: Return type must be annotated.
26
+ # pyre-fixme[2]: Parameter must be annotated.
27
+ def closest_point(mesh, points):
28
+ """Helper function that mimics trimesh.proximity.closest_point but uses
29
+ IGL for faster queries."""
30
+ v = mesh.vertices
31
+ vi = mesh.faces
32
+ dist, face_idxs, p = point_mesh_squared_distance(points, v, vi)
33
+ return p, dist, face_idxs
34
+
35
+ except ImportError:
36
+ from trimesh.proximity import closest_point
37
+
38
+
39
+ def closest_point_barycentrics(v, vi, points):
40
+ """Given a 3D mesh and a set of query points, return closest point barycentrics
41
+ Args:
42
+ v: np.array (float)
43
+ [N, 3] mesh vertices
44
+
45
+ vi: np.array (int)
46
+ [N, 3] mesh triangle indices
47
+
48
+ points: np.array (float)
49
+ [M, 3] query points
50
+
51
+ Returns:
52
+ Tuple[approx, barys, interp_idxs, face_idxs]
53
+ approx: [M, 3] approximated (closest) points on the mesh
54
+ barys: [M, 3] barycentric weights that produce "approx"
55
+ interp_idxs: [M, 3] vertex indices for barycentric interpolation
56
+ face_idxs: [M] face indices for barycentric interpolation. interp_idxs = vi[face_idxs]
57
+ """
58
+ mesh = Trimesh(vertices=v, faces=vi, process=False)
59
+ p, _, face_idxs = closest_point(mesh, points)
60
+ p = p.reshape((points.shape[0], 3))
61
+ face_idxs = face_idxs.reshape((points.shape[0],))
62
+ barys = points_to_barycentric(mesh.triangles[face_idxs], p)
63
+ b0, b1, b2 = np.split(barys, 3, axis=1)
64
+
65
+ interp_idxs = vi[face_idxs]
66
+ v0 = v[interp_idxs[:, 0]]
67
+ v1 = v[interp_idxs[:, 1]]
68
+ v2 = v[interp_idxs[:, 2]]
69
+ approx = b0 * v0 + b1 * v1 + b2 * v2
70
+ return approx, barys, interp_idxs, face_idxs
71
+
72
+ def make_uv_face_index(
73
+ vt: th.Tensor,
74
+ vti: th.Tensor,
75
+ uv_shape: Union[Tuple[int, int], int],
76
+ flip_uv: bool = True,
77
+ device: Optional[Union[str, th.device]] = None,
78
+ ):
79
+ """Compute a UV-space face index map identifying which mesh face contains each
80
+ texel. For texels with no assigned triangle, the index will be -1."""
81
+
82
+ if isinstance(uv_shape, int):
83
+ uv_shape = (uv_shape, uv_shape)
84
+
85
+ uv_max_shape_ind = uv_shape.index(max(uv_shape))
86
+ uv_min_shape_ind = uv_shape.index(min(uv_shape))
87
+ uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
88
+
89
+ if device is not None:
90
+ if isinstance(device, str):
91
+ dev = th.device(device)
92
+ else:
93
+ dev = device
94
+ assert dev.type == "cuda"
95
+ else:
96
+ dev = th.device("cuda")
97
+
98
+ vt = 1.0 - vt.clone()
99
+
100
+ if flip_uv:
101
+ vt = vt.clone()
102
+ vt[:, 1] = 1 - vt[:, 1]
103
+ vt_pix = 2.0 * vt.to(dev) - 1.0
104
+ vt_pix = th.cat([vt_pix, th.ones_like(vt_pix[:, 0:1])], dim=1)
105
+
106
+ vt_pix[:, uv_min_shape_ind] *= uv_ratio
107
+ meshes = Meshes(vt_pix[np.newaxis], vti[np.newaxis].to(dev))
108
+ with th.no_grad():
109
+ face_index, _, _, _ = rasterize_meshes(
110
+ meshes, uv_shape, faces_per_pixel=1, z_clip_value=0.0, bin_size=0
111
+ )
112
+ face_index = face_index[0, ..., 0]
113
+ return face_index
114
+
115
+
116
+ def make_uv_vert_index(
117
+ vt: th.Tensor,
118
+ vi: th.Tensor,
119
+ vti: th.Tensor,
120
+ uv_shape: Union[Tuple[int, int], int],
121
+ flip_uv: bool = True,
122
+ ):
123
+ """Compute a UV-space vertex index map identifying which mesh vertices
124
+ comprise the triangle containing each texel. For texels with no assigned
125
+ triangle, all indices will be -1.
126
+ """
127
+ face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv)
128
+ vert_index_map = vi[face_index_map.clamp(min=0)]
129
+ vert_index_map[face_index_map < 0] = -1
130
+ return vert_index_map.long()
131
+
132
+
133
+ def bary_coords(points: th.Tensor, triangles: th.Tensor, eps: float = 1.0e-6):
134
+ """Computes barycentric coordinates for a set of 2D query points given
135
+ coordintes for the 3 vertices of the enclosing triangle for each point."""
136
+ x = points[:, 0] - triangles[2, :, 0]
137
+ x1 = triangles[0, :, 0] - triangles[2, :, 0]
138
+ x2 = triangles[1, :, 0] - triangles[2, :, 0]
139
+ y = points[:, 1] - triangles[2, :, 1]
140
+ y1 = triangles[0, :, 1] - triangles[2, :, 1]
141
+ y2 = triangles[1, :, 1] - triangles[2, :, 1]
142
+ denom = y2 * x1 - y1 * x2
143
+ n0 = y2 * x - x2 * y
144
+ n1 = x1 * y - y1 * x
145
+
146
+ # Small epsilon to prevent divide-by-zero error.
147
+ denom = th.where(denom >= 0, denom.clamp(min=eps), denom.clamp(max=-eps))
148
+
149
+ bary_0 = n0 / denom
150
+ bary_1 = n1 / denom
151
+ bary_2 = 1.0 - bary_0 - bary_1
152
+
153
+ return th.stack((bary_0, bary_1, bary_2))
154
+
155
+
156
+ def make_uv_barys(
157
+ vt: th.Tensor,
158
+ vti: th.Tensor,
159
+ uv_shape: Union[Tuple[int, int], int],
160
+ flip_uv: bool = True,
161
+ ):
162
+ """Compute a UV-space barycentric map where each texel contains barycentric
163
+ coordinates for that texel within its enclosing UV triangle. For texels
164
+ with no assigned triangle, all 3 barycentric coordinates will be 0.
165
+ """
166
+ if isinstance(uv_shape, int):
167
+ uv_shape = (uv_shape, uv_shape)
168
+
169
+ if flip_uv:
170
+ # Flip here because texture coordinates in some of our topo files are
171
+ # stored in OpenGL convention with Y=0 on the bottom of the texture
172
+ # unlike numpy/torch arrays/tensors.
173
+ vt = vt.clone()
174
+ vt[:, 1] = 1 - vt[:, 1]
175
+
176
+ face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv=False)
177
+ vti_map = vti.long()[face_index_map.clamp(min=0)]
178
+
179
+ uv_max_shape_ind = uv_shape.index(max(uv_shape))
180
+ uv_min_shape_ind = uv_shape.index(min(uv_shape))
181
+ uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
182
+ vt = vt.clone()
183
+ vt = vt * 2 - 1
184
+ vt[:, uv_min_shape_ind] *= uv_ratio
185
+ uv_tri_uvs = vt[vti_map].permute(2, 0, 1, 3)
186
+
187
+ uv_grid = th.meshgrid(
188
+ th.linspace(0.5, uv_shape[0] - 0.5, uv_shape[0]) / uv_shape[0],
189
+ th.linspace(0.5, uv_shape[1] - 0.5, uv_shape[1]) / uv_shape[1],
190
+ )
191
+ uv_grid = th.stack(uv_grid[::-1], dim=2).to(uv_tri_uvs)
192
+ uv_grid = uv_grid * 2 - 1
193
+ uv_grid[..., uv_min_shape_ind] *= uv_ratio
194
+
195
+ bary_map = bary_coords(uv_grid.view(-1, 2), uv_tri_uvs.view(3, -1, 2))
196
+ bary_map = bary_map.permute(1, 0).view(uv_shape[0], uv_shape[1], 3)
197
+ bary_map[face_index_map < 0] = 0
198
+ return face_index_map, bary_map
199
+
200
+
201
+ def index_image_impaint(
202
+ index_image: th.Tensor,
203
+ bary_image: Optional[th.Tensor] = None,
204
+ distance_threshold=100.0,
205
+ ):
206
+ # getting the mask around the indexes?
207
+ if len(index_image.shape) == 3:
208
+ valid_index = (index_image != -1).any(dim=-1)
209
+ elif len(index_image.shape) == 2:
210
+ valid_index = index_image != -1
211
+ else:
212
+ raise ValueError("`index_image` should be a [H,W] or [H,W,C] image")
213
+
214
+ invalid_index = ~valid_index
215
+
216
+ device = index_image.device
217
+
218
+ valid_ij = th.stack(th.where(valid_index), dim=-1)
219
+ invalid_ij = th.stack(th.where(invalid_index), dim=-1)
220
+ lookup_valid = KDTree(valid_ij.cpu().numpy())
221
+
222
+ dists, idxs = lookup_valid.query(invalid_ij.cpu())
223
+
224
+ # TODO: try average?
225
+ idxs = th.as_tensor(idxs, device=device)[..., 0]
226
+ dists = th.as_tensor(dists, device=device)[..., 0]
227
+
228
+ dist_mask = dists < distance_threshold
229
+
230
+ invalid_border = th.zeros_like(invalid_index)
231
+ invalid_border[invalid_index] = dist_mask
232
+
233
+ invalid_src_ij = valid_ij[idxs][dist_mask]
234
+ invalid_dst_ij = invalid_ij[dist_mask]
235
+
236
+ index_image_imp = index_image.clone()
237
+
238
+ index_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = index_image[
239
+ invalid_src_ij[:, 0], invalid_src_ij[:, 1]
240
+ ]
241
+
242
+ if bary_image is not None:
243
+ bary_image_imp = bary_image.clone()
244
+
245
+ bary_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = bary_image[
246
+ invalid_src_ij[:, 0], invalid_src_ij[:, 1]
247
+ ]
248
+
249
+ return index_image_imp, bary_image_imp
250
+ return index_image_imp
251
+
252
+
253
+ class GeometryModule(nn.Module):
254
+ def __init__(
255
+ self,
256
+ v,
257
+ vi,
258
+ vt,
259
+ vti,
260
+ uv_size,
261
+ v2uv: Optional[th.Tensor] = None,
262
+ flip_uv=False,
263
+ impaint=False,
264
+ impaint_threshold=100.0,
265
+ ):
266
+ super().__init__()
267
+
268
+ self.register_buffer("v", th.as_tensor(v))
269
+ self.register_buffer("vi", th.as_tensor(vi))
270
+ self.register_buffer("vt", th.as_tensor(vt))
271
+ self.register_buffer("vti", th.as_tensor(vti))
272
+ if v2uv is not None:
273
+ self.register_buffer("v2uv", th.as_tensor(v2uv, dtype=th.int64))
274
+
275
+ # TODO: should we just pass topology here?
276
+ # self.n_verts = v2uv.shape[0]
277
+ self.n_verts = vi.max() + 1
278
+
279
+ self.uv_size = uv_size
280
+
281
+ # TODO: can't we just index face_index?
282
+ index_image = make_uv_vert_index(
283
+ self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv
284
+ ).cpu()
285
+ face_index, bary_image = make_uv_barys(
286
+ self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv
287
+ )
288
+ if impaint:
289
+ if min(uv_size) >= 1024:
290
+ logger.info(
291
+ "impainting index image might take a while for sizes >= 1024"
292
+ )
293
+
294
+ index_image, bary_image = index_image_impaint(
295
+ index_image, bary_image, impaint_threshold
296
+ )
297
+ # TODO: we can avoid doing this 2x
298
+ face_index = index_image_impaint(
299
+ face_index, distance_threshold=impaint_threshold
300
+ )
301
+
302
+ self.register_buffer("index_image", index_image.cpu())
303
+ self.register_buffer("bary_image", bary_image.cpu())
304
+ self.register_buffer("face_index_image", face_index.cpu())
305
+
306
+ def render_index_images(self, uv_size, flip_uv=False, impaint=False):
307
+ index_image = make_uv_vert_index(
308
+ self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv
309
+ )
310
+ face_image, bary_image = make_uv_barys(
311
+ self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv
312
+ )
313
+
314
+ if impaint:
315
+ index_image, bary_image = index_image_impaint(
316
+ index_image,
317
+ bary_image,
318
+ )
319
+
320
+ return index_image, face_image, bary_image
321
+
322
+ def vn(self, verts):
323
+ return vert_normals(verts, self.vi[np.newaxis].to(th.long))
324
+
325
+ def to_uv(self, values):
326
+ return values_to_uv(values, self.index_image, self.bary_image)
327
+
328
+ def from_uv(self, values_uv):
329
+ # TODO: we need to sample this
330
+ return sample_uv(values_uv, self.vt, self.v2uv.to(th.long))
331
+
332
+ def rand_sample_3d_uv(self, count, uv_img):
333
+ """
334
+ Sample a set of 3D points on the surface of mesh, return corresponding interpolated values in UV space.
335
+
336
+ Args:
337
+ count - num of 3D points to be sampled
338
+
339
+ uv_img - the image in uv space to be sampled, e.g., texture
340
+ """
341
+ _mesh = Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.vi.detach().cpu().numpy(), process=False)
342
+ points, _ = trimesh.sample.sample_surface(_mesh, count)
343
+ return self.sample_uv_from_3dpts(points, uv_img)
344
+
345
+ def sample_uv_from_3dpts(self, points, uv_img):
346
+ num_pts = points.shape[0]
347
+ approx, barys, interp_idxs, face_idxs = closest_point_barycentrics(self.v.detach().cpu().numpy(), self.vi.detach().cpu().numpy(), points)
348
+ interp_uv_coords = self.vt[interp_idxs, :] # [N, 3, 2]
349
+ # do bary interp first to get interp_uv_coord in high-reso uv space
350
+ target_uv_coords = th.sum(interp_uv_coords * th.from_numpy(barys)[..., None], dim=1).float()
351
+ # then directly sample from uv space
352
+ sampled_values = sample_uv(values_uv=uv_img.permute(2, 0, 1)[None, ...], uv_coords=target_uv_coords) # [1, count, c]
353
+ approx_values = sampled_values[0].reshape(num_pts, uv_img.shape[2])
354
+ return approx_values.numpy(), points
355
+
356
+ def vert_sample_uv(self, uv_img):
357
+ count = self.v.shape[0]
358
+ points = self.v.detach().cpu().numpy()
359
+ approx_values, _ = self.sample_uv_from_3dpts(points, uv_img)
360
+ return approx_values
361
+
362
+
363
+ def sample_uv(
364
+ values_uv,
365
+ uv_coords,
366
+ v2uv: Optional[th.Tensor] = None,
367
+ mode: str = "bilinear",
368
+ align_corners: bool = True,
369
+ flip_uvs: bool = False,
370
+ ):
371
+ batch_size = values_uv.shape[0]
372
+
373
+ if flip_uvs:
374
+ uv_coords = uv_coords.clone()
375
+ uv_coords[:, 1] = 1.0 - uv_coords[:, 1]
376
+
377
+ # uv_coords_norm is [1, N, 1, 2] afterwards
378
+ uv_coords_norm = (uv_coords * 2.0 - 1.0)[np.newaxis, :, np.newaxis].expand(
379
+ batch_size, -1, -1, -1
380
+ )
381
+ # uv_shape = values_uv.shape[-2:]
382
+ # uv_max_shape_ind = uv_shape.index(max(uv_shape))
383
+ # uv_min_shape_ind = uv_shape.index(min(uv_shape))
384
+ # uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
385
+ # uv_coords_norm[..., uv_min_shape_ind] *= uv_ratio
386
+
387
+ values = (
388
+ F.grid_sample(values_uv, uv_coords_norm, align_corners=align_corners, mode=mode)
389
+ .squeeze(-1)
390
+ .permute((0, 2, 1))
391
+ )
392
+
393
+ if v2uv is not None:
394
+ values_duplicate = values[:, v2uv]
395
+ values = values_duplicate.mean(2)
396
+
397
+ return values
398
+
399
+
400
+ def values_to_uv(values, index_img, bary_img):
401
+ uv_size = index_img.shape
402
+ index_mask = th.all(index_img != -1, dim=-1)
403
+ idxs_flat = index_img[index_mask].to(th.int64)
404
+ bary_flat = bary_img[index_mask].to(th.float32)
405
+ # NOTE: here we assume
406
+ values_flat = th.sum(values[:, idxs_flat].permute(0, 3, 1, 2) * bary_flat, dim=-1)
407
+ values_uv = th.zeros(
408
+ values.shape[0],
409
+ values.shape[-1],
410
+ uv_size[0],
411
+ uv_size[1],
412
+ dtype=values.dtype,
413
+ device=values.device,
414
+ )
415
+ values_uv[:, :, index_mask] = values_flat
416
+ return values_uv
417
+
418
+
419
+ def face_normals(v, vi, eps: float = 1e-5):
420
+ pts = v[:, vi]
421
+ v0 = pts[:, :, 1] - pts[:, :, 0]
422
+ v1 = pts[:, :, 2] - pts[:, :, 0]
423
+ n = th.cross(v0, v1, dim=-1)
424
+ norm = th.norm(n, dim=-1, keepdim=True)
425
+ norm[norm < eps] = 1
426
+ n /= norm
427
+ return n
428
+
429
+
430
+ def vert_normals(v, vi, eps: float = 1.0e-5):
431
+ fnorms = face_normals(v, vi)
432
+ fnorms = fnorms[:, :, None].expand(-1, -1, 3, -1).reshape(fnorms.shape[0], -1, 3)
433
+ vi_flat = vi.view(1, -1).expand(v.shape[0], -1)
434
+ vnorms = th.zeros_like(v)
435
+ for j in range(3):
436
+ vnorms[..., j].scatter_add_(1, vi_flat, fnorms[..., j])
437
+ norm = th.norm(vnorms, dim=-1, keepdim=True)
438
+ norm[norm < eps] = 1
439
+ vnorms /= norm
440
+ return vnorms
441
+
442
+
443
+ def compute_view_cos(verts, faces, camera_pos):
444
+ vn = F.normalize(vert_normals(verts, faces), dim=-1)
445
+ v2c = F.normalize(verts - camera_pos[:, np.newaxis], dim=-1)
446
+ return th.einsum("bnd,bnd->bn", vn, v2c)
447
+
448
+
449
+ def compute_tbn(geom, vt, vi, vti):
450
+ """Computes tangent, bitangent, and normal vectors given a mesh.
451
+ Args:
452
+ geom: [N, n_verts, 3] th.Tensor
453
+ Vertex positions.
454
+ vt: [n_uv_coords, 2] th.Tensor
455
+ UV coordinates.
456
+ vi: [..., 3] th.Tensor
457
+ Face vertex indices.
458
+ vti: [..., 3] th.Tensor
459
+ Face UV indices.
460
+ Returns:
461
+ [..., 3] th.Tensors for T, B, N.
462
+ """
463
+
464
+ v0 = geom[:, vi[..., 0]]
465
+ v1 = geom[:, vi[..., 1]]
466
+ v2 = geom[:, vi[..., 2]]
467
+ vt0 = vt[vti[..., 0]]
468
+ vt1 = vt[vti[..., 1]]
469
+ vt2 = vt[vti[..., 2]]
470
+
471
+ v01 = v1 - v0
472
+ v02 = v2 - v0
473
+ vt01 = vt1 - vt0
474
+ vt02 = vt2 - vt0
475
+ f = 1.0 / (
476
+ vt01[None, ..., 0] * vt02[None, ..., 1]
477
+ - vt01[None, ..., 1] * vt02[None, ..., 0]
478
+ )
479
+ tangent = f[..., None] * th.stack(
480
+ [
481
+ v01[..., 0] * vt02[None, ..., 1] - v02[..., 0] * vt01[None, ..., 1],
482
+ v01[..., 1] * vt02[None, ..., 1] - v02[..., 1] * vt01[None, ..., 1],
483
+ v01[..., 2] * vt02[None, ..., 1] - v02[..., 2] * vt01[None, ..., 1],
484
+ ],
485
+ dim=-1,
486
+ )
487
+ tangent = F.normalize(tangent, dim=-1)
488
+ normal = F.normalize(th.cross(v01, v02, dim=3), dim=-1)
489
+ bitangent = F.normalize(th.cross(tangent, normal, dim=3), dim=-1)
490
+
491
+ return tangent, bitangent, normal
492
+
493
+
494
+ def compute_v2uv(n_verts, vi, vti, n_max=4):
495
+ """Computes mapping from vertex indices to texture indices.
496
+
497
+ Args:
498
+ vi: [F, 3], triangles
499
+ vti: [F, 3], texture triangles
500
+ n_max: int, max number of texture locations
501
+
502
+ Returns:
503
+ [n_verts, n_max], texture indices
504
+ """
505
+ v2uv_dict = {}
506
+ for i_v, i_uv in zip(vi.reshape(-1), vti.reshape(-1)):
507
+ v2uv_dict.setdefault(i_v, set()).add(i_uv)
508
+ assert len(v2uv_dict) == n_verts
509
+ v2uv = np.zeros((n_verts, n_max), dtype=np.int32)
510
+ for i in range(n_verts):
511
+ vals = sorted(list(v2uv_dict[i]))
512
+ v2uv[i, :] = vals[0]
513
+ v2uv[i, : len(vals)] = np.array(vals)
514
+ return v2uv
515
+
516
+
517
+ def compute_neighbours(n_verts, vi, n_max_values=10):
518
+ """Computes first-ring neighbours given vertices and faces."""
519
+ n_vi = vi.shape[0]
520
+
521
+ adj = {i: set() for i in range(n_verts)}
522
+ for i in range(n_vi):
523
+ for idx in vi[i]:
524
+ adj[idx] |= set(vi[i]) - set([idx])
525
+
526
+ nbs_idxs = np.tile(np.arange(n_verts)[:, np.newaxis], (1, n_max_values))
527
+ nbs_weights = np.zeros((n_verts, n_max_values), dtype=np.float32)
528
+
529
+ for idx in range(n_verts):
530
+ n_values = min(len(adj[idx]), n_max_values)
531
+ nbs_idxs[idx, :n_values] = np.array(list(adj[idx]))[:n_values]
532
+ nbs_weights[idx, :n_values] = -1.0 / n_values
533
+
534
+ return nbs_idxs, nbs_weights
535
+
536
+
537
+ def make_postex(v, idxim, barim):
538
+ return (
539
+ barim[None, :, :, 0, None] * v[:, idxim[:, :, 0]]
540
+ + barim[None, :, :, 1, None] * v[:, idxim[:, :, 1]]
541
+ + barim[None, :, :, 2, None] * v[:, idxim[:, :, 2]]
542
+ ).permute(0, 3, 1, 2)
543
+
544
+
545
+ def matrix_to_axisangle(r):
546
+ th = th.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.0))[..., None]
547
+ vec = (
548
+ 0.5
549
+ * th.stack(
550
+ [
551
+ r[..., 2, 1] - r[..., 1, 2],
552
+ r[..., 0, 2] - r[..., 2, 0],
553
+ r[..., 1, 0] - r[..., 0, 1],
554
+ ],
555
+ dim=-1,
556
+ )
557
+ / th.sin(th)
558
+ )
559
+ return th, vec
560
+
561
+
562
+ def axisangle_to_matrix(rvec):
563
+ theta = th.sqrt(1e-5 + th.sum(rvec**2, dim=-1))
564
+ rvec = rvec / theta[..., None]
565
+ costh = th.cos(theta)
566
+ sinth = th.sin(theta)
567
+ return th.stack(
568
+ (
569
+ th.stack(
570
+ (
571
+ rvec[..., 0] ** 2 + (1.0 - rvec[..., 0] ** 2) * costh,
572
+ rvec[..., 0] * rvec[..., 1] * (1.0 - costh) - rvec[..., 2] * sinth,
573
+ rvec[..., 0] * rvec[..., 2] * (1.0 - costh) + rvec[..., 1] * sinth,
574
+ ),
575
+ dim=-1,
576
+ ),
577
+ th.stack(
578
+ (
579
+ rvec[..., 0] * rvec[..., 1] * (1.0 - costh) + rvec[..., 2] * sinth,
580
+ rvec[..., 1] ** 2 + (1.0 - rvec[..., 1] ** 2) * costh,
581
+ rvec[..., 1] * rvec[..., 2] * (1.0 - costh) - rvec[..., 0] * sinth,
582
+ ),
583
+ dim=-1,
584
+ ),
585
+ th.stack(
586
+ (
587
+ rvec[..., 0] * rvec[..., 2] * (1.0 - costh) - rvec[..., 1] * sinth,
588
+ rvec[..., 1] * rvec[..., 2] * (1.0 - costh) + rvec[..., 0] * sinth,
589
+ rvec[..., 2] ** 2 + (1.0 - rvec[..., 2] ** 2) * costh,
590
+ ),
591
+ dim=-1,
592
+ ),
593
+ ),
594
+ dim=-2,
595
+ )
596
+
597
+
598
+ def rotation_interp(r0, r1, alpha):
599
+ r0a = r0.view(-1, 3, 3)
600
+ r1a = r1.view(-1, 3, 3)
601
+ r = th.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0)
602
+
603
+ th, rvec = matrix_to_axisangle(r)
604
+ rvec = rvec * (alpha * th)
605
+
606
+ r = axisangle_to_matrix(rvec)
607
+ return th.bmm(r0a, r.view(-1, 3, 3)).view_as(r0)
608
+
609
+
610
+ def convert_camera_parameters(Rt, K):
611
+ R = Rt[:, :3, :3]
612
+ t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2)
613
+ return dict(
614
+ campos=t,
615
+ camrot=R,
616
+ focal=K[:, :2, :2],
617
+ princpt=K[:, :2, 2],
618
+ )
619
+
620
+
621
+ def project_points_multi(p, Rt, K, normalize=False, size=None):
622
+ """Project a set of 3D points into multiple cameras with a pinhole model.
623
+ Args:
624
+ p: [B, N, 3], input 3D points in world coordinates
625
+ Rt: [B, NC, 3, 4], extrinsics (where NC is the number of cameras to project to)
626
+ K: [B, NC, 3, 3], intrinsics
627
+ normalize: bool, whether to normalize coordinates to [-1.0, 1.0]
628
+ Returns:
629
+ tuple:
630
+ - [B, NC, N, 2] - projected points
631
+ - [B, NC, N] - their
632
+ """
633
+ B, N = p.shape[:2]
634
+ NC = Rt.shape[1]
635
+
636
+ Rt = Rt.reshape(B * NC, 3, 4)
637
+ K = K.reshape(B * NC, 3, 3)
638
+
639
+ # [B, N, 3] -> [B * NC, N, 3]
640
+ p = p[:, np.newaxis].expand(-1, NC, -1, -1).reshape(B * NC, -1, 3)
641
+ p_cam = p @ Rt[:, :3, :3].transpose(-2, -1) + Rt[:, :3, 3][:, np.newaxis]
642
+ p_pix = p_cam @ K.transpose(-2, -1)
643
+ p_depth = p_pix[:, :, 2:]
644
+ p_pix = (p_pix[..., :2] / p_depth).reshape(B, NC, N, 2)
645
+ p_depth = p_depth.reshape(B, NC, N)
646
+
647
+ if normalize:
648
+ assert size is not None
649
+ h, w = size
650
+ p_pix = (
651
+ 2.0 * p_pix / th.as_tensor([w, h], dtype=th.float32, device=p.device) - 1.0
652
+ )
653
+ return p_pix, p_depth
dva/io.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import cv2
9
+ import numpy as np
10
+ import copy
11
+ import importlib
12
+ from typing import Any, Dict
13
+
14
+ def load_module(module_name, class_name=None, silent: bool = False):
15
+ module = importlib.import_module(module_name)
16
+ return getattr(module, class_name) if class_name else module
17
+
18
+
19
+ def load_class(class_name):
20
+ return load_module(*class_name.rsplit(".", 1))
21
+
22
+
23
+ def load_from_config(config, **kwargs):
24
+ """Instantiate an object given a config and arguments."""
25
+ assert "class_name" in config and "module_name" not in config
26
+ config = copy.deepcopy(config)
27
+ class_name = config.pop("class_name")
28
+ object_class = load_class(class_name)
29
+ return object_class(**config, **kwargs)
30
+
31
+
32
+ def load_opencv_calib(extrin_path, intrin_path):
33
+ cameras = {}
34
+
35
+ fse = cv2.FileStorage()
36
+ fse.open(extrin_path, cv2.FileStorage_READ)
37
+
38
+ fsi = cv2.FileStorage()
39
+ fsi.open(intrin_path, cv2.FileStorage_READ)
40
+
41
+ names = [
42
+ fse.getNode("names").at(c).string() for c in range(fse.getNode("names").size())
43
+ ]
44
+
45
+ for camera in names:
46
+ rot = fse.getNode(f"R_{camera}").mat()
47
+ R = fse.getNode(f"Rot_{camera}").mat()
48
+ T = fse.getNode(f"T_{camera}").mat()
49
+ R_pred = cv2.Rodrigues(rot)[0]
50
+ assert np.all(np.isclose(R_pred, R))
51
+ K = fsi.getNode(f"K_{camera}").mat()
52
+ cameras[camera] = {
53
+ "Rt": np.concatenate([R, T], axis=1).astype(np.float32),
54
+ "K": K.astype(np.float32),
55
+ }
56
+ return cameras
dva/layers.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+ import numpy as np
11
+
12
+ from dva.mvp.models.utils import Conv2dWN, Conv2dWNUB, ConvTranspose2dWNUB, initmod
13
+
14
+
15
+ class ConvBlock(nn.Module):
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ out_channels,
20
+ size,
21
+ lrelu_slope=0.2,
22
+ kernel_size=3,
23
+ padding=1,
24
+ wnorm_dim=0,
25
+ ):
26
+ super().__init__()
27
+
28
+ self.conv_resize = Conv2dWN(in_channels, out_channels, kernel_size=1)
29
+ self.conv1 = Conv2dWNUB(
30
+ in_channels,
31
+ in_channels,
32
+ kernel_size=kernel_size,
33
+ padding=padding,
34
+ height=size,
35
+ width=size,
36
+ )
37
+
38
+ self.lrelu1 = nn.LeakyReLU(lrelu_slope)
39
+ self.conv2 = Conv2dWNUB(
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size=kernel_size,
43
+ padding=padding,
44
+ height=size,
45
+ width=size,
46
+ )
47
+ self.lrelu2 = nn.LeakyReLU(lrelu_slope)
48
+
49
+ def forward(self, x):
50
+ x_skip = self.conv_resize(x)
51
+ x = self.conv1(x)
52
+ x = self.lrelu1(x)
53
+ x = self.conv2(x)
54
+ x = self.lrelu2(x)
55
+ return x + x_skip
56
+
57
+
58
+ def tile2d(x, size: int):
59
+ """Tile a given set of features into a convolutional map.
60
+
61
+ Args:
62
+ x: float tensor of shape [N, F]
63
+ size: int or a tuple
64
+
65
+ Returns:
66
+ a feature map [N, F, size[0], size[1]]
67
+ """
68
+ # size = size if isinstance(size, tuple) else (size, size)
69
+ # NOTE: expecting only int here (!!!)
70
+ return x[:, :, np.newaxis, np.newaxis].expand(-1, -1, size, size)
71
+
72
+
73
+ def weights_initializer(m, alpha: float = 1.0):
74
+ return initmod(m, nn.init.calculate_gain("leaky_relu", alpha))
75
+
76
+
77
+ class UNetWB(nn.Module):
78
+ def __init__(
79
+ self,
80
+ in_channels,
81
+ out_channels,
82
+ size,
83
+ n_init_ftrs=8,
84
+ out_scale=0.1,
85
+ ):
86
+ # super().__init__(*args, **kwargs)
87
+ super().__init__()
88
+
89
+ self.out_scale = 0.1
90
+
91
+ F = n_init_ftrs
92
+
93
+ # TODO: allow changing the size?
94
+ self.size = size
95
+
96
+ self.down1 = nn.Sequential(
97
+ Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1),
98
+ nn.LeakyReLU(0.2),
99
+ )
100
+ self.down2 = nn.Sequential(
101
+ Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
102
+ nn.LeakyReLU(0.2),
103
+ )
104
+ self.down3 = nn.Sequential(
105
+ Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
106
+ nn.LeakyReLU(0.2),
107
+ )
108
+ self.down4 = nn.Sequential(
109
+ Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1),
110
+ nn.LeakyReLU(0.2),
111
+ )
112
+ self.down5 = nn.Sequential(
113
+ Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1),
114
+ nn.LeakyReLU(0.2),
115
+ )
116
+ self.up1 = nn.Sequential(
117
+ ConvTranspose2dWNUB(
118
+ 16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1
119
+ ),
120
+ nn.LeakyReLU(0.2),
121
+ )
122
+ self.up2 = nn.Sequential(
123
+ ConvTranspose2dWNUB(8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
124
+ nn.LeakyReLU(0.2),
125
+ )
126
+ self.up3 = nn.Sequential(
127
+ ConvTranspose2dWNUB(4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
128
+ nn.LeakyReLU(0.2),
129
+ )
130
+ self.up4 = nn.Sequential(
131
+ ConvTranspose2dWNUB(2 * F, F, self.size // 2, self.size // 2, 4, 2, 1),
132
+ nn.LeakyReLU(0.2),
133
+ )
134
+ self.up5 = nn.Sequential(
135
+ ConvTranspose2dWNUB(F, F, self.size, self.size, 4, 2, 1), nn.LeakyReLU(0.2)
136
+ )
137
+ self.out = Conv2dWNUB(
138
+ F + in_channels, out_channels, self.size, self.size, kernel_size=1
139
+ )
140
+ self.apply(lambda x: initmod(x, 0.2))
141
+ initmod(self.out, 1.0)
142
+
143
+ def forward(self, x):
144
+ x1 = x
145
+ x2 = self.down1(x1)
146
+ x3 = self.down2(x2)
147
+ x4 = self.down3(x3)
148
+ x5 = self.down4(x4)
149
+ x6 = self.down5(x5)
150
+ # TODO: switch to concat?
151
+ x = self.up1(x6) + x5
152
+ x = self.up2(x) + x4
153
+ x = self.up3(x) + x3
154
+ x = self.up4(x) + x2
155
+ x = self.up5(x)
156
+ x = th.cat([x, x1], dim=1)
157
+ return self.out(x) * self.out_scale
dva/losses.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+ import torch as th
9
+ import numpy as np
10
+
11
+ import logging
12
+
13
+ from .vgg import VGGLossMasked
14
+
15
+ logger = logging.getLogger("dva.{__name__}")
16
+
17
+ class DCTLoss(nn.Module):
18
+ def __init__(self, weights):
19
+ super().__init__()
20
+ self.weights = weights
21
+
22
+ def forward(self, inputs, preds, iteration=None):
23
+ loss_dict = {"loss_total": 0.0}
24
+ target = inputs['gt']
25
+ recon = preds['recon']
26
+ posterior = preds['posterior']
27
+ fft_gt = th.view_as_real(th.fft.fft(target.reshape(target.shape[0], -1)))
28
+ fft_recon = th.view_as_real(th.fft.fft(recon.reshape(recon.shape[0], -1)))
29
+ loss_recon_dct_l1 = th.mean(th.abs(fft_gt - fft_recon))
30
+ loss_recon_l1 = th.mean(th.abs(target - recon))
31
+ loss_kl = posterior.kl().mean()
32
+ loss_dict.update(loss_recon_l1=loss_recon_l1, loss_recon_dct_l1=loss_recon_dct_l1, loss_kl=loss_kl)
33
+ loss_total = self.weights.recon * loss_recon_dct_l1 + self.weights.kl * loss_kl
34
+
35
+ loss_dict["loss_total"] = loss_total
36
+ return loss_total, loss_dict
37
+
38
+ class VAESepL2Loss(nn.Module):
39
+ def __init__(self, weights):
40
+ super().__init__()
41
+ self.weights = weights
42
+
43
+ def forward(self, inputs, preds, iteration=None):
44
+ loss_dict = {"loss_total": 0.0}
45
+ target = inputs['gt']
46
+ recon = preds['recon']
47
+ posterior = preds['posterior']
48
+ recon_diff = (target - recon) ** 2
49
+ loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
50
+ loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
51
+ loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
52
+ loss_kl = posterior.kl().mean()
53
+ loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
54
+ loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
55
+ if "kl" in self.weights:
56
+ loss_total += self.weights.kl * loss_kl
57
+
58
+ loss_dict["loss_total"] = loss_total
59
+ return loss_total, loss_dict
60
+
61
+ class VAESepLoss(nn.Module):
62
+ def __init__(self, weights):
63
+ super().__init__()
64
+ self.weights = weights
65
+
66
+ def forward(self, inputs, preds, iteration=None):
67
+ loss_dict = {"loss_total": 0.0}
68
+ target = inputs['gt']
69
+ recon = preds['recon']
70
+ posterior = preds['posterior']
71
+ recon_diff = th.abs(target - recon)
72
+ loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
73
+ loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
74
+ loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
75
+ loss_kl = posterior.kl().mean()
76
+ loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
77
+ loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
78
+ if "kl" in self.weights:
79
+ loss_total += self.weights.kl * loss_kl
80
+
81
+ loss_dict["loss_total"] = loss_total
82
+ return loss_total, loss_dict
83
+
84
+ class VAELoss(nn.Module):
85
+ def __init__(self, weights):
86
+ super().__init__()
87
+ self.weights = weights
88
+
89
+ def forward(self, inputs, preds, iteration=None):
90
+ loss_dict = {"loss_total": 0.0}
91
+ target = inputs['gt']
92
+ recon = preds['recon']
93
+ posterior = preds['posterior']
94
+ loss_recon_l1 = th.mean(th.abs(target - recon))
95
+ loss_kl = posterior.kl().mean()
96
+ loss_dict.update(loss_recon_l1=loss_recon_l1, loss_kl=loss_kl)
97
+ loss_total = self.weights.recon * loss_recon_l1 + self.weights.kl * loss_kl
98
+
99
+ loss_dict["loss_total"] = loss_total
100
+ return loss_total, loss_dict
101
+
102
+ class PrimSDFLoss(nn.Module):
103
+ def __init__(self, weights, shape_opt_steps=2000, tex_opt_steps=6000):
104
+ super().__init__()
105
+ self.weights = weights
106
+ self.shape_opt_steps = shape_opt_steps
107
+ self.tex_opt_steps = tex_opt_steps
108
+
109
+ def forward(self, inputs, preds, iteration=None):
110
+ loss_dict = {"loss_total": 0.0}
111
+
112
+ if iteration < self.shape_opt_steps:
113
+ target_sdf = inputs['sdf']
114
+ sdf = preds['sdf']
115
+ loss_sdf_l1 = th.mean(th.abs(sdf - target_sdf))
116
+ loss_dict.update(loss_sdf_l1=loss_sdf_l1)
117
+ loss_total = self.weights.sdf_l1 * loss_sdf_l1
118
+
119
+ prim_scale = preds["prim_scale"]
120
+ # we use 1/scale instead of the original 100/scale as our scale is normalized to [-1, 1] cube
121
+ if "vol_sum" in self.weights:
122
+ loss_prim_vol_sum = th.mean(th.sum(th.prod(1 / prim_scale, dim=-1), dim=-1))
123
+ loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
124
+ loss_total += self.weights.vol_sum * loss_prim_vol_sum
125
+
126
+ if iteration >= self.shape_opt_steps and iteration < self.tex_opt_steps:
127
+ target_tex = inputs['tex']
128
+ tex = preds['tex']
129
+ loss_tex_l1 = th.mean(th.abs(tex - target_tex))
130
+ loss_dict.update(loss_tex_l1=loss_tex_l1)
131
+
132
+ loss_total = (
133
+ self.weights.rgb_l1 * loss_tex_l1
134
+ )
135
+ if "mat_l1" in self.weights:
136
+ target_mat = inputs['mat']
137
+ mat = preds['mat']
138
+ loss_mat_l1 = th.mean(th.abs(mat - target_mat))
139
+ loss_dict.update(loss_mat_l1=loss_mat_l1)
140
+ loss_total += self.weights.mat_l1 * loss_mat_l1
141
+
142
+ if "grad_l2" in self.weights:
143
+ loss_grad_l2 = th.mean((preds["grad"] - inputs["grad"]) ** 2)
144
+ loss_total += self.weights.grad_l2 * loss_grad_l2
145
+ loss_dict.update(loss_grad_l2=loss_grad_l2)
146
+
147
+ loss_dict["loss_total"] = loss_total
148
+ return loss_total, loss_dict
149
+
150
+
151
+ class TotalMVPLoss(nn.Module):
152
+ def __init__(self, weights, assets=None):
153
+ super().__init__()
154
+
155
+ self.weights = weights
156
+
157
+ if "vgg" in self.weights:
158
+ self.vgg_loss = VGGLossMasked()
159
+
160
+ def forward(self, inputs, preds, iteration=None):
161
+
162
+ loss_dict = {"loss_total": 0.0}
163
+
164
+ B = inputs["image"].shape
165
+
166
+ # rgb
167
+ target_rgb = inputs["image"].permute(0, 2, 3, 1)
168
+ # removing the mask
169
+ target_rgb = target_rgb * inputs["image_mask"][:, 0, :, :, np.newaxis]
170
+
171
+ rgb = preds["rgb"]
172
+ loss_rgb_mse = th.mean(((rgb - target_rgb) / 16.0) ** 2.0)
173
+ loss_dict.update(loss_rgb_mse=loss_rgb_mse)
174
+
175
+ alpha = preds["alpha"]
176
+
177
+ # mask loss
178
+ target_mask = inputs["image_mask"][:, 0].to(th.float32)
179
+ loss_mask_mae = th.mean((target_mask - alpha).abs())
180
+ loss_dict.update(loss_mask_mae=loss_mask_mae)
181
+
182
+ B = alpha.shape[0]
183
+
184
+ # beta prior on opacity
185
+ loss_alpha_prior = th.mean(
186
+ th.log(0.1 + alpha.reshape(B, -1))
187
+ + th.log(0.1 + 1.0 - alpha.reshape(B, -1))
188
+ - -2.20727
189
+ )
190
+ loss_dict.update(loss_alpha_prior=loss_alpha_prior)
191
+
192
+ prim_scale = preds["prim_scale"]
193
+ loss_prim_vol_sum = th.mean(th.sum(th.prod(100.0 / prim_scale, dim=-1), dim=-1))
194
+ loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
195
+
196
+ loss_total = (
197
+ self.weights.rgb_mse * loss_rgb_mse
198
+ + self.weights.mask_mae * loss_mask_mae
199
+ + self.weights.alpha_prior * loss_alpha_prior
200
+ + self.weights.prim_vol_sum * loss_prim_vol_sum
201
+ )
202
+
203
+ if "embs_l2" in self.weights:
204
+ loss_embs_l2 = th.sum(th.norm(preds["embs"], dim=1))
205
+ loss_total += self.weights.embs_l2 * loss_embs_l2
206
+ loss_dict.update(loss_embs_l2=loss_embs_l2)
207
+
208
+ if "vgg" in self.weights:
209
+ loss_vgg = self.vgg_loss(
210
+ rgb.permute(0, 3, 1, 2),
211
+ target_rgb.permute(0, 3, 1, 2),
212
+ inputs["image_mask"],
213
+ )
214
+ loss_total += self.weights.vgg * loss_vgg
215
+ loss_dict.update(loss_vgg=loss_vgg)
216
+
217
+ if "prim_scale_var" in self.weights:
218
+ log_prim_scale = th.log(prim_scale)
219
+ # NOTE: should we detach this?
220
+ log_prim_scale_mean = th.mean(log_prim_scale, dim=1, keepdim=True)
221
+ loss_prim_scale_var = th.mean((log_prim_scale - log_prim_scale_mean) ** 2.0)
222
+ loss_total += self.weights.prim_scale_var * loss_prim_scale_var
223
+ loss_dict.update(loss_prim_scale_var=loss_prim_scale_var)
224
+
225
+ loss_dict["loss_total"] = loss_total
226
+
227
+ return loss_total, loss_dict
228
+
229
+
230
+ def process_losses(loss_dict, reduce=True, detach=True):
231
+ """Preprocess the dict of losses outputs."""
232
+ result = {
233
+ k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_")
234
+ }
235
+ if detach:
236
+ result = {k: v.detach() for k, v in result.items()}
237
+ if reduce:
238
+ result = {k: float(v.mean().item()) for k, v in result.items()}
239
+ return result
dva/mvp/extensions/mvpraymarch/bvh.cu ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include <cmath>
8
+ #include <cstdio>
9
+ #include <functional>
10
+ #include <map>
11
+
12
+ #include "helper_math.h"
13
+
14
+ #include "cudadispatch.h"
15
+
16
+ #include "primtransf.h"
17
+
18
+ // Expands a 10-bit integer into 30 bits
19
+ // by inserting 2 zeros after each bit.
20
+ __device__ unsigned int expand_bits(unsigned int v) {
21
+ v = (v * 0x00010001u) & 0xFF0000FFu;
22
+ v = (v * 0x00000101u) & 0x0F00F00Fu;
23
+ v = (v * 0x00000011u) & 0xC30C30C3u;
24
+ v = (v * 0x00000005u) & 0x49249249u;
25
+ return v;
26
+ }
27
+
28
+ // Calculates a 30-bit Morton code for the
29
+ // given 3D point located within the unit cube [0,1].
30
+ __device__ unsigned int morton3D(float x, float y, float z) {
31
+ x = fminf(fmaxf(x * 1024.0f, 0.0f), 1023.0f);
32
+ y = fminf(fmaxf(y * 1024.0f, 0.0f), 1023.0f);
33
+ z = fminf(fmaxf(z * 1024.0f, 0.0f), 1023.0f);
34
+ unsigned int xx = expand_bits((unsigned int)x);
35
+ unsigned int yy = expand_bits((unsigned int)y);
36
+ unsigned int zz = expand_bits((unsigned int)z);
37
+ return xx * 4 + yy * 2 + zz;
38
+ }
39
+
40
+ template<typename PrimTransfT>
41
+ __global__ void compute_morton_kernel(
42
+ int N, int K,
43
+ typename PrimTransfT::Data data,
44
+ int * code
45
+ ) {
46
+ const int count = N * K;
47
+ for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) {
48
+ const int k = index % K;
49
+ const int n = index / K;
50
+
51
+ //float4 c = center[n * K + k];
52
+ float3 c = data.get_center(n, k);
53
+ code[n * K + k] = morton3D(c.x, c.y, c.z);
54
+ }
55
+ }
56
+
57
+ __forceinline__ __device__ int delta(int* sortedcodes, int x, int y, int K) {
58
+ if (x >= 0 && x <= K - 1 && y >= 0 && y <= K - 1) {
59
+ return sortedcodes[x] == sortedcodes[y] ?
60
+ 32 + __clz(x ^ y) :
61
+ __clz(sortedcodes[x] ^ sortedcodes[y]);
62
+ }
63
+ return -1;
64
+ }
65
+
66
+ __forceinline__ __device__ int sign(int x) {
67
+ return (int)(x > 0) - (int)(x < 0);
68
+ }
69
+
70
+ __device__ int find_split(
71
+ int* sortedcodes,
72
+ int first,
73
+ int last,
74
+ int K) {
75
+ float commonPrefix = delta(sortedcodes, first, last, K);
76
+ int split = first;
77
+ int step = last - first;
78
+
79
+ do {
80
+ step = (step + 1) >> 1; // exponential decrease
81
+ int newSplit = split + step; // proposed new position
82
+
83
+ if (newSplit < last) {
84
+ int splitPrefix = delta(sortedcodes, first, newSplit, K);
85
+ if (splitPrefix > commonPrefix) {
86
+ split = newSplit; // accept proposal
87
+ }
88
+ }
89
+ } while (step > 1);
90
+
91
+ return split;
92
+ }
93
+
94
+ __device__ int2 determine_range(int* sortedcodes, int K, int idx) {
95
+ int d = sign(delta(sortedcodes, idx, idx + 1, K) - delta(sortedcodes, idx, idx - 1, K));
96
+ int dmin = delta(sortedcodes, idx, idx - d, K);
97
+ int lmax = 2;
98
+ while (delta(sortedcodes, idx, idx + lmax * d, K) > dmin) {
99
+ lmax = lmax * 2;
100
+ }
101
+
102
+ int l = 0;
103
+ for (int t = lmax / 2; t >= 1; t /= 2) {
104
+ if (delta(sortedcodes, idx, idx + (l + t)*d, K) > dmin) {
105
+ l += t;
106
+ }
107
+ }
108
+
109
+ int j = idx + l*d;
110
+ int2 range;
111
+ range.x = min(idx, j);
112
+ range.y = max(idx, j);
113
+
114
+ return range;
115
+ }
116
+
117
+ __global__ void build_tree_kernel(
118
+ int N, int K,
119
+ int * sortedcodes,
120
+ int2 * nodechildren,
121
+ int * nodeparent) {
122
+ const int count = N * (K + K - 1);
123
+ for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) {
124
+ const int k = index % (K + K - 1);
125
+ const int n = index / (K + K - 1);
126
+
127
+ if (k >= K - 1) {
128
+ // leaf
129
+ nodechildren[n * (K + K - 1) + k] = make_int2(-(k - (K - 1)) - 1, -(k - (K - 1)) - 2);
130
+ } else {
131
+ // internal node
132
+
133
+ // find out which range of objects the node corresponds to
134
+ int2 range = determine_range(sortedcodes + n * K, K, k);
135
+ int first = range.x;
136
+ int last = range.y;
137
+
138
+ // determine where to split the range
139
+ int split = find_split(sortedcodes + n * K, first, last, K);
140
+
141
+ // select childA
142
+ int childa = split == first ? (K - 1) + split : split;
143
+
144
+ // select childB
145
+ int childb = split + 1 == last ? (K - 1) + split + 1 : split + 1;
146
+
147
+ // record parent-child relationships
148
+ nodechildren[n * (K + K - 1) + k] = make_int2(childa, childb);
149
+ nodeparent[n * (K + K - 1) + childa] = k;
150
+ nodeparent[n * (K + K - 1) + childb] = k;
151
+ }
152
+ }
153
+ }
154
+
155
+ template<typename PrimTransfT>
156
+ __global__ void compute_aabb_kernel(
157
+ int N, int K,
158
+ typename PrimTransfT::Data data,
159
+ int * sortedobjid,
160
+ int2 * nodechildren,
161
+ int * nodeparent,
162
+ float3 * nodeaabb,
163
+ int * atom) {
164
+ const int count = N * K;
165
+ for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) {
166
+ const int k = index % K;
167
+ const int n = index / K;
168
+
169
+ // compute BBOX for leaf
170
+ int kk = sortedobjid[n * K + k];
171
+
172
+ float3 pmin;
173
+ float3 pmax;
174
+ data.compute_aabb(n, kk, pmin, pmax);
175
+
176
+ nodeaabb[n * (K + K - 1) * 2 + ((K - 1) + k) * 2 + 0] = pmin;
177
+ nodeaabb[n * (K + K - 1) * 2 + ((K - 1) + k) * 2 + 1] = pmax;
178
+
179
+ int node = nodeparent[n * (K + K - 1) + ((K - 1) + k)];
180
+
181
+ while (node != -1 && atomicCAS(&atom[n * (K - 1) + node], 0, 1) == 1) {
182
+ int2 children = nodechildren[n * (K + K - 1) + node];
183
+ float3 laabbmin = nodeaabb[n * (K + K - 1) * 2 + children.x * 2 + 0];
184
+ float3 laabbmax = nodeaabb[n * (K + K - 1) * 2 + children.x * 2 + 1];
185
+ float3 raabbmin = nodeaabb[n * (K + K - 1) * 2 + children.y * 2 + 0];
186
+ float3 raabbmax = nodeaabb[n * (K + K - 1) * 2 + children.y * 2 + 1];
187
+
188
+ float3 aabbmin = fminf(laabbmin, raabbmin);
189
+ float3 aabbmax = fmaxf(laabbmax, raabbmax);
190
+
191
+ nodeaabb[n * (K + K - 1) * 2 + node * 2 + 0] = aabbmin;
192
+ nodeaabb[n * (K + K - 1) * 2 + node * 2 + 1] = aabbmax;
193
+
194
+ node = nodeparent[n * (K + K - 1) + node];
195
+
196
+ __threadfence();
197
+ }
198
+ }
199
+ }
200
+
201
+ void compute_morton_cuda(
202
+ int N, int K,
203
+ float * primpos,
204
+ int * code,
205
+ int algorithm,
206
+ cudaStream_t stream) {
207
+ int count = N * K;
208
+ int blocksize = 512;
209
+ int gridsize = (count + blocksize - 1) / blocksize;
210
+
211
+ std::shared_ptr<PrimTransfDataBase> primtransf_data;
212
+ primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
213
+ PrimTransfDataBase{},
214
+ K, (float3*)primpos, nullptr,
215
+ K * 3, nullptr, nullptr,
216
+ K, nullptr, nullptr});
217
+
218
+ std::map<int, std::function<void(dim3, dim3, cudaStream_t, int, int, std::shared_ptr<PrimTransfDataBase>, int*)>> dispatcher = {
219
+ { 0, make_cudacall(compute_morton_kernel<PrimTransfSRT>) }
220
+ };
221
+
222
+ auto iter = dispatcher.find(min(0, algorithm));
223
+ if (iter != dispatcher.end()) {
224
+ (iter->second)(
225
+ dim3(gridsize), dim3(blocksize), stream,
226
+ N, K,
227
+ primtransf_data,
228
+ code);
229
+ }
230
+ }
231
+
232
+ void build_tree_cuda(
233
+ int N, int K,
234
+ int * sortedcode,
235
+ int * nodechildren,
236
+ int * nodeparent,
237
+ cudaStream_t stream) {
238
+ int count = N * (K + K - 1);
239
+ int nthreads = 512;
240
+ int nblocks = (count + nthreads - 1) / nthreads;
241
+ build_tree_kernel<<<nblocks, nthreads, 0, stream>>>(
242
+ N, K,
243
+ sortedcode,
244
+ reinterpret_cast<int2 *>(nodechildren),
245
+ nodeparent);
246
+ }
247
+
248
+ void compute_aabb_cuda(
249
+ int N, int K,
250
+ float * primpos,
251
+ float * primrot,
252
+ float * primscale,
253
+ int * sortedobjid,
254
+ int * nodechildren,
255
+ int * nodeparent,
256
+ float * nodeaabb,
257
+ int algorithm,
258
+ cudaStream_t stream) {
259
+ int * atom;
260
+ cudaMalloc(&atom, N * (K - 1) * 4);
261
+ cudaMemset(atom, 0, N * (K - 1) * 4);
262
+
263
+ int count = N * K;
264
+ int blocksize = 512;
265
+ int gridsize = (count + blocksize - 1) / blocksize;
266
+
267
+ std::shared_ptr<PrimTransfDataBase> primtransf_data;
268
+ primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
269
+ PrimTransfDataBase{},
270
+ K, (float3*)primpos, nullptr,
271
+ K * 3, (float3*)primrot, nullptr,
272
+ K, (float3*)primscale, nullptr});
273
+
274
+ std::map<int, std::function<void(dim3, dim3, cudaStream_t, int, int, std::shared_ptr<PrimTransfDataBase>, int*, int2*, int*, float3*, int*)>> dispatcher = {
275
+ { 0, make_cudacall(compute_aabb_kernel<PrimTransfSRT>) }
276
+ };
277
+
278
+ auto iter = dispatcher.find(min(0, algorithm));
279
+ if (iter != dispatcher.end()) {
280
+ (iter->second)(
281
+ dim3(gridsize), dim3(blocksize), stream,
282
+ N, K,
283
+ primtransf_data,
284
+ sortedobjid,
285
+ reinterpret_cast<int2 *>(nodechildren),
286
+ nodeparent,
287
+ reinterpret_cast<float3 *>(nodeaabb),
288
+ atom);
289
+ }
290
+
291
+ cudaFree(atom);
292
+ }
dva/mvp/extensions/mvpraymarch/cudadispatch.h ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #ifndef cudadispatch_h_
8
+ #define cudadispatch_h_
9
+
10
+ #include <functional>
11
+ #include <memory>
12
+ #include <type_traits>
13
+
14
+ template<typename T, typename = void>
15
+ struct get_base {
16
+ typedef T type;
17
+ };
18
+
19
+ template<typename T>
20
+ struct get_base<T, typename std::enable_if<std::is_base_of<typename T::base, T>::value>::type> {
21
+ typedef std::shared_ptr<typename T::base> type;
22
+ };
23
+
24
+ template<typename T> struct is_shared_ptr : std::false_type {};
25
+ template<typename T> struct is_shared_ptr<std::shared_ptr<T>> : std::true_type {};
26
+
27
+ template<typename OutT, typename T>
28
+ auto convert_shptr_impl2(std::shared_ptr<T> t) {
29
+ return *static_cast<OutT*>(t.get());
30
+ }
31
+
32
+ template<typename OutT, typename T>
33
+ auto convert_shptr_impl(T&& t, std::false_type) {
34
+ return convert_shptr_impl2<OutT>(t);
35
+ }
36
+
37
+ template<typename OutT, typename T>
38
+ auto convert_shptr_impl(T&& t, std::true_type) {
39
+ return std::forward<T>(t);
40
+ }
41
+
42
+ template<typename OutT, typename T>
43
+ auto convert_shptr(T&& t) {
44
+ return convert_shptr_impl<OutT>(std::forward<T>(t), std::is_same<OutT, T>{});
45
+ }
46
+
47
+ template<typename... ArgsIn>
48
+ struct cudacall {
49
+ struct functbase {
50
+ virtual ~functbase() {}
51
+ virtual void call(dim3, dim3, cudaStream_t, ArgsIn...) const = 0;
52
+ };
53
+
54
+ template<typename... ArgsOut>
55
+ struct funct : public functbase {
56
+ std::function<void(ArgsOut...)> fn;
57
+ funct(void(*fn_)(ArgsOut...)) : fn(fn_) { }
58
+ void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsIn... args) const {
59
+ void (*const*kfunc)(ArgsOut...) = fn.template target<void (*)(ArgsOut...)>();
60
+ (*kfunc)<<<gridsize, blocksize, 0, stream>>>(
61
+ std::forward<ArgsOut>(convert_shptr<ArgsOut>(std::forward<ArgsIn>(args)))...);
62
+ }
63
+ };
64
+
65
+ std::shared_ptr<functbase> fn;
66
+
67
+ template<typename... ArgsOut>
68
+ cudacall(void(*fn_)(ArgsOut...)) : fn(std::make_shared<funct<ArgsOut...>>(fn_)) { }
69
+
70
+ template<typename... ArgsTmp>
71
+ void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsTmp&&... args) const {
72
+ fn->call(gridsize, blocksize, stream, std::forward<ArgsIn>(args)...);
73
+ }
74
+ };
75
+
76
+ template <typename F, typename T>
77
+ struct binder {
78
+ F f; T t;
79
+ template <typename... Args>
80
+ auto operator()(Args&&... args) const
81
+ -> decltype(f(t, std::forward<Args>(args)...)) {
82
+ return f(t, std::forward<Args>(args)...);
83
+ }
84
+ };
85
+
86
+ template <typename F, typename T>
87
+ binder<typename std::decay<F>::type
88
+ , typename std::decay<T>::type> BindFirst(F&& f, T&& t) {
89
+ return { std::forward<F>(f), std::forward<T>(t) };
90
+ }
91
+
92
+ template<typename... ArgsOut>
93
+ auto make_cudacall_(void(*fn)(ArgsOut...)) {
94
+ return BindFirst(
95
+ std::mem_fn(&cudacall<typename get_base<ArgsOut>::type...>::template call<typename get_base<ArgsOut>::type...>),
96
+ cudacall<typename get_base<ArgsOut>::type...>(fn));
97
+ }
98
+
99
+ template<typename... ArgsOut>
100
+ std::function<void(dim3, dim3, cudaStream_t, typename get_base<ArgsOut>::type...)> make_cudacall(void(*fn)(ArgsOut...)) {
101
+ return std::function<void(dim3, dim3, cudaStream_t, typename get_base<ArgsOut>::type...)>(make_cudacall_(fn));
102
+ }
103
+
104
+ #endif
dva/mvp/extensions/mvpraymarch/helper_math.h ADDED
@@ -0,0 +1,1453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * Please refer to the NVIDIA end user license agreement (EULA) associated
5
+ * with this source code for terms and conditions that govern your use of
6
+ * this software. Any use, reproduction, disclosure, or distribution of
7
+ * this software and related documentation outside the terms of the EULA
8
+ * is strictly prohibited.
9
+ *
10
+ */
11
+
12
+ /*
13
+ * This file implements common mathematical operations on vector types
14
+ * (float3, float4 etc.) since these are not provided as standard by CUDA.
15
+ *
16
+ * The syntax is modeled on the Cg standard library.
17
+ *
18
+ * This is part of the Helper library includes
19
+ *
20
+ * Thanks to Linh Hah for additions and fixes.
21
+ */
22
+
23
+ #ifndef HELPER_MATH_H
24
+ #define HELPER_MATH_H
25
+
26
+ #include "cuda_runtime.h"
27
+
28
+ typedef unsigned int uint;
29
+ typedef unsigned short ushort;
30
+
31
+ #ifndef EXIT_WAIVED
32
+ #define EXIT_WAIVED 2
33
+ #endif
34
+
35
+ #ifndef __CUDACC__
36
+ #include <math.h>
37
+
38
+ ////////////////////////////////////////////////////////////////////////////////
39
+ // host implementations of CUDA functions
40
+ ////////////////////////////////////////////////////////////////////////////////
41
+
42
+ inline float fminf(float a, float b)
43
+ {
44
+ return a < b ? a : b;
45
+ }
46
+
47
+ inline float fmaxf(float a, float b)
48
+ {
49
+ return a > b ? a : b;
50
+ }
51
+
52
+ inline int max(int a, int b)
53
+ {
54
+ return a > b ? a : b;
55
+ }
56
+
57
+ inline int min(int a, int b)
58
+ {
59
+ return a < b ? a : b;
60
+ }
61
+
62
+ inline float rsqrtf(float x)
63
+ {
64
+ return 1.0f / sqrtf(x);
65
+ }
66
+ #endif
67
+
68
+ ////////////////////////////////////////////////////////////////////////////////
69
+ // constructors
70
+ ////////////////////////////////////////////////////////////////////////////////
71
+
72
+ inline __host__ __device__ float2 make_float2(float s)
73
+ {
74
+ return make_float2(s, s);
75
+ }
76
+ inline __host__ __device__ float2 make_float2(float3 a)
77
+ {
78
+ return make_float2(a.x, a.y);
79
+ }
80
+ inline __host__ __device__ float2 make_float2(int2 a)
81
+ {
82
+ return make_float2(float(a.x), float(a.y));
83
+ }
84
+ inline __host__ __device__ float2 make_float2(uint2 a)
85
+ {
86
+ return make_float2(float(a.x), float(a.y));
87
+ }
88
+
89
+ inline __host__ __device__ int2 make_int2(int s)
90
+ {
91
+ return make_int2(s, s);
92
+ }
93
+ inline __host__ __device__ int2 make_int2(int3 a)
94
+ {
95
+ return make_int2(a.x, a.y);
96
+ }
97
+ inline __host__ __device__ int2 make_int2(uint2 a)
98
+ {
99
+ return make_int2(int(a.x), int(a.y));
100
+ }
101
+ inline __host__ __device__ int2 make_int2(float2 a)
102
+ {
103
+ return make_int2(int(a.x), int(a.y));
104
+ }
105
+
106
+ inline __host__ __device__ uint2 make_uint2(uint s)
107
+ {
108
+ return make_uint2(s, s);
109
+ }
110
+ inline __host__ __device__ uint2 make_uint2(uint3 a)
111
+ {
112
+ return make_uint2(a.x, a.y);
113
+ }
114
+ inline __host__ __device__ uint2 make_uint2(int2 a)
115
+ {
116
+ return make_uint2(uint(a.x), uint(a.y));
117
+ }
118
+
119
+ inline __host__ __device__ float3 make_float3(float s)
120
+ {
121
+ return make_float3(s, s, s);
122
+ }
123
+ inline __host__ __device__ float3 make_float3(float2 a)
124
+ {
125
+ return make_float3(a.x, a.y, 0.0f);
126
+ }
127
+ inline __host__ __device__ float3 make_float3(float2 a, float s)
128
+ {
129
+ return make_float3(a.x, a.y, s);
130
+ }
131
+ inline __host__ __device__ float3 make_float3(float4 a)
132
+ {
133
+ return make_float3(a.x, a.y, a.z);
134
+ }
135
+ inline __host__ __device__ float3 make_float3(int3 a)
136
+ {
137
+ return make_float3(float(a.x), float(a.y), float(a.z));
138
+ }
139
+ inline __host__ __device__ float3 make_float3(uint3 a)
140
+ {
141
+ return make_float3(float(a.x), float(a.y), float(a.z));
142
+ }
143
+
144
+ inline __host__ __device__ int3 make_int3(int s)
145
+ {
146
+ return make_int3(s, s, s);
147
+ }
148
+ inline __host__ __device__ int3 make_int3(int2 a)
149
+ {
150
+ return make_int3(a.x, a.y, 0);
151
+ }
152
+ inline __host__ __device__ int3 make_int3(int2 a, int s)
153
+ {
154
+ return make_int3(a.x, a.y, s);
155
+ }
156
+ inline __host__ __device__ int3 make_int3(uint3 a)
157
+ {
158
+ return make_int3(int(a.x), int(a.y), int(a.z));
159
+ }
160
+ inline __host__ __device__ int3 make_int3(float3 a)
161
+ {
162
+ return make_int3(int(a.x), int(a.y), int(a.z));
163
+ }
164
+
165
+ inline __host__ __device__ uint3 make_uint3(uint s)
166
+ {
167
+ return make_uint3(s, s, s);
168
+ }
169
+ inline __host__ __device__ uint3 make_uint3(uint2 a)
170
+ {
171
+ return make_uint3(a.x, a.y, 0);
172
+ }
173
+ inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
174
+ {
175
+ return make_uint3(a.x, a.y, s);
176
+ }
177
+ inline __host__ __device__ uint3 make_uint3(uint4 a)
178
+ {
179
+ return make_uint3(a.x, a.y, a.z);
180
+ }
181
+ inline __host__ __device__ uint3 make_uint3(int3 a)
182
+ {
183
+ return make_uint3(uint(a.x), uint(a.y), uint(a.z));
184
+ }
185
+
186
+ inline __host__ __device__ float4 make_float4(float s)
187
+ {
188
+ return make_float4(s, s, s, s);
189
+ }
190
+ inline __host__ __device__ float4 make_float4(float3 a)
191
+ {
192
+ return make_float4(a.x, a.y, a.z, 0.0f);
193
+ }
194
+ inline __host__ __device__ float4 make_float4(float3 a, float w)
195
+ {
196
+ return make_float4(a.x, a.y, a.z, w);
197
+ }
198
+ inline __host__ __device__ float4 make_float4(int4 a)
199
+ {
200
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
201
+ }
202
+ inline __host__ __device__ float4 make_float4(uint4 a)
203
+ {
204
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
205
+ }
206
+
207
+ inline __host__ __device__ int4 make_int4(int s)
208
+ {
209
+ return make_int4(s, s, s, s);
210
+ }
211
+ inline __host__ __device__ int4 make_int4(int3 a)
212
+ {
213
+ return make_int4(a.x, a.y, a.z, 0);
214
+ }
215
+ inline __host__ __device__ int4 make_int4(int3 a, int w)
216
+ {
217
+ return make_int4(a.x, a.y, a.z, w);
218
+ }
219
+ inline __host__ __device__ int4 make_int4(uint4 a)
220
+ {
221
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
222
+ }
223
+ inline __host__ __device__ int4 make_int4(float4 a)
224
+ {
225
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
226
+ }
227
+
228
+
229
+ inline __host__ __device__ uint4 make_uint4(uint s)
230
+ {
231
+ return make_uint4(s, s, s, s);
232
+ }
233
+ inline __host__ __device__ uint4 make_uint4(uint3 a)
234
+ {
235
+ return make_uint4(a.x, a.y, a.z, 0);
236
+ }
237
+ inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
238
+ {
239
+ return make_uint4(a.x, a.y, a.z, w);
240
+ }
241
+ inline __host__ __device__ uint4 make_uint4(int4 a)
242
+ {
243
+ return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
244
+ }
245
+
246
+ ////////////////////////////////////////////////////////////////////////////////
247
+ // negate
248
+ ////////////////////////////////////////////////////////////////////////////////
249
+
250
+ inline __host__ __device__ float2 operator-(float2 &a)
251
+ {
252
+ return make_float2(-a.x, -a.y);
253
+ }
254
+ inline __host__ __device__ int2 operator-(int2 &a)
255
+ {
256
+ return make_int2(-a.x, -a.y);
257
+ }
258
+ inline __host__ __device__ float3 operator-(float3 &a)
259
+ {
260
+ return make_float3(-a.x, -a.y, -a.z);
261
+ }
262
+ inline __host__ __device__ int3 operator-(int3 &a)
263
+ {
264
+ return make_int3(-a.x, -a.y, -a.z);
265
+ }
266
+ inline __host__ __device__ float4 operator-(float4 &a)
267
+ {
268
+ return make_float4(-a.x, -a.y, -a.z, -a.w);
269
+ }
270
+ inline __host__ __device__ int4 operator-(int4 &a)
271
+ {
272
+ return make_int4(-a.x, -a.y, -a.z, -a.w);
273
+ }
274
+
275
+ ////////////////////////////////////////////////////////////////////////////////
276
+ // addition
277
+ ////////////////////////////////////////////////////////////////////////////////
278
+
279
+ inline __host__ __device__ float2 operator+(float2 a, float2 b)
280
+ {
281
+ return make_float2(a.x + b.x, a.y + b.y);
282
+ }
283
+ inline __host__ __device__ void operator+=(float2 &a, float2 b)
284
+ {
285
+ a.x += b.x;
286
+ a.y += b.y;
287
+ }
288
+ inline __host__ __device__ float2 operator+(float2 a, float b)
289
+ {
290
+ return make_float2(a.x + b, a.y + b);
291
+ }
292
+ inline __host__ __device__ float2 operator+(float b, float2 a)
293
+ {
294
+ return make_float2(a.x + b, a.y + b);
295
+ }
296
+ inline __host__ __device__ void operator+=(float2 &a, float b)
297
+ {
298
+ a.x += b;
299
+ a.y += b;
300
+ }
301
+
302
+ inline __host__ __device__ int2 operator+(int2 a, int2 b)
303
+ {
304
+ return make_int2(a.x + b.x, a.y + b.y);
305
+ }
306
+ inline __host__ __device__ void operator+=(int2 &a, int2 b)
307
+ {
308
+ a.x += b.x;
309
+ a.y += b.y;
310
+ }
311
+ inline __host__ __device__ int2 operator+(int2 a, int b)
312
+ {
313
+ return make_int2(a.x + b, a.y + b);
314
+ }
315
+ inline __host__ __device__ int2 operator+(int b, int2 a)
316
+ {
317
+ return make_int2(a.x + b, a.y + b);
318
+ }
319
+ inline __host__ __device__ void operator+=(int2 &a, int b)
320
+ {
321
+ a.x += b;
322
+ a.y += b;
323
+ }
324
+
325
+ inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
326
+ {
327
+ return make_uint2(a.x + b.x, a.y + b.y);
328
+ }
329
+ inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
330
+ {
331
+ a.x += b.x;
332
+ a.y += b.y;
333
+ }
334
+ inline __host__ __device__ uint2 operator+(uint2 a, uint b)
335
+ {
336
+ return make_uint2(a.x + b, a.y + b);
337
+ }
338
+ inline __host__ __device__ uint2 operator+(uint b, uint2 a)
339
+ {
340
+ return make_uint2(a.x + b, a.y + b);
341
+ }
342
+ inline __host__ __device__ void operator+=(uint2 &a, uint b)
343
+ {
344
+ a.x += b;
345
+ a.y += b;
346
+ }
347
+
348
+
349
+ inline __host__ __device__ float3 operator+(float3 a, float3 b)
350
+ {
351
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
352
+ }
353
+ inline __host__ __device__ void operator+=(float3 &a, float3 b)
354
+ {
355
+ a.x += b.x;
356
+ a.y += b.y;
357
+ a.z += b.z;
358
+ }
359
+ inline __host__ __device__ float3 operator+(float3 a, float b)
360
+ {
361
+ return make_float3(a.x + b, a.y + b, a.z + b);
362
+ }
363
+ inline __host__ __device__ void operator+=(float3 &a, float b)
364
+ {
365
+ a.x += b;
366
+ a.y += b;
367
+ a.z += b;
368
+ }
369
+
370
+ inline __host__ __device__ int3 operator+(int3 a, int3 b)
371
+ {
372
+ return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
373
+ }
374
+ inline __host__ __device__ void operator+=(int3 &a, int3 b)
375
+ {
376
+ a.x += b.x;
377
+ a.y += b.y;
378
+ a.z += b.z;
379
+ }
380
+ inline __host__ __device__ int3 operator+(int3 a, int b)
381
+ {
382
+ return make_int3(a.x + b, a.y + b, a.z + b);
383
+ }
384
+ inline __host__ __device__ void operator+=(int3 &a, int b)
385
+ {
386
+ a.x += b;
387
+ a.y += b;
388
+ a.z += b;
389
+ }
390
+
391
+ inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
392
+ {
393
+ return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
394
+ }
395
+ inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
396
+ {
397
+ a.x += b.x;
398
+ a.y += b.y;
399
+ a.z += b.z;
400
+ }
401
+ inline __host__ __device__ uint3 operator+(uint3 a, uint b)
402
+ {
403
+ return make_uint3(a.x + b, a.y + b, a.z + b);
404
+ }
405
+ inline __host__ __device__ void operator+=(uint3 &a, uint b)
406
+ {
407
+ a.x += b;
408
+ a.y += b;
409
+ a.z += b;
410
+ }
411
+
412
+ inline __host__ __device__ int3 operator+(int b, int3 a)
413
+ {
414
+ return make_int3(a.x + b, a.y + b, a.z + b);
415
+ }
416
+ inline __host__ __device__ uint3 operator+(uint b, uint3 a)
417
+ {
418
+ return make_uint3(a.x + b, a.y + b, a.z + b);
419
+ }
420
+ inline __host__ __device__ float3 operator+(float b, float3 a)
421
+ {
422
+ return make_float3(a.x + b, a.y + b, a.z + b);
423
+ }
424
+
425
+ inline __host__ __device__ float4 operator+(float4 a, float4 b)
426
+ {
427
+ return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
428
+ }
429
+ inline __host__ __device__ void operator+=(float4 &a, float4 b)
430
+ {
431
+ a.x += b.x;
432
+ a.y += b.y;
433
+ a.z += b.z;
434
+ a.w += b.w;
435
+ }
436
+ inline __host__ __device__ float4 operator+(float4 a, float b)
437
+ {
438
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
439
+ }
440
+ inline __host__ __device__ float4 operator+(float b, float4 a)
441
+ {
442
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
443
+ }
444
+ inline __host__ __device__ void operator+=(float4 &a, float b)
445
+ {
446
+ a.x += b;
447
+ a.y += b;
448
+ a.z += b;
449
+ a.w += b;
450
+ }
451
+
452
+ inline __host__ __device__ int4 operator+(int4 a, int4 b)
453
+ {
454
+ return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
455
+ }
456
+ inline __host__ __device__ void operator+=(int4 &a, int4 b)
457
+ {
458
+ a.x += b.x;
459
+ a.y += b.y;
460
+ a.z += b.z;
461
+ a.w += b.w;
462
+ }
463
+ inline __host__ __device__ int4 operator+(int4 a, int b)
464
+ {
465
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
466
+ }
467
+ inline __host__ __device__ int4 operator+(int b, int4 a)
468
+ {
469
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
470
+ }
471
+ inline __host__ __device__ void operator+=(int4 &a, int b)
472
+ {
473
+ a.x += b;
474
+ a.y += b;
475
+ a.z += b;
476
+ a.w += b;
477
+ }
478
+
479
+ inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
480
+ {
481
+ return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
482
+ }
483
+ inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
484
+ {
485
+ a.x += b.x;
486
+ a.y += b.y;
487
+ a.z += b.z;
488
+ a.w += b.w;
489
+ }
490
+ inline __host__ __device__ uint4 operator+(uint4 a, uint b)
491
+ {
492
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
493
+ }
494
+ inline __host__ __device__ uint4 operator+(uint b, uint4 a)
495
+ {
496
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
497
+ }
498
+ inline __host__ __device__ void operator+=(uint4 &a, uint b)
499
+ {
500
+ a.x += b;
501
+ a.y += b;
502
+ a.z += b;
503
+ a.w += b;
504
+ }
505
+
506
+ ////////////////////////////////////////////////////////////////////////////////
507
+ // subtract
508
+ ////////////////////////////////////////////////////////////////////////////////
509
+
510
+ inline __host__ __device__ float2 operator-(float2 a, float2 b)
511
+ {
512
+ return make_float2(a.x - b.x, a.y - b.y);
513
+ }
514
+ inline __host__ __device__ void operator-=(float2 &a, float2 b)
515
+ {
516
+ a.x -= b.x;
517
+ a.y -= b.y;
518
+ }
519
+ inline __host__ __device__ float2 operator-(float2 a, float b)
520
+ {
521
+ return make_float2(a.x - b, a.y - b);
522
+ }
523
+ inline __host__ __device__ float2 operator-(float b, float2 a)
524
+ {
525
+ return make_float2(b - a.x, b - a.y);
526
+ }
527
+ inline __host__ __device__ void operator-=(float2 &a, float b)
528
+ {
529
+ a.x -= b;
530
+ a.y -= b;
531
+ }
532
+
533
+ inline __host__ __device__ int2 operator-(int2 a, int2 b)
534
+ {
535
+ return make_int2(a.x - b.x, a.y - b.y);
536
+ }
537
+ inline __host__ __device__ void operator-=(int2 &a, int2 b)
538
+ {
539
+ a.x -= b.x;
540
+ a.y -= b.y;
541
+ }
542
+ inline __host__ __device__ int2 operator-(int2 a, int b)
543
+ {
544
+ return make_int2(a.x - b, a.y - b);
545
+ }
546
+ inline __host__ __device__ int2 operator-(int b, int2 a)
547
+ {
548
+ return make_int2(b - a.x, b - a.y);
549
+ }
550
+ inline __host__ __device__ void operator-=(int2 &a, int b)
551
+ {
552
+ a.x -= b;
553
+ a.y -= b;
554
+ }
555
+
556
+ inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
557
+ {
558
+ return make_uint2(a.x - b.x, a.y - b.y);
559
+ }
560
+ inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
561
+ {
562
+ a.x -= b.x;
563
+ a.y -= b.y;
564
+ }
565
+ inline __host__ __device__ uint2 operator-(uint2 a, uint b)
566
+ {
567
+ return make_uint2(a.x - b, a.y - b);
568
+ }
569
+ inline __host__ __device__ uint2 operator-(uint b, uint2 a)
570
+ {
571
+ return make_uint2(b - a.x, b - a.y);
572
+ }
573
+ inline __host__ __device__ void operator-=(uint2 &a, uint b)
574
+ {
575
+ a.x -= b;
576
+ a.y -= b;
577
+ }
578
+
579
+ inline __host__ __device__ float3 operator-(float3 a, float3 b)
580
+ {
581
+ return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
582
+ }
583
+ inline __host__ __device__ void operator-=(float3 &a, float3 b)
584
+ {
585
+ a.x -= b.x;
586
+ a.y -= b.y;
587
+ a.z -= b.z;
588
+ }
589
+ inline __host__ __device__ float3 operator-(float3 a, float b)
590
+ {
591
+ return make_float3(a.x - b, a.y - b, a.z - b);
592
+ }
593
+ inline __host__ __device__ float3 operator-(float b, float3 a)
594
+ {
595
+ return make_float3(b - a.x, b - a.y, b - a.z);
596
+ }
597
+ inline __host__ __device__ void operator-=(float3 &a, float b)
598
+ {
599
+ a.x -= b;
600
+ a.y -= b;
601
+ a.z -= b;
602
+ }
603
+
604
+ inline __host__ __device__ int3 operator-(int3 a, int3 b)
605
+ {
606
+ return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
607
+ }
608
+ inline __host__ __device__ void operator-=(int3 &a, int3 b)
609
+ {
610
+ a.x -= b.x;
611
+ a.y -= b.y;
612
+ a.z -= b.z;
613
+ }
614
+ inline __host__ __device__ int3 operator-(int3 a, int b)
615
+ {
616
+ return make_int3(a.x - b, a.y - b, a.z - b);
617
+ }
618
+ inline __host__ __device__ int3 operator-(int b, int3 a)
619
+ {
620
+ return make_int3(b - a.x, b - a.y, b - a.z);
621
+ }
622
+ inline __host__ __device__ void operator-=(int3 &a, int b)
623
+ {
624
+ a.x -= b;
625
+ a.y -= b;
626
+ a.z -= b;
627
+ }
628
+
629
+ inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
630
+ {
631
+ return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
632
+ }
633
+ inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
634
+ {
635
+ a.x -= b.x;
636
+ a.y -= b.y;
637
+ a.z -= b.z;
638
+ }
639
+ inline __host__ __device__ uint3 operator-(uint3 a, uint b)
640
+ {
641
+ return make_uint3(a.x - b, a.y - b, a.z - b);
642
+ }
643
+ inline __host__ __device__ uint3 operator-(uint b, uint3 a)
644
+ {
645
+ return make_uint3(b - a.x, b - a.y, b - a.z);
646
+ }
647
+ inline __host__ __device__ void operator-=(uint3 &a, uint b)
648
+ {
649
+ a.x -= b;
650
+ a.y -= b;
651
+ a.z -= b;
652
+ }
653
+
654
+ inline __host__ __device__ float4 operator-(float4 a, float4 b)
655
+ {
656
+ return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
657
+ }
658
+ inline __host__ __device__ void operator-=(float4 &a, float4 b)
659
+ {
660
+ a.x -= b.x;
661
+ a.y -= b.y;
662
+ a.z -= b.z;
663
+ a.w -= b.w;
664
+ }
665
+ inline __host__ __device__ float4 operator-(float4 a, float b)
666
+ {
667
+ return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
668
+ }
669
+ inline __host__ __device__ void operator-=(float4 &a, float b)
670
+ {
671
+ a.x -= b;
672
+ a.y -= b;
673
+ a.z -= b;
674
+ a.w -= b;
675
+ }
676
+
677
+ inline __host__ __device__ int4 operator-(int4 a, int4 b)
678
+ {
679
+ return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
680
+ }
681
+ inline __host__ __device__ void operator-=(int4 &a, int4 b)
682
+ {
683
+ a.x -= b.x;
684
+ a.y -= b.y;
685
+ a.z -= b.z;
686
+ a.w -= b.w;
687
+ }
688
+ inline __host__ __device__ int4 operator-(int4 a, int b)
689
+ {
690
+ return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
691
+ }
692
+ inline __host__ __device__ int4 operator-(int b, int4 a)
693
+ {
694
+ return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
695
+ }
696
+ inline __host__ __device__ void operator-=(int4 &a, int b)
697
+ {
698
+ a.x -= b;
699
+ a.y -= b;
700
+ a.z -= b;
701
+ a.w -= b;
702
+ }
703
+
704
+ inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
705
+ {
706
+ return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
707
+ }
708
+ inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
709
+ {
710
+ a.x -= b.x;
711
+ a.y -= b.y;
712
+ a.z -= b.z;
713
+ a.w -= b.w;
714
+ }
715
+ inline __host__ __device__ uint4 operator-(uint4 a, uint b)
716
+ {
717
+ return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
718
+ }
719
+ inline __host__ __device__ uint4 operator-(uint b, uint4 a)
720
+ {
721
+ return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
722
+ }
723
+ inline __host__ __device__ void operator-=(uint4 &a, uint b)
724
+ {
725
+ a.x -= b;
726
+ a.y -= b;
727
+ a.z -= b;
728
+ a.w -= b;
729
+ }
730
+
731
+ ////////////////////////////////////////////////////////////////////////////////
732
+ // multiply
733
+ ////////////////////////////////////////////////////////////////////////////////
734
+
735
+ inline __host__ __device__ float2 operator*(float2 a, float2 b)
736
+ {
737
+ return make_float2(a.x * b.x, a.y * b.y);
738
+ }
739
+ inline __host__ __device__ void operator*=(float2 &a, float2 b)
740
+ {
741
+ a.x *= b.x;
742
+ a.y *= b.y;
743
+ }
744
+ inline __host__ __device__ float2 operator*(float2 a, float b)
745
+ {
746
+ return make_float2(a.x * b, a.y * b);
747
+ }
748
+ inline __host__ __device__ float2 operator*(float b, float2 a)
749
+ {
750
+ return make_float2(b * a.x, b * a.y);
751
+ }
752
+ inline __host__ __device__ void operator*=(float2 &a, float b)
753
+ {
754
+ a.x *= b;
755
+ a.y *= b;
756
+ }
757
+
758
+ inline __host__ __device__ int2 operator*(int2 a, int2 b)
759
+ {
760
+ return make_int2(a.x * b.x, a.y * b.y);
761
+ }
762
+ inline __host__ __device__ void operator*=(int2 &a, int2 b)
763
+ {
764
+ a.x *= b.x;
765
+ a.y *= b.y;
766
+ }
767
+ inline __host__ __device__ int2 operator*(int2 a, int b)
768
+ {
769
+ return make_int2(a.x * b, a.y * b);
770
+ }
771
+ inline __host__ __device__ int2 operator*(int b, int2 a)
772
+ {
773
+ return make_int2(b * a.x, b * a.y);
774
+ }
775
+ inline __host__ __device__ void operator*=(int2 &a, int b)
776
+ {
777
+ a.x *= b;
778
+ a.y *= b;
779
+ }
780
+
781
+ inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
782
+ {
783
+ return make_uint2(a.x * b.x, a.y * b.y);
784
+ }
785
+ inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
786
+ {
787
+ a.x *= b.x;
788
+ a.y *= b.y;
789
+ }
790
+ inline __host__ __device__ uint2 operator*(uint2 a, uint b)
791
+ {
792
+ return make_uint2(a.x * b, a.y * b);
793
+ }
794
+ inline __host__ __device__ uint2 operator*(uint b, uint2 a)
795
+ {
796
+ return make_uint2(b * a.x, b * a.y);
797
+ }
798
+ inline __host__ __device__ void operator*=(uint2 &a, uint b)
799
+ {
800
+ a.x *= b;
801
+ a.y *= b;
802
+ }
803
+
804
+ inline __host__ __device__ float3 operator*(float3 a, float3 b)
805
+ {
806
+ return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
807
+ }
808
+ inline __host__ __device__ void operator*=(float3 &a, float3 b)
809
+ {
810
+ a.x *= b.x;
811
+ a.y *= b.y;
812
+ a.z *= b.z;
813
+ }
814
+ inline __host__ __device__ float3 operator*(float3 a, float b)
815
+ {
816
+ return make_float3(a.x * b, a.y * b, a.z * b);
817
+ }
818
+ inline __host__ __device__ float3 operator*(float b, float3 a)
819
+ {
820
+ return make_float3(b * a.x, b * a.y, b * a.z);
821
+ }
822
+ inline __host__ __device__ void operator*=(float3 &a, float b)
823
+ {
824
+ a.x *= b;
825
+ a.y *= b;
826
+ a.z *= b;
827
+ }
828
+
829
+ inline __host__ __device__ int3 operator*(int3 a, int3 b)
830
+ {
831
+ return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
832
+ }
833
+ inline __host__ __device__ void operator*=(int3 &a, int3 b)
834
+ {
835
+ a.x *= b.x;
836
+ a.y *= b.y;
837
+ a.z *= b.z;
838
+ }
839
+ inline __host__ __device__ int3 operator*(int3 a, int b)
840
+ {
841
+ return make_int3(a.x * b, a.y * b, a.z * b);
842
+ }
843
+ inline __host__ __device__ int3 operator*(int b, int3 a)
844
+ {
845
+ return make_int3(b * a.x, b * a.y, b * a.z);
846
+ }
847
+ inline __host__ __device__ void operator*=(int3 &a, int b)
848
+ {
849
+ a.x *= b;
850
+ a.y *= b;
851
+ a.z *= b;
852
+ }
853
+
854
+ inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
855
+ {
856
+ return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
857
+ }
858
+ inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
859
+ {
860
+ a.x *= b.x;
861
+ a.y *= b.y;
862
+ a.z *= b.z;
863
+ }
864
+ inline __host__ __device__ uint3 operator*(uint3 a, uint b)
865
+ {
866
+ return make_uint3(a.x * b, a.y * b, a.z * b);
867
+ }
868
+ inline __host__ __device__ uint3 operator*(uint b, uint3 a)
869
+ {
870
+ return make_uint3(b * a.x, b * a.y, b * a.z);
871
+ }
872
+ inline __host__ __device__ void operator*=(uint3 &a, uint b)
873
+ {
874
+ a.x *= b;
875
+ a.y *= b;
876
+ a.z *= b;
877
+ }
878
+
879
+ inline __host__ __device__ float4 operator*(float4 a, float4 b)
880
+ {
881
+ return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
882
+ }
883
+ inline __host__ __device__ void operator*=(float4 &a, float4 b)
884
+ {
885
+ a.x *= b.x;
886
+ a.y *= b.y;
887
+ a.z *= b.z;
888
+ a.w *= b.w;
889
+ }
890
+ inline __host__ __device__ float4 operator*(float4 a, float b)
891
+ {
892
+ return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
893
+ }
894
+ inline __host__ __device__ float4 operator*(float b, float4 a)
895
+ {
896
+ return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
897
+ }
898
+ inline __host__ __device__ void operator*=(float4 &a, float b)
899
+ {
900
+ a.x *= b;
901
+ a.y *= b;
902
+ a.z *= b;
903
+ a.w *= b;
904
+ }
905
+
906
+ inline __host__ __device__ int4 operator*(int4 a, int4 b)
907
+ {
908
+ return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
909
+ }
910
+ inline __host__ __device__ void operator*=(int4 &a, int4 b)
911
+ {
912
+ a.x *= b.x;
913
+ a.y *= b.y;
914
+ a.z *= b.z;
915
+ a.w *= b.w;
916
+ }
917
+ inline __host__ __device__ int4 operator*(int4 a, int b)
918
+ {
919
+ return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
920
+ }
921
+ inline __host__ __device__ int4 operator*(int b, int4 a)
922
+ {
923
+ return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
924
+ }
925
+ inline __host__ __device__ void operator*=(int4 &a, int b)
926
+ {
927
+ a.x *= b;
928
+ a.y *= b;
929
+ a.z *= b;
930
+ a.w *= b;
931
+ }
932
+
933
+ inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
934
+ {
935
+ return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
936
+ }
937
+ inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
938
+ {
939
+ a.x *= b.x;
940
+ a.y *= b.y;
941
+ a.z *= b.z;
942
+ a.w *= b.w;
943
+ }
944
+ inline __host__ __device__ uint4 operator*(uint4 a, uint b)
945
+ {
946
+ return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
947
+ }
948
+ inline __host__ __device__ uint4 operator*(uint b, uint4 a)
949
+ {
950
+ return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
951
+ }
952
+ inline __host__ __device__ void operator*=(uint4 &a, uint b)
953
+ {
954
+ a.x *= b;
955
+ a.y *= b;
956
+ a.z *= b;
957
+ a.w *= b;
958
+ }
959
+
960
+ ////////////////////////////////////////////////////////////////////////////////
961
+ // divide
962
+ ////////////////////////////////////////////////////////////////////////////////
963
+
964
+ inline __host__ __device__ float2 operator/(float2 a, float2 b)
965
+ {
966
+ return make_float2(a.x / b.x, a.y / b.y);
967
+ }
968
+ inline __host__ __device__ void operator/=(float2 &a, float2 b)
969
+ {
970
+ a.x /= b.x;
971
+ a.y /= b.y;
972
+ }
973
+ inline __host__ __device__ float2 operator/(float2 a, float b)
974
+ {
975
+ return make_float2(a.x / b, a.y / b);
976
+ }
977
+ inline __host__ __device__ void operator/=(float2 &a, float b)
978
+ {
979
+ a.x /= b;
980
+ a.y /= b;
981
+ }
982
+ inline __host__ __device__ float2 operator/(float b, float2 a)
983
+ {
984
+ return make_float2(b / a.x, b / a.y);
985
+ }
986
+
987
+ inline __host__ __device__ float3 operator/(float3 a, float3 b)
988
+ {
989
+ return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
990
+ }
991
+ inline __host__ __device__ void operator/=(float3 &a, float3 b)
992
+ {
993
+ a.x /= b.x;
994
+ a.y /= b.y;
995
+ a.z /= b.z;
996
+ }
997
+ inline __host__ __device__ float3 operator/(float3 a, float b)
998
+ {
999
+ return make_float3(a.x / b, a.y / b, a.z / b);
1000
+ }
1001
+ inline __host__ __device__ void operator/=(float3 &a, float b)
1002
+ {
1003
+ a.x /= b;
1004
+ a.y /= b;
1005
+ a.z /= b;
1006
+ }
1007
+ inline __host__ __device__ float3 operator/(float b, float3 a)
1008
+ {
1009
+ return make_float3(b / a.x, b / a.y, b / a.z);
1010
+ }
1011
+
1012
+ inline __host__ __device__ float4 operator/(float4 a, float4 b)
1013
+ {
1014
+ return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
1015
+ }
1016
+ inline __host__ __device__ void operator/=(float4 &a, float4 b)
1017
+ {
1018
+ a.x /= b.x;
1019
+ a.y /= b.y;
1020
+ a.z /= b.z;
1021
+ a.w /= b.w;
1022
+ }
1023
+ inline __host__ __device__ float4 operator/(float4 a, float b)
1024
+ {
1025
+ return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
1026
+ }
1027
+ inline __host__ __device__ void operator/=(float4 &a, float b)
1028
+ {
1029
+ a.x /= b;
1030
+ a.y /= b;
1031
+ a.z /= b;
1032
+ a.w /= b;
1033
+ }
1034
+ inline __host__ __device__ float4 operator/(float b, float4 a)
1035
+ {
1036
+ return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
1037
+ }
1038
+
1039
+ ////////////////////////////////////////////////////////////////////////////////
1040
+ // min
1041
+ ////////////////////////////////////////////////////////////////////////////////
1042
+
1043
+ inline __host__ __device__ float2 fminf(float2 a, float2 b)
1044
+ {
1045
+ return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
1046
+ }
1047
+ inline __host__ __device__ float3 fminf(float3 a, float3 b)
1048
+ {
1049
+ return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
1050
+ }
1051
+ inline __host__ __device__ float4 fminf(float4 a, float4 b)
1052
+ {
1053
+ return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
1054
+ }
1055
+
1056
+ inline __host__ __device__ int2 min(int2 a, int2 b)
1057
+ {
1058
+ return make_int2(min(a.x,b.x), min(a.y,b.y));
1059
+ }
1060
+ inline __host__ __device__ int3 min(int3 a, int3 b)
1061
+ {
1062
+ return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
1063
+ }
1064
+ inline __host__ __device__ int4 min(int4 a, int4 b)
1065
+ {
1066
+ return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
1067
+ }
1068
+
1069
+ inline __host__ __device__ uint2 min(uint2 a, uint2 b)
1070
+ {
1071
+ return make_uint2(min(a.x,b.x), min(a.y,b.y));
1072
+ }
1073
+ inline __host__ __device__ uint3 min(uint3 a, uint3 b)
1074
+ {
1075
+ return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
1076
+ }
1077
+ inline __host__ __device__ uint4 min(uint4 a, uint4 b)
1078
+ {
1079
+ return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
1080
+ }
1081
+
1082
+ ////////////////////////////////////////////////////////////////////////////////
1083
+ // max
1084
+ ////////////////////////////////////////////////////////////////////////////////
1085
+
1086
+ inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
1087
+ {
1088
+ return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
1089
+ }
1090
+ inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
1091
+ {
1092
+ return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
1093
+ }
1094
+ inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
1095
+ {
1096
+ return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
1097
+ }
1098
+
1099
+ inline __host__ __device__ int2 max(int2 a, int2 b)
1100
+ {
1101
+ return make_int2(max(a.x,b.x), max(a.y,b.y));
1102
+ }
1103
+ inline __host__ __device__ int3 max(int3 a, int3 b)
1104
+ {
1105
+ return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
1106
+ }
1107
+ inline __host__ __device__ int4 max(int4 a, int4 b)
1108
+ {
1109
+ return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
1110
+ }
1111
+
1112
+ inline __host__ __device__ uint2 max(uint2 a, uint2 b)
1113
+ {
1114
+ return make_uint2(max(a.x,b.x), max(a.y,b.y));
1115
+ }
1116
+ inline __host__ __device__ uint3 max(uint3 a, uint3 b)
1117
+ {
1118
+ return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
1119
+ }
1120
+ inline __host__ __device__ uint4 max(uint4 a, uint4 b)
1121
+ {
1122
+ return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
1123
+ }
1124
+
1125
+ ////////////////////////////////////////////////////////////////////////////////
1126
+ // lerp
1127
+ // - linear interpolation between a and b, based on value t in [0, 1] range
1128
+ ////////////////////////////////////////////////////////////////////////////////
1129
+
1130
+ inline __device__ __host__ float lerp(float a, float b, float t)
1131
+ {
1132
+ return a + t*(b-a);
1133
+ }
1134
+ inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
1135
+ {
1136
+ return a + t*(b-a);
1137
+ }
1138
+ inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
1139
+ {
1140
+ return a + t*(b-a);
1141
+ }
1142
+ inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
1143
+ {
1144
+ return a + t*(b-a);
1145
+ }
1146
+
1147
+ ////////////////////////////////////////////////////////////////////////////////
1148
+ // clamp
1149
+ // - clamp the value v to be in the range [a, b]
1150
+ ////////////////////////////////////////////////////////////////////////////////
1151
+
1152
+ inline __device__ __host__ float clamp(float f, float a, float b)
1153
+ {
1154
+ return fmaxf(a, fminf(f, b));
1155
+ }
1156
+ inline __device__ __host__ int clamp(int f, int a, int b)
1157
+ {
1158
+ return max(a, min(f, b));
1159
+ }
1160
+ inline __device__ __host__ uint clamp(uint f, uint a, uint b)
1161
+ {
1162
+ return max(a, min(f, b));
1163
+ }
1164
+
1165
+ inline __device__ __host__ float2 clamp(float2 v, float a, float b)
1166
+ {
1167
+ return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
1168
+ }
1169
+ inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
1170
+ {
1171
+ return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1172
+ }
1173
+ inline __device__ __host__ float3 clamp(float3 v, float a, float b)
1174
+ {
1175
+ return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1176
+ }
1177
+ inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
1178
+ {
1179
+ return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1180
+ }
1181
+ inline __device__ __host__ float4 clamp(float4 v, float a, float b)
1182
+ {
1183
+ return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1184
+ }
1185
+ inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
1186
+ {
1187
+ return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1188
+ }
1189
+
1190
+ inline __device__ __host__ int2 clamp(int2 v, int a, int b)
1191
+ {
1192
+ return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
1193
+ }
1194
+ inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
1195
+ {
1196
+ return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1197
+ }
1198
+ inline __device__ __host__ int3 clamp(int3 v, int a, int b)
1199
+ {
1200
+ return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1201
+ }
1202
+ inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
1203
+ {
1204
+ return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1205
+ }
1206
+ inline __device__ __host__ int4 clamp(int4 v, int a, int b)
1207
+ {
1208
+ return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1209
+ }
1210
+ inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
1211
+ {
1212
+ return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1213
+ }
1214
+
1215
+ inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
1216
+ {
1217
+ return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
1218
+ }
1219
+ inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
1220
+ {
1221
+ return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1222
+ }
1223
+ inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
1224
+ {
1225
+ return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1226
+ }
1227
+ inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
1228
+ {
1229
+ return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1230
+ }
1231
+ inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
1232
+ {
1233
+ return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1234
+ }
1235
+ inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
1236
+ {
1237
+ return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1238
+ }
1239
+
1240
+ ////////////////////////////////////////////////////////////////////////////////
1241
+ // dot product
1242
+ ////////////////////////////////////////////////////////////////////////////////
1243
+
1244
+ inline __host__ __device__ float dot(float2 a, float2 b)
1245
+ {
1246
+ return a.x * b.x + a.y * b.y;
1247
+ }
1248
+ inline __host__ __device__ float dot(float3 a, float3 b)
1249
+ {
1250
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1251
+ }
1252
+ inline __host__ __device__ float dot(float4 a, float4 b)
1253
+ {
1254
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1255
+ }
1256
+
1257
+ inline __host__ __device__ int dot(int2 a, int2 b)
1258
+ {
1259
+ return a.x * b.x + a.y * b.y;
1260
+ }
1261
+ inline __host__ __device__ int dot(int3 a, int3 b)
1262
+ {
1263
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1264
+ }
1265
+ inline __host__ __device__ int dot(int4 a, int4 b)
1266
+ {
1267
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1268
+ }
1269
+
1270
+ inline __host__ __device__ uint dot(uint2 a, uint2 b)
1271
+ {
1272
+ return a.x * b.x + a.y * b.y;
1273
+ }
1274
+ inline __host__ __device__ uint dot(uint3 a, uint3 b)
1275
+ {
1276
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1277
+ }
1278
+ inline __host__ __device__ uint dot(uint4 a, uint4 b)
1279
+ {
1280
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1281
+ }
1282
+
1283
+ ////////////////////////////////////////////////////////////////////////////////
1284
+ // length
1285
+ ////////////////////////////////////////////////////////////////////////////////
1286
+
1287
+ inline __host__ __device__ float length(float2 v)
1288
+ {
1289
+ return sqrtf(dot(v, v));
1290
+ }
1291
+ inline __host__ __device__ float length(float3 v)
1292
+ {
1293
+ return sqrtf(dot(v, v));
1294
+ }
1295
+ inline __host__ __device__ float length(float4 v)
1296
+ {
1297
+ return sqrtf(dot(v, v));
1298
+ }
1299
+
1300
+ ////////////////////////////////////////////////////////////////////////////////
1301
+ // normalize
1302
+ ////////////////////////////////////////////////////////////////////////////////
1303
+
1304
+ inline __host__ __device__ float2 normalize(float2 v)
1305
+ {
1306
+ float invLen = rsqrtf(dot(v, v));
1307
+ return v * invLen;
1308
+ }
1309
+ inline __host__ __device__ float3 normalize(float3 v)
1310
+ {
1311
+ float invLen = rsqrtf(dot(v, v));
1312
+ return v * invLen;
1313
+ }
1314
+ inline __host__ __device__ float4 normalize(float4 v)
1315
+ {
1316
+ float invLen = rsqrtf(dot(v, v));
1317
+ return v * invLen;
1318
+ }
1319
+
1320
+ ////////////////////////////////////////////////////////////////////////////////
1321
+ // floor
1322
+ ////////////////////////////////////////////////////////////////////////////////
1323
+
1324
+ inline __host__ __device__ float2 floorf(float2 v)
1325
+ {
1326
+ return make_float2(floorf(v.x), floorf(v.y));
1327
+ }
1328
+ inline __host__ __device__ float3 floorf(float3 v)
1329
+ {
1330
+ return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
1331
+ }
1332
+ inline __host__ __device__ float4 floorf(float4 v)
1333
+ {
1334
+ return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
1335
+ }
1336
+
1337
+ ////////////////////////////////////////////////////////////////////////////////
1338
+ // frac - returns the fractional portion of a scalar or each vector component
1339
+ ////////////////////////////////////////////////////////////////////////////////
1340
+
1341
+ inline __host__ __device__ float fracf(float v)
1342
+ {
1343
+ return v - floorf(v);
1344
+ }
1345
+ inline __host__ __device__ float2 fracf(float2 v)
1346
+ {
1347
+ return make_float2(fracf(v.x), fracf(v.y));
1348
+ }
1349
+ inline __host__ __device__ float3 fracf(float3 v)
1350
+ {
1351
+ return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
1352
+ }
1353
+ inline __host__ __device__ float4 fracf(float4 v)
1354
+ {
1355
+ return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
1356
+ }
1357
+
1358
+ ////////////////////////////////////////////////////////////////////////////////
1359
+ // fmod
1360
+ ////////////////////////////////////////////////////////////////////////////////
1361
+
1362
+ inline __host__ __device__ float2 fmodf(float2 a, float2 b)
1363
+ {
1364
+ return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
1365
+ }
1366
+ inline __host__ __device__ float3 fmodf(float3 a, float3 b)
1367
+ {
1368
+ return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
1369
+ }
1370
+ inline __host__ __device__ float4 fmodf(float4 a, float4 b)
1371
+ {
1372
+ return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
1373
+ }
1374
+
1375
+ ////////////////////////////////////////////////////////////////////////////////
1376
+ // absolute value
1377
+ ////////////////////////////////////////////////////////////////////////////////
1378
+
1379
+ inline __host__ __device__ float2 fabs(float2 v)
1380
+ {
1381
+ return make_float2(fabs(v.x), fabs(v.y));
1382
+ }
1383
+ inline __host__ __device__ float3 fabs(float3 v)
1384
+ {
1385
+ return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
1386
+ }
1387
+ inline __host__ __device__ float4 fabs(float4 v)
1388
+ {
1389
+ return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
1390
+ }
1391
+
1392
+ inline __host__ __device__ int2 abs(int2 v)
1393
+ {
1394
+ return make_int2(abs(v.x), abs(v.y));
1395
+ }
1396
+ inline __host__ __device__ int3 abs(int3 v)
1397
+ {
1398
+ return make_int3(abs(v.x), abs(v.y), abs(v.z));
1399
+ }
1400
+ inline __host__ __device__ int4 abs(int4 v)
1401
+ {
1402
+ return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
1403
+ }
1404
+
1405
+ ////////////////////////////////////////////////////////////////////////////////
1406
+ // reflect
1407
+ // - returns reflection of incident ray I around surface normal N
1408
+ // - N should be normalized, reflected vector's length is equal to length of I
1409
+ ////////////////////////////////////////////////////////////////////////////////
1410
+
1411
+ inline __host__ __device__ float3 reflect(float3 i, float3 n)
1412
+ {
1413
+ return i - 2.0f * n * dot(n,i);
1414
+ }
1415
+
1416
+ ////////////////////////////////////////////////////////////////////////////////
1417
+ // cross product
1418
+ ////////////////////////////////////////////////////////////////////////////////
1419
+
1420
+ inline __host__ __device__ float3 cross(float3 a, float3 b)
1421
+ {
1422
+ return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
1423
+ }
1424
+
1425
+ ////////////////////////////////////////////////////////////////////////////////
1426
+ // smoothstep
1427
+ // - returns 0 if x < a
1428
+ // - returns 1 if x > b
1429
+ // - otherwise returns smooth interpolation between 0 and 1 based on x
1430
+ ////////////////////////////////////////////////////////////////////////////////
1431
+
1432
+ inline __device__ __host__ float smoothstep(float a, float b, float x)
1433
+ {
1434
+ float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1435
+ return (y*y*(3.0f - (2.0f*y)));
1436
+ }
1437
+ inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
1438
+ {
1439
+ float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1440
+ return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
1441
+ }
1442
+ inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
1443
+ {
1444
+ float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1445
+ return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
1446
+ }
1447
+ inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
1448
+ {
1449
+ float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1450
+ return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
1451
+ }
1452
+
1453
+ #endif
dva/mvp/extensions/mvpraymarch/makefile ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ all:
2
+ python setup.py build_ext --inplace
dva/mvp/extensions/mvpraymarch/mvpraymarch.cpp ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include <torch/extension.h>
8
+ #include <c10/cuda/CUDAStream.h>
9
+
10
+ #include <vector>
11
+
12
+ void compute_morton_cuda(
13
+ int N, int K,
14
+ float * primpos,
15
+ int * code,
16
+ int algorithm,
17
+ cudaStream_t stream);
18
+
19
+ void build_tree_cuda(
20
+ int N, int K,
21
+ int * sortedcode,
22
+ int * nodechildren,
23
+ int * nodeparent,
24
+ cudaStream_t stream);
25
+
26
+ void compute_aabb_cuda(
27
+ int N, int K,
28
+ float * primpos,
29
+ float * primrot,
30
+ float * primscale,
31
+ int * sortedobjid,
32
+ int * nodechildren,
33
+ int * nodeparent,
34
+ float * nodeaabb,
35
+ int algorithm,
36
+ cudaStream_t stream);
37
+
38
+ void raymarch_forward_cuda(
39
+ int N, int H, int W, int K,
40
+ float * rayposim,
41
+ float * raydirim,
42
+ float stepsize,
43
+ float * tminmaxim,
44
+
45
+ int * sortedobjid,
46
+ int * nodechildren,
47
+ float * nodeaabb,
48
+
49
+ float * primpos,
50
+ float * primrot,
51
+ float * primscale,
52
+
53
+ int TD, int TH, int TW,
54
+ float * tplate,
55
+ int WD, int WH, int WW,
56
+ float * warp,
57
+
58
+ float * rayrgbaim,
59
+ float * raysatim,
60
+ int * raytermim,
61
+
62
+ int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes,
63
+ bool chlast, float fadescale, float fadeexp, int accum, float termthresh,
64
+ int griddim, int blocksizex, int blocksizey,
65
+ cudaStream_t stream);
66
+
67
+ void raymarch_backward_cuda(
68
+ int N, int H, int W, int K,
69
+ float * rayposim,
70
+ float * raydirim,
71
+ float stepsize,
72
+ float * tminmaxim,
73
+
74
+ int * sortedobjid,
75
+ int * nodechildren,
76
+ float * nodeaabb,
77
+
78
+ float * primpos,
79
+ float * grad_primpos,
80
+ float * primrot,
81
+ float * grad_primrot,
82
+ float * primscale,
83
+ float * grad_primscale,
84
+
85
+ int TD, int TH, int TW,
86
+ float * tplate,
87
+ float * grad_tplate,
88
+ int WD, int WH, int WW,
89
+ float * warp,
90
+ float * grad_warp,
91
+
92
+ float * rayrgbaim,
93
+ float * grad_rayrgba,
94
+ float * raysatim,
95
+ int * raytermim,
96
+
97
+ int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes,
98
+ bool chlast, float fadescale, float fadeexp, int accum, float termthresh,
99
+ int griddim, int blocksizex, int blocksizey,
100
+ cudaStream_t stream);
101
+
102
+ #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
103
+ #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
104
+ #define CHECK_INPUT(x) CHECK_CUDA((x)); CHECK_CONTIGUOUS((x))
105
+
106
+ std::vector<torch::Tensor> compute_morton(
107
+ torch::Tensor primpos,
108
+ torch::Tensor code,
109
+ int algorithm) {
110
+ CHECK_INPUT(primpos);
111
+ CHECK_INPUT(code);
112
+
113
+ int N = primpos.size(0);
114
+ int K = primpos.size(1);
115
+
116
+ compute_morton_cuda(
117
+ N, K,
118
+ reinterpret_cast<float *>(primpos.data_ptr()),
119
+ reinterpret_cast<int *>(code.data_ptr()),
120
+ algorithm,
121
+ 0);
122
+
123
+ return {};
124
+ }
125
+
126
+ std::vector<torch::Tensor> build_tree(
127
+ torch::Tensor sortedcode,
128
+ torch::Tensor nodechildren,
129
+ torch::Tensor nodeparent) {
130
+ CHECK_INPUT(sortedcode);
131
+ CHECK_INPUT(nodechildren);
132
+ CHECK_INPUT(nodeparent);
133
+
134
+ int N = sortedcode.size(0);
135
+ int K = sortedcode.size(1);
136
+
137
+ build_tree_cuda(N, K,
138
+ reinterpret_cast<int *>(sortedcode.data_ptr()),
139
+ reinterpret_cast<int *>(nodechildren.data_ptr()),
140
+ reinterpret_cast<int *>(nodeparent.data_ptr()),
141
+ 0);
142
+
143
+ return {};
144
+ }
145
+
146
+ std::vector<torch::Tensor> compute_aabb(
147
+ torch::Tensor primpos,
148
+ torch::optional<torch::Tensor> primrot,
149
+ torch::optional<torch::Tensor> primscale,
150
+ torch::Tensor sortedobjid,
151
+ torch::Tensor nodechildren,
152
+ torch::Tensor nodeparent,
153
+ torch::Tensor nodeaabb,
154
+ int algorithm) {
155
+ CHECK_INPUT(sortedobjid);
156
+ CHECK_INPUT(primpos);
157
+ if (primrot) { CHECK_INPUT(*primrot); }
158
+ if (primscale) { CHECK_INPUT(*primscale); }
159
+ CHECK_INPUT(nodechildren);
160
+ CHECK_INPUT(nodeparent);
161
+ CHECK_INPUT(nodeaabb);
162
+
163
+ int N = primpos.size(0);
164
+ int K = primpos.size(1);
165
+
166
+ compute_aabb_cuda(N, K,
167
+ reinterpret_cast<float *>(primpos.data_ptr()),
168
+ primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
169
+ primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
170
+ reinterpret_cast<int *>(sortedobjid.data_ptr()),
171
+ reinterpret_cast<int *>(nodechildren.data_ptr()),
172
+ reinterpret_cast<int *>(nodeparent.data_ptr()),
173
+ reinterpret_cast<float *>(nodeaabb.data_ptr()),
174
+ algorithm,
175
+ 0);
176
+
177
+ return {};
178
+ }
179
+
180
+ std::vector<torch::Tensor> raymarch_forward(
181
+ torch::Tensor rayposim,
182
+ torch::Tensor raydirim,
183
+ float stepsize,
184
+ torch::Tensor tminmaxim,
185
+
186
+ torch::optional<torch::Tensor> sortedobjid,
187
+ torch::optional<torch::Tensor> nodechildren,
188
+ torch::optional<torch::Tensor> nodeaabb,
189
+
190
+ torch::Tensor primpos,
191
+ torch::optional<torch::Tensor> primrot,
192
+ torch::optional<torch::Tensor> primscale,
193
+
194
+ torch::Tensor tplate,
195
+ torch::optional<torch::Tensor> warp,
196
+
197
+ torch::Tensor rayrgbaim,
198
+ torch::optional<torch::Tensor> raysatim,
199
+ torch::optional<torch::Tensor> raytermim,
200
+
201
+ int algorithm=0,
202
+ bool sortboxes=true,
203
+ int maxhitboxes=512,
204
+ bool synchitboxes=false,
205
+ bool chlast=false,
206
+ float fadescale=8.f,
207
+ float fadeexp=8.f,
208
+ int accum=0,
209
+ float termthresh=0.f,
210
+ int griddim=3,
211
+ int blocksizex=8,
212
+ int blocksizey=16) {
213
+ CHECK_INPUT(rayposim);
214
+ CHECK_INPUT(raydirim);
215
+ CHECK_INPUT(tminmaxim);
216
+ if (sortedobjid) { CHECK_INPUT(*sortedobjid); }
217
+ if (nodechildren) { CHECK_INPUT(*nodechildren); }
218
+ if (nodeaabb) { CHECK_INPUT(*nodeaabb); }
219
+ CHECK_INPUT(tplate);
220
+ if (warp) { CHECK_INPUT(*warp); }
221
+ CHECK_INPUT(primpos);
222
+ if (primrot) { CHECK_INPUT(*primrot); }
223
+ if (primscale) { CHECK_INPUT(*primscale); }
224
+ CHECK_INPUT(rayrgbaim);
225
+ if (raysatim) { CHECK_INPUT(*raysatim); }
226
+ if (raytermim) { CHECK_INPUT(*raytermim); }
227
+
228
+ int N = rayposim.size(0);
229
+ int H = rayposim.size(1);
230
+ int W = rayposim.size(2);
231
+ int K = primpos.size(1);
232
+
233
+ int TD, TH, TW;
234
+ if (chlast) {
235
+ TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4);
236
+ } else {
237
+ TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5);
238
+ }
239
+
240
+ int WD = 0, WH = 0, WW = 0;
241
+ if (warp) {
242
+ if (chlast) {
243
+ WD = warp->size(2); WH = warp->size(3); WW = warp->size(4);
244
+ } else {
245
+ WD = warp->size(3); WH = warp->size(4); WW = warp->size(5);
246
+ }
247
+ }
248
+
249
+ raymarch_forward_cuda(N, H, W, K,
250
+ reinterpret_cast<float *>(rayposim.data_ptr()),
251
+ reinterpret_cast<float *>(raydirim.data_ptr()),
252
+ stepsize,
253
+ reinterpret_cast<float *>(tminmaxim.data_ptr()),
254
+ sortedobjid ? reinterpret_cast<int *>(sortedobjid->data_ptr()) : nullptr,
255
+ nodechildren ? reinterpret_cast<int *>(nodechildren->data_ptr()) : nullptr,
256
+ nodeaabb ? reinterpret_cast<float *>(nodeaabb->data_ptr()) : nullptr,
257
+
258
+ // prim transforms
259
+ reinterpret_cast<float *>(primpos.data_ptr()),
260
+ primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
261
+ primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
262
+
263
+ // prim sampler
264
+ TD, TH, TW,
265
+ reinterpret_cast<float *>(tplate.data_ptr()),
266
+ WD, WH, WW,
267
+ warp ? reinterpret_cast<float *>(warp->data_ptr()) : nullptr,
268
+
269
+ // prim accumulator
270
+ reinterpret_cast<float *>(rayrgbaim.data_ptr()),
271
+ raysatim ? reinterpret_cast<float *>(raysatim->data_ptr()) : nullptr,
272
+ raytermim ? reinterpret_cast<int *>(raytermim->data_ptr()) : nullptr,
273
+
274
+ // options
275
+ algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh,
276
+ griddim, blocksizex, blocksizey,
277
+ 0);
278
+
279
+ return {};
280
+ }
281
+
282
+ std::vector<torch::Tensor> raymarch_backward(
283
+ torch::Tensor rayposim,
284
+ torch::Tensor raydirim,
285
+ float stepsize,
286
+ torch::Tensor tminmaxim,
287
+
288
+ torch::optional<torch::Tensor> sortedobjid,
289
+ torch::optional<torch::Tensor> nodechildren,
290
+ torch::optional<torch::Tensor> nodeaabb,
291
+
292
+ torch::Tensor primpos,
293
+ torch::Tensor grad_primpos,
294
+ torch::optional<torch::Tensor> primrot,
295
+ torch::optional<torch::Tensor> grad_primrot,
296
+ torch::optional<torch::Tensor> primscale,
297
+ torch::optional<torch::Tensor> grad_primscale,
298
+
299
+ torch::Tensor tplate,
300
+ torch::Tensor grad_tplate,
301
+ torch::optional<torch::Tensor> warp,
302
+ torch::optional<torch::Tensor> grad_warp,
303
+
304
+ torch::Tensor rayrgbaim,
305
+ torch::Tensor grad_rayrgba,
306
+ torch::optional<torch::Tensor> raysatim,
307
+ torch::optional<torch::Tensor> raytermim,
308
+
309
+ int algorithm=0,
310
+ bool sortboxes=true,
311
+ int maxhitboxes=512,
312
+ bool synchitboxes=false,
313
+ bool chlast=false,
314
+ float fadescale=8.f,
315
+ float fadeexp=8.f,
316
+ int accum=0,
317
+ float termthresh=0.f,
318
+ int griddim=3,
319
+ int blocksizex=8,
320
+ int blocksizey=16) {
321
+ CHECK_INPUT(rayposim);
322
+ CHECK_INPUT(raydirim);
323
+ CHECK_INPUT(tminmaxim);
324
+ if (sortedobjid) { CHECK_INPUT(*sortedobjid); }
325
+ if (nodechildren) { CHECK_INPUT(*nodechildren); }
326
+ if (nodeaabb) { CHECK_INPUT(*nodeaabb); }
327
+ CHECK_INPUT(tplate);
328
+ if (warp) { CHECK_INPUT(*warp); }
329
+ CHECK_INPUT(primpos);
330
+ if (primrot) { CHECK_INPUT(*primrot); }
331
+ if (primscale) { CHECK_INPUT(*primscale); }
332
+ CHECK_INPUT(rayrgbaim);
333
+ if (raysatim) { CHECK_INPUT(*raysatim); }
334
+ if (raytermim) { CHECK_INPUT(*raytermim); }
335
+ CHECK_INPUT(grad_rayrgba);
336
+ CHECK_INPUT(grad_tplate);
337
+ if (grad_warp) { CHECK_INPUT(*grad_warp); }
338
+ CHECK_INPUT(grad_primpos);
339
+ if (grad_primrot) { CHECK_INPUT(*grad_primrot); }
340
+ if (grad_primscale) { CHECK_INPUT(*grad_primscale); }
341
+
342
+ int N = rayposim.size(0);
343
+ int H = rayposim.size(1);
344
+ int W = rayposim.size(2);
345
+ int K = primpos.size(1);
346
+
347
+ int TD, TH, TW;
348
+ if (chlast) {
349
+ TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4);
350
+ } else {
351
+ TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5);
352
+ }
353
+
354
+ int WD = 0, WH = 0, WW = 0;
355
+ if (warp) {
356
+ if (chlast) {
357
+ WD = warp->size(2); WH = warp->size(3); WW = warp->size(4);
358
+ } else {
359
+ WD = warp->size(3); WH = warp->size(4); WW = warp->size(5);
360
+ }
361
+ }
362
+
363
+ raymarch_backward_cuda(N, H, W, K,
364
+ reinterpret_cast<float *>(rayposim.data_ptr()),
365
+ reinterpret_cast<float *>(raydirim.data_ptr()),
366
+ stepsize,
367
+ reinterpret_cast<float *>(tminmaxim.data_ptr()),
368
+ sortedobjid ? reinterpret_cast<int *>(sortedobjid->data_ptr()) : nullptr,
369
+ nodechildren ? reinterpret_cast<int *>(nodechildren->data_ptr()) : nullptr,
370
+ nodeaabb ? reinterpret_cast<float *>(nodeaabb->data_ptr()) : nullptr,
371
+
372
+ reinterpret_cast<float *>(primpos.data_ptr()),
373
+ reinterpret_cast<float *>(grad_primpos.data_ptr()),
374
+ primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
375
+ grad_primrot ? reinterpret_cast<float *>(grad_primrot->data_ptr()) : nullptr,
376
+ primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
377
+ grad_primscale ? reinterpret_cast<float *>(grad_primscale->data_ptr()) : nullptr,
378
+
379
+ TD, TH, TW,
380
+ reinterpret_cast<float *>(tplate.data_ptr()),
381
+ reinterpret_cast<float *>(grad_tplate.data_ptr()),
382
+ WD, WH, WW,
383
+ warp ? reinterpret_cast<float *>(warp->data_ptr()) : nullptr,
384
+ grad_warp ? reinterpret_cast<float *>(grad_warp->data_ptr()) : nullptr,
385
+
386
+ reinterpret_cast<float *>(rayrgbaim.data_ptr()),
387
+ reinterpret_cast<float *>(grad_rayrgba.data_ptr()),
388
+ raysatim ? reinterpret_cast<float *>(raysatim->data_ptr()) : nullptr,
389
+ raytermim ? reinterpret_cast<int *>(raytermim->data_ptr()) : nullptr,
390
+
391
+ algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh,
392
+ griddim, blocksizex, blocksizey,
393
+ 0);
394
+
395
+ return {};
396
+ }
397
+
398
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
399
+ m.def("compute_morton", &compute_morton, "compute morton codes (CUDA)");
400
+ m.def("build_tree", &build_tree, "build BVH tree (CUDA)");
401
+ m.def("compute_aabb", &compute_aabb, "compute AABB sizes (CUDA)");
402
+
403
+ m.def("raymarch_forward", &raymarch_forward, "raymarch forward (CUDA)");
404
+ m.def("raymarch_backward", &raymarch_backward, "raymarch backward (CUDA)");
405
+ }
dva/mvp/extensions/mvpraymarch/mvpraymarch.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import time
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.autograd import Function
13
+ import torch.nn.functional as F
14
+
15
+ try:
16
+ from . import mvpraymarchlib
17
+ except:
18
+ import mvpraymarchlib
19
+
20
+ def build_accel(primtransfin, algo, fixedorder=False):
21
+ """build bvh structure given primitive centers and sizes
22
+
23
+ Parameters:
24
+ ----------
25
+ primtransfin : tuple[tensor, tensor, tensor]
26
+ primitive transform tensors
27
+ algo : int
28
+ raymarching algorithm
29
+ fixedorder : optional[str]
30
+ True means the bvh builder will not reorder primitives and will
31
+ use a trivial tree structure. Likely to be slow for arbitrary
32
+ configurations of primitives.
33
+
34
+ """
35
+ primpos, primrot, primscale = primtransfin
36
+
37
+ N = primpos.size(0)
38
+ K = primpos.size(1)
39
+
40
+ dev = primpos.device
41
+
42
+ # compute and sort morton codes
43
+ if fixedorder:
44
+ sortedobjid = (torch.arange(N*K, dtype=torch.int32, device=dev) % K).view(N, K)
45
+ else:
46
+ cmax = primpos.max(dim=1, keepdim=True)[0]
47
+ cmin = primpos.min(dim=1, keepdim=True)[0]
48
+
49
+ centers_norm = (primpos - cmin) / (cmax - cmin).clamp(min=1e-8)
50
+
51
+ mortoncode = torch.empty((N, K), dtype=torch.int32, device=dev)
52
+ mvpraymarchlib.compute_morton(centers_norm, mortoncode, algo)
53
+ sortedcode, sortedobjid_long = torch.sort(mortoncode, dim=-1)
54
+ sortedobjid = sortedobjid_long.int()
55
+
56
+ if fixedorder:
57
+ nodechildren = torch.cat([
58
+ torch.arange(1, (K - 1) * 2 + 1, dtype=torch.int32, device=dev),
59
+ torch.div(torch.arange(-2, -(K * 2 + 1) - 1, -1, dtype=torch.int32, device=dev), 2, rounding_mode="floor")],
60
+ dim=0).view(1, K + K - 1, 2).repeat(N, 1, 1)
61
+ nodeparent = (
62
+ torch.div(torch.arange(-1, K * 2 - 2, dtype=torch.int32, device=dev), 2, rounding_mode="floor")
63
+ .view(1, -1).repeat(N, 1))
64
+ else:
65
+ nodechildren = torch.empty((N, K + K - 1, 2), dtype=torch.int32, device=dev)
66
+ nodeparent = torch.full((N, K + K - 1), -1, dtype=torch.int32, device=dev)
67
+ mvpraymarchlib.build_tree(sortedcode, nodechildren, nodeparent)
68
+
69
+ nodeaabb = torch.empty((N, K + K - 1, 2, 3), dtype=torch.float32, device=dev)
70
+ mvpraymarchlib.compute_aabb(*primtransfin, sortedobjid, nodechildren, nodeparent, nodeaabb, algo)
71
+
72
+ return sortedobjid, nodechildren, nodeaabb
73
+
74
+ class MVPRaymarch(Function):
75
+ """Custom Function for raymarching Mixture of Volumetric Primitives."""
76
+ @staticmethod
77
+ def forward(self, raypos, raydir, stepsize, tminmax,
78
+ primpos, primrot, primscale,
79
+ template, warp,
80
+ rayterm, gradmode, options):
81
+ algo = options["algo"]
82
+ usebvh = options["usebvh"]
83
+ sortprims = options["sortprims"]
84
+ randomorder = options["randomorder"]
85
+ maxhitboxes = options["maxhitboxes"]
86
+ synchitboxes = options["synchitboxes"]
87
+ chlast = options["chlast"]
88
+ fadescale = options["fadescale"]
89
+ fadeexp = options["fadeexp"]
90
+ accum = options["accum"]
91
+ termthresh = options["termthresh"]
92
+ griddim = options["griddim"]
93
+ if isinstance(options["blocksize"], tuple):
94
+ blocksizex, blocksizey = options["blocksize"]
95
+ else:
96
+ blocksizex = options["blocksize"]
97
+ blocksizey = 1
98
+
99
+ assert raypos.is_contiguous() and raypos.size(3) == 3
100
+ assert raydir.is_contiguous() and raydir.size(3) == 3
101
+ assert tminmax.is_contiguous() and tminmax.size(3) == 2
102
+
103
+ assert primpos is None or primpos.is_contiguous() and primpos.size(2) == 3
104
+ assert primrot is None or primrot.is_contiguous() and primrot.size(2) == 3
105
+ assert primscale is None or primscale.is_contiguous() and primscale.size(2) == 3
106
+
107
+ if chlast:
108
+ assert template.is_contiguous() and len(template.size()) == 6 and template.size(-1) == 4
109
+ assert warp is None or (warp.is_contiguous() and warp.size(-1) == 3)
110
+ else:
111
+ assert template.is_contiguous() and len(template.size()) == 6 and template.size(2) == 4
112
+ assert warp is None or (warp.is_contiguous() and warp.size(2) == 3)
113
+
114
+ primtransfin = (primpos, primrot, primscale)
115
+
116
+ # Build bvh
117
+ if usebvh is not False:
118
+ # compute radius of primitives
119
+ sortedobjid, nodechildren, nodeaabb = build_accel(primtransfin,
120
+ algo, fixedorder=usebvh=="fixedorder")
121
+ assert sortedobjid.is_contiguous()
122
+ assert nodechildren.is_contiguous()
123
+ assert nodeaabb.is_contiguous()
124
+
125
+ if randomorder:
126
+ sortedobjid = sortedobjid[torch.randperm(len(sortedobjid))]
127
+ else:
128
+ _, sortedobjid, nodechildren, nodeaabb = None, None, None, None
129
+
130
+ # march through boxes
131
+ N, H, W = raypos.size(0), raypos.size(1), raypos.size(2)
132
+ rayrgba = torch.empty((N, H, W, 4), device=raypos.device)
133
+ if gradmode:
134
+ raysat = torch.full((N, H, W, 3), -1, dtype=torch.float32, device=raypos.device)
135
+ rayterm = None
136
+ else:
137
+ raysat = None
138
+ rayterm = None
139
+
140
+ mvpraymarchlib.raymarch_forward(
141
+ raypos, raydir, stepsize, tminmax,
142
+ sortedobjid, nodechildren, nodeaabb,
143
+ *primtransfin,
144
+ template, warp,
145
+ rayrgba, raysat, rayterm,
146
+ algo, sortprims, maxhitboxes, synchitboxes, chlast,
147
+ fadescale, fadeexp,
148
+ accum, termthresh,
149
+ griddim, blocksizex, blocksizey)
150
+
151
+ self.save_for_backward(
152
+ raypos, raydir, tminmax,
153
+ sortedobjid, nodechildren, nodeaabb,
154
+ primpos, primrot, primscale,
155
+ template, warp,
156
+ rayrgba, raysat, rayterm)
157
+ self.options = options
158
+ self.stepsize = stepsize
159
+
160
+ return rayrgba
161
+
162
+ @staticmethod
163
+ def backward(self, grad_rayrgba):
164
+ (raypos, raydir, tminmax,
165
+ sortedobjid, nodechildren, nodeaabb,
166
+ primpos, primrot, primscale,
167
+ template, warp,
168
+ rayrgba, raysat, rayterm) = self.saved_tensors
169
+ algo = self.options["algo"]
170
+ usebvh = self.options["usebvh"]
171
+ sortprims = self.options["sortprims"]
172
+ maxhitboxes = self.options["maxhitboxes"]
173
+ synchitboxes = self.options["synchitboxes"]
174
+ chlast = self.options["chlast"]
175
+ fadescale = self.options["fadescale"]
176
+ fadeexp = self.options["fadeexp"]
177
+ accum = self.options["accum"]
178
+ termthresh = self.options["termthresh"]
179
+ griddim = self.options["griddim"]
180
+ if isinstance(self.options["bwdblocksize"], tuple):
181
+ blocksizex, blocksizey = self.options["bwdblocksize"]
182
+ else:
183
+ blocksizex = self.options["bwdblocksize"]
184
+ blocksizey = 1
185
+
186
+ stepsize = self.stepsize
187
+
188
+ grad_primpos = torch.zeros_like(primpos)
189
+ grad_primrot = torch.zeros_like(primrot)
190
+ grad_primscale = torch.zeros_like(primscale)
191
+ primtransfin = (primpos, grad_primpos, primrot, grad_primrot, primscale, grad_primscale)
192
+
193
+ grad_template = torch.zeros_like(template)
194
+ grad_warp = torch.zeros_like(warp) if warp is not None else None
195
+
196
+ mvpraymarchlib.raymarch_backward(raypos, raydir, stepsize, tminmax,
197
+ sortedobjid, nodechildren, nodeaabb,
198
+
199
+ *primtransfin,
200
+
201
+ template, grad_template, warp, grad_warp,
202
+
203
+ rayrgba, grad_rayrgba.contiguous(), raysat, rayterm,
204
+
205
+ algo, sortprims, maxhitboxes, synchitboxes, chlast,
206
+ fadescale, fadeexp,
207
+ accum, termthresh,
208
+ griddim, blocksizex, blocksizey)
209
+
210
+ return (None, None, None, None,
211
+ grad_primpos, grad_primrot, grad_primscale,
212
+ grad_template, grad_warp,
213
+ None, None, None)
214
+
215
+ def mvpraymarch(raypos, raydir, stepsize, tminmax,
216
+ primtransf,
217
+ template, warp,
218
+ rayterm=None,
219
+ algo=0, usebvh="fixedorder",
220
+ sortprims=False, randomorder=False,
221
+ maxhitboxes=512, synchitboxes=True,
222
+ chlast=True, fadescale=8., fadeexp=8.,
223
+ accum=0, termthresh=0.,
224
+ griddim=3, blocksize=(8, 16), bwdblocksize=(8, 16)):
225
+ """Main entry point for raymarching MVP.
226
+
227
+ Parameters:
228
+ ----------
229
+ raypos: N x H x W x 3 tensor of ray origins
230
+ raydir: N x H x W x 3 tensor of ray directions
231
+ stepsize: raymarching step size
232
+ tminmax: N x H x W x 2 tensor of raymarching min/max bounds
233
+ template: N x K x 4 x TD x TH x TW tensor of K RGBA primitives
234
+ warp: N x K x 3 x TD x TH x TW tensor of K warp fields (optional)
235
+ primpos: N x K x 3 tensor of primitive centers
236
+ primrot: N x K x 3 x 3 tensor of primitive orientations
237
+ primscale: N x K x 3 tensor of primitive inverse dimension lengths
238
+ algo: algorithm for raymarching (valid values: 0, 1). algo=0 is the fastest.
239
+ Currently algo=0 has a limit of 512 primitives per ray, so problems can
240
+ occur if there are many more boxes. all sortprims=True options have
241
+ this limitation, but you can use (algo=1, sortprims=False,
242
+ usebvh="fixedorder") which works correctly and has no primitive number
243
+ limitation (but is slightly slower).
244
+ usebvh: True to use bvh, "fixedorder" for a simple BVH, False for no bvh
245
+ sortprims: True to sort overlapping primitives at a sample point. Must
246
+ be True for gradients to match the PyTorch gradients. Seems unstable
247
+ if False but also not a big performance bottleneck.
248
+ chlast: whether template is provided as channels last or not. True tends
249
+ to be faster.
250
+ fadescale: Opacity is faded at the borders of the primitives by the equation
251
+ exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of
252
+ the primitive.
253
+ fadeexp: Opacity is faded at the borders of the primitives by the equation
254
+ exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of
255
+ the primitive.
256
+ griddim: CUDA grid dimensionality.
257
+ blocksize: blocksize of CUDA kernels. Should be 2-element tuple if
258
+ griddim>1, or integer if griddim==1."""
259
+ if isinstance(primtransf, tuple):
260
+ primpos, primrot, primscale = primtransf
261
+ else:
262
+ primpos, primrot, primscale = (
263
+ primtransf[:, :, 0, :].contiguous(),
264
+ primtransf[:, :, 1:4, :].contiguous(),
265
+ primtransf[:, :, 4, :].contiguous())
266
+ primtransfin = (primpos, primrot, primscale)
267
+
268
+ out = MVPRaymarch.apply(raypos, raydir, stepsize, tminmax,
269
+ *primtransfin,
270
+ template, warp,
271
+ rayterm, torch.is_grad_enabled(),
272
+ {"algo": algo, "usebvh": usebvh, "sortprims": sortprims, "randomorder": randomorder,
273
+ "maxhitboxes": maxhitboxes, "synchitboxes": synchitboxes,
274
+ "chlast": chlast, "fadescale": fadescale, "fadeexp": fadeexp,
275
+ "accum": accum, "termthresh": termthresh,
276
+ "griddim": griddim, "blocksize": blocksize, "bwdblocksize": bwdblocksize})
277
+ return out
278
+
279
+ class Rodrigues(nn.Module):
280
+ def __init__(self):
281
+ super(Rodrigues, self).__init__()
282
+
283
+ def forward(self, rvec):
284
+ theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))
285
+ rvec = rvec / theta[:, None]
286
+ costh = torch.cos(theta)
287
+ sinth = torch.sin(theta)
288
+ return torch.stack((
289
+ rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh,
290
+ rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth,
291
+ rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth,
292
+
293
+ rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth,
294
+ rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh,
295
+ rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth,
296
+
297
+ rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth,
298
+ rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth,
299
+ rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3)
300
+
301
+ def gradcheck(usebvh=True, sortprims=True, maxhitboxes=512, synchitboxes=False,
302
+ dowarp=False, chlast=False, fadescale=8., fadeexp=8.,
303
+ accum=0, termthresh=0., algo=0, griddim=2, blocksize=(8, 16), bwdblocksize=(8, 16)):
304
+ N = 2
305
+ H = 65
306
+ W = 65
307
+ k3 = 4
308
+ K = k3*k3*k3
309
+
310
+ M = 32
311
+
312
+ print("=================================================================")
313
+ print("usebvh={}, sortprims={}, maxhb={}, synchb={}, dowarp={}, chlast={}, "
314
+ "fadescale={}, fadeexp={}, accum={}, termthresh={}, algo={}, griddim={}, "
315
+ "blocksize={}, bwdblocksize={}".format(
316
+ usebvh, sortprims, maxhitboxes, synchitboxes, dowarp, chlast,
317
+ fadescale, fadeexp, accum, termthresh, algo, griddim, blocksize,
318
+ bwdblocksize))
319
+
320
+ # generate random inputs
321
+ torch.manual_seed(1112)
322
+
323
+ coherent_rays = True
324
+ if not coherent_rays:
325
+ _raypos = torch.randn(N, H, W, 3).to("cuda")
326
+ _raydir = torch.randn(N, H, W, 3).to("cuda")
327
+ _raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))
328
+ else:
329
+ focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)])
330
+ princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)])
331
+ pixely, pixelx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float())
332
+ pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)
333
+
334
+ raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
335
+ raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
336
+ raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))
337
+
338
+ _raypos = torch.tensor([-0.0, 0.0, -4.])[None, None, None, :].repeat(N, H, W, 1).to("cuda")
339
+ _raydir = raydir.to("cuda")
340
+ _raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))
341
+
342
+ max_len = 6.0
343
+ _stepsize = max_len / 15.386928
344
+ _tminmax = max_len*torch.arange(2, dtype=torch.float32)[None, None, None, :].repeat(N, H, W, 1).to("cuda") + \
345
+ torch.rand(N, H, W, 2, device="cuda") * 1.
346
+
347
+ _template = torch.randn(N, K, 4, M, M, M, requires_grad=True)
348
+ _template.data[:, :, -1, :, :, :] -= 3.5
349
+ _template = _template.contiguous().detach().clone()
350
+ _template.requires_grad = True
351
+ gridxyz = torch.stack(torch.meshgrid(
352
+ torch.linspace(-1., 1., M//2),
353
+ torch.linspace(-1., 1., M//2),
354
+ torch.linspace(-1., 1., M//2))[::-1], dim=0).contiguous()
355
+ _warp = (torch.randn(N, K, 3, M//2, M//2, M//2) * 0.01 + gridxyz[None, None, :, :, :, :]).contiguous().detach().clone()
356
+ _warp.requires_grad = True
357
+ _primpos = torch.randn(N, K, 3, requires_grad=True)
358
+ _primpos = torch.randn(N, K, 3, requires_grad=True)
359
+
360
+ coherent_centers = True
361
+ if coherent_centers:
362
+ ns = k3
363
+ #assert ns*ns*ns==K
364
+ grid3d = torch.stack(torch.meshgrid(
365
+ torch.linspace(-1., 1., ns),
366
+ torch.linspace(-1., 1., ns),
367
+ torch.linspace(-1., 1., K//(ns*ns)))[::-1], dim=0)[None]
368
+ _primpos = ((
369
+ grid3d.permute((0, 2, 3, 4, 1)).reshape(1, K, 3).expand(N, -1, -1) +
370
+ 0.1 * torch.randn(N, K, 3, requires_grad=True)
371
+ )).contiguous().detach().clone()
372
+ _primpos.requires_grad = True
373
+ scale_ws = 1.
374
+ _primrot = torch.randn(N, K, 3)
375
+ rodrigues = Rodrigues()
376
+ _primrot = rodrigues(_primrot.view(-1, 3)).view(N, K, 3, 3).contiguous().detach().clone()
377
+ _primrot.requires_grad = True
378
+
379
+ _primscale = torch.randn(N, K, 3, requires_grad=True)
380
+ _primscale.data *= 0.0
381
+
382
+ if dowarp:
383
+ params = [_template, _warp, _primscale, _primrot, _primpos]
384
+ paramnames = ["template", "warp", "primscale", "primrot", "primpos"]
385
+ else:
386
+ params = [_template, _primscale, _primrot, _primpos]
387
+ paramnames = ["template", "primscale", "primrot", "primpos"]
388
+
389
+ termthreshorig = termthresh
390
+
391
+ ########################### run pytorch version ###########################
392
+
393
+ raypos = _raypos
394
+ raydir = _raydir
395
+ stepsize = _stepsize
396
+ tminmax = _tminmax
397
+
398
+ #template = F.softplus(_template.to("cuda") * 1.5)
399
+ template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
400
+ warp = _warp.to("cuda")
401
+ primpos = _primpos.to("cuda") * 0.3
402
+ primrot = _primrot.to("cuda")
403
+ primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))
404
+
405
+ # python raymarching implementation
406
+ rayrgba = torch.zeros((N, H, W, 4)).to("cuda")
407
+ raypos = raypos + raydir * tminmax[:, :, :, 0, None]
408
+ t = tminmax[:, :, :, 0]
409
+
410
+ step = 0
411
+ t0 = t.detach().clone()
412
+ raypos0 = raypos.detach().clone()
413
+
414
+ torch.cuda.synchronize()
415
+ time0 = time.time()
416
+
417
+ while (t < tminmax[:, :, :, 1]).any():
418
+ valid2 = torch.ones_like(rayrgba[:, :, :, 3:4])
419
+
420
+ for k in range(K):
421
+ y0 = torch.bmm(
422
+ (raypos - primpos[:, k, None, None, :]).view(raypos.size(0), -1, raypos.size(3)),
423
+ primrot[:, k, :, :]).view_as(raypos) * primscale[:, k, None, None, :]
424
+
425
+ fade = torch.exp(-fadescale * torch.sum(torch.abs(y0) ** fadeexp, dim=-1, keepdim=True))
426
+
427
+ if dowarp:
428
+ y1 = F.grid_sample(
429
+ warp[:, k, :, :, :, :],
430
+ y0[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)
431
+ else:
432
+ y1 = y0
433
+
434
+ sample = F.grid_sample(
435
+ template[:, k, :, :, :, :],
436
+ y1[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)
437
+
438
+ valid1 = (
439
+ torch.prod(y0[:, :, :, :] >= -1., dim=-1, keepdim=True) *
440
+ torch.prod(y0[:, :, :, :] <= 1., dim=-1, keepdim=True))
441
+
442
+ valid = ((t >= tminmax[:, :, :, 0]) & (t < tminmax[:, :, :, 1])).float()[:, :, :, None]
443
+
444
+ alpha0 = sample[:, :, :, 3:4]
445
+
446
+ rgb = sample[:, :, :, 0:3] * valid * valid1
447
+ alpha = alpha0 * fade * stepsize * valid * valid1
448
+
449
+ if accum == 0:
450
+ newalpha = rayrgba[:, :, :, 3:4] + alpha
451
+ contrib = (newalpha.clamp(max=1.0) - rayrgba[:, :, :, 3:4]) * valid * valid1
452
+ rayrgba = rayrgba + contrib * torch.cat([rgb, torch.ones_like(alpha)], dim=-1)
453
+ else:
454
+ raise
455
+
456
+ step += 1
457
+ t = t0 + stepsize * step
458
+ raypos = raypos0 + raydir * stepsize * step
459
+
460
+ print(rayrgba[..., -1].min().item(), rayrgba[..., -1].max().item())
461
+
462
+ sample0 = rayrgba
463
+
464
+ torch.cuda.synchronize()
465
+ time1 = time.time()
466
+
467
+ sample0.backward(torch.ones_like(sample0))
468
+
469
+ torch.cuda.synchronize()
470
+ time2 = time.time()
471
+
472
+ print("{:<10} {:>10} {:>10} {:>10}".format("", "fwd", "bwd", "total"))
473
+ print("{:<10} {:10.5} {:10.5} {:10.5}".format("pytime", time1 - time0, time2 - time1, time2 - time0))
474
+
475
+ grads0 = [p.grad.detach().clone() for p in params]
476
+
477
+ for p in params:
478
+ p.grad.detach_()
479
+ p.grad.zero_()
480
+
481
+ ############################## run cuda version ###########################
482
+
483
+ raypos = _raypos
484
+ raydir = _raydir
485
+ stepsize = _stepsize
486
+ tminmax = _tminmax
487
+
488
+ template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
489
+ warp = _warp.to("cuda")
490
+ if chlast:
491
+ template = template.permute(0, 1, 3, 4, 5, 2).contiguous()
492
+ warp = warp.permute(0, 1, 3, 4, 5, 2).contiguous()
493
+ primpos = _primpos.to("cuda") * 0.3
494
+ primrot = _primrot.to("cuda")
495
+ primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))
496
+
497
+ niter = 1
498
+
499
+ tf, tb = 0., 0.
500
+ for i in range(niter):
501
+ for p in params:
502
+ try:
503
+ p.grad.detach_()
504
+ p.grad.zero_()
505
+ except:
506
+ pass
507
+ t0 = time.time()
508
+ torch.cuda.synchronize()
509
+ sample1 = mvpraymarch(raypos, raydir, stepsize, tminmax,
510
+ (primpos, primrot, primscale),
511
+ template, warp if dowarp else None,
512
+ algo=algo, usebvh=usebvh, sortprims=sortprims,
513
+ maxhitboxes=maxhitboxes, synchitboxes=synchitboxes,
514
+ chlast=chlast, fadescale=fadescale, fadeexp=fadeexp,
515
+ accum=accum, termthresh=termthreshorig,
516
+ griddim=griddim, blocksize=blocksize, bwdblocksize=bwdblocksize)
517
+ t1 = time.time()
518
+ torch.cuda.synchronize()
519
+ sample1.backward(torch.ones_like(sample1), retain_graph=True)
520
+ torch.cuda.synchronize()
521
+ t2 = time.time()
522
+ tf += t1 - t0
523
+ tb += t2 - t1
524
+
525
+ print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
526
+ grads1 = [p.grad.detach().clone() for p in params]
527
+
528
+ ############# compare results #############
529
+
530
+ print("-----------------------------------------------------------------")
531
+ print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "||py||", "||cuda||", "index", "py", "cuda"))
532
+ ind = torch.argmax(torch.abs(sample0 - sample1))
533
+ print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
534
+ "fwd",
535
+ torch.max(torch.abs(sample0 - sample1)).item(),
536
+ (torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
537
+ torch.sqrt(torch.sum(sample0 * sample0)).item(),
538
+ torch.sqrt(torch.sum(sample1 * sample1)).item(),
539
+ ind.item(),
540
+ sample0.view(-1)[ind].item(),
541
+ sample1.view(-1)[ind].item()))
542
+
543
+ for p, g0, g1 in zip(paramnames, grads0, grads1):
544
+ ind = torch.argmax(torch.abs(g0 - g1))
545
+ print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
546
+ p,
547
+ torch.max(torch.abs(g0 - g1)).item(),
548
+ (torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
549
+ torch.sqrt(torch.sum(g0 * g0)).item(),
550
+ torch.sqrt(torch.sum(g1 * g1)).item(),
551
+ ind.item(),
552
+ g0.view(-1)[ind].item(),
553
+ g1.view(-1)[ind].item()))
554
+
555
+ if __name__ == "__main__":
556
+ gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True,
557
+ dowarp=False, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=0, griddim=3)
558
+ gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True,
559
+ dowarp=True, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=1, griddim=3)
dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include <chrono>
8
+ #include <functional>
9
+ #include <iostream>
10
+ #include <map>
11
+ #include <memory>
12
+ #include <tuple>
13
+ #include <vector>
14
+
15
+ #include "helper_math.h"
16
+
17
+ #include "cudadispatch.h"
18
+
19
+ #include "utils.h"
20
+
21
+ #include "primtransf.h"
22
+ #include "primsampler.h"
23
+ #include "primaccum.h"
24
+
25
+ #include "mvpraymarch_subset_kernel.h"
26
+
27
+ typedef std::shared_ptr<PrimTransfDataBase> PrimTransfDataBase_ptr;
28
+ typedef std::shared_ptr<PrimSamplerDataBase> PrimSamplerDataBase_ptr;
29
+ typedef std::shared_ptr<PrimAccumDataBase> PrimAccumDataBase_ptr;
30
+ typedef std::function<void(dim3, dim3, cudaStream_t, int, int, int, int,
31
+ float3*, float3*, float, float2*, int*, int2*, float3*,
32
+ PrimTransfDataBase_ptr, PrimSamplerDataBase_ptr,
33
+ PrimAccumDataBase_ptr)> mapfn_t;
34
+ typedef RaySubsetFixedBVH<false, 512, true, PrimTransfSRT> raysubset_t;
35
+
36
+ void raymarch_forward_cuda(
37
+ int N, int H, int W, int K,
38
+ float * rayposim,
39
+ float * raydirim,
40
+ float stepsize,
41
+ float * tminmaxim,
42
+
43
+ int * sortedobjid,
44
+ int * nodechildren,
45
+ float * nodeaabb,
46
+ float * primpos,
47
+ float * primrot,
48
+ float * primscale,
49
+
50
+ int TD, int TH, int TW,
51
+ float * tplate,
52
+ int WD, int WH, int WW,
53
+ float * warp,
54
+
55
+ float * rayrgbaim,
56
+ float * raysatim,
57
+ int * raytermim,
58
+
59
+ int algorithm,
60
+ bool sortboxes,
61
+ int maxhitboxes,
62
+ bool synchitboxes,
63
+ bool chlast,
64
+ float fadescale,
65
+ float fadeexp,
66
+ int accum,
67
+ float termthresh,
68
+ int griddim, int blocksizex, int blocksizey,
69
+ cudaStream_t stream) {
70
+ dim3 blocksize(blocksizex, blocksizey);
71
+ dim3 gridsize;
72
+ gridsize = dim3(
73
+ (W + blocksize.x - 1) / blocksize.x,
74
+ (H + blocksize.y - 1) / blocksize.y,
75
+ N);
76
+
77
+ std::shared_ptr<PrimTransfDataBase> primtransf_data;
78
+ primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
79
+ PrimTransfDataBase{},
80
+ K, (float3*)primpos, nullptr,
81
+ K * 3, (float3*)primrot, nullptr,
82
+ K, (float3*)primscale, nullptr});
83
+ std::shared_ptr<PrimSamplerDataBase> primsampler_data;
84
+ if (algorithm == 1) {
85
+ primsampler_data = std::make_shared<PrimSamplerTW<true>::Data>(PrimSamplerTW<true>::Data{
86
+ PrimSamplerDataBase{},
87
+ fadescale, fadeexp,
88
+ K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr,
89
+ K * WD * WH * WW * 3, WD, WH, WW, warp, nullptr});
90
+ } else {
91
+ primsampler_data = std::make_shared<PrimSamplerTW<false>::Data>(PrimSamplerTW<false>::Data{
92
+ PrimSamplerDataBase{},
93
+ fadescale, fadeexp,
94
+ K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr,
95
+ 0, 0, 0, 0, nullptr, nullptr});
96
+ }
97
+ std::shared_ptr<PrimAccumDataBase> primaccum_data = std::make_shared<PrimAccumAdditive::Data>(PrimAccumAdditive::Data{
98
+ PrimAccumDataBase{},
99
+ termthresh, H * W, W, 1, (float4*)rayrgbaim, nullptr, (float3*)raysatim});
100
+
101
+ std::map<int, mapfn_t> dispatcher = {
102
+ {0, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<false>, PrimAccumAdditive>)},
103
+ {1, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<true>, PrimAccumAdditive>)}};
104
+
105
+ auto iter = dispatcher.find(algorithm);
106
+ if (iter != dispatcher.end()) {
107
+ (iter->second)(
108
+ gridsize, blocksize, stream,
109
+ N, H, W, K,
110
+ reinterpret_cast<float3 *>(rayposim),
111
+ reinterpret_cast<float3 *>(raydirim),
112
+ stepsize,
113
+ reinterpret_cast<float2 *>(tminmaxim),
114
+ reinterpret_cast<int *>(sortedobjid),
115
+ reinterpret_cast<int2 *>(nodechildren),
116
+ reinterpret_cast<float3 *>(nodeaabb),
117
+ primtransf_data,
118
+ primsampler_data,
119
+ primaccum_data);
120
+ }
121
+ }
122
+
123
+ void raymarch_backward_cuda(
124
+ int N, int H, int W, int K,
125
+ float * rayposim,
126
+ float * raydirim,
127
+ float stepsize,
128
+ float * tminmaxim,
129
+ int * sortedobjid,
130
+ int * nodechildren,
131
+ float * nodeaabb,
132
+
133
+ float * primpos,
134
+ float * grad_primpos,
135
+ float * primrot,
136
+ float * grad_primrot,
137
+ float * primscale,
138
+ float * grad_primscale,
139
+
140
+ int TD, int TH, int TW,
141
+ float * tplate,
142
+ float * grad_tplate,
143
+ int WD, int WH, int WW,
144
+ float * warp,
145
+ float * grad_warp,
146
+
147
+ float * rayrgbaim,
148
+ float * grad_rayrgba,
149
+ float * raysatim,
150
+ int * raytermim,
151
+
152
+ int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes,
153
+ bool chlast, float fadescale, float fadeexp, int accum, float termthresh,
154
+ int griddim, int blocksizex, int blocksizey,
155
+
156
+ cudaStream_t stream) {
157
+ dim3 blocksize(blocksizex, blocksizey);
158
+ dim3 gridsize;
159
+ gridsize = dim3(
160
+ (W + blocksize.x - 1) / blocksize.x,
161
+ (H + blocksize.y - 1) / blocksize.y,
162
+ N);
163
+
164
+ std::shared_ptr<PrimTransfDataBase> primtransf_data;
165
+ primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
166
+ PrimTransfDataBase{},
167
+ K, (float3*)primpos, (float3*)grad_primpos,
168
+ K * 3, (float3*)primrot, (float3*)grad_primrot,
169
+ K, (float3*)primscale, (float3*)grad_primscale});
170
+ std::shared_ptr<PrimSamplerDataBase> primsampler_data;
171
+ if (algorithm == 1) {
172
+ primsampler_data = std::make_shared<PrimSamplerTW<true>::Data>(PrimSamplerTW<true>::Data{
173
+ PrimSamplerDataBase{},
174
+ fadescale, fadeexp,
175
+ K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate,
176
+ K * WD * WH * WW * 3, WD, WH, WW, warp, grad_warp});
177
+ } else {
178
+ primsampler_data = std::make_shared<PrimSamplerTW<false>::Data>(PrimSamplerTW<false>::Data{
179
+ PrimSamplerDataBase{},
180
+ fadescale, fadeexp,
181
+ K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate,
182
+ 0, 0, 0, 0, nullptr, nullptr});
183
+ }
184
+ std::shared_ptr<PrimAccumDataBase> primaccum_data = std::make_shared<PrimAccumAdditive::Data>(PrimAccumAdditive::Data{
185
+ PrimAccumDataBase{},
186
+ termthresh, H * W, W, 1, (float4*)rayrgbaim, (float4*)grad_rayrgba, (float3*)raysatim});
187
+
188
+ std::map<int, mapfn_t> dispatcher = {
189
+ {0, make_cudacall(raymarch_subset_backward_kernel<true, 512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<false>, PrimAccumAdditive>)},
190
+ {1, make_cudacall(raymarch_subset_backward_kernel<true, 512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<true>, PrimAccumAdditive>)}};
191
+
192
+ auto iter = dispatcher.find(algorithm);
193
+ if (iter != dispatcher.end()) {
194
+ (iter->second)(
195
+ gridsize, blocksize, stream,
196
+ N, H, W, K,
197
+ reinterpret_cast<float3 *>(rayposim),
198
+ reinterpret_cast<float3 *>(raydirim),
199
+ stepsize,
200
+ reinterpret_cast<float2 *>(tminmaxim),
201
+ reinterpret_cast<int *>(sortedobjid),
202
+ reinterpret_cast<int2 *>(nodechildren),
203
+ reinterpret_cast<float3 *>(nodeaabb),
204
+ primtransf_data,
205
+ primsampler_data,
206
+ primaccum_data);
207
+ }
208
+ }
dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ template<
8
+ int maxhitboxes,
9
+ int nwarps,
10
+ class RaySubsetT=RaySubsetFixedBVH<false, 512, true, PrimTransfSRT>,
11
+ class PrimTransfT=PrimTransfSRT,
12
+ class PrimSamplerT=PrimSamplerTW<false>,
13
+ class PrimAccumT=PrimAccumAdditive>
14
+ __global__ void raymarch_subset_forward_kernel(
15
+ int N, int H, int W, int K,
16
+ float3 * rayposim,
17
+ float3 * raydirim,
18
+ float stepsize,
19
+ float2 * tminmaxim,
20
+ int * sortedobjid,
21
+ int2 * nodechildren,
22
+ float3 * nodeaabb,
23
+ typename PrimTransfT::Data primtransf_data,
24
+ typename PrimSamplerT::Data primsampler_data,
25
+ typename PrimAccumT::Data primaccum_data
26
+ ) {
27
+ int w = blockIdx.x * blockDim.x + threadIdx.x;
28
+ int h = blockIdx.y * blockDim.y + threadIdx.y;
29
+ int n = blockIdx.z;
30
+ bool validthread = (w < W) && (h < H) && (n<N);
31
+
32
+ assert(nwarps == 0 || blockDim.x * blockDim.y / 32 <= nwarps);
33
+ const int warpid = __shfl_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32, 0);
34
+ assert(__match_any_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32) == 0xffffffff);
35
+
36
+ // warpmask contains the valid threads in the warp
37
+ unsigned warpmask = 0xffffffff;
38
+ n = min(N - 1, n);
39
+ h = min(H - 1, h);
40
+ w = min(W - 1, w);
41
+
42
+ sortedobjid += n * K;
43
+ nodechildren += n * (K + K - 1);
44
+ nodeaabb += n * (K + K - 1) * 2;
45
+
46
+ primtransf_data.n_stride(n);
47
+ primsampler_data.n_stride(n);
48
+ primaccum_data.n_stride(n, h, w);
49
+
50
+ float3 raypos = rayposim[n * H * W + h * W + w];
51
+ float3 raydir = raydirim[n * H * W + h * W + w];
52
+ float2 tminmax = tminmaxim[n * H * W + h * W + w];
53
+
54
+ int hitboxes[nwarps > 0 ? 1 : maxhitboxes];
55
+ __shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1];
56
+ int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes;
57
+ int nhitboxes = 0;
58
+
59
+ // find raytminmax
60
+ float2 rtminmax = make_float2(std::numeric_limits<float>::infinity(), -std::numeric_limits<float>::infinity());
61
+ RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax,
62
+ sortedobjid, nodechildren, nodeaabb,
63
+ primtransf_data, hitboxes_ptr, nhitboxes);
64
+ rtminmax.x = max(rtminmax.x, tminmax.x);
65
+ rtminmax.y = min(rtminmax.y, tminmax.y);
66
+ __syncwarp(warpmask);
67
+
68
+ float t = tminmax.x;
69
+ raypos = raypos + raydir * tminmax.x;
70
+
71
+ int incs = floor((rtminmax.x - t) / stepsize);
72
+ t += incs * stepsize;
73
+ raypos += raydir * incs * stepsize;
74
+
75
+ PrimAccumT pa;
76
+
77
+ while (!__all_sync(warpmask, t > rtminmax.y + 1e-5f || pa.is_done())) {
78
+ for (int ks = 0; ks < nhitboxes; ++ks) {
79
+ int k = hitboxes_ptr[ks];
80
+
81
+ // compute primitive-relative coordinate
82
+ PrimTransfT pt;
83
+ float3 samplepos = pt.forward(primtransf_data, k, raypos);
84
+
85
+ if (pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f) {
86
+ // sample
87
+ PrimSamplerT ps;
88
+ float4 sample = ps.forward(primsampler_data, k, samplepos);
89
+
90
+ // accumulate
91
+ pa.forward_prim(primaccum_data, sample, stepsize);
92
+ }
93
+ }
94
+
95
+ // update position
96
+ t += stepsize;
97
+ raypos += raydir * stepsize;
98
+ }
99
+
100
+ pa.write(primaccum_data);
101
+ }
102
+
103
+ template <
104
+ bool forwarddir,
105
+ int maxhitboxes,
106
+ int nwarps,
107
+ class RaySubsetT=RaySubsetFixedBVH<false, 512, true, PrimTransfSRT>,
108
+ class PrimTransfT=PrimTransfSRT,
109
+ class PrimSamplerT=PrimSamplerTW<false>,
110
+ class PrimAccumT=PrimAccumAdditive>
111
+ __global__ void raymarch_subset_backward_kernel(
112
+ int N, int H, int W, int K,
113
+ float3 * rayposim,
114
+ float3 * raydirim,
115
+ float stepsize,
116
+ float2 * tminmaxim,
117
+ int * sortedobjid,
118
+ int2 * nodechildren,
119
+ float3 * nodeaabb,
120
+ typename PrimTransfT::Data primtransf_data,
121
+ typename PrimSamplerT::Data primsampler_data,
122
+ typename PrimAccumT::Data primaccum_data
123
+ ) {
124
+ int w = blockIdx.x * blockDim.x + threadIdx.x;
125
+ int h = blockIdx.y * blockDim.y + threadIdx.y;
126
+ int n = blockIdx.z;
127
+ bool validthread = (w < W) && (h < H) && (n<N);
128
+
129
+ assert(nwarps == 0 || blockDim.x * blockDim.y / 32 <= nwarps);
130
+ const int warpid = __shfl_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32, 0);
131
+ assert(__match_any_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32) == 0xffffffff);
132
+
133
+ // warpmask contains the valid threads in the warp
134
+ unsigned warpmask = 0xffffffff;
135
+ n = min(N - 1, n);
136
+ h = min(H - 1, h);
137
+ w = min(W - 1, w);
138
+
139
+ sortedobjid += n * K;
140
+ nodechildren += n * (K + K - 1);
141
+ nodeaabb += n * (K + K - 1) * 2;
142
+
143
+ primtransf_data.n_stride(n);
144
+ primsampler_data.n_stride(n);
145
+ primaccum_data.n_stride(n, h, w);
146
+
147
+ float3 raypos = rayposim[n * H * W + h * W + w];
148
+ float3 raydir = raydirim[n * H * W + h * W + w];
149
+ float2 tminmax = tminmaxim[n * H * W + h * W + w];
150
+
151
+ PrimAccumT pa;
152
+ pa.read(primaccum_data);
153
+
154
+ int hitboxes[nwarps > 0 ? 1 : maxhitboxes];
155
+ __shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1];
156
+ int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes;
157
+ int nhitboxes = 0;
158
+
159
+ // find raytminmax
160
+ float2 rtminmax = make_float2(std::numeric_limits<float>::infinity(), -std::numeric_limits<float>::infinity());
161
+ RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax,
162
+ sortedobjid, nodechildren, nodeaabb,
163
+ primtransf_data, hitboxes_ptr, nhitboxes);
164
+ rtminmax.x = max(rtminmax.x, tminmax.x);
165
+ rtminmax.y = min(rtminmax.y, tminmax.y);
166
+ __syncwarp(warpmask);
167
+
168
+ // set up raymarching position
169
+ float t = tminmax.x;
170
+ raypos = raypos + raydir * tminmax.x;
171
+
172
+ int incs = floor((rtminmax.x - t) / stepsize);
173
+ t += incs * stepsize;
174
+ raypos += raydir * incs * stepsize;
175
+
176
+ if (!forwarddir) {
177
+ int nsteps = pa.get_nsteps();
178
+ t += nsteps * stepsize;
179
+ raypos += raydir * nsteps * stepsize;
180
+ }
181
+
182
+ while (__any_sync(warpmask, (
183
+ (forwarddir && t < rtminmax.y + 1e-5f ||
184
+ !forwarddir && t > rtminmax.x - 1e-5f) &&
185
+ !pa.is_done()))) {
186
+ for (int ks = 0; ks < nhitboxes; ++ks) {
187
+ int k = hitboxes_ptr[forwarddir ? ks : nhitboxes - ks - 1];
188
+
189
+ PrimTransfT pt;
190
+ float3 samplepos = pt.forward(primtransf_data, k, raypos);
191
+
192
+ bool evalprim = pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f;
193
+
194
+ float3 dL_samplepos = make_float3(0.f);
195
+ if (evalprim) {
196
+ PrimSamplerT ps;
197
+ float4 sample = ps.forward(primsampler_data, k, samplepos);
198
+
199
+ float4 dL_sample = pa.forwardbackward_prim(primaccum_data, sample, stepsize);
200
+
201
+ dL_samplepos = ps.backward(primsampler_data, k, samplepos, sample, dL_sample, validthread);
202
+ }
203
+
204
+ if (__any_sync(warpmask, evalprim)) {
205
+ pt.backward(primtransf_data, k, samplepos, dL_samplepos, validthread && evalprim);
206
+ }
207
+ }
208
+
209
+ if (forwarddir) {
210
+ t += stepsize;
211
+ raypos += raydir * stepsize;
212
+ } else {
213
+ t -= stepsize;
214
+ raypos -= raydir * stepsize;
215
+ }
216
+ }
217
+ }
218
+
dva/mvp/extensions/mvpraymarch/primaccum.h ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #ifndef MVPRAYMARCHER_PRIMACCUM_H_
8
+ #define MVPRAYMARCHER_PRIMACCUM_H_
9
+
10
+ struct PrimAccumDataBase {
11
+ typedef PrimAccumDataBase base;
12
+ };
13
+
14
+ struct PrimAccumAdditive {
15
+ struct Data : public PrimAccumDataBase {
16
+ float termthresh;
17
+
18
+ int nstride, hstride, wstride;
19
+ float4 * rayrgbaim;
20
+ float4 * grad_rayrgbaim;
21
+ float3 * raysatim;
22
+
23
+ __forceinline__ __device__ void n_stride(int n, int h, int w) {
24
+ rayrgbaim += n * nstride + h * hstride + w * wstride;
25
+ grad_rayrgbaim += n * nstride + h * hstride + w * wstride;
26
+ if (raysatim) {
27
+ raysatim += n * nstride + h * hstride + w * wstride;
28
+ }
29
+ }
30
+ };
31
+
32
+ float4 rayrgba;
33
+ float3 raysat;
34
+ bool sat;
35
+ float4 dL_rayrgba;
36
+
37
+ __forceinline__ __device__ PrimAccumAdditive() :
38
+ rayrgba(make_float4(0.f)),
39
+ raysat(make_float3(-1.f)),
40
+ sat(false) {
41
+ }
42
+
43
+ __forceinline__ __device__ bool is_done() const {
44
+ return sat;
45
+ }
46
+
47
+ __forceinline__ __device__ int get_nsteps() const {
48
+ return 0;
49
+ }
50
+
51
+ __forceinline__ __device__ void write(const Data & data) {
52
+ *data.rayrgbaim = rayrgba;
53
+ if (data.raysatim) {
54
+ *data.raysatim = raysat;
55
+ }
56
+ }
57
+
58
+ __forceinline__ __device__ void read(const Data & data) {
59
+ dL_rayrgba = *data.grad_rayrgbaim;
60
+ raysat = *data.raysatim;
61
+ }
62
+
63
+ __forceinline__ __device__ void forward_prim(const Data & data, float4 sample, float stepsize) {
64
+ // accumulate
65
+ float3 rgb = make_float3(sample);
66
+ float alpha = sample.w;
67
+ float newalpha = rayrgba.w + alpha * stepsize;
68
+ float contrib = fminf(newalpha, 1.f) - rayrgba.w;
69
+
70
+ rayrgba += make_float4(rgb, 1.f) * contrib;
71
+
72
+ if (newalpha >= 1.f) {
73
+ // save saturation point
74
+ if (!sat) {
75
+ raysat = rgb;
76
+ }
77
+ sat = true;
78
+ }
79
+ }
80
+
81
+ __forceinline__ __device__ float4 forwardbackward_prim(const Data & data, float4 sample, float stepsize) {
82
+ float3 rgb = make_float3(sample);
83
+ float4 rgb1 = make_float4(rgb, 1.f);
84
+ sample.w *= stepsize;
85
+
86
+ bool thissat = rayrgba.w + sample.w >= 1.f;
87
+ sat = sat || thissat;
88
+
89
+ float weight = sat ? (1.f - rayrgba.w) : sample.w;
90
+
91
+ float3 dL_rgb = weight * make_float3(dL_rayrgba);
92
+ float dL_alpha = sat ? 0.f :
93
+ stepsize * dot(rgb1 - (raysat.x > -1.f ? make_float4(raysat, 1.f) : make_float4(0.f)), dL_rayrgba);
94
+
95
+ rayrgba += make_float4(rgb, 1.f) * weight;
96
+
97
+ return make_float4(dL_rgb, dL_alpha);
98
+ }
99
+ };
100
+
101
+ #endif
dva/mvp/extensions/mvpraymarch/primsampler.h ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #ifndef MVPRAYMARCHER_PRIMSAMPLER_H_
8
+ #define MVPRAYMARCHER_PRIMSAMPLER_H_
9
+
10
+ struct PrimSamplerDataBase {
11
+ typedef PrimSamplerDataBase base;
12
+ };
13
+
14
+ template<
15
+ bool dowarp,
16
+ template<typename> class GridSamplerT=GridSamplerChlast>
17
+ struct PrimSamplerTW {
18
+ struct Data : public PrimSamplerDataBase {
19
+ float fadescale, fadeexp;
20
+
21
+ int tplate_nstride;
22
+ int TD, TH, TW;
23
+ float * tplate;
24
+ float * grad_tplate;
25
+
26
+ int warp_nstride;
27
+ int WD, WH, WW;
28
+ float * warp;
29
+ float * grad_warp;
30
+
31
+ __forceinline__ __device__ void n_stride(int n) {
32
+ tplate += n * tplate_nstride;
33
+ grad_tplate += n * tplate_nstride;
34
+ warp += n * warp_nstride;
35
+ grad_warp += n * warp_nstride;
36
+ }
37
+ };
38
+
39
+ float fade;
40
+ float * tplate_ptr;
41
+ float * warp_ptr;
42
+ float3 yy1;
43
+
44
+ __forceinline__ __device__ float4 forward(
45
+ const Data & data,
46
+ int k,
47
+ float3 y0) {
48
+ fade = __expf(-data.fadescale * (
49
+ __powf(abs(y0.x), data.fadeexp) +
50
+ __powf(abs(y0.y), data.fadeexp) +
51
+ __powf(abs(y0.z), data.fadeexp)));
52
+
53
+ if (dowarp) {
54
+ warp_ptr = data.warp + (k * 3 * data.WD * data.WH * data.WW);
55
+ yy1 = GridSamplerT<float3>::forward(3, data.WD, data.WH, data.WW, warp_ptr, y0, false);
56
+ } else {
57
+ yy1 = y0;
58
+ }
59
+
60
+ tplate_ptr = data.tplate + (k * 4 * data.TD * data.TH * data.TW);
61
+ float4 sample = GridSamplerT<float4>::forward(4, data.TD, data.TH, data.TW, tplate_ptr, yy1, false);
62
+
63
+ sample.w *= fade;
64
+
65
+ return sample;
66
+ }
67
+
68
+ __forceinline__ __device__ float3 backward(const Data & data, int k, float3 y0,
69
+ float4 sample, float4 dL_sample, bool validthread) {
70
+ float3 dfade_y0 = -(data.fadescale * data.fadeexp) * make_float3(
71
+ __powf(abs(y0.x), data.fadeexp - 1.f) * (y0.x > 0.f ? 1.f : -1.f),
72
+ __powf(abs(y0.y), data.fadeexp - 1.f) * (y0.y > 0.f ? 1.f : -1.f),
73
+ __powf(abs(y0.z), data.fadeexp - 1.f) * (y0.z > 0.f ? 1.f : -1.f));
74
+ float3 dL_y0 = dfade_y0 * sample.w * dL_sample.w;
75
+
76
+ dL_sample.w *= fade;
77
+
78
+ float * grad_tplate_ptr = data.grad_tplate + (k * 4 * data.TD * data.TH * data.TW);
79
+ float3 dL_y1 = GridSamplerT<float4>::backward(4, data.TD, data.TH, data.TW,
80
+ tplate_ptr, grad_tplate_ptr, yy1, validthread ? dL_sample : make_float4(0.f), false);
81
+
82
+ if (dowarp) {
83
+ float * grad_warp_ptr = data.grad_warp + (k * 3 * data.WD * data.WH * data.WW);
84
+ dL_y0 += GridSamplerT<float3>::backward(3, data.WD, data.WH, data.WW,
85
+ warp_ptr, grad_warp_ptr, y0, validthread ? dL_y1 : make_float3(0.f), false);
86
+ } else {
87
+ dL_y0 += dL_y1;
88
+ }
89
+
90
+ return dL_y0;
91
+ }
92
+ };
93
+
94
+ #endif
dva/mvp/extensions/mvpraymarch/primtransf.h ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #ifndef MVPRAYMARCHER_PRIMTRANSF_H_
8
+ #define MVPRAYMARCHER_PRIMTRANSF_H_
9
+
10
+ #include "utils.h"
11
+
12
+ __forceinline__ __device__ void compute_aabb_srt(
13
+ float3 pt, float3 pr0, float3 pr1, float3 pr2, float3 ps,
14
+ float3 & pmin, float3 & pmax) {
15
+ float3 p;
16
+ p = make_float3(-1.f, -1.f, -1.f) / ps;
17
+ p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
18
+
19
+ pmin = p;
20
+ pmax = p;
21
+
22
+ p = make_float3(1.f, -1.f, -1.f) / ps;
23
+ p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
24
+
25
+ pmin = fminf(pmin, p);
26
+ pmax = fmaxf(pmax, p);
27
+
28
+ p = make_float3(-1.f, 1.f, -1.f) / ps;
29
+ p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
30
+
31
+ pmin = fminf(pmin, p);
32
+ pmax = fmaxf(pmax, p);
33
+
34
+ p = make_float3(1.f, 1.f, -1.f) / ps;
35
+ p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
36
+
37
+ pmin = fminf(pmin, p);
38
+ pmax = fmaxf(pmax, p);
39
+
40
+ p = make_float3(-1.f, -1.f, 1.f) / ps;
41
+ p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
42
+
43
+ pmin = fminf(pmin, p);
44
+ pmax = fmaxf(pmax, p);
45
+
46
+ p = make_float3(1.f, -1.f, 1.f) / ps;
47
+ p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
48
+
49
+ pmin = fminf(pmin, p);
50
+ pmax = fmaxf(pmax, p);
51
+
52
+ p = make_float3(-1.f, 1.f, 1.f) / ps;
53
+ p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
54
+
55
+ pmin = fminf(pmin, p);
56
+ pmax = fmaxf(pmax, p);
57
+
58
+ p = make_float3(1.f, 1.f, 1.f) / ps;
59
+ p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
60
+
61
+ pmin = fminf(pmin, p);
62
+ pmax = fmaxf(pmax, p);
63
+ }
64
+
65
+ struct PrimTransfDataBase {
66
+ typedef PrimTransfDataBase base;
67
+ };
68
+
69
+ struct PrimTransfSRT {
70
+ struct Data : public PrimTransfDataBase {
71
+ int primpos_nstride;
72
+ float3 * primpos;
73
+ float3 * grad_primpos;
74
+ int primrot_nstride;
75
+ float3 * primrot;
76
+ float3 * grad_primrot;
77
+ int primscale_nstride;
78
+ float3 * primscale;
79
+ float3 * grad_primscale;
80
+
81
+ __forceinline__ __device__ void n_stride(int n) {
82
+ primpos += n * primpos_nstride;
83
+ grad_primpos += n * primpos_nstride;
84
+ primrot += n * primrot_nstride;
85
+ grad_primrot += n * primrot_nstride;
86
+ primscale += n * primscale_nstride;
87
+ grad_primscale += n * primscale_nstride;
88
+ }
89
+
90
+ __forceinline__ __device__ float3 get_center(int n, int k) {
91
+ return primpos[n * primpos_nstride + k];
92
+ }
93
+
94
+ __forceinline__ __device__ void compute_aabb(int n, int k, float3 & pmin, float3 & pmax) {
95
+ float3 pt = primpos[n * primpos_nstride + k];
96
+ float3 pr0 = primrot[n * primrot_nstride + k * 3 + 0];
97
+ float3 pr1 = primrot[n * primrot_nstride + k * 3 + 1];
98
+ float3 pr2 = primrot[n * primrot_nstride + k * 3 + 2];
99
+ float3 ps = primscale[n * primscale_nstride + k];
100
+
101
+ compute_aabb_srt(pt, pr0, pr1, pr2, ps, pmin, pmax);
102
+ }
103
+ };
104
+
105
+ float3 xmt;
106
+ float3 pr0;
107
+ float3 pr1;
108
+ float3 pr2;
109
+ float3 rxmt;
110
+ float3 ps;
111
+
112
+ static __forceinline__ __device__ bool valid(float3 pos) {
113
+ return (
114
+ pos.x > -1.f && pos.x < 1.f &&
115
+ pos.y > -1.f && pos.y < 1.f &&
116
+ pos.z > -1.f && pos.z < 1.f);
117
+ }
118
+
119
+ __forceinline__ __device__ float3 forward(
120
+ const Data & data,
121
+ int k,
122
+ float3 x) {
123
+ float3 pt = data.primpos[k];
124
+ pr0 = data.primrot[(k) * 3 + 0];
125
+ pr1 = data.primrot[(k) * 3 + 1];
126
+ pr2 = data.primrot[(k) * 3 + 2];
127
+ ps = data.primscale[k];
128
+ xmt = x - pt;
129
+ rxmt = pr0 * xmt.x + pr1 * xmt.y + pr2 * xmt.z;
130
+ float3 y0 = rxmt * ps;
131
+ return y0;
132
+ }
133
+
134
+ static __forceinline__ __device__ void forward2(
135
+ const Data & data,
136
+ int k,
137
+ float3 r, float3 d, float3 & rout, float3 & dout) {
138
+ float3 pt = data.primpos[k];
139
+ float3 pr0 = data.primrot[k * 3 + 0];
140
+ float3 pr1 = data.primrot[k * 3 + 1];
141
+ float3 pr2 = data.primrot[k * 3 + 2];
142
+ float3 ps = data.primscale[k];
143
+ float3 xmt = r - pt;
144
+ float3 dmt = d;
145
+ float3 rxmt = pr0 * xmt.x;
146
+ float3 rdmt = pr0 * dmt.x;
147
+ rxmt += pr1 * xmt.y;
148
+ rdmt += pr1 * dmt.y;
149
+ rxmt += pr2 * xmt.z;
150
+ rdmt += pr2 * dmt.z;
151
+ rout = rxmt * ps;
152
+ dout = rdmt * ps;
153
+ }
154
+
155
+ __forceinline__ __device__ void backward(const Data & data, int k, float3 x, float3 dL_y0, bool validthread) {
156
+ fastAtomicAdd((float*)data.grad_primscale + k * 3 + 0, validthread ? rxmt.x * dL_y0.x : 0.f);
157
+ fastAtomicAdd((float*)data.grad_primscale + k * 3 + 1, validthread ? rxmt.y * dL_y0.y : 0.f);
158
+ fastAtomicAdd((float*)data.grad_primscale + k * 3 + 2, validthread ? rxmt.z * dL_y0.z : 0.f);
159
+
160
+ dL_y0 *= ps;
161
+ float3 gpr0 = xmt.x * dL_y0;
162
+ fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 0, validthread ? gpr0.x : 0.f);
163
+ fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 1, validthread ? gpr0.y : 0.f);
164
+ fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 2, validthread ? gpr0.z : 0.f);
165
+
166
+ float3 gpr1 = xmt.y * dL_y0;
167
+ fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 0, validthread ? gpr1.x : 0.f);
168
+ fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 1, validthread ? gpr1.y : 0.f);
169
+ fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 2, validthread ? gpr1.z : 0.f);
170
+
171
+ float3 gpr2 = xmt.z * dL_y0;
172
+ fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 0, validthread ? gpr2.x : 0.f);
173
+ fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 1, validthread ? gpr2.y : 0.f);
174
+ fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 2, validthread ? gpr2.z : 0.f);
175
+
176
+ fastAtomicAdd((float*)data.grad_primpos + k * 3 + 0, validthread ? -dot(pr0, dL_y0) : 0.f);
177
+ fastAtomicAdd((float*)data.grad_primpos + k * 3 + 1, validthread ? -dot(pr1, dL_y0) : 0.f);
178
+ fastAtomicAdd((float*)data.grad_primpos + k * 3 + 2, validthread ? -dot(pr2, dL_y0) : 0.f);
179
+ }
180
+ };
181
+
182
+ #endif
dva/mvp/extensions/mvpraymarch/setup.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from setuptools import setup
8
+
9
+ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
10
+
11
+ if __name__ == "__main__":
12
+ import torch
13
+ setup(
14
+ name="mvpraymarch",
15
+ ext_modules=[
16
+ CUDAExtension(
17
+ "mvpraymarchlib",
18
+ sources=["mvpraymarch.cpp", "mvpraymarch_kernel.cu", "bvh.cu"],
19
+ extra_compile_args={
20
+ "nvcc": [
21
+ "-use_fast_math",
22
+ "-arch=sm_70",
23
+ "-std=c++17",
24
+ "-lineinfo",
25
+ ]
26
+ }
27
+ )
28
+ ],
29
+ cmdclass={"build_ext": BuildExtension}
30
+ )
dva/mvp/extensions/mvpraymarch/utils.h ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #ifndef MVPRAYMARCHER_UTILS_H_
8
+ #define MVPRAYMARCHER_UTILS_H_
9
+
10
+ #include <cassert>
11
+ #include <cmath>
12
+
13
+ #include <limits>
14
+
15
+ #include "helper_math.h"
16
+
17
+ static __forceinline__ __device__ float clock_diff(long long int end, long long int start) {
18
+ long long int max_clock = std::numeric_limits<long long int>::max();
19
+ return (end<start? (end + float(max_clock-start)) : float(end-start));
20
+ }
21
+
22
+ static __forceinline__ __device__
23
+ bool allgt(float3 a, float3 b) {
24
+ return a.x >= b.x && a.y >= b.y && a.z >= b.z;
25
+ }
26
+
27
+ static __forceinline__ __device__
28
+ bool alllt(float3 a, float3 b) {
29
+ return a.x <= b.x && a.y <= b.y && a.z <= b.z;
30
+ }
31
+
32
+ static __forceinline__ __device__
33
+ float4 softplus(float4 x) {
34
+ return make_float4(
35
+ x.x > 20.f ? x.x : logf(1.f + expf(x.x)),
36
+ x.y > 20.f ? x.y : logf(1.f + expf(x.y)),
37
+ x.z > 20.f ? x.z : logf(1.f + expf(x.z)),
38
+ x.w > 20.f ? x.w : logf(1.f + expf(x.w)));
39
+ }
40
+
41
+ static __forceinline__ __device__
42
+ float softplus(float x) {
43
+ // that's a neat trick
44
+ return __logf(1.f + __expf(-abs(x))) + max(x, 0.f);
45
+ }
46
+ static __forceinline__ __device__
47
+ float softplus_grad(float x) {
48
+ // that's a neat trick
49
+ float expnabsx = __expf(-abs(x));
50
+ return (0.5f - expnabsx / (1.f + expnabsx)) * copysign(1.f, x) + 0.5f;
51
+ }
52
+
53
+
54
+ static __forceinline__ __device__
55
+ float4 sigmoid(float4 x) {
56
+ return make_float4(
57
+ 1.f / (1.f + expf(-x.x)),
58
+ 1.f / (1.f + expf(-x.y)),
59
+ 1.f / (1.f + expf(-x.z)),
60
+ 1.f / (1.f + expf(-x.w)));
61
+ }
62
+
63
+ // perform reduction on warp, then call atomicAdd for only one lane
64
+ static __forceinline__ __device__ void fastAtomicAdd(float * ptr, float val) {
65
+ for (int offset = 16; offset > 0; offset /= 2) {
66
+ val += __shfl_down_sync(0xffffffff, val, offset);
67
+ }
68
+
69
+ const int laneid = (threadIdx.y * blockDim.x + threadIdx.x) % 32;
70
+ if (laneid == 0) {
71
+ atomicAdd(ptr, val);
72
+ }
73
+ }
74
+
75
+
76
+ static __forceinline__ __device__
77
+ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
78
+ return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
79
+ }
80
+
81
+ static __forceinline__ __device__
82
+ void safe_add_3d(float *data, int d, int h, int w,
83
+ int sD, int sH, int sW, int D, int H, int W,
84
+ float delta) {
85
+ if (within_bounds_3d(d, h, w, D, H, W)) {
86
+ atomicAdd(data + d * sD + h * sH + w * sW, delta);
87
+ }
88
+ }
89
+
90
+ static __forceinline__ __device__
91
+ void safe_add_3d(float3 *data, int d, int h, int w,
92
+ int sD, int sH, int sW, int D, int H, int W,
93
+ float3 delta) {
94
+ if (within_bounds_3d(d, h, w, D, H, W)) {
95
+ atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 0, delta.x);
96
+ atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 1, delta.y);
97
+ atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 2, delta.z);
98
+ }
99
+ }
100
+
101
+ static __forceinline__ __device__
102
+ void safe_add_3d(float4 *data, int d, int h, int w,
103
+ int sD, int sH, int sW, int D, int H, int W,
104
+ float4 delta) {
105
+ if (within_bounds_3d(d, h, w, D, H, W)) {
106
+ atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 0, delta.x);
107
+ atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 1, delta.y);
108
+ atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 2, delta.z);
109
+ atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 3, delta.w);
110
+ }
111
+ }
112
+
113
+ static __forceinline__ __device__
114
+ float clip_coordinates(float in, int clip_limit) {
115
+ return ::min(static_cast<float>(clip_limit - 1), ::max(in, 0.f));
116
+ }
117
+
118
+ template <typename scalar_t>
119
+ static __forceinline__ __device__
120
+ float clip_coordinates_set_grad(float in, int clip_limit, scalar_t *grad_in) {
121
+ if (in < 0.f) {
122
+ *grad_in = static_cast<scalar_t>(0);
123
+ return 0.f;
124
+ } else {
125
+ float max = static_cast<float>(clip_limit - 1);
126
+ if (in > max) {
127
+ *grad_in = static_cast<scalar_t>(0);
128
+ return max;
129
+ } else {
130
+ *grad_in = static_cast<scalar_t>(1);
131
+ return in;
132
+ }
133
+ }
134
+ }
135
+
136
+ template<typename out_t>
137
+ static __device__ out_t grid_sample_forward(int C, int inp_D, int inp_H,
138
+ int inp_W, float* vals, float3 pos, bool border) {
139
+ int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D;
140
+ int out_sC = 1;
141
+
142
+ // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
143
+ float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1);
144
+ float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1);
145
+ float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1);
146
+
147
+ if (border) {
148
+ // clip coordinates to image borders
149
+ ix = clip_coordinates(ix, inp_W);
150
+ iy = clip_coordinates(iy, inp_H);
151
+ iz = clip_coordinates(iz, inp_D);
152
+ }
153
+
154
+ // get corner pixel values from (x, y, z)
155
+ // for 4d, we used north-east-south-west
156
+ // for 5d, we add top-bottom
157
+ int ix_tnw = static_cast<int>(::floor(ix));
158
+ int iy_tnw = static_cast<int>(::floor(iy));
159
+ int iz_tnw = static_cast<int>(::floor(iz));
160
+
161
+ int ix_tne = ix_tnw + 1;
162
+ int iy_tne = iy_tnw;
163
+ int iz_tne = iz_tnw;
164
+
165
+ int ix_tsw = ix_tnw;
166
+ int iy_tsw = iy_tnw + 1;
167
+ int iz_tsw = iz_tnw;
168
+
169
+ int ix_tse = ix_tnw + 1;
170
+ int iy_tse = iy_tnw + 1;
171
+ int iz_tse = iz_tnw;
172
+
173
+ int ix_bnw = ix_tnw;
174
+ int iy_bnw = iy_tnw;
175
+ int iz_bnw = iz_tnw + 1;
176
+
177
+ int ix_bne = ix_tnw + 1;
178
+ int iy_bne = iy_tnw;
179
+ int iz_bne = iz_tnw + 1;
180
+
181
+ int ix_bsw = ix_tnw;
182
+ int iy_bsw = iy_tnw + 1;
183
+ int iz_bsw = iz_tnw + 1;
184
+
185
+ int ix_bse = ix_tnw + 1;
186
+ int iy_bse = iy_tnw + 1;
187
+ int iz_bse = iz_tnw + 1;
188
+
189
+ // get surfaces to each neighbor:
190
+ float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
191
+ float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
192
+ float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
193
+ float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
194
+ float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
195
+ float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
196
+ float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
197
+ float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
198
+
199
+ out_t result;
200
+ //auto inp_ptr_NC = input.data + n * inp_sN;
201
+ //auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
202
+ float * inp_ptr_NC = vals;
203
+ float * out_ptr_NCDHW = &result.x;
204
+ for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
205
+ // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
206
+ // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
207
+ // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
208
+ // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
209
+ *out_ptr_NCDHW = static_cast<float>(0);
210
+ if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
211
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
212
+ }
213
+ if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
214
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
215
+ }
216
+ if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
217
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
218
+ }
219
+ if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
220
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
221
+ }
222
+ if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
223
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
224
+ }
225
+ if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
226
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
227
+ }
228
+ if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
229
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
230
+ }
231
+ if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
232
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
233
+ }
234
+ }
235
+ return result;
236
+ }
237
+
238
+ template<typename out_t>
239
+ static __device__ float3 grid_sample_backward(int C, int inp_D, int inp_H,
240
+ int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out,
241
+ bool border) {
242
+ int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D;
243
+ int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H, gInp_sC = inp_W * inp_H * inp_D;
244
+ int gOut_sC = 1;
245
+
246
+ // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
247
+ float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1);
248
+ float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1);
249
+ float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1);
250
+
251
+ float gix_mult = (inp_W - 1.f) / 2;
252
+ float giy_mult = (inp_H - 1.f) / 2;
253
+ float giz_mult = (inp_D - 1.f) / 2;
254
+
255
+ if (border) {
256
+ // clip coordinates to image borders
257
+ ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult);
258
+ iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult);
259
+ iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult);
260
+ }
261
+
262
+ // get corner pixel values from (x, y, z)
263
+ // for 4d, we used north-east-south-west
264
+ // for 5d, we add top-bottom
265
+ int ix_tnw = static_cast<int>(::floor(ix));
266
+ int iy_tnw = static_cast<int>(::floor(iy));
267
+ int iz_tnw = static_cast<int>(::floor(iz));
268
+
269
+ int ix_tne = ix_tnw + 1;
270
+ int iy_tne = iy_tnw;
271
+ int iz_tne = iz_tnw;
272
+
273
+ int ix_tsw = ix_tnw;
274
+ int iy_tsw = iy_tnw + 1;
275
+ int iz_tsw = iz_tnw;
276
+
277
+ int ix_tse = ix_tnw + 1;
278
+ int iy_tse = iy_tnw + 1;
279
+ int iz_tse = iz_tnw;
280
+
281
+ int ix_bnw = ix_tnw;
282
+ int iy_bnw = iy_tnw;
283
+ int iz_bnw = iz_tnw + 1;
284
+
285
+ int ix_bne = ix_tnw + 1;
286
+ int iy_bne = iy_tnw;
287
+ int iz_bne = iz_tnw + 1;
288
+
289
+ int ix_bsw = ix_tnw;
290
+ int iy_bsw = iy_tnw + 1;
291
+ int iz_bsw = iz_tnw + 1;
292
+
293
+ int ix_bse = ix_tnw + 1;
294
+ int iy_bse = iy_tnw + 1;
295
+ int iz_bse = iz_tnw + 1;
296
+
297
+ // get surfaces to each neighbor:
298
+ float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
299
+ float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
300
+ float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
301
+ float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
302
+ float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
303
+ float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
304
+ float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
305
+ float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
306
+
307
+ float gix = static_cast<float>(0), giy = static_cast<float>(0), giz = static_cast<float>(0);
308
+ //float *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
309
+ //float *gInp_ptr_NC = grad_input.data + n * gInp_sN;
310
+ //float *inp_ptr_NC = input.data + n * inp_sN;
311
+ float *gOut_ptr_NCDHW = &grad_out.x;
312
+ float *gInp_ptr_NC = grad_vals;
313
+ float *inp_ptr_NC = vals;
314
+ // calculate bilinear weighted pixel value and set output pixel
315
+ for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
316
+ float gOut = *gOut_ptr_NCDHW;
317
+
318
+ // calculate and set grad_input
319
+ safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
320
+ safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
321
+ safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
322
+ safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
323
+ safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
324
+ safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
325
+ safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
326
+ safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
327
+
328
+ // calculate grad_grid
329
+ if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
330
+ float tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
331
+ gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;
332
+ giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;
333
+ giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;
334
+ }
335
+ if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
336
+ float tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
337
+ gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;
338
+ giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;
339
+ giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;
340
+ }
341
+ if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
342
+ float tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
343
+ gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;
344
+ giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;
345
+ giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;
346
+ }
347
+ if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
348
+ float tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
349
+ gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;
350
+ giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;
351
+ giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;
352
+ }
353
+ if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
354
+ float bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
355
+ gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;
356
+ giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;
357
+ giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;
358
+ }
359
+ if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
360
+ float bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
361
+ gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;
362
+ giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;
363
+ giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;
364
+ }
365
+ if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
366
+ float bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
367
+ gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;
368
+ giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;
369
+ giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;
370
+ }
371
+ if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
372
+ float bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
373
+ gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;
374
+ giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;
375
+ giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;
376
+ }
377
+ }
378
+
379
+ return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz);
380
+ }
381
+
382
+ // this dummy struct necessary because c++ is dumb
383
+ template<typename out_t>
384
+ struct GridSampler {
385
+ static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W,
386
+ float* vals, float3 pos, bool border) {
387
+ return grid_sample_forward<out_t>(C, inp_D, inp_H, inp_W, vals, pos, border);
388
+ }
389
+
390
+ static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W,
391
+ float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) {
392
+ return grid_sample_backward<out_t>(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border);
393
+ }
394
+ };
395
+
396
+ //template <typename T>
397
+ //__device__ void cswap ( T& a, T& b ) {
398
+ // T c(a); a=b; b=c;
399
+ //}
400
+
401
+ static __forceinline__ __device__
402
+ int within_bounds_3d_ind(int d, int h, int w, int D, int H, int W) {
403
+ return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W ? ((d * H) + h) * W + w : -1;
404
+ }
405
+
406
+ template<class out_t>
407
+ static __device__ out_t grid_sample_chlast_forward(int, int inp_D, int inp_H,
408
+ int inp_W, float * vals, float3 pos, bool border) {
409
+ int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H;
410
+
411
+ // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
412
+ float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1);
413
+ float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1);
414
+ float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1);
415
+
416
+ if (border) {
417
+ // clip coordinates to image borders
418
+ ix = clip_coordinates(ix, inp_W);
419
+ iy = clip_coordinates(iy, inp_H);
420
+ iz = clip_coordinates(iz, inp_D);
421
+ }
422
+
423
+ // get corner pixel values from (x, y, z)
424
+ // for 4d, we used north-east-south-west
425
+ // for 5d, we add top-bottom
426
+ int ix_tnw = static_cast<int>(::floor(ix));
427
+ int iy_tnw = static_cast<int>(::floor(iy));
428
+ int iz_tnw = static_cast<int>(::floor(iz));
429
+
430
+ int ix_tne = ix_tnw + 1;
431
+ int iy_tne = iy_tnw;
432
+ int iz_tne = iz_tnw;
433
+
434
+ int ix_tsw = ix_tnw;
435
+ int iy_tsw = iy_tnw + 1;
436
+ int iz_tsw = iz_tnw;
437
+
438
+ int ix_tse = ix_tnw + 1;
439
+ int iy_tse = iy_tnw + 1;
440
+ int iz_tse = iz_tnw;
441
+
442
+ int ix_bnw = ix_tnw;
443
+ int iy_bnw = iy_tnw;
444
+ int iz_bnw = iz_tnw + 1;
445
+
446
+ int ix_bne = ix_tnw + 1;
447
+ int iy_bne = iy_tnw;
448
+ int iz_bne = iz_tnw + 1;
449
+
450
+ int ix_bsw = ix_tnw;
451
+ int iy_bsw = iy_tnw + 1;
452
+ int iz_bsw = iz_tnw + 1;
453
+
454
+ int ix_bse = ix_tnw + 1;
455
+ int iy_bse = iy_tnw + 1;
456
+ int iz_bse = iz_tnw + 1;
457
+
458
+ // get surfaces to each neighbor:
459
+ float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
460
+ float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
461
+ float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
462
+ float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
463
+ float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
464
+ float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
465
+ float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
466
+ float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
467
+
468
+ out_t result;
469
+ memset(&result, 0, sizeof(out_t));
470
+ out_t * inp_ptr_NC = (out_t*)vals;
471
+ out_t * out_ptr_NCDHW = &result;
472
+ {
473
+ if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
474
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
475
+ }
476
+ if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
477
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
478
+ }
479
+ if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
480
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
481
+ }
482
+ if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
483
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
484
+ }
485
+ if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
486
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
487
+ }
488
+ if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
489
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
490
+ }
491
+ if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
492
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
493
+ }
494
+ if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
495
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
496
+ }
497
+ }
498
+
499
+ return result;
500
+ }
501
+
502
+ template<typename out_t>
503
+ static __device__ float3 grid_sample_chlast_backward(int, int inp_D, int inp_H,
504
+ int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out,
505
+ bool border) {
506
+ int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H;
507
+ int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H;
508
+
509
+ // normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
510
+ float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1);
511
+ float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1);
512
+ float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1);
513
+
514
+ float gix_mult = (inp_W - 1.f) / 2;
515
+ float giy_mult = (inp_H - 1.f) / 2;
516
+ float giz_mult = (inp_D - 1.f) / 2;
517
+
518
+ if (border) {
519
+ // clip coordinates to image borders
520
+ ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult);
521
+ iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult);
522
+ iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult);
523
+ }
524
+
525
+ // get corner pixel values from (x, y, z)
526
+ // for 4d, we used north-east-south-west
527
+ // for 5d, we add top-bottom
528
+ int ix_tnw = static_cast<int>(::floor(ix));
529
+ int iy_tnw = static_cast<int>(::floor(iy));
530
+ int iz_tnw = static_cast<int>(::floor(iz));
531
+
532
+ int ix_tne = ix_tnw + 1;
533
+ int iy_tne = iy_tnw;
534
+ int iz_tne = iz_tnw;
535
+
536
+ int ix_tsw = ix_tnw;
537
+ int iy_tsw = iy_tnw + 1;
538
+ int iz_tsw = iz_tnw;
539
+
540
+ int ix_tse = ix_tnw + 1;
541
+ int iy_tse = iy_tnw + 1;
542
+ int iz_tse = iz_tnw;
543
+
544
+ int ix_bnw = ix_tnw;
545
+ int iy_bnw = iy_tnw;
546
+ int iz_bnw = iz_tnw + 1;
547
+
548
+ int ix_bne = ix_tnw + 1;
549
+ int iy_bne = iy_tnw;
550
+ int iz_bne = iz_tnw + 1;
551
+
552
+ int ix_bsw = ix_tnw;
553
+ int iy_bsw = iy_tnw + 1;
554
+ int iz_bsw = iz_tnw + 1;
555
+
556
+ int ix_bse = ix_tnw + 1;
557
+ int iy_bse = iy_tnw + 1;
558
+ int iz_bse = iz_tnw + 1;
559
+
560
+ // get surfaces to each neighbor:
561
+ float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
562
+ float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
563
+ float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
564
+ float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
565
+ float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
566
+ float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
567
+ float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
568
+ float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
569
+
570
+ float gix = static_cast<float>(0), giy = static_cast<float>(0), giz = static_cast<float>(0);
571
+ out_t *gOut_ptr_NCDHW = &grad_out;
572
+ out_t *gInp_ptr_NC = (out_t*)grad_vals;
573
+ out_t *inp_ptr_NC = (out_t*)vals;
574
+
575
+ // calculate bilinear weighted pixel value and set output pixel
576
+ {
577
+ out_t gOut = *gOut_ptr_NCDHW;
578
+
579
+ // calculate and set grad_input
580
+ safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
581
+ safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
582
+ safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
583
+ safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
584
+ safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
585
+ safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
586
+ safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
587
+ safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
588
+
589
+ // calculate grad_grid
590
+ if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
591
+ out_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
592
+ gix -= (iy_bse - iy) * (iz_bse - iz) * dot(tnw_val, gOut);
593
+ giy -= (ix_bse - ix) * (iz_bse - iz) * dot(tnw_val, gOut);
594
+ giz -= (ix_bse - ix) * (iy_bse - iy) * dot(tnw_val, gOut);
595
+ }
596
+ if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
597
+ out_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
598
+ gix += (iy_bsw - iy) * (iz_bsw - iz) * dot(tne_val, gOut);
599
+ giy -= (ix - ix_bsw) * (iz_bsw - iz) * dot(tne_val, gOut);
600
+ giz -= (ix - ix_bsw) * (iy_bsw - iy) * dot(tne_val, gOut);
601
+ }
602
+ if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
603
+ out_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
604
+ gix -= (iy - iy_bne) * (iz_bne - iz) * dot(tsw_val, gOut);
605
+ giy += (ix_bne - ix) * (iz_bne - iz) * dot(tsw_val, gOut);
606
+ giz -= (ix_bne - ix) * (iy - iy_bne) * dot(tsw_val, gOut);
607
+ }
608
+ if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
609
+ out_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
610
+ gix += (iy - iy_bnw) * (iz_bnw - iz) * dot(tse_val, gOut);
611
+ giy += (ix - ix_bnw) * (iz_bnw - iz) * dot(tse_val, gOut);
612
+ giz -= (ix - ix_bnw) * (iy - iy_bnw) * dot(tse_val, gOut);
613
+ }
614
+ if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
615
+ out_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
616
+ gix -= (iy_tse - iy) * (iz - iz_tse) * dot(bnw_val, gOut);
617
+ giy -= (ix_tse - ix) * (iz - iz_tse) * dot(bnw_val, gOut);
618
+ giz += (ix_tse - ix) * (iy_tse - iy) * dot(bnw_val, gOut);
619
+ }
620
+ if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
621
+ out_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
622
+ gix += (iy_tsw - iy) * (iz - iz_tsw) * dot(bne_val, gOut);
623
+ giy -= (ix - ix_tsw) * (iz - iz_tsw) * dot(bne_val, gOut);
624
+ giz += (ix - ix_tsw) * (iy_tsw - iy) * dot(bne_val, gOut);
625
+ }
626
+ if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
627
+ out_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
628
+ gix -= (iy - iy_tne) * (iz - iz_tne) * dot(bsw_val, gOut);
629
+ giy += (ix_tne - ix) * (iz - iz_tne) * dot(bsw_val, gOut);
630
+ giz += (ix_tne - ix) * (iy - iy_tne) * dot(bsw_val, gOut);
631
+ }
632
+ if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
633
+ out_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
634
+ gix += (iy - iy_tnw) * (iz - iz_tnw) * dot(bse_val, gOut);
635
+ giy += (ix - ix_tnw) * (iz - iz_tnw) * dot(bse_val, gOut);
636
+ giz += (ix - ix_tnw) * (iy - iy_tnw) * dot(bse_val, gOut);
637
+ }
638
+ }
639
+
640
+ return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz);
641
+ }
642
+
643
+ template<typename out_t>
644
+ struct GridSamplerChlast {
645
+ static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W,
646
+ float* vals, float3 pos, bool border) {
647
+ return grid_sample_chlast_forward<out_t>(C, inp_D, inp_H, inp_W, vals, pos, border);
648
+ }
649
+
650
+ static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W,
651
+ float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) {
652
+ return grid_sample_chlast_backward<out_t>(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border);
653
+ }
654
+ };
655
+
656
+
657
+ inline __host__ __device__ float min_component(float3 a) {
658
+ return fminf(fminf(a.x,a.y),a.z);
659
+ }
660
+
661
+ inline __host__ __device__ float max_component(float3 a) {
662
+ return fmaxf(fmaxf(a.x,a.y),a.z);
663
+ }
664
+
665
+ inline __host__ __device__ float3 abs(float3 a) {
666
+ return make_float3(abs(a.x), abs(a.y), abs(a.z));
667
+ }
668
+
669
+ __forceinline__ __device__ bool ray_aabb_hit(float3 p0, float3 p1, float3 raypos, float3 raydir) {
670
+ float3 t0 = (p0 - raypos) / raydir;
671
+ float3 t1 = (p1 - raypos) / raydir;
672
+ float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
673
+
674
+ return max_component(tmin) <= min_component(tmax);
675
+ }
676
+
677
+ __forceinline__ __device__ bool ray_aabb_hit_ird(float3 p0, float3 p1, float3 raypos, float3 ird) {
678
+ float3 t0 = (p0 - raypos) * ird;
679
+ float3 t1 = (p1 - raypos) * ird;
680
+ float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
681
+
682
+ return max_component(tmin) <= min_component(tmax);
683
+
684
+ }
685
+ __forceinline__ __device__ void ray_aabb_hit_ird_tminmax(float3 p0, float3 p1,
686
+ float3 raypos, float3 ird, float &otmin, float &otmax) {
687
+ float3 t0 = (p0 - raypos) * ird;
688
+ float3 t1 = (p1 - raypos) * ird;
689
+ float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
690
+ tmin = fminf(t0,t1);
691
+ tmax = fmaxf(t0,t1);
692
+ otmin = max_component(tmin);
693
+ otmax = min_component(tmax);
694
+ }
695
+
696
+ inline __device__ bool aabb_intersect(float3 p0, float3 p1, float3 r0, float3 rd, float &tmin, float &tmax) {
697
+ float tymin, tymax, tzmin, tzmax;
698
+ const float3 bounds[2] = {p0, p1};
699
+ float3 ird = 1.0f/rd;
700
+ int sx = (ird.x<0) ? 1 : 0;
701
+ int sy = (ird.y<0) ? 1 : 0;
702
+ int sz = (ird.z<0) ? 1 : 0;
703
+ tmin = (bounds[sx].x - r0.x) * ird.x;
704
+ tmax = (bounds[1-sx].x - r0.x) * ird.x;
705
+ tymin = (bounds[sy].y - r0.y) * ird.y;
706
+ tymax = (bounds[1-sy].y - r0.y) * ird.y;
707
+
708
+ if ((tmin > tymax) || (tymin > tmax))
709
+ return false;
710
+ if (tymin > tmin)
711
+ tmin = tymin;
712
+ if (tymax < tmax)
713
+ tmax = tymax;
714
+
715
+ tzmin = (bounds[sz].z - r0.z) * ird.z;
716
+ tzmax = (bounds[1-sz].z - r0.z) * ird.z;
717
+
718
+ if ((tmin > tzmax) || (tzmin > tmax))
719
+ return false;
720
+ if (tzmin > tmin)
721
+ tmin = tzmin;
722
+ if (tzmax < tmax)
723
+ tmax = tzmax;
724
+
725
+ return true;
726
+ }
727
+
728
+ template<bool sortboxes, int maxhitboxes, bool sync, class PrimTransfT>
729
+ static __forceinline__ __device__ void ray_subset_fixedbvh(
730
+ unsigned warpmask,
731
+ int K,
732
+ float3 raypos,
733
+ float3 raydir,
734
+ float2 tminmax,
735
+ float2 &rtminmax,
736
+ int * sortedobjid,
737
+ int2 * nodechildren,
738
+ float3 * nodeaabb,
739
+ const typename PrimTransfT::Data & primtransf_data,
740
+ int *hitboxes,
741
+ int & num) {
742
+ float3 iraydir = 1.0f/raydir;
743
+ int stack[64];
744
+ int* stack_ptr = stack;
745
+ *stack_ptr++ = -1;
746
+ int node = 0;
747
+ do {
748
+ // check if we're in a leaf
749
+ if (node >= (K - 1)) {
750
+ {
751
+ int k = node - (K - 1);
752
+
753
+ float3 r0, rd;
754
+ PrimTransfT::forward2(primtransf_data, k, raypos, raydir, r0, rd);
755
+
756
+ float3 ird = 1.0f/rd;
757
+ float3 t0 = (-1.f - r0) * ird;
758
+ float3 t1 = (1.f - r0) * ird;
759
+ float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
760
+
761
+ float trmin = max_component(tmin);
762
+ float trmax = min_component(tmax);
763
+
764
+ bool intersection = trmin <= trmax;
765
+
766
+ if (intersection) {
767
+ // hit
768
+ rtminmax.x = fminf(rtminmax.x, trmin);
769
+ rtminmax.y = fmaxf(rtminmax.y, trmax);
770
+ }
771
+
772
+ if (sync) {
773
+ intersection = __any_sync(warpmask, intersection);
774
+ }
775
+
776
+ if (intersection) {
777
+ if (sortboxes) {
778
+ if (num < maxhitboxes) {
779
+ int j = num - 1;
780
+ while (j >= 0 && hitboxes[j] > k) {
781
+ hitboxes[j + 1] = hitboxes[j];
782
+ j = j - 1;
783
+ }
784
+ hitboxes[j + 1] = k;
785
+ num++;
786
+ }
787
+ } else {
788
+ if (num < maxhitboxes) {
789
+ hitboxes[num++] = k;
790
+ }
791
+ }
792
+ }
793
+ }
794
+
795
+ node = *--stack_ptr;
796
+ } else {
797
+ int2 children = make_int2(node * 2 + 1, node * 2 + 2);
798
+
799
+ // check if we're in each child's bbox
800
+ float3 * nodeaabb_ptr = nodeaabb + children.x * 2;
801
+ bool traverse_l = ray_aabb_hit_ird(nodeaabb_ptr[0], nodeaabb_ptr[1], raypos, iraydir);
802
+ bool traverse_r = ray_aabb_hit_ird(nodeaabb_ptr[2], nodeaabb_ptr[3], raypos, iraydir);
803
+
804
+ if (sync) {
805
+ traverse_l = __any_sync(warpmask, traverse_l);
806
+ traverse_r = __any_sync(warpmask, traverse_r);
807
+ }
808
+
809
+ // update stack
810
+ if (!traverse_l && !traverse_r) {
811
+ node = *--stack_ptr;
812
+ } else {
813
+ node = traverse_l ? children.x : children.y;
814
+ if (traverse_l && traverse_r) {
815
+ *stack_ptr++ = children.y;
816
+ }
817
+ }
818
+
819
+ if (sync) {
820
+ __syncwarp(warpmask);
821
+ }
822
+ }
823
+ } while (node != -1);
824
+ }
825
+
826
+ template<bool sortboxes, int maxhitboxes, bool sync, class PrimTransfT>
827
+ struct RaySubsetFixedBVH {
828
+ static __forceinline__ __device__ void forward(
829
+ unsigned warpmask,
830
+ int K,
831
+ float3 raypos,
832
+ float3 raydir,
833
+ float2 tminmax,
834
+ float2 &rtminmax,
835
+ int * sortedobjid,
836
+ int2 * nodechildren,
837
+ float3 * nodeaabb,
838
+ const typename PrimTransfT::Data & primtransf_data,
839
+ int *hitboxes,
840
+ int & num) {
841
+ ray_subset_fixedbvh<sortboxes, maxhitboxes, sync, PrimTransfT>(
842
+ warpmask, K, raypos, raydir, tminmax, rtminmax,
843
+ sortedobjid, nodechildren, nodeaabb, primtransf_data, hitboxes, num);
844
+ }
845
+ };
846
+
847
+ #endif
dva/mvp/extensions/utils/helper_math.h ADDED
@@ -0,0 +1,1453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * Please refer to the NVIDIA end user license agreement (EULA) associated
5
+ * with this source code for terms and conditions that govern your use of
6
+ * this software. Any use, reproduction, disclosure, or distribution of
7
+ * this software and related documentation outside the terms of the EULA
8
+ * is strictly prohibited.
9
+ *
10
+ */
11
+
12
+ /*
13
+ * This file implements common mathematical operations on vector types
14
+ * (float3, float4 etc.) since these are not provided as standard by CUDA.
15
+ *
16
+ * The syntax is modeled on the Cg standard library.
17
+ *
18
+ * This is part of the Helper library includes
19
+ *
20
+ * Thanks to Linh Hah for additions and fixes.
21
+ */
22
+
23
+ #ifndef HELPER_MATH_H
24
+ #define HELPER_MATH_H
25
+
26
+ #include "cuda_runtime.h"
27
+
28
+ typedef unsigned int uint;
29
+ typedef unsigned short ushort;
30
+
31
+ #ifndef EXIT_WAIVED
32
+ #define EXIT_WAIVED 2
33
+ #endif
34
+
35
+ #ifndef __CUDACC__
36
+ #include <math.h>
37
+
38
+ ////////////////////////////////////////////////////////////////////////////////
39
+ // host implementations of CUDA functions
40
+ ////////////////////////////////////////////////////////////////////////////////
41
+
42
+ inline float fminf(float a, float b)
43
+ {
44
+ return a < b ? a : b;
45
+ }
46
+
47
+ inline float fmaxf(float a, float b)
48
+ {
49
+ return a > b ? a : b;
50
+ }
51
+
52
+ inline int max(int a, int b)
53
+ {
54
+ return a > b ? a : b;
55
+ }
56
+
57
+ inline int min(int a, int b)
58
+ {
59
+ return a < b ? a : b;
60
+ }
61
+
62
+ inline float rsqrtf(float x)
63
+ {
64
+ return 1.0f / sqrtf(x);
65
+ }
66
+ #endif
67
+
68
+ ////////////////////////////////////////////////////////////////////////////////
69
+ // constructors
70
+ ////////////////////////////////////////////////////////////////////////////////
71
+
72
+ inline __host__ __device__ float2 make_float2(float s)
73
+ {
74
+ return make_float2(s, s);
75
+ }
76
+ inline __host__ __device__ float2 make_float2(float3 a)
77
+ {
78
+ return make_float2(a.x, a.y);
79
+ }
80
+ inline __host__ __device__ float2 make_float2(int2 a)
81
+ {
82
+ return make_float2(float(a.x), float(a.y));
83
+ }
84
+ inline __host__ __device__ float2 make_float2(uint2 a)
85
+ {
86
+ return make_float2(float(a.x), float(a.y));
87
+ }
88
+
89
+ inline __host__ __device__ int2 make_int2(int s)
90
+ {
91
+ return make_int2(s, s);
92
+ }
93
+ inline __host__ __device__ int2 make_int2(int3 a)
94
+ {
95
+ return make_int2(a.x, a.y);
96
+ }
97
+ inline __host__ __device__ int2 make_int2(uint2 a)
98
+ {
99
+ return make_int2(int(a.x), int(a.y));
100
+ }
101
+ inline __host__ __device__ int2 make_int2(float2 a)
102
+ {
103
+ return make_int2(int(a.x), int(a.y));
104
+ }
105
+
106
+ inline __host__ __device__ uint2 make_uint2(uint s)
107
+ {
108
+ return make_uint2(s, s);
109
+ }
110
+ inline __host__ __device__ uint2 make_uint2(uint3 a)
111
+ {
112
+ return make_uint2(a.x, a.y);
113
+ }
114
+ inline __host__ __device__ uint2 make_uint2(int2 a)
115
+ {
116
+ return make_uint2(uint(a.x), uint(a.y));
117
+ }
118
+
119
+ inline __host__ __device__ float3 make_float3(float s)
120
+ {
121
+ return make_float3(s, s, s);
122
+ }
123
+ inline __host__ __device__ float3 make_float3(float2 a)
124
+ {
125
+ return make_float3(a.x, a.y, 0.0f);
126
+ }
127
+ inline __host__ __device__ float3 make_float3(float2 a, float s)
128
+ {
129
+ return make_float3(a.x, a.y, s);
130
+ }
131
+ inline __host__ __device__ float3 make_float3(float4 a)
132
+ {
133
+ return make_float3(a.x, a.y, a.z);
134
+ }
135
+ inline __host__ __device__ float3 make_float3(int3 a)
136
+ {
137
+ return make_float3(float(a.x), float(a.y), float(a.z));
138
+ }
139
+ inline __host__ __device__ float3 make_float3(uint3 a)
140
+ {
141
+ return make_float3(float(a.x), float(a.y), float(a.z));
142
+ }
143
+
144
+ inline __host__ __device__ int3 make_int3(int s)
145
+ {
146
+ return make_int3(s, s, s);
147
+ }
148
+ inline __host__ __device__ int3 make_int3(int2 a)
149
+ {
150
+ return make_int3(a.x, a.y, 0);
151
+ }
152
+ inline __host__ __device__ int3 make_int3(int2 a, int s)
153
+ {
154
+ return make_int3(a.x, a.y, s);
155
+ }
156
+ inline __host__ __device__ int3 make_int3(uint3 a)
157
+ {
158
+ return make_int3(int(a.x), int(a.y), int(a.z));
159
+ }
160
+ inline __host__ __device__ int3 make_int3(float3 a)
161
+ {
162
+ return make_int3(int(a.x), int(a.y), int(a.z));
163
+ }
164
+
165
+ inline __host__ __device__ uint3 make_uint3(uint s)
166
+ {
167
+ return make_uint3(s, s, s);
168
+ }
169
+ inline __host__ __device__ uint3 make_uint3(uint2 a)
170
+ {
171
+ return make_uint3(a.x, a.y, 0);
172
+ }
173
+ inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
174
+ {
175
+ return make_uint3(a.x, a.y, s);
176
+ }
177
+ inline __host__ __device__ uint3 make_uint3(uint4 a)
178
+ {
179
+ return make_uint3(a.x, a.y, a.z);
180
+ }
181
+ inline __host__ __device__ uint3 make_uint3(int3 a)
182
+ {
183
+ return make_uint3(uint(a.x), uint(a.y), uint(a.z));
184
+ }
185
+
186
+ inline __host__ __device__ float4 make_float4(float s)
187
+ {
188
+ return make_float4(s, s, s, s);
189
+ }
190
+ inline __host__ __device__ float4 make_float4(float3 a)
191
+ {
192
+ return make_float4(a.x, a.y, a.z, 0.0f);
193
+ }
194
+ inline __host__ __device__ float4 make_float4(float3 a, float w)
195
+ {
196
+ return make_float4(a.x, a.y, a.z, w);
197
+ }
198
+ inline __host__ __device__ float4 make_float4(int4 a)
199
+ {
200
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
201
+ }
202
+ inline __host__ __device__ float4 make_float4(uint4 a)
203
+ {
204
+ return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
205
+ }
206
+
207
+ inline __host__ __device__ int4 make_int4(int s)
208
+ {
209
+ return make_int4(s, s, s, s);
210
+ }
211
+ inline __host__ __device__ int4 make_int4(int3 a)
212
+ {
213
+ return make_int4(a.x, a.y, a.z, 0);
214
+ }
215
+ inline __host__ __device__ int4 make_int4(int3 a, int w)
216
+ {
217
+ return make_int4(a.x, a.y, a.z, w);
218
+ }
219
+ inline __host__ __device__ int4 make_int4(uint4 a)
220
+ {
221
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
222
+ }
223
+ inline __host__ __device__ int4 make_int4(float4 a)
224
+ {
225
+ return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
226
+ }
227
+
228
+
229
+ inline __host__ __device__ uint4 make_uint4(uint s)
230
+ {
231
+ return make_uint4(s, s, s, s);
232
+ }
233
+ inline __host__ __device__ uint4 make_uint4(uint3 a)
234
+ {
235
+ return make_uint4(a.x, a.y, a.z, 0);
236
+ }
237
+ inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
238
+ {
239
+ return make_uint4(a.x, a.y, a.z, w);
240
+ }
241
+ inline __host__ __device__ uint4 make_uint4(int4 a)
242
+ {
243
+ return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
244
+ }
245
+
246
+ ////////////////////////////////////////////////////////////////////////////////
247
+ // negate
248
+ ////////////////////////////////////////////////////////////////////////////////
249
+
250
+ inline __host__ __device__ float2 operator-(float2 &a)
251
+ {
252
+ return make_float2(-a.x, -a.y);
253
+ }
254
+ inline __host__ __device__ int2 operator-(int2 &a)
255
+ {
256
+ return make_int2(-a.x, -a.y);
257
+ }
258
+ inline __host__ __device__ float3 operator-(float3 &a)
259
+ {
260
+ return make_float3(-a.x, -a.y, -a.z);
261
+ }
262
+ inline __host__ __device__ int3 operator-(int3 &a)
263
+ {
264
+ return make_int3(-a.x, -a.y, -a.z);
265
+ }
266
+ inline __host__ __device__ float4 operator-(float4 &a)
267
+ {
268
+ return make_float4(-a.x, -a.y, -a.z, -a.w);
269
+ }
270
+ inline __host__ __device__ int4 operator-(int4 &a)
271
+ {
272
+ return make_int4(-a.x, -a.y, -a.z, -a.w);
273
+ }
274
+
275
+ ////////////////////////////////////////////////////////////////////////////////
276
+ // addition
277
+ ////////////////////////////////////////////////////////////////////////////////
278
+
279
+ inline __host__ __device__ float2 operator+(float2 a, float2 b)
280
+ {
281
+ return make_float2(a.x + b.x, a.y + b.y);
282
+ }
283
+ inline __host__ __device__ void operator+=(float2 &a, float2 b)
284
+ {
285
+ a.x += b.x;
286
+ a.y += b.y;
287
+ }
288
+ inline __host__ __device__ float2 operator+(float2 a, float b)
289
+ {
290
+ return make_float2(a.x + b, a.y + b);
291
+ }
292
+ inline __host__ __device__ float2 operator+(float b, float2 a)
293
+ {
294
+ return make_float2(a.x + b, a.y + b);
295
+ }
296
+ inline __host__ __device__ void operator+=(float2 &a, float b)
297
+ {
298
+ a.x += b;
299
+ a.y += b;
300
+ }
301
+
302
+ inline __host__ __device__ int2 operator+(int2 a, int2 b)
303
+ {
304
+ return make_int2(a.x + b.x, a.y + b.y);
305
+ }
306
+ inline __host__ __device__ void operator+=(int2 &a, int2 b)
307
+ {
308
+ a.x += b.x;
309
+ a.y += b.y;
310
+ }
311
+ inline __host__ __device__ int2 operator+(int2 a, int b)
312
+ {
313
+ return make_int2(a.x + b, a.y + b);
314
+ }
315
+ inline __host__ __device__ int2 operator+(int b, int2 a)
316
+ {
317
+ return make_int2(a.x + b, a.y + b);
318
+ }
319
+ inline __host__ __device__ void operator+=(int2 &a, int b)
320
+ {
321
+ a.x += b;
322
+ a.y += b;
323
+ }
324
+
325
+ inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
326
+ {
327
+ return make_uint2(a.x + b.x, a.y + b.y);
328
+ }
329
+ inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
330
+ {
331
+ a.x += b.x;
332
+ a.y += b.y;
333
+ }
334
+ inline __host__ __device__ uint2 operator+(uint2 a, uint b)
335
+ {
336
+ return make_uint2(a.x + b, a.y + b);
337
+ }
338
+ inline __host__ __device__ uint2 operator+(uint b, uint2 a)
339
+ {
340
+ return make_uint2(a.x + b, a.y + b);
341
+ }
342
+ inline __host__ __device__ void operator+=(uint2 &a, uint b)
343
+ {
344
+ a.x += b;
345
+ a.y += b;
346
+ }
347
+
348
+
349
+ inline __host__ __device__ float3 operator+(float3 a, float3 b)
350
+ {
351
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
352
+ }
353
+ inline __host__ __device__ void operator+=(float3 &a, float3 b)
354
+ {
355
+ a.x += b.x;
356
+ a.y += b.y;
357
+ a.z += b.z;
358
+ }
359
+ inline __host__ __device__ float3 operator+(float3 a, float b)
360
+ {
361
+ return make_float3(a.x + b, a.y + b, a.z + b);
362
+ }
363
+ inline __host__ __device__ void operator+=(float3 &a, float b)
364
+ {
365
+ a.x += b;
366
+ a.y += b;
367
+ a.z += b;
368
+ }
369
+
370
+ inline __host__ __device__ int3 operator+(int3 a, int3 b)
371
+ {
372
+ return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
373
+ }
374
+ inline __host__ __device__ void operator+=(int3 &a, int3 b)
375
+ {
376
+ a.x += b.x;
377
+ a.y += b.y;
378
+ a.z += b.z;
379
+ }
380
+ inline __host__ __device__ int3 operator+(int3 a, int b)
381
+ {
382
+ return make_int3(a.x + b, a.y + b, a.z + b);
383
+ }
384
+ inline __host__ __device__ void operator+=(int3 &a, int b)
385
+ {
386
+ a.x += b;
387
+ a.y += b;
388
+ a.z += b;
389
+ }
390
+
391
+ inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
392
+ {
393
+ return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
394
+ }
395
+ inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
396
+ {
397
+ a.x += b.x;
398
+ a.y += b.y;
399
+ a.z += b.z;
400
+ }
401
+ inline __host__ __device__ uint3 operator+(uint3 a, uint b)
402
+ {
403
+ return make_uint3(a.x + b, a.y + b, a.z + b);
404
+ }
405
+ inline __host__ __device__ void operator+=(uint3 &a, uint b)
406
+ {
407
+ a.x += b;
408
+ a.y += b;
409
+ a.z += b;
410
+ }
411
+
412
+ inline __host__ __device__ int3 operator+(int b, int3 a)
413
+ {
414
+ return make_int3(a.x + b, a.y + b, a.z + b);
415
+ }
416
+ inline __host__ __device__ uint3 operator+(uint b, uint3 a)
417
+ {
418
+ return make_uint3(a.x + b, a.y + b, a.z + b);
419
+ }
420
+ inline __host__ __device__ float3 operator+(float b, float3 a)
421
+ {
422
+ return make_float3(a.x + b, a.y + b, a.z + b);
423
+ }
424
+
425
+ inline __host__ __device__ float4 operator+(float4 a, float4 b)
426
+ {
427
+ return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
428
+ }
429
+ inline __host__ __device__ void operator+=(float4 &a, float4 b)
430
+ {
431
+ a.x += b.x;
432
+ a.y += b.y;
433
+ a.z += b.z;
434
+ a.w += b.w;
435
+ }
436
+ inline __host__ __device__ float4 operator+(float4 a, float b)
437
+ {
438
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
439
+ }
440
+ inline __host__ __device__ float4 operator+(float b, float4 a)
441
+ {
442
+ return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
443
+ }
444
+ inline __host__ __device__ void operator+=(float4 &a, float b)
445
+ {
446
+ a.x += b;
447
+ a.y += b;
448
+ a.z += b;
449
+ a.w += b;
450
+ }
451
+
452
+ inline __host__ __device__ int4 operator+(int4 a, int4 b)
453
+ {
454
+ return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
455
+ }
456
+ inline __host__ __device__ void operator+=(int4 &a, int4 b)
457
+ {
458
+ a.x += b.x;
459
+ a.y += b.y;
460
+ a.z += b.z;
461
+ a.w += b.w;
462
+ }
463
+ inline __host__ __device__ int4 operator+(int4 a, int b)
464
+ {
465
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
466
+ }
467
+ inline __host__ __device__ int4 operator+(int b, int4 a)
468
+ {
469
+ return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
470
+ }
471
+ inline __host__ __device__ void operator+=(int4 &a, int b)
472
+ {
473
+ a.x += b;
474
+ a.y += b;
475
+ a.z += b;
476
+ a.w += b;
477
+ }
478
+
479
+ inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
480
+ {
481
+ return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
482
+ }
483
+ inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
484
+ {
485
+ a.x += b.x;
486
+ a.y += b.y;
487
+ a.z += b.z;
488
+ a.w += b.w;
489
+ }
490
+ inline __host__ __device__ uint4 operator+(uint4 a, uint b)
491
+ {
492
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
493
+ }
494
+ inline __host__ __device__ uint4 operator+(uint b, uint4 a)
495
+ {
496
+ return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
497
+ }
498
+ inline __host__ __device__ void operator+=(uint4 &a, uint b)
499
+ {
500
+ a.x += b;
501
+ a.y += b;
502
+ a.z += b;
503
+ a.w += b;
504
+ }
505
+
506
+ ////////////////////////////////////////////////////////////////////////////////
507
+ // subtract
508
+ ////////////////////////////////////////////////////////////////////////////////
509
+
510
+ inline __host__ __device__ float2 operator-(float2 a, float2 b)
511
+ {
512
+ return make_float2(a.x - b.x, a.y - b.y);
513
+ }
514
+ inline __host__ __device__ void operator-=(float2 &a, float2 b)
515
+ {
516
+ a.x -= b.x;
517
+ a.y -= b.y;
518
+ }
519
+ inline __host__ __device__ float2 operator-(float2 a, float b)
520
+ {
521
+ return make_float2(a.x - b, a.y - b);
522
+ }
523
+ inline __host__ __device__ float2 operator-(float b, float2 a)
524
+ {
525
+ return make_float2(b - a.x, b - a.y);
526
+ }
527
+ inline __host__ __device__ void operator-=(float2 &a, float b)
528
+ {
529
+ a.x -= b;
530
+ a.y -= b;
531
+ }
532
+
533
+ inline __host__ __device__ int2 operator-(int2 a, int2 b)
534
+ {
535
+ return make_int2(a.x - b.x, a.y - b.y);
536
+ }
537
+ inline __host__ __device__ void operator-=(int2 &a, int2 b)
538
+ {
539
+ a.x -= b.x;
540
+ a.y -= b.y;
541
+ }
542
+ inline __host__ __device__ int2 operator-(int2 a, int b)
543
+ {
544
+ return make_int2(a.x - b, a.y - b);
545
+ }
546
+ inline __host__ __device__ int2 operator-(int b, int2 a)
547
+ {
548
+ return make_int2(b - a.x, b - a.y);
549
+ }
550
+ inline __host__ __device__ void operator-=(int2 &a, int b)
551
+ {
552
+ a.x -= b;
553
+ a.y -= b;
554
+ }
555
+
556
+ inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
557
+ {
558
+ return make_uint2(a.x - b.x, a.y - b.y);
559
+ }
560
+ inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
561
+ {
562
+ a.x -= b.x;
563
+ a.y -= b.y;
564
+ }
565
+ inline __host__ __device__ uint2 operator-(uint2 a, uint b)
566
+ {
567
+ return make_uint2(a.x - b, a.y - b);
568
+ }
569
+ inline __host__ __device__ uint2 operator-(uint b, uint2 a)
570
+ {
571
+ return make_uint2(b - a.x, b - a.y);
572
+ }
573
+ inline __host__ __device__ void operator-=(uint2 &a, uint b)
574
+ {
575
+ a.x -= b;
576
+ a.y -= b;
577
+ }
578
+
579
+ inline __host__ __device__ float3 operator-(float3 a, float3 b)
580
+ {
581
+ return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
582
+ }
583
+ inline __host__ __device__ void operator-=(float3 &a, float3 b)
584
+ {
585
+ a.x -= b.x;
586
+ a.y -= b.y;
587
+ a.z -= b.z;
588
+ }
589
+ inline __host__ __device__ float3 operator-(float3 a, float b)
590
+ {
591
+ return make_float3(a.x - b, a.y - b, a.z - b);
592
+ }
593
+ inline __host__ __device__ float3 operator-(float b, float3 a)
594
+ {
595
+ return make_float3(b - a.x, b - a.y, b - a.z);
596
+ }
597
+ inline __host__ __device__ void operator-=(float3 &a, float b)
598
+ {
599
+ a.x -= b;
600
+ a.y -= b;
601
+ a.z -= b;
602
+ }
603
+
604
+ inline __host__ __device__ int3 operator-(int3 a, int3 b)
605
+ {
606
+ return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
607
+ }
608
+ inline __host__ __device__ void operator-=(int3 &a, int3 b)
609
+ {
610
+ a.x -= b.x;
611
+ a.y -= b.y;
612
+ a.z -= b.z;
613
+ }
614
+ inline __host__ __device__ int3 operator-(int3 a, int b)
615
+ {
616
+ return make_int3(a.x - b, a.y - b, a.z - b);
617
+ }
618
+ inline __host__ __device__ int3 operator-(int b, int3 a)
619
+ {
620
+ return make_int3(b - a.x, b - a.y, b - a.z);
621
+ }
622
+ inline __host__ __device__ void operator-=(int3 &a, int b)
623
+ {
624
+ a.x -= b;
625
+ a.y -= b;
626
+ a.z -= b;
627
+ }
628
+
629
+ inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
630
+ {
631
+ return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
632
+ }
633
+ inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
634
+ {
635
+ a.x -= b.x;
636
+ a.y -= b.y;
637
+ a.z -= b.z;
638
+ }
639
+ inline __host__ __device__ uint3 operator-(uint3 a, uint b)
640
+ {
641
+ return make_uint3(a.x - b, a.y - b, a.z - b);
642
+ }
643
+ inline __host__ __device__ uint3 operator-(uint b, uint3 a)
644
+ {
645
+ return make_uint3(b - a.x, b - a.y, b - a.z);
646
+ }
647
+ inline __host__ __device__ void operator-=(uint3 &a, uint b)
648
+ {
649
+ a.x -= b;
650
+ a.y -= b;
651
+ a.z -= b;
652
+ }
653
+
654
+ inline __host__ __device__ float4 operator-(float4 a, float4 b)
655
+ {
656
+ return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
657
+ }
658
+ inline __host__ __device__ void operator-=(float4 &a, float4 b)
659
+ {
660
+ a.x -= b.x;
661
+ a.y -= b.y;
662
+ a.z -= b.z;
663
+ a.w -= b.w;
664
+ }
665
+ inline __host__ __device__ float4 operator-(float4 a, float b)
666
+ {
667
+ return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
668
+ }
669
+ inline __host__ __device__ void operator-=(float4 &a, float b)
670
+ {
671
+ a.x -= b;
672
+ a.y -= b;
673
+ a.z -= b;
674
+ a.w -= b;
675
+ }
676
+
677
+ inline __host__ __device__ int4 operator-(int4 a, int4 b)
678
+ {
679
+ return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
680
+ }
681
+ inline __host__ __device__ void operator-=(int4 &a, int4 b)
682
+ {
683
+ a.x -= b.x;
684
+ a.y -= b.y;
685
+ a.z -= b.z;
686
+ a.w -= b.w;
687
+ }
688
+ inline __host__ __device__ int4 operator-(int4 a, int b)
689
+ {
690
+ return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
691
+ }
692
+ inline __host__ __device__ int4 operator-(int b, int4 a)
693
+ {
694
+ return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
695
+ }
696
+ inline __host__ __device__ void operator-=(int4 &a, int b)
697
+ {
698
+ a.x -= b;
699
+ a.y -= b;
700
+ a.z -= b;
701
+ a.w -= b;
702
+ }
703
+
704
+ inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
705
+ {
706
+ return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
707
+ }
708
+ inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
709
+ {
710
+ a.x -= b.x;
711
+ a.y -= b.y;
712
+ a.z -= b.z;
713
+ a.w -= b.w;
714
+ }
715
+ inline __host__ __device__ uint4 operator-(uint4 a, uint b)
716
+ {
717
+ return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
718
+ }
719
+ inline __host__ __device__ uint4 operator-(uint b, uint4 a)
720
+ {
721
+ return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
722
+ }
723
+ inline __host__ __device__ void operator-=(uint4 &a, uint b)
724
+ {
725
+ a.x -= b;
726
+ a.y -= b;
727
+ a.z -= b;
728
+ a.w -= b;
729
+ }
730
+
731
+ ////////////////////////////////////////////////////////////////////////////////
732
+ // multiply
733
+ ////////////////////////////////////////////////////////////////////////////////
734
+
735
+ inline __host__ __device__ float2 operator*(float2 a, float2 b)
736
+ {
737
+ return make_float2(a.x * b.x, a.y * b.y);
738
+ }
739
+ inline __host__ __device__ void operator*=(float2 &a, float2 b)
740
+ {
741
+ a.x *= b.x;
742
+ a.y *= b.y;
743
+ }
744
+ inline __host__ __device__ float2 operator*(float2 a, float b)
745
+ {
746
+ return make_float2(a.x * b, a.y * b);
747
+ }
748
+ inline __host__ __device__ float2 operator*(float b, float2 a)
749
+ {
750
+ return make_float2(b * a.x, b * a.y);
751
+ }
752
+ inline __host__ __device__ void operator*=(float2 &a, float b)
753
+ {
754
+ a.x *= b;
755
+ a.y *= b;
756
+ }
757
+
758
+ inline __host__ __device__ int2 operator*(int2 a, int2 b)
759
+ {
760
+ return make_int2(a.x * b.x, a.y * b.y);
761
+ }
762
+ inline __host__ __device__ void operator*=(int2 &a, int2 b)
763
+ {
764
+ a.x *= b.x;
765
+ a.y *= b.y;
766
+ }
767
+ inline __host__ __device__ int2 operator*(int2 a, int b)
768
+ {
769
+ return make_int2(a.x * b, a.y * b);
770
+ }
771
+ inline __host__ __device__ int2 operator*(int b, int2 a)
772
+ {
773
+ return make_int2(b * a.x, b * a.y);
774
+ }
775
+ inline __host__ __device__ void operator*=(int2 &a, int b)
776
+ {
777
+ a.x *= b;
778
+ a.y *= b;
779
+ }
780
+
781
+ inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
782
+ {
783
+ return make_uint2(a.x * b.x, a.y * b.y);
784
+ }
785
+ inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
786
+ {
787
+ a.x *= b.x;
788
+ a.y *= b.y;
789
+ }
790
+ inline __host__ __device__ uint2 operator*(uint2 a, uint b)
791
+ {
792
+ return make_uint2(a.x * b, a.y * b);
793
+ }
794
+ inline __host__ __device__ uint2 operator*(uint b, uint2 a)
795
+ {
796
+ return make_uint2(b * a.x, b * a.y);
797
+ }
798
+ inline __host__ __device__ void operator*=(uint2 &a, uint b)
799
+ {
800
+ a.x *= b;
801
+ a.y *= b;
802
+ }
803
+
804
+ inline __host__ __device__ float3 operator*(float3 a, float3 b)
805
+ {
806
+ return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
807
+ }
808
+ inline __host__ __device__ void operator*=(float3 &a, float3 b)
809
+ {
810
+ a.x *= b.x;
811
+ a.y *= b.y;
812
+ a.z *= b.z;
813
+ }
814
+ inline __host__ __device__ float3 operator*(float3 a, float b)
815
+ {
816
+ return make_float3(a.x * b, a.y * b, a.z * b);
817
+ }
818
+ inline __host__ __device__ float3 operator*(float b, float3 a)
819
+ {
820
+ return make_float3(b * a.x, b * a.y, b * a.z);
821
+ }
822
+ inline __host__ __device__ void operator*=(float3 &a, float b)
823
+ {
824
+ a.x *= b;
825
+ a.y *= b;
826
+ a.z *= b;
827
+ }
828
+
829
+ inline __host__ __device__ int3 operator*(int3 a, int3 b)
830
+ {
831
+ return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
832
+ }
833
+ inline __host__ __device__ void operator*=(int3 &a, int3 b)
834
+ {
835
+ a.x *= b.x;
836
+ a.y *= b.y;
837
+ a.z *= b.z;
838
+ }
839
+ inline __host__ __device__ int3 operator*(int3 a, int b)
840
+ {
841
+ return make_int3(a.x * b, a.y * b, a.z * b);
842
+ }
843
+ inline __host__ __device__ int3 operator*(int b, int3 a)
844
+ {
845
+ return make_int3(b * a.x, b * a.y, b * a.z);
846
+ }
847
+ inline __host__ __device__ void operator*=(int3 &a, int b)
848
+ {
849
+ a.x *= b;
850
+ a.y *= b;
851
+ a.z *= b;
852
+ }
853
+
854
+ inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
855
+ {
856
+ return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
857
+ }
858
+ inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
859
+ {
860
+ a.x *= b.x;
861
+ a.y *= b.y;
862
+ a.z *= b.z;
863
+ }
864
+ inline __host__ __device__ uint3 operator*(uint3 a, uint b)
865
+ {
866
+ return make_uint3(a.x * b, a.y * b, a.z * b);
867
+ }
868
+ inline __host__ __device__ uint3 operator*(uint b, uint3 a)
869
+ {
870
+ return make_uint3(b * a.x, b * a.y, b * a.z);
871
+ }
872
+ inline __host__ __device__ void operator*=(uint3 &a, uint b)
873
+ {
874
+ a.x *= b;
875
+ a.y *= b;
876
+ a.z *= b;
877
+ }
878
+
879
+ inline __host__ __device__ float4 operator*(float4 a, float4 b)
880
+ {
881
+ return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
882
+ }
883
+ inline __host__ __device__ void operator*=(float4 &a, float4 b)
884
+ {
885
+ a.x *= b.x;
886
+ a.y *= b.y;
887
+ a.z *= b.z;
888
+ a.w *= b.w;
889
+ }
890
+ inline __host__ __device__ float4 operator*(float4 a, float b)
891
+ {
892
+ return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
893
+ }
894
+ inline __host__ __device__ float4 operator*(float b, float4 a)
895
+ {
896
+ return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
897
+ }
898
+ inline __host__ __device__ void operator*=(float4 &a, float b)
899
+ {
900
+ a.x *= b;
901
+ a.y *= b;
902
+ a.z *= b;
903
+ a.w *= b;
904
+ }
905
+
906
+ inline __host__ __device__ int4 operator*(int4 a, int4 b)
907
+ {
908
+ return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
909
+ }
910
+ inline __host__ __device__ void operator*=(int4 &a, int4 b)
911
+ {
912
+ a.x *= b.x;
913
+ a.y *= b.y;
914
+ a.z *= b.z;
915
+ a.w *= b.w;
916
+ }
917
+ inline __host__ __device__ int4 operator*(int4 a, int b)
918
+ {
919
+ return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
920
+ }
921
+ inline __host__ __device__ int4 operator*(int b, int4 a)
922
+ {
923
+ return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
924
+ }
925
+ inline __host__ __device__ void operator*=(int4 &a, int b)
926
+ {
927
+ a.x *= b;
928
+ a.y *= b;
929
+ a.z *= b;
930
+ a.w *= b;
931
+ }
932
+
933
+ inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
934
+ {
935
+ return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
936
+ }
937
+ inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
938
+ {
939
+ a.x *= b.x;
940
+ a.y *= b.y;
941
+ a.z *= b.z;
942
+ a.w *= b.w;
943
+ }
944
+ inline __host__ __device__ uint4 operator*(uint4 a, uint b)
945
+ {
946
+ return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
947
+ }
948
+ inline __host__ __device__ uint4 operator*(uint b, uint4 a)
949
+ {
950
+ return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
951
+ }
952
+ inline __host__ __device__ void operator*=(uint4 &a, uint b)
953
+ {
954
+ a.x *= b;
955
+ a.y *= b;
956
+ a.z *= b;
957
+ a.w *= b;
958
+ }
959
+
960
+ ////////////////////////////////////////////////////////////////////////////////
961
+ // divide
962
+ ////////////////////////////////////////////////////////////////////////////////
963
+
964
+ inline __host__ __device__ float2 operator/(float2 a, float2 b)
965
+ {
966
+ return make_float2(a.x / b.x, a.y / b.y);
967
+ }
968
+ inline __host__ __device__ void operator/=(float2 &a, float2 b)
969
+ {
970
+ a.x /= b.x;
971
+ a.y /= b.y;
972
+ }
973
+ inline __host__ __device__ float2 operator/(float2 a, float b)
974
+ {
975
+ return make_float2(a.x / b, a.y / b);
976
+ }
977
+ inline __host__ __device__ void operator/=(float2 &a, float b)
978
+ {
979
+ a.x /= b;
980
+ a.y /= b;
981
+ }
982
+ inline __host__ __device__ float2 operator/(float b, float2 a)
983
+ {
984
+ return make_float2(b / a.x, b / a.y);
985
+ }
986
+
987
+ inline __host__ __device__ float3 operator/(float3 a, float3 b)
988
+ {
989
+ return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
990
+ }
991
+ inline __host__ __device__ void operator/=(float3 &a, float3 b)
992
+ {
993
+ a.x /= b.x;
994
+ a.y /= b.y;
995
+ a.z /= b.z;
996
+ }
997
+ inline __host__ __device__ float3 operator/(float3 a, float b)
998
+ {
999
+ return make_float3(a.x / b, a.y / b, a.z / b);
1000
+ }
1001
+ inline __host__ __device__ void operator/=(float3 &a, float b)
1002
+ {
1003
+ a.x /= b;
1004
+ a.y /= b;
1005
+ a.z /= b;
1006
+ }
1007
+ inline __host__ __device__ float3 operator/(float b, float3 a)
1008
+ {
1009
+ return make_float3(b / a.x, b / a.y, b / a.z);
1010
+ }
1011
+
1012
+ inline __host__ __device__ float4 operator/(float4 a, float4 b)
1013
+ {
1014
+ return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
1015
+ }
1016
+ inline __host__ __device__ void operator/=(float4 &a, float4 b)
1017
+ {
1018
+ a.x /= b.x;
1019
+ a.y /= b.y;
1020
+ a.z /= b.z;
1021
+ a.w /= b.w;
1022
+ }
1023
+ inline __host__ __device__ float4 operator/(float4 a, float b)
1024
+ {
1025
+ return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
1026
+ }
1027
+ inline __host__ __device__ void operator/=(float4 &a, float b)
1028
+ {
1029
+ a.x /= b;
1030
+ a.y /= b;
1031
+ a.z /= b;
1032
+ a.w /= b;
1033
+ }
1034
+ inline __host__ __device__ float4 operator/(float b, float4 a)
1035
+ {
1036
+ return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
1037
+ }
1038
+
1039
+ ////////////////////////////////////////////////////////////////////////////////
1040
+ // min
1041
+ ////////////////////////////////////////////////////////////////////////////////
1042
+
1043
+ inline __host__ __device__ float2 fminf(float2 a, float2 b)
1044
+ {
1045
+ return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
1046
+ }
1047
+ inline __host__ __device__ float3 fminf(float3 a, float3 b)
1048
+ {
1049
+ return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
1050
+ }
1051
+ inline __host__ __device__ float4 fminf(float4 a, float4 b)
1052
+ {
1053
+ return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
1054
+ }
1055
+
1056
+ inline __host__ __device__ int2 min(int2 a, int2 b)
1057
+ {
1058
+ return make_int2(min(a.x,b.x), min(a.y,b.y));
1059
+ }
1060
+ inline __host__ __device__ int3 min(int3 a, int3 b)
1061
+ {
1062
+ return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
1063
+ }
1064
+ inline __host__ __device__ int4 min(int4 a, int4 b)
1065
+ {
1066
+ return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
1067
+ }
1068
+
1069
+ inline __host__ __device__ uint2 min(uint2 a, uint2 b)
1070
+ {
1071
+ return make_uint2(min(a.x,b.x), min(a.y,b.y));
1072
+ }
1073
+ inline __host__ __device__ uint3 min(uint3 a, uint3 b)
1074
+ {
1075
+ return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
1076
+ }
1077
+ inline __host__ __device__ uint4 min(uint4 a, uint4 b)
1078
+ {
1079
+ return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
1080
+ }
1081
+
1082
+ ////////////////////////////////////////////////////////////////////////////////
1083
+ // max
1084
+ ////////////////////////////////////////////////////////////////////////////////
1085
+
1086
+ inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
1087
+ {
1088
+ return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
1089
+ }
1090
+ inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
1091
+ {
1092
+ return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
1093
+ }
1094
+ inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
1095
+ {
1096
+ return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
1097
+ }
1098
+
1099
+ inline __host__ __device__ int2 max(int2 a, int2 b)
1100
+ {
1101
+ return make_int2(max(a.x,b.x), max(a.y,b.y));
1102
+ }
1103
+ inline __host__ __device__ int3 max(int3 a, int3 b)
1104
+ {
1105
+ return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
1106
+ }
1107
+ inline __host__ __device__ int4 max(int4 a, int4 b)
1108
+ {
1109
+ return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
1110
+ }
1111
+
1112
+ inline __host__ __device__ uint2 max(uint2 a, uint2 b)
1113
+ {
1114
+ return make_uint2(max(a.x,b.x), max(a.y,b.y));
1115
+ }
1116
+ inline __host__ __device__ uint3 max(uint3 a, uint3 b)
1117
+ {
1118
+ return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
1119
+ }
1120
+ inline __host__ __device__ uint4 max(uint4 a, uint4 b)
1121
+ {
1122
+ return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
1123
+ }
1124
+
1125
+ ////////////////////////////////////////////////////////////////////////////////
1126
+ // lerp
1127
+ // - linear interpolation between a and b, based on value t in [0, 1] range
1128
+ ////////////////////////////////////////////////////////////////////////////////
1129
+
1130
+ inline __device__ __host__ float lerp(float a, float b, float t)
1131
+ {
1132
+ return a + t*(b-a);
1133
+ }
1134
+ inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
1135
+ {
1136
+ return a + t*(b-a);
1137
+ }
1138
+ inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
1139
+ {
1140
+ return a + t*(b-a);
1141
+ }
1142
+ inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
1143
+ {
1144
+ return a + t*(b-a);
1145
+ }
1146
+
1147
+ ////////////////////////////////////////////////////////////////////////////////
1148
+ // clamp
1149
+ // - clamp the value v to be in the range [a, b]
1150
+ ////////////////////////////////////////////////////////////////////////////////
1151
+
1152
+ inline __device__ __host__ float clamp(float f, float a, float b)
1153
+ {
1154
+ return fmaxf(a, fminf(f, b));
1155
+ }
1156
+ inline __device__ __host__ int clamp(int f, int a, int b)
1157
+ {
1158
+ return max(a, min(f, b));
1159
+ }
1160
+ inline __device__ __host__ uint clamp(uint f, uint a, uint b)
1161
+ {
1162
+ return max(a, min(f, b));
1163
+ }
1164
+
1165
+ inline __device__ __host__ float2 clamp(float2 v, float a, float b)
1166
+ {
1167
+ return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
1168
+ }
1169
+ inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
1170
+ {
1171
+ return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1172
+ }
1173
+ inline __device__ __host__ float3 clamp(float3 v, float a, float b)
1174
+ {
1175
+ return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1176
+ }
1177
+ inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
1178
+ {
1179
+ return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1180
+ }
1181
+ inline __device__ __host__ float4 clamp(float4 v, float a, float b)
1182
+ {
1183
+ return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1184
+ }
1185
+ inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
1186
+ {
1187
+ return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1188
+ }
1189
+
1190
+ inline __device__ __host__ int2 clamp(int2 v, int a, int b)
1191
+ {
1192
+ return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
1193
+ }
1194
+ inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
1195
+ {
1196
+ return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1197
+ }
1198
+ inline __device__ __host__ int3 clamp(int3 v, int a, int b)
1199
+ {
1200
+ return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1201
+ }
1202
+ inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
1203
+ {
1204
+ return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1205
+ }
1206
+ inline __device__ __host__ int4 clamp(int4 v, int a, int b)
1207
+ {
1208
+ return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1209
+ }
1210
+ inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
1211
+ {
1212
+ return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1213
+ }
1214
+
1215
+ inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
1216
+ {
1217
+ return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
1218
+ }
1219
+ inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
1220
+ {
1221
+ return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
1222
+ }
1223
+ inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
1224
+ {
1225
+ return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
1226
+ }
1227
+ inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
1228
+ {
1229
+ return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
1230
+ }
1231
+ inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
1232
+ {
1233
+ return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
1234
+ }
1235
+ inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
1236
+ {
1237
+ return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
1238
+ }
1239
+
1240
+ ////////////////////////////////////////////////////////////////////////////////
1241
+ // dot product
1242
+ ////////////////////////////////////////////////////////////////////////////////
1243
+
1244
+ inline __host__ __device__ float dot(float2 a, float2 b)
1245
+ {
1246
+ return a.x * b.x + a.y * b.y;
1247
+ }
1248
+ inline __host__ __device__ float dot(float3 a, float3 b)
1249
+ {
1250
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1251
+ }
1252
+ inline __host__ __device__ float dot(float4 a, float4 b)
1253
+ {
1254
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1255
+ }
1256
+
1257
+ inline __host__ __device__ int dot(int2 a, int2 b)
1258
+ {
1259
+ return a.x * b.x + a.y * b.y;
1260
+ }
1261
+ inline __host__ __device__ int dot(int3 a, int3 b)
1262
+ {
1263
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1264
+ }
1265
+ inline __host__ __device__ int dot(int4 a, int4 b)
1266
+ {
1267
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1268
+ }
1269
+
1270
+ inline __host__ __device__ uint dot(uint2 a, uint2 b)
1271
+ {
1272
+ return a.x * b.x + a.y * b.y;
1273
+ }
1274
+ inline __host__ __device__ uint dot(uint3 a, uint3 b)
1275
+ {
1276
+ return a.x * b.x + a.y * b.y + a.z * b.z;
1277
+ }
1278
+ inline __host__ __device__ uint dot(uint4 a, uint4 b)
1279
+ {
1280
+ return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
1281
+ }
1282
+
1283
+ ////////////////////////////////////////////////////////////////////////////////
1284
+ // length
1285
+ ////////////////////////////////////////////////////////////////////////////////
1286
+
1287
+ inline __host__ __device__ float length(float2 v)
1288
+ {
1289
+ return sqrtf(dot(v, v));
1290
+ }
1291
+ inline __host__ __device__ float length(float3 v)
1292
+ {
1293
+ return sqrtf(dot(v, v));
1294
+ }
1295
+ inline __host__ __device__ float length(float4 v)
1296
+ {
1297
+ return sqrtf(dot(v, v));
1298
+ }
1299
+
1300
+ ////////////////////////////////////////////////////////////////////////////////
1301
+ // normalize
1302
+ ////////////////////////////////////////////////////////////////////////////////
1303
+
1304
+ inline __host__ __device__ float2 normalize(float2 v)
1305
+ {
1306
+ float invLen = rsqrtf(dot(v, v));
1307
+ return v * invLen;
1308
+ }
1309
+ inline __host__ __device__ float3 normalize(float3 v)
1310
+ {
1311
+ float invLen = rsqrtf(dot(v, v));
1312
+ return v * invLen;
1313
+ }
1314
+ inline __host__ __device__ float4 normalize(float4 v)
1315
+ {
1316
+ float invLen = rsqrtf(dot(v, v));
1317
+ return v * invLen;
1318
+ }
1319
+
1320
+ ////////////////////////////////////////////////////////////////////////////////
1321
+ // floor
1322
+ ////////////////////////////////////////////////////////////////////////////////
1323
+
1324
+ inline __host__ __device__ float2 floorf(float2 v)
1325
+ {
1326
+ return make_float2(floorf(v.x), floorf(v.y));
1327
+ }
1328
+ inline __host__ __device__ float3 floorf(float3 v)
1329
+ {
1330
+ return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
1331
+ }
1332
+ inline __host__ __device__ float4 floorf(float4 v)
1333
+ {
1334
+ return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
1335
+ }
1336
+
1337
+ ////////////////////////////////////////////////////////////////////////////////
1338
+ // frac - returns the fractional portion of a scalar or each vector component
1339
+ ////////////////////////////////////////////////////////////////////////////////
1340
+
1341
+ inline __host__ __device__ float fracf(float v)
1342
+ {
1343
+ return v - floorf(v);
1344
+ }
1345
+ inline __host__ __device__ float2 fracf(float2 v)
1346
+ {
1347
+ return make_float2(fracf(v.x), fracf(v.y));
1348
+ }
1349
+ inline __host__ __device__ float3 fracf(float3 v)
1350
+ {
1351
+ return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
1352
+ }
1353
+ inline __host__ __device__ float4 fracf(float4 v)
1354
+ {
1355
+ return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
1356
+ }
1357
+
1358
+ ////////////////////////////////////////////////////////////////////////////////
1359
+ // fmod
1360
+ ////////////////////////////////////////////////////////////////////////////////
1361
+
1362
+ inline __host__ __device__ float2 fmodf(float2 a, float2 b)
1363
+ {
1364
+ return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
1365
+ }
1366
+ inline __host__ __device__ float3 fmodf(float3 a, float3 b)
1367
+ {
1368
+ return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
1369
+ }
1370
+ inline __host__ __device__ float4 fmodf(float4 a, float4 b)
1371
+ {
1372
+ return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
1373
+ }
1374
+
1375
+ ////////////////////////////////////////////////////////////////////////////////
1376
+ // absolute value
1377
+ ////////////////////////////////////////////////////////////////////////////////
1378
+
1379
+ inline __host__ __device__ float2 fabs(float2 v)
1380
+ {
1381
+ return make_float2(fabs(v.x), fabs(v.y));
1382
+ }
1383
+ inline __host__ __device__ float3 fabs(float3 v)
1384
+ {
1385
+ return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
1386
+ }
1387
+ inline __host__ __device__ float4 fabs(float4 v)
1388
+ {
1389
+ return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
1390
+ }
1391
+
1392
+ inline __host__ __device__ int2 abs(int2 v)
1393
+ {
1394
+ return make_int2(abs(v.x), abs(v.y));
1395
+ }
1396
+ inline __host__ __device__ int3 abs(int3 v)
1397
+ {
1398
+ return make_int3(abs(v.x), abs(v.y), abs(v.z));
1399
+ }
1400
+ inline __host__ __device__ int4 abs(int4 v)
1401
+ {
1402
+ return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
1403
+ }
1404
+
1405
+ ////////////////////////////////////////////////////////////////////////////////
1406
+ // reflect
1407
+ // - returns reflection of incident ray I around surface normal N
1408
+ // - N should be normalized, reflected vector's length is equal to length of I
1409
+ ////////////////////////////////////////////////////////////////////////////////
1410
+
1411
+ inline __host__ __device__ float3 reflect(float3 i, float3 n)
1412
+ {
1413
+ return i - 2.0f * n * dot(n,i);
1414
+ }
1415
+
1416
+ ////////////////////////////////////////////////////////////////////////////////
1417
+ // cross product
1418
+ ////////////////////////////////////////////////////////////////////////////////
1419
+
1420
+ inline __host__ __device__ float3 cross(float3 a, float3 b)
1421
+ {
1422
+ return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
1423
+ }
1424
+
1425
+ ////////////////////////////////////////////////////////////////////////////////
1426
+ // smoothstep
1427
+ // - returns 0 if x < a
1428
+ // - returns 1 if x > b
1429
+ // - otherwise returns smooth interpolation between 0 and 1 based on x
1430
+ ////////////////////////////////////////////////////////////////////////////////
1431
+
1432
+ inline __device__ __host__ float smoothstep(float a, float b, float x)
1433
+ {
1434
+ float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1435
+ return (y*y*(3.0f - (2.0f*y)));
1436
+ }
1437
+ inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
1438
+ {
1439
+ float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1440
+ return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
1441
+ }
1442
+ inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
1443
+ {
1444
+ float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1445
+ return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
1446
+ }
1447
+ inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
1448
+ {
1449
+ float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
1450
+ return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
1451
+ }
1452
+
1453
+ #endif
dva/mvp/extensions/utils/makefile ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ all:
2
+ python setup.py build_ext --inplace
dva/mvp/extensions/utils/setup.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from setuptools import setup
8
+
9
+ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
10
+
11
+ if __name__ == "__main__":
12
+ import torch
13
+ setup(
14
+ name="utils",
15
+ ext_modules=[
16
+ CUDAExtension(
17
+ "utilslib",
18
+ sources=["utils.cpp", "utils_kernel.cu"],
19
+ extra_compile_args={
20
+ "nvcc": [
21
+ "-arch=sm_70",
22
+ "-std=c++14",
23
+ "-lineinfo",
24
+ ]
25
+ }
26
+ )
27
+ ],
28
+ cmdclass={"build_ext": BuildExtension}
29
+ )
dva/mvp/extensions/utils/utils.cpp ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #include <torch/extension.h>
8
+ #include <c10/cuda/CUDAStream.h>
9
+
10
+ #include <vector>
11
+
12
+ void compute_raydirs_forward_cuda(
13
+ int N, int H, int W,
14
+ float * viewposim,
15
+ float * viewrotim,
16
+ float * focalim,
17
+ float * princptim,
18
+ float * pixelcoordsim,
19
+ float volradius,
20
+ float * raypos,
21
+ float * raydir,
22
+ float * tminmax,
23
+ cudaStream_t stream);
24
+
25
+ void compute_raydirs_backward_cuda(
26
+ int N, int H, int W,
27
+ float * viewposim,
28
+ float * viewrotim,
29
+ float * focalim,
30
+ float * princptim,
31
+ float * pixelcoordsim,
32
+ float volradius,
33
+ float * raypos,
34
+ float * raydir,
35
+ float * tminmax,
36
+ float * grad_viewposim,
37
+ float * grad_viewrotim,
38
+ float * grad_focalim,
39
+ float * grad_princptim,
40
+ cudaStream_t stream);
41
+
42
+ #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
43
+ #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
44
+ #define CHECK_INPUT(x) CHECK_CUDA((x)); CHECK_CONTIGUOUS((x))
45
+
46
+ std::vector<torch::Tensor> compute_raydirs_forward(
47
+ torch::Tensor viewposim,
48
+ torch::Tensor viewrotim,
49
+ torch::Tensor focalim,
50
+ torch::Tensor princptim,
51
+ torch::optional<torch::Tensor> pixelcoordsim,
52
+ int W, int H,
53
+ float volradius,
54
+ torch::Tensor rayposim,
55
+ torch::Tensor raydirim,
56
+ torch::Tensor tminmaxim) {
57
+ CHECK_INPUT(viewposim);
58
+ CHECK_INPUT(viewrotim);
59
+ CHECK_INPUT(focalim);
60
+ CHECK_INPUT(princptim);
61
+ if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); }
62
+ CHECK_INPUT(rayposim);
63
+ CHECK_INPUT(raydirim);
64
+ CHECK_INPUT(tminmaxim);
65
+
66
+ int N = viewposim.size(0);
67
+ assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W));
68
+
69
+ compute_raydirs_forward_cuda(N, H, W,
70
+ reinterpret_cast<float *>(viewposim.data_ptr()),
71
+ reinterpret_cast<float *>(viewrotim.data_ptr()),
72
+ reinterpret_cast<float *>(focalim.data_ptr()),
73
+ reinterpret_cast<float *>(princptim.data_ptr()),
74
+ pixelcoordsim ? reinterpret_cast<float *>(pixelcoordsim->data_ptr()) : nullptr,
75
+ volradius,
76
+ reinterpret_cast<float *>(rayposim.data_ptr()),
77
+ reinterpret_cast<float *>(raydirim.data_ptr()),
78
+ reinterpret_cast<float *>(tminmaxim.data_ptr()),
79
+ 0);
80
+
81
+ return {};
82
+ }
83
+
84
+ std::vector<torch::Tensor> compute_raydirs_backward(
85
+ torch::Tensor viewposim,
86
+ torch::Tensor viewrotim,
87
+ torch::Tensor focalim,
88
+ torch::Tensor princptim,
89
+ torch::optional<torch::Tensor> pixelcoordsim,
90
+ int W, int H,
91
+ float volradius,
92
+ torch::Tensor rayposim,
93
+ torch::Tensor raydirim,
94
+ torch::Tensor tminmaxim,
95
+ torch::Tensor grad_viewpos,
96
+ torch::Tensor grad_viewrot,
97
+ torch::Tensor grad_focal,
98
+ torch::Tensor grad_princpt) {
99
+ CHECK_INPUT(viewposim);
100
+ CHECK_INPUT(viewrotim);
101
+ CHECK_INPUT(focalim);
102
+ CHECK_INPUT(princptim);
103
+ if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); }
104
+ CHECK_INPUT(rayposim);
105
+ CHECK_INPUT(raydirim);
106
+ CHECK_INPUT(tminmaxim);
107
+ CHECK_INPUT(grad_viewpos);
108
+ CHECK_INPUT(grad_viewrot);
109
+ CHECK_INPUT(grad_focal);
110
+ CHECK_INPUT(grad_princpt);
111
+
112
+ int N = viewposim.size(0);
113
+ assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W));
114
+
115
+ compute_raydirs_backward_cuda(N, H, W,
116
+ reinterpret_cast<float *>(viewposim.data_ptr()),
117
+ reinterpret_cast<float *>(viewrotim.data_ptr()),
118
+ reinterpret_cast<float *>(focalim.data_ptr()),
119
+ reinterpret_cast<float *>(princptim.data_ptr()),
120
+ pixelcoordsim ? reinterpret_cast<float *>(pixelcoordsim->data_ptr()) : nullptr,
121
+ volradius,
122
+ reinterpret_cast<float *>(rayposim.data_ptr()),
123
+ reinterpret_cast<float *>(raydirim.data_ptr()),
124
+ reinterpret_cast<float *>(tminmaxim.data_ptr()),
125
+ reinterpret_cast<float *>(grad_viewpos.data_ptr()),
126
+ reinterpret_cast<float *>(grad_viewrot.data_ptr()),
127
+ reinterpret_cast<float *>(grad_focal.data_ptr()),
128
+ reinterpret_cast<float *>(grad_princpt.data_ptr()),
129
+ 0);
130
+
131
+ return {};
132
+ }
133
+
134
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
135
+ m.def("compute_raydirs_forward", &compute_raydirs_forward, "raydirs forward (CUDA)");
136
+ m.def("compute_raydirs_backward", &compute_raydirs_backward, "raydirs backward (CUDA)");
137
+ }
dva/mvp/extensions/utils/utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import time
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.autograd import Function
13
+ import torch.nn.functional as F
14
+
15
+ try:
16
+ from . import utilslib
17
+ except:
18
+ import utilslib
19
+
20
+ class ComputeRaydirs(Function):
21
+ @staticmethod
22
+ def forward(self, viewpos, viewrot, focal, princpt, pixelcoords, volradius):
23
+ for tensor in [viewpos, viewrot, focal, princpt, pixelcoords]:
24
+ assert tensor.is_contiguous()
25
+
26
+ N = viewpos.size(0)
27
+ if isinstance(pixelcoords, tuple):
28
+ W, H = pixelcoords
29
+ pixelcoords = None
30
+ else:
31
+ H = pixelcoords.size(1)
32
+ W = pixelcoords.size(2)
33
+
34
+ raypos = torch.empty((N, H, W, 3), device=viewpos.device)
35
+ raydirs = torch.empty((N, H, W, 3), device=viewpos.device)
36
+ tminmax = torch.empty((N, H, W, 2), device=viewpos.device)
37
+ utilslib.compute_raydirs_forward(viewpos, viewrot, focal, princpt,
38
+ pixelcoords, W, H, volradius, raypos, raydirs, tminmax)
39
+
40
+ return raypos, raydirs, tminmax
41
+
42
+ @staticmethod
43
+ def backward(self, grad_raydirs, grad_tminmax):
44
+ return None, None, None, None, None, None
45
+
46
+ def compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius):
47
+ raypos, raydirs, tminmax = ComputeRaydirs.apply(viewpos, viewrot, focal, princpt, pixelcoords, volradius)
48
+ return raypos, raydirs, tminmax
49
+
50
+ class Rodrigues(nn.Module):
51
+ def __init__(self):
52
+ super(Rodrigues, self).__init__()
53
+
54
+ def forward(self, rvec):
55
+ theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))
56
+ rvec = rvec / theta[:, None]
57
+ costh = torch.cos(theta)
58
+ sinth = torch.sin(theta)
59
+ return torch.stack((
60
+ rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh,
61
+ rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth,
62
+ rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth,
63
+
64
+ rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth,
65
+ rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh,
66
+ rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth,
67
+
68
+ rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth,
69
+ rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth,
70
+ rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3)
71
+
72
+ def gradcheck():
73
+ N = 2
74
+ H = 64
75
+ W = 64
76
+ k3 = 4
77
+ K = k3*k3*k3
78
+
79
+ M = 32
80
+ volradius = 1.
81
+
82
+ # generate random inputs
83
+ torch.manual_seed(1113)
84
+
85
+ rodrigues = Rodrigues()
86
+
87
+ _viewpos = torch.tensor([[-0.0, 0.0, -4.] for n in range(N)], device="cuda") + torch.randn(N, 3, device="cuda") * 0.1
88
+ viewrvec = torch.randn(N, 3, device="cuda") * 0.01
89
+ _viewrot = rodrigues(viewrvec)
90
+
91
+ _focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)], device="cuda")
92
+ _princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)], device="cuda")
93
+ pixely, pixelx = torch.meshgrid(torch.arange(H, device="cuda").float(), torch.arange(W, device="cuda").float())
94
+ _pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)
95
+
96
+ _viewpos = _viewpos.contiguous().detach().clone()
97
+ _viewpos.requires_grad = True
98
+ _viewrot = _viewrot.contiguous().detach().clone()
99
+ _viewrot.requires_grad = True
100
+ _focal = _focal.contiguous().detach().clone()
101
+ _focal.requires_grad = True
102
+ _princpt = _princpt.contiguous().detach().clone()
103
+ _princpt.requires_grad = True
104
+ _pixelcoords = _pixelcoords.contiguous().detach().clone()
105
+ _pixelcoords.requires_grad = True
106
+
107
+ max_len = 6.0
108
+ _stepsize = max_len / 15.5
109
+
110
+ params = [_viewpos, _viewrot, _focal, _princpt]
111
+ paramnames = ["viewpos", "viewrot", "focal", "princpt"]
112
+
113
+ ########################### run pytorch version ###########################
114
+
115
+ viewpos = _viewpos
116
+ viewrot = _viewrot
117
+ focal = _focal
118
+ princpt = _princpt
119
+ pixelcoords = _pixelcoords
120
+
121
+ raypos = viewpos[:, None, None, :].repeat(1, H, W, 1)
122
+
123
+ raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
124
+ raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
125
+ raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2)
126
+ raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))
127
+
128
+ t1 = (-1. - viewpos[:, None, None, :]) / raydir
129
+ t2 = ( 1. - viewpos[:, None, None, :]) / raydir
130
+ tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]),
131
+ torch.max(torch.min(t1[..., 1], t2[..., 1]),
132
+ torch.min(t1[..., 2], t2[..., 2]))).clamp(min=0.)
133
+ tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]),
134
+ torch.min(torch.max(t1[..., 1], t2[..., 1]),
135
+ torch.max(t1[..., 2], t2[..., 2])))
136
+
137
+ tminmax = torch.stack([tmin, tmax], dim=-1)
138
+
139
+ sample0 = raydir
140
+
141
+ torch.cuda.synchronize()
142
+ time1 = time.time()
143
+
144
+ sample0.backward(torch.ones_like(sample0))
145
+
146
+ torch.cuda.synchronize()
147
+ time2 = time.time()
148
+
149
+ grads0 = [p.grad.detach().clone() if p.grad is not None else None for p in params]
150
+
151
+ for p in params:
152
+ if p.grad is not None:
153
+ p.grad.detach_()
154
+ p.grad.zero_()
155
+
156
+ ############################## run cuda version ###########################
157
+
158
+ viewpos = _viewpos
159
+ viewrot = _viewrot
160
+ focal = _focal
161
+ princpt = _princpt
162
+ pixelcoords = _pixelcoords
163
+
164
+ niter = 1
165
+
166
+ for p in params:
167
+ if p.grad is not None:
168
+ p.grad.detach_()
169
+ p.grad.zero_()
170
+ t0 = time.time()
171
+ torch.cuda.synchronize()
172
+
173
+ sample1 = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius)[1]
174
+
175
+ t1 = time.time()
176
+ torch.cuda.synchronize()
177
+
178
+ print("-----------------------------------------------------------------")
179
+ print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "index", "py", "cuda"))
180
+ ind = torch.argmax(torch.abs(sample0 - sample1))
181
+ print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
182
+ "fwd",
183
+ torch.max(torch.abs(sample0 - sample1)).item(),
184
+ (torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
185
+ ind.item(),
186
+ sample0.view(-1)[ind].item(),
187
+ sample1.view(-1)[ind].item()))
188
+
189
+ sample1.backward(torch.ones_like(sample1), retain_graph=True)
190
+
191
+ torch.cuda.synchronize()
192
+ t2 = time.time()
193
+
194
+
195
+ print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
196
+ grads1 = [p.grad.detach().clone() if p.grad is not None else None for p in params]
197
+
198
+ ############# compare results #############
199
+
200
+ for p, g0, g1 in zip(paramnames, grads0, grads1):
201
+ ind = torch.argmax(torch.abs(g0 - g1))
202
+ print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
203
+ p,
204
+ torch.max(torch.abs(g0 - g1)).item(),
205
+ (torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
206
+ ind.item(),
207
+ g0.view(-1)[ind].item(),
208
+ g1.view(-1)[ind].item()))
209
+
210
+ if __name__ == "__main__":
211
+ gradcheck()