kyleleey commited on
Commit
98a77e0
·
1 Parent(s): 9df3c71

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +184 -0
  2. ckpts/configs.yml +354 -0
  3. ckpts/iter0800000.pth +3 -0
  4. video3d/__init__.py +6 -0
  5. video3d/cages/cages.py +218 -0
  6. video3d/cub_dataloaders.py +404 -0
  7. video3d/cub_dataloaders_ddp.py +434 -0
  8. video3d/dataloaders.py +375 -0
  9. video3d/dataloaders_ddp.py +1210 -0
  10. video3d/diffusion/sd.py +252 -0
  11. video3d/diffusion/sd_utils.py +123 -0
  12. video3d/diffusion/vsd.py +323 -0
  13. video3d/discriminator_architecture.py +83 -0
  14. video3d/flow/__init__.py +0 -0
  15. video3d/flow/flow.py +51 -0
  16. video3d/flow/utils.py +23 -0
  17. video3d/geometry/dlmesh.py +85 -0
  18. video3d/geometry/dmtet.py +361 -0
  19. video3d/model.py +1526 -0
  20. video3d/model_ddp.py +0 -0
  21. video3d/networks.py +1724 -0
  22. video3d/render/light.py +191 -0
  23. video3d/render/material.py +282 -0
  24. video3d/render/mesh.py +377 -0
  25. video3d/render/mlptexture.py +122 -0
  26. video3d/render/obj.py +288 -0
  27. video3d/render/regularizer.py +93 -0
  28. video3d/render/render.py +369 -0
  29. video3d/render/renderutils/__init__.py +11 -0
  30. video3d/render/renderutils/bsdf.py +151 -0
  31. video3d/render/renderutils/c_src/bsdf.cu +710 -0
  32. video3d/render/renderutils/c_src/bsdf.h +84 -0
  33. video3d/render/renderutils/c_src/common.cpp +74 -0
  34. video3d/render/renderutils/c_src/common.h +41 -0
  35. video3d/render/renderutils/c_src/cubemap.cu +350 -0
  36. video3d/render/renderutils/c_src/cubemap.h +38 -0
  37. video3d/render/renderutils/c_src/loss.cu +210 -0
  38. video3d/render/renderutils/c_src/loss.h +38 -0
  39. video3d/render/renderutils/c_src/mesh.cu +94 -0
  40. video3d/render/renderutils/c_src/mesh.h +23 -0
  41. video3d/render/renderutils/c_src/normal.cu +182 -0
  42. video3d/render/renderutils/c_src/normal.h +27 -0
  43. video3d/render/renderutils/c_src/tensor.h +92 -0
  44. video3d/render/renderutils/c_src/torch_bindings.cpp +1062 -0
  45. video3d/render/renderutils/c_src/vec3f.h +109 -0
  46. video3d/render/renderutils/c_src/vec4f.h +25 -0
  47. video3d/render/renderutils/loss.py +41 -0
  48. video3d/render/renderutils/ops.py +554 -0
  49. video3d/render/renderutils/tests/test_bsdf.py +296 -0
  50. video3d/render/renderutils/tests/test_cubemap.py +47 -0
.gitignore ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ data
3
+ data/*/
4
+ data/*/*
5
+ !data/preprocessing/
6
+ pretrained/*/
7
+ results
8
+ neural_renderer
9
+ *.zip
10
+ unchanged/
11
+ cvpr23_results/
12
+ # slurm.bash
13
+ results
14
+ results/*/
15
+ results/*
16
+ results/*/*
17
+ results/dor_checkpoints/*
18
+ results/dor_checkpoints/*/*
19
+ results/dor_checkpoints/*/*/*
20
+
21
+
22
+ .vscode
23
+ .vscode/
24
+
25
+ dor_bash_files/
26
+ zzli_bash_files/
27
+ ray_bash_files/
28
+
29
+ config/dor_exp/
30
+ config/zzli_exp/
31
+ config/ray_exp/
32
+
33
+ wandb
34
+ wandb/*/
35
+ wandb/*/*
36
+ wandb/*/*/*
37
+ canon/out/*
38
+ canon/out/
39
+ # Byte-compiled / optimized / DLL files
40
+ __pycache__/
41
+ *.py[cod]
42
+ *$py.class
43
+
44
+ # C extensions
45
+ *.so
46
+
47
+ # Distribution / packaging
48
+ .Python
49
+ build/
50
+ develop-eggs/
51
+ dist/
52
+ downloads/
53
+ eggs/
54
+ .eggs/
55
+ lib/
56
+ lib64/
57
+ parts/
58
+ sdist/
59
+ var/
60
+ wheels/
61
+ pip-wheel-metadata/
62
+ share/python-wheels/
63
+ *.egg-info/
64
+ .installed.cfg
65
+ *.egg
66
+ MANIFEST
67
+
68
+ # PyInstaller
69
+ # Usually these files are written by a python script from a template
70
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
71
+ *.manifest
72
+ *.spec
73
+
74
+ # Installer logs
75
+ pip-log.txt
76
+ pip-delete-this-directory.txt
77
+
78
+ # Unit test / coverage reports
79
+ htmlcov/
80
+ .tox/
81
+ .nox/
82
+ .coverage
83
+ .coverage.*
84
+ .cache
85
+ nosetests.xml
86
+ coverage.xml
87
+ *.cover
88
+ *.py,cover
89
+ .hypothesis/
90
+ .pytest_cache/
91
+
92
+ # Translations
93
+ *.mo
94
+ *.pot
95
+
96
+ # Django stuff:
97
+ *.log
98
+ local_settings.py
99
+ db.sqlite3
100
+ db.sqlite3-journal
101
+
102
+ # Flask stuff:
103
+ instance/
104
+ .webassets-cache
105
+
106
+ # Scrapy stuff:
107
+ .scrapy
108
+
109
+ # Sphinx documentation
110
+ docs/_build/
111
+
112
+ # PyBuilder
113
+ target/
114
+
115
+ # Jupyter Notebook
116
+ .ipynb_checkpoints
117
+
118
+ # IPython
119
+ profile_default/
120
+ ipython_config.py
121
+
122
+ # pyenv
123
+ .python-version
124
+
125
+ # pipenv
126
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
127
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
128
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
129
+ # install all needed dependencies.
130
+ #Pipfile.lock
131
+
132
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
133
+ __pypackages__/
134
+
135
+ # Celery stuff
136
+ celerybeat-schedule
137
+ celerybeat.pid
138
+
139
+ # SageMath parsed files
140
+ *.sage.py
141
+
142
+ # Environments
143
+ .env
144
+ .venv
145
+ env/
146
+ venv/
147
+ ENV/
148
+ env.bak/
149
+ venv.bak/
150
+
151
+ # Spyder project settings
152
+ .spyderproject
153
+ .spyproject
154
+
155
+ # Rope project settings
156
+ .ropeproject
157
+
158
+ # mkdocs documentation
159
+ /site
160
+
161
+ # mypy
162
+ .mypy_cache/
163
+ .dmypy.json
164
+ dmypy.json
165
+
166
+ # Pyre type checker
167
+ .pyre/
168
+ /.idea
169
+
170
+ # dependencies
171
+ # nvdiffrast/
172
+ data/preprocessing/videos/RAFT/
173
+ preprocessing_data/RAFT/
174
+ preprocessing_data/RAFT/*
175
+ preprocessing_data/preprocessing/videos/RAFT/
176
+ # debug
177
+
178
+
179
+ DINO_v2_check/out_dor
180
+ DINO_v2_check/out_dor/*
181
+
182
+ eval/*/
183
+ scripts/vis/
184
+ eval/
ckpts/configs.yml ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ amb_diff_max:
2
+ - 1.0
3
+ - 1.0
4
+ amb_diff_min:
5
+ - 0.0
6
+ - 0.5
7
+ arti_reg_loss_epochs:
8
+ - 8
9
+ - 276
10
+ arti_reg_loss_weight: 0.2
11
+ articulation_arch: attention
12
+ articulation_epochs:
13
+ - 2
14
+ - 276
15
+ articulation_feature_mode: sample+global
16
+ articulation_multiplier: 0.1
17
+ attach_legs_to_body_epochs:
18
+ - 8
19
+ - 276
20
+ avg_seqshape_epochs:
21
+ - 0
22
+ - 0
23
+ avg_texture_epochs:
24
+ - 0
25
+ - 0
26
+ background_mode: none
27
+ backward_prior: true
28
+ bank_mean_dist_loss_weight: 0.0
29
+ batch_size: 6
30
+ best_pose_start_iter: 10000
31
+ blur_mask: false
32
+ body_bone_idx_preset:
33
+ 0:
34
+ - 0
35
+ - 0
36
+ - 0
37
+ - 0
38
+ 500000:
39
+ - 0
40
+ - 0
41
+ - 0
42
+ - 0
43
+ body_bones_type: z_minmax_y+
44
+ body_rotate_reg_mode: all-bones
45
+ bone_y_thresh: 0.4
46
+ bsdf: diffuse
47
+ cam_pos_z_offset: 10
48
+ checkpoint_dir: /viscam/u/zzli/workspace/4DAnimalKingdom_dev/results/paper_exp/same_dino_1109/mb_all_data_1k_artiID_r500k
49
+ clip_tex: false
50
+ clip_tex_loss_weight: 0.0
51
+ combine_dataset: true
52
+ config: config/zzli_exp/same_dino_1109/mb_data1k_artiID_r500k.yml
53
+ constrain_legs: false
54
+ crop_fov_approx: 25
55
+ data_loader_mode: n_frame
56
+ dataset: video
57
+ debug_seq: false
58
+ deform_epochs:
59
+ - 0
60
+ - 276
61
+ deformation_reg_loss_weight: 10.0
62
+ device: cuda:0
63
+ diffusion_albedo_ratio: 0.2
64
+ diffusion_angle_front: 60
65
+ diffusion_angle_overhead: 30
66
+ diffusion_append_prompt_directions: true
67
+ diffusion_guidance_scale: 100
68
+ diffusion_light_ambient: 0.5
69
+ diffusion_light_diffuse: 0.8
70
+ diffusion_loss_weight: 0.0001
71
+ diffusion_max_step: 0.6
72
+ diffusion_num_random_cameras: 1
73
+ diffusion_phi_offset: 180
74
+ diffusion_precision: float16
75
+ diffusion_prompt: an elephant
76
+ diffusion_radius_range:
77
+ - 9
78
+ - 11
79
+ diffusion_random_light: true
80
+ diffusion_resolution: 256
81
+ diffusion_shading_ratio: 0.4
82
+ diffusion_theta_range:
83
+ - 0
84
+ - 100
85
+ diffusion_uniform_sphere_rate: 1
86
+ dim_of_classes: 128
87
+ dino_feat_im_loss_weight:
88
+ 0: 10.0
89
+ 300000: 1.0
90
+ dino_feature_dim: 16
91
+ dino_feature_input: false
92
+ dino_feature_recon_dim: 16
93
+ dino_max: 1.0
94
+ dino_min: 0.0
95
+ disable_fewshot: false
96
+ disc_gt: false
97
+ disc_iv: true
98
+ disc_iv_label: Real
99
+ disc_reg_mul: 10.0
100
+ discriminator_loss_weight: 1.0
101
+ dmtet_grid: 256
102
+ dmtet_grid_smaller: 256
103
+ dmtet_grid_smaller_epoch: 1
104
+ embed_concat_pts: true
105
+ embedder_freq_arti: 8
106
+ embedder_freq_deform: 10
107
+ embedder_freq_dino: 8
108
+ embedder_freq_shape: 8
109
+ embedder_freq_tex: 10
110
+ enable_articulation: true
111
+ enable_articulation_bone_threshold: true
112
+ enable_articulation_idadd: true
113
+ enable_deform: true
114
+ enable_disc: true
115
+ enable_encoder: true
116
+ enable_lighting: true
117
+ enable_mask_distribution: true
118
+ enable_memory_bank: true
119
+ enable_pose: true
120
+ enable_prior: true
121
+ enable_sds: false
122
+ encoder_arch: vit
123
+ encoder_frozen: true
124
+ encoder_pretrained: true
125
+ enhance_back_view: true
126
+ enhance_back_view_path: /viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data
127
+ extra_renders:
128
+ instance:
129
+ - geo_normal
130
+ - diffuse
131
+ - gray
132
+ faces_per_pixel: 10
133
+ few_shot_category_num: -1
134
+ few_shot_class_vector_init: copy
135
+ few_shot_data_dir:
136
+ - /viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all
137
+ - /viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered
138
+ few_shot_iteration_save: true
139
+ few_shot_iteration_save_freq: 2000
140
+ few_shot_lr: 0.0001
141
+ few_shot_optimize: exp
142
+ few_shot_optimize_bank: all
143
+ few_shot_original_classes_num: 7
144
+ few_shot_resume: true
145
+ few_shot_test_category_names:
146
+ - caracal
147
+ - impala
148
+ - ox
149
+ - squirrel
150
+ - wolf
151
+ few_shot_test_category_num: 5
152
+ few_shot_val_image_num: 5
153
+ fix_viz_batch: false
154
+ flow_loss_epochs:
155
+ - 0
156
+ - 0
157
+ flow_loss_weight: 0.0
158
+ forbid_leg_rotate: true
159
+ fov_w: 60
160
+ full_size_h: 1080
161
+ full_size_w: 1920
162
+ gamma: 1e-6
163
+ gan_tex: false
164
+ grid_scale: 7
165
+ hidden_size: 256
166
+ in_image_size: 256
167
+ init_sdf: ellipsoid
168
+ is_dry_run: false
169
+ iter_arti_reg_loss_start: 60000
170
+ iter_articulation_start: 20000
171
+ iter_attach_leg_to_body_start: 60000
172
+ iter_deformation_start: 500000
173
+ iter_leg_rotation_start: 300000
174
+ iter_nozeroy_start: 20000
175
+ jitter_grid: 0.05
176
+ kd_max:
177
+ - 1.0
178
+ - 1.0
179
+ - 1.0
180
+ - 1.0
181
+ kd_min:
182
+ - 0.0
183
+ - 0.0
184
+ - 0.0
185
+ - 0.0
186
+ keep_num_checkpoint: 1
187
+ ks_max:
188
+ - 0.0
189
+ - 0.0
190
+ - 0.0
191
+ ks_min:
192
+ - 0.0
193
+ - 0.0
194
+ - 0.0
195
+ latent_dim: 256
196
+ load_dino_cluster: false
197
+ load_dino_feature: true
198
+ log_freq_images: 501
199
+ log_freq_losses: 50
200
+ log_train_images: true
201
+ logit_loss_dino_feat_im_loss_multiplier:
202
+ 0: 50.0
203
+ 300000: 500.0
204
+ logit_loss_weight: 1.0
205
+ lookat_init:
206
+ - 0.0
207
+ - 0.0
208
+ - 0.0
209
+ lookat_zeroy: true
210
+ lr: 6.0e-05
211
+ mask_disc_loss_feat_condition: true
212
+ mask_disc_loss_weight: 0.1
213
+ mask_discriminator_iter:
214
+ - 80000
215
+ - 300000
216
+ mask_distribution_loss_freq: 1
217
+ mask_distribution_loss_weight: 0.0
218
+ mask_distribution_path: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/mask_distribution
219
+ max_arti_angle: 60
220
+ max_trans_xy_range_ratio: 0.5
221
+ max_trans_z_range_ratio: 0.5
222
+ memory_bank_init: copy
223
+ memory_bank_size: 60
224
+ memory_bank_topk: 10
225
+ memory_encoder: DINO
226
+ memory_retrieve: cos-linear
227
+ mesh_edge_length_loss_weight: 0.0
228
+ mesh_normal_consistency_loss_weight: 0.0
229
+ min_seq_len: 1
230
+ nrm_max:
231
+ - 1.0
232
+ - 1.0
233
+ - 1.0
234
+ nrm_min:
235
+ - -1.0
236
+ - -1.0
237
+ - 0.0
238
+ num_body_bones: 8
239
+ num_epochs: 1375
240
+ num_iterations: 10000000
241
+ num_layers_arti: 4
242
+ num_layers_deform: 5
243
+ num_layers_dino: 5
244
+ num_layers_light: 5
245
+ num_layers_tex: 8
246
+ num_leg_bones: 3
247
+ num_legs: 4
248
+ num_sample_frames: 1
249
+ num_workers: 8
250
+ out_image_size: 256
251
+ perturb_articulation_epochs:
252
+ - 0
253
+ - 0
254
+ perturb_normal: false
255
+ perturb_sdf: false
256
+ pose_arch: encoder_dino_patch_key
257
+ pose_entropy_loss_weight: 0.0
258
+ pose_epochs:
259
+ - 0
260
+ - 0
261
+ pose_xflip_recon_epochs:
262
+ - 0
263
+ - 0
264
+ pose_xflip_reg_loss_weight: 0.0
265
+ prior_condition_choice: mod
266
+ prior_lr: 0.0006
267
+ prior_sdf_mode: mlp
268
+ pyplot_metrics: false
269
+ random_flip_train: true
270
+ random_mask_law: random_azimuth
271
+ random_sample_train_frames: false
272
+ random_sample_val_frames: true
273
+ rank: 0
274
+ reg_body_rotate_mult: 0.1
275
+ render_dino_mode: feature_mlp
276
+ renderer_spp: 4
277
+ resume: true
278
+ resume_prior_optim: true
279
+ rgb_loss_weight: 1.0
280
+ rgb_suffix: .png
281
+ root_dir: /viscam/u/zzli
282
+ rot_all_quad_epochs:
283
+ - 0
284
+ - 276
285
+ rot_rand_quad_epochs:
286
+ - 0
287
+ - 0
288
+ rot_rep: quadlookat
289
+ rot_temp_scalar: 1.0
290
+ run_few_shot: true
291
+ run_train: true
292
+ save_checkpoint_freq: 1
293
+ save_result_freq: 501
294
+ sdf_bce_reg_loss_min_weight: 0
295
+ sdf_bce_reg_loss_weight: 0
296
+ sdf_gradient_reg_loss_min_weight: 0.1
297
+ sdf_gradient_reg_loss_weight: 0.1
298
+ sdf_inflate_reg_loss_epochs:
299
+ - 0
300
+ - 0
301
+ sdf_reg_decay_start_iter: 10000
302
+ seed: 0
303
+ seqshape_epochs:
304
+ - 0
305
+ - 0
306
+ shuffle_train_seqs: true
307
+ sigma: 1e-6
308
+ silhouette_dt_loss_weight: 0.0
309
+ silhouette_inv_dt_loss_weight: 50.0
310
+ silhouette_loss_weight: 5.0
311
+ skinning_temperature: 0.05
312
+ skip_beginning: 0
313
+ skip_end: 0
314
+ small_leg_angle: true
315
+ smooth_deformation_loss_weight: 10.0
316
+ static_root_bones: false
317
+ sym_deform: true
318
+ sym_dino: false
319
+ sym_prior_shape: true
320
+ sym_texture: true
321
+ temp_clip_high: 10.0
322
+ temp_clip_low: 1.0
323
+ tex_im_size: 256
324
+ texture_epochs:
325
+ - 0
326
+ - 276
327
+ texture_mode: mlp
328
+ train_data_dir:
329
+ bear: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/bear_comb_dinov2_new/train
330
+ cow: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/cow_comb_dinov2_new/train
331
+ elephant: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/elephant_comb_dinov2_new/train
332
+ giraffe: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/giraffe_comb_dinov2_new/train
333
+ horse: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/horse_comb_dinov2_new/train
334
+ sheep: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/sheep_comb_dinov2_new/train
335
+ zebra: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/zebra_comb_dinov2_new/train
336
+ train_with_cub: false
337
+ use_logger: true
338
+ use_scheduler: false
339
+ use_wandb: false
340
+ val_data_dir:
341
+ bear: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/bear_comb_dinov2_new/val
342
+ cow: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/cow_comb_dinov2_new/val
343
+ elephant: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/elephant_comb_dinov2_new/val
344
+ giraffe: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/giraffe_comb_dinov2_new/val
345
+ horse: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/horse_comb_dinov2_new/val
346
+ sheep: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/sheep_comb_dinov2_new/val
347
+ zebra: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/zebra_comb_dinov2_new/val
348
+ visualize_validation: true
349
+ vit_final_layer_type: conv
350
+ which_vit: dino_vits8
351
+ world_size: 1
352
+ zflip_epochs:
353
+ - 0
354
+ - 0
ckpts/iter0800000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c7b090f1ff3e76e2ba608a25a2bd79af2892d6bb307132c9d038082395c1d57
3
+ size 306560367
video3d/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .utils.misc import setup_runtime
2
+ from .trainer import Trainer
3
+ from .trainer_ddp import TrainerDDP
4
+ from .model import Unsup3D
5
+ from .model_ddp import Unsup3DDDP
6
+ from .trainer_few_shot import Fewshot_Trainer
video3d/cages/cages.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cages code used from https://github.com/yifita/deep_cage
2
+ import torch
3
+ import numpy as np
4
+ import trimesh
5
+
6
+
7
+
8
+ def deform_with_MVC(cage, cage_deformed, cage_face, query, verbose=False):
9
+ """
10
+ cage (B,C,3)
11
+ cage_deformed (B,C,3)
12
+ cage_face (B,F,3) int64
13
+ query (B,Q,3)
14
+ """
15
+ weights, weights_unnormed = mean_value_coordinates_3D(query, cage, cage_face, verbose=True)
16
+ # weights = weights.detach()
17
+ deformed = torch.sum(weights.unsqueeze(-1)*cage_deformed.unsqueeze(1), dim=2)
18
+ if verbose:
19
+ return deformed, weights, weights_unnormed
20
+ return deformed
21
+
22
+
23
+ def loadInitCage(template):
24
+ init_cage_V, init_cage_F = read_trimesh(template)
25
+ init_cage_V = torch.from_numpy(init_cage_V[:,:3].astype(np.float32)).unsqueeze(0)*2.0
26
+ init_cage_F = torch.from_numpy(init_cage_F[:,:3].astype(np.int64)).unsqueeze(0)
27
+ return init_cage_V, init_cage_F
28
+
29
+
30
+ def read_trimesh(path):
31
+ mesh = trimesh.load(path)
32
+ return mesh.vertices, mesh.faces
33
+
34
+
35
+ # util functions from pytorch_points
36
+ PI = 3.1415927
37
+
38
+ def normalize_to_box(input):
39
+ """
40
+ normalize point cloud to unit bounding box
41
+ center = (max - min)/2
42
+ scale = max(abs(x))
43
+ input: pc [N, P, dim] or [P, dim]
44
+ output: pc, centroid, furthest_distance
45
+ """
46
+ if len(input.shape) == 2:
47
+ axis = 0
48
+ P = input.shape[0]
49
+ D = input.shape[1]
50
+ elif len(input.shape) == 3:
51
+ axis = 1
52
+ P = input.shape[1]
53
+ D = input.shape[2]
54
+ if isinstance(input, np.ndarray):
55
+ maxP = np.amax(input, axis=axis, keepdims=True)
56
+ minP = np.amin(input, axis=axis, keepdims=True)
57
+ centroid = (maxP+minP)/2
58
+ input = input - centroid
59
+ furthest_distance = np.amax(np.abs(input), axis=(axis, -1), keepdims=True)
60
+ input = input / furthest_distance
61
+ elif isinstance(input, torch.Tensor):
62
+ maxP = torch.max(input, dim=axis, keepdim=True)[0]
63
+ minP = torch.min(input, dim=axis, keepdim=True)[0]
64
+ centroid = (maxP+minP)/2
65
+ input = input - centroid
66
+ in_shape = list(input.shape[:axis])+[P*D]
67
+ furthest_distance = torch.max(torch.abs(input).view(in_shape), dim=axis, keepdim=True)[0]
68
+ furthest_distance = furthest_distance.unsqueeze(-1)
69
+ input = input / furthest_distance
70
+
71
+ return input, centroid, furthest_distance
72
+
73
+ def normalize(tensor, dim=-1):
74
+ """normalize tensor in specified dimension"""
75
+ return torch.nn.functional.normalize(tensor, p=2, dim=dim, eps=1e-12, out=None)
76
+
77
+
78
+ def check_values(tensor):
79
+ """return true if tensor doesn't contain NaN or Inf"""
80
+ return not (torch.any(torch.isnan(tensor)).item() or torch.any(torch.isinf(tensor)).item())
81
+
82
+
83
+ class ScatterAdd(torch.autograd.Function):
84
+ @staticmethod
85
+ def forward(ctx, src, idx, dim, out_size, fill=0.0):
86
+ out = torch.full(out_size, fill, device=src.device, dtype=src.dtype)
87
+ ctx.save_for_backward(idx)
88
+ out.scatter_add_(dim, idx, src)
89
+ ctx.mark_non_differentiable(idx)
90
+ ctx.dim = dim
91
+ return out
92
+
93
+ @staticmethod
94
+ def backward(ctx, ograd):
95
+ idx, = ctx.saved_tensors
96
+ grad = torch.gather(ograd, ctx.dim, idx)
97
+ return grad, None, None, None, None
98
+
99
+
100
+ _scatter_add = ScatterAdd.apply
101
+
102
+
103
+ def scatter_add(src, idx, dim, out_size=None, fill=0.0):
104
+ if out_size is None:
105
+ out_size = list(src.size())
106
+ dim_size = idx.max().item()+1
107
+ out_size[dim] = dim_size
108
+ return _scatter_add(src, idx, dim, out_size, fill)
109
+
110
+
111
+ def mean_value_coordinates_3D(query, vertices, faces, verbose=False):
112
+ """
113
+ Tao Ju et.al. MVC for 3D triangle meshes
114
+ params:
115
+ query (B,P,3)
116
+ vertices (B,N,3)
117
+ faces (B,F,3)
118
+ return:
119
+ wj (B,P,N)
120
+ """
121
+ B, F, _ = faces.shape
122
+ _, P, _ = query.shape
123
+ _, N, _ = vertices.shape
124
+ # u_i = p_i - x (B,P,N,3)
125
+ uj = vertices.unsqueeze(1) - query.unsqueeze(2)
126
+ # \|u_i\| (B,P,N,1)
127
+ dj = torch.norm(uj, dim=-1, p=2, keepdim=True)
128
+ uj = normalize(uj, dim=-1)
129
+ # gather triangle B,P,F,3,3
130
+ ui = torch.gather(uj.unsqueeze(2).expand(-1,-1,F,-1,-1),
131
+ 3,
132
+ faces.unsqueeze(1).unsqueeze(-1).expand(-1,P,-1,-1,3))
133
+ # li = \|u_{i+1}-u_{i-1}\| (B,P,F,3)
134
+ li = torch.norm(ui[:,:,:,[1, 2, 0],:] - ui[:, :, :,[2, 0, 1],:], dim=-1, p=2)
135
+ eps = 2e-5
136
+ li = torch.where(li>=2, li-(li.detach()-(2-eps)), li)
137
+ li = torch.where(li<=-2, li-(li.detach()+(2-eps)), li)
138
+ # asin(x) is inf at +/-1
139
+ # θi = 2arcsin[li/2] (B,P,F,3)
140
+ theta_i = 2*torch.asin(li/2)
141
+ assert(check_values(theta_i))
142
+ # B,P,F,1
143
+ h = torch.sum(theta_i, dim=-1, keepdim=True)/2
144
+ # wi← sin[θi]d{i−1}d{i+1}
145
+ # (B,P,F,3) ci ← (2sin[h]sin[h−θi])/(sin[θ_{i+1}]sin[θ_{i−1}])−1
146
+ ci = 2*torch.sin(h)*torch.sin(h-theta_i)/(torch.sin(theta_i[:,:,:,[1, 2, 0]])*torch.sin(theta_i[:,:,:,[2, 0, 1]]))-1
147
+
148
+ # NOTE: because of floating point ci can be slightly larger than 1, causing problem with sqrt(1-ci^2)
149
+ # NOTE: sqrt(x)' is nan for x=0, hence use eps
150
+ eps = 1e-5
151
+ ci = torch.where(ci>=1, ci-(ci.detach()-(1-eps)), ci)
152
+ ci = torch.where(ci<=-1, ci-(ci.detach()+(1-eps)), ci)
153
+ # si← sign[det[u1,u2,u3]]sqrt(1-ci^2)
154
+ # (B,P,F)*(B,P,F,3)
155
+
156
+ si = torch.sign(torch.det(ui)).unsqueeze(-1)*torch.sqrt(1-ci**2) # sqrt gradient nan for 0
157
+ assert(check_values(si))
158
+ # (B,P,F,3)
159
+ di = torch.gather(dj.unsqueeze(2).squeeze(-1).expand(-1,-1,F,-1), 3,
160
+ faces.unsqueeze(1).expand(-1,P,-1,-1))
161
+ assert(check_values(di))
162
+ # if si.requires_grad:
163
+ # vertices.register_hook(save_grad("mvc/dv"))
164
+ # li.register_hook(save_grad("mvc/dli"))
165
+ # theta_i.register_hook(save_grad("mvc/dtheta"))
166
+ # ci.register_hook(save_grad("mvc/dci"))
167
+ # si.register_hook(save_grad("mvc/dsi"))
168
+ # di.register_hook(save_grad("mvc/ddi"))
169
+
170
+ # wi← (θi −c[i+1]θ[i−1] −c[i−1]θ[i+1])/(disin[θi+1]s[i−1])
171
+ # B,P,F,3
172
+ # CHECK is there a 2* in the denominator
173
+ wi = (theta_i-ci[:,:,:,[1,2,0]]*theta_i[:,:,:,[2,0,1]]-ci[:,:,:,[2,0,1]]*theta_i[:,:,:,[1,2,0]])/(di*torch.sin(theta_i[:,:,:,[1,2,0]])*si[:,:,:,[2,0,1]])
174
+ # if ∃i,|si| ≤ ε, set wi to 0. coplaner with T but outside
175
+ # ignore coplaner outside triangle
176
+ # alternative check
177
+ # (B,F,3,3)
178
+ # triangle_points = torch.gather(vertices.unsqueeze(1).expand(-1,F,-1,-1), 2, faces.unsqueeze(-1).expand(-1,-1,-1,3))
179
+ # # (B,P,F,3), (B,1,F,3) -> (B,P,F,1)
180
+ # determinant = dot_product(triangle_points[:,:,:,0].unsqueeze(1)-query.unsqueeze(2),
181
+ # torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0],
182
+ # triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1).unsqueeze(1), dim=-1, keepdim=True).detach()
183
+ # # (B,P,F,1)
184
+ # sqrdist = determinant*determinant / (4 * sqrNorm(torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0], triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1), keepdim=True))
185
+
186
+ wi = torch.where(torch.any(torch.abs(si) <= 1e-5, keepdim=True, dim=-1), torch.zeros_like(wi), wi)
187
+ # wi = torch.where(sqrdist <= 1e-5, torch.zeros_like(wi), wi)
188
+
189
+ # if π −h < ε, x lies on t, use 2D barycentric coordinates
190
+ # inside triangle
191
+ inside_triangle = (PI-h).squeeze(-1)<1e-4
192
+ # set all F for this P to zero
193
+ wi = torch.where(torch.any(inside_triangle, dim=-1, keepdim=True).unsqueeze(-1), torch.zeros_like(wi), wi)
194
+ # CHECK is it di https://www.cse.wustl.edu/~taoju/research/meanvalue.pdf or li http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.516.1856&rep=rep1&type=pdf
195
+ wi = torch.where(inside_triangle.unsqueeze(-1).expand(-1,-1,-1,wi.shape[-1]), torch.sin(theta_i)*di[:,:,:,[2,0,1]]*di[:,:,:,[1,2,0]], wi)
196
+
197
+ # sum over all faces face -> vertex (B,P,F*3) -> (B,P,N)
198
+ wj = scatter_add(wi.reshape(B,P,-1).contiguous(), faces.unsqueeze(1).expand(-1,P,-1,-1).reshape(B,P,-1), 2, out_size=(B,P,N))
199
+
200
+ # close to vertex (B,P,N)
201
+ close_to_point = dj.squeeze(-1) < 1e-8
202
+ # set all F for this P to zero
203
+ wj = torch.where(torch.any(close_to_point, dim=-1, keepdim=True), torch.zeros_like(wj), wj)
204
+ wj = torch.where(close_to_point, torch.ones_like(wj), wj)
205
+
206
+ # (B,P,1)
207
+ sumWj = torch.sum(wj, dim=-1, keepdim=True)
208
+ sumWj = torch.where(sumWj==0, torch.ones_like(sumWj), sumWj)
209
+
210
+ wj_normalised = wj / sumWj
211
+ # if wj.requires_grad:
212
+ # saved_variables["mvc/wi"] = wi
213
+ # wi.register_hook(save_grad("mvc/dwi"))
214
+ # wj.register_hook(save_grad("mvc/dwj"))
215
+ if verbose:
216
+ return wj_normalised, wi
217
+ else:
218
+ return wj_normalised
video3d/cub_dataloaders.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import scipy.io as sio
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+ from types import SimpleNamespace
9
+
10
+
11
+ def get_cub_loader(data_dir, split='test', is_validation=False, batch_size=256, num_workers=4, image_size=256):
12
+ opts = SimpleNamespace()
13
+ opts.data_dir = data_dir
14
+ opts.padding_frac = 0.05
15
+ opts.jitter_frac = 0.05
16
+ opts.input_size = image_size
17
+ opts.split = split
18
+
19
+ dataset = CUBDataset(opts)
20
+ loader = torch.utils.data.DataLoader(
21
+ dataset,
22
+ batch_size=batch_size,
23
+ shuffle=not is_validation,
24
+ num_workers=num_workers,
25
+ pin_memory=True
26
+ )
27
+ return loader
28
+
29
+
30
+ class CUBDataset(Dataset):
31
+ def __init__(self, opts):
32
+ super().__init__()
33
+
34
+ self.opts = opts
35
+ self.img_size = opts.input_size
36
+ self.jitter_frac = opts.jitter_frac
37
+ self.padding_frac = opts.padding_frac
38
+ self.split = opts.split
39
+ self.data_dir = opts.data_dir
40
+ self.data_cache_dir = osp.join(self.data_dir, 'cachedir/cub')
41
+ self.img_dir = osp.join(self.data_dir, 'images')
42
+
43
+ self.anno_path = osp.join(self.data_cache_dir, 'data', '%s_cub_cleaned.mat' % self.split)
44
+ self.anno_sfm_path = osp.join(self.data_cache_dir, 'sfm', 'anno_%s.mat' % self.split)
45
+
46
+ if not osp.exists(self.anno_path):
47
+ print('%s doesnt exist!' % self.anno_path)
48
+ import pdb; pdb.set_trace()
49
+
50
+ # Load the annotation file.
51
+ print('loading %s' % self.anno_path)
52
+ self.anno = sio.loadmat(
53
+ self.anno_path, struct_as_record=False, squeeze_me=True)['images']
54
+ self.anno_sfm = sio.loadmat(
55
+ self.anno_sfm_path, struct_as_record=False, squeeze_me=True)['sfm_anno']
56
+
57
+ self.kp_perm = np.array([1, 2, 3, 4, 5, 6, 11, 12, 13, 10, 7, 8, 9, 14, 15]) - 1;
58
+
59
+ self.num_imgs = len(self.anno)
60
+ print('%d images' % self.num_imgs)
61
+
62
+ def forward_img(self, index):
63
+ data = self.anno[index]
64
+ data_sfm = self.anno_sfm[0]
65
+
66
+ # sfm_pose = (sfm_c, sfm_t, sfm_r)
67
+ sfm_pose = [np.copy(data_sfm.scale), np.copy(data_sfm.trans), np.copy(data_sfm.rot)]
68
+
69
+ sfm_rot = np.pad(sfm_pose[2], (0,1), 'constant')
70
+ sfm_rot[3, 3] = 1
71
+ sfm_pose[2] = quaternion_from_matrix(sfm_rot, isprecise=True)
72
+
73
+ img_path = osp.join(self.img_dir, str(data.rel_path))
74
+ #img_path = img_path.replace("JPEG", "jpg")
75
+ img = np.array(Image.open(img_path))
76
+
77
+ # Some are grayscale:
78
+ if len(img.shape) == 2:
79
+ img = np.repeat(np.expand_dims(img, 2), 3, axis=2)
80
+ mask = data.mask
81
+ mask = np.expand_dims(mask, 2)
82
+ h,w,_ = mask.shape
83
+
84
+ # Adjust to 0 indexing
85
+ bbox = np.array(
86
+ [data.bbox.x1, data.bbox.y1, data.bbox.x2, data.bbox.y2],
87
+ float) - 1
88
+
89
+ parts = data.parts.T.astype(float)
90
+ kp = np.copy(parts)
91
+ vis = kp[:, 2] > 0
92
+ kp[vis, :2] -= 1
93
+
94
+ # Peturb bbox
95
+ if self.split == 'train':
96
+ bbox = peturb_bbox(
97
+ bbox, pf=self.padding_frac, jf=self.jitter_frac)
98
+ else:
99
+ bbox = peturb_bbox(
100
+ bbox, pf=self.padding_frac, jf=0)
101
+ bbox = square_bbox(bbox)
102
+
103
+ # crop image around bbox, translate kps
104
+ img, mask, kp, sfm_pose = self.crop_image(img, mask, bbox, kp, vis, sfm_pose)
105
+
106
+ # scale image, and mask. And scale kps.
107
+ img, mask, kp, sfm_pose = self.scale_image(img, mask, kp, vis, sfm_pose)
108
+
109
+ # Mirror image on random.
110
+ if self.split == 'train':
111
+ img, mask, kp, sfm_pose = self.mirror_image(img, mask, kp, sfm_pose)
112
+
113
+ # Normalize kp to be [-1, 1]
114
+ img_h, img_w = img.shape[:2]
115
+ kp_norm, sfm_pose = self.normalize_kp(kp, sfm_pose, img_h, img_w)
116
+
117
+ # img = Image.fromarray(np.asarray(img, np.uint8))
118
+ mask = np.asarray(mask, np.float32)
119
+ return img, kp_norm, mask, sfm_pose, img_path
120
+
121
+ def normalize_kp(self, kp, sfm_pose, img_h, img_w):
122
+ vis = kp[:, 2, None] > 0
123
+ new_kp = np.stack([2 * (kp[:, 0] / img_w) - 1,
124
+ 2 * (kp[:, 1] / img_h) - 1,
125
+ kp[:, 2]]).T
126
+ sfm_pose[0] *= (1.0/img_w + 1.0/img_h)
127
+ sfm_pose[1][0] = 2.0 * (sfm_pose[1][0] / img_w) - 1
128
+ sfm_pose[1][1] = 2.0 * (sfm_pose[1][1] / img_h) - 1
129
+ new_kp = vis * new_kp
130
+
131
+ return new_kp, sfm_pose
132
+
133
+ def crop_image(self, img, mask, bbox, kp, vis, sfm_pose):
134
+ # crop image and mask and translate kps
135
+ img = crop(img, bbox, bgval=1)
136
+ mask = crop(mask, bbox, bgval=0)
137
+ kp[vis, 0] -= bbox[0]
138
+ kp[vis, 1] -= bbox[1]
139
+ sfm_pose[1][0] -= bbox[0]
140
+ sfm_pose[1][1] -= bbox[1]
141
+ return img, mask, kp, sfm_pose
142
+
143
+ def scale_image(self, img, mask, kp, vis, sfm_pose):
144
+ # Scale image so largest bbox size is img_size
145
+ bwidth = np.shape(img)[0]
146
+ bheight = np.shape(img)[1]
147
+ scale = self.img_size / float(max(bwidth, bheight))
148
+ img_scale, _ = resize_img(img, scale)
149
+ # if img_scale.shape[0] != self.img_size:
150
+ # print('bad!')
151
+ # import ipdb; ipdb.set_trace()
152
+ # mask_scale, _ = resize_img(mask, scale)
153
+ # mask_scale, _ = resize_img(mask, scale, interpolation=cv2.INTER_NEAREST)
154
+ mask_scale, _ = resize_img(mask, scale)
155
+ kp[vis, :2] *= scale
156
+ sfm_pose[0] *= scale
157
+ sfm_pose[1] *= scale
158
+
159
+ return img_scale, mask_scale, kp, sfm_pose
160
+
161
+ def mirror_image(self, img, mask, kp, sfm_pose):
162
+ kp_perm = self.kp_perm
163
+ if np.random.rand(1) > 0.5:
164
+ # Need copy bc torch collate doesnt like neg strides
165
+ img_flip = img[:, ::-1, :].copy()
166
+ mask_flip = mask[:, ::-1].copy()
167
+
168
+ # Flip kps.
169
+ new_x = img.shape[1] - kp[:, 0] - 1
170
+ kp_flip = np.hstack((new_x[:, None], kp[:, 1:]))
171
+ kp_flip = kp_flip[kp_perm, :]
172
+ # Flip sfm_pose Rot.
173
+ R = quaternion_matrix(sfm_pose[2])
174
+ flip_R = np.diag([-1, 1, 1, 1]).dot(R.dot(np.diag([-1, 1, 1, 1])))
175
+ sfm_pose[2] = quaternion_from_matrix(flip_R, isprecise=True)
176
+ # Flip tx
177
+ tx = img.shape[1] - sfm_pose[1][0] - 1
178
+ sfm_pose[1][0] = tx
179
+ return img_flip, mask_flip, kp_flip, sfm_pose
180
+ else:
181
+ return img, mask, kp, sfm_pose
182
+
183
+ def __len__(self):
184
+ return self.num_imgs
185
+
186
+ def __getitem__(self, index):
187
+ img, kp, mask, sfm_pose, img_path = self.forward_img(index)
188
+ sfm_pose[0].shape = 1
189
+ mask = np.expand_dims(mask, 2)
190
+
191
+ images = torch.FloatTensor(img /255.).permute(2,0,1).unsqueeze(0)
192
+ masks = torch.FloatTensor(mask).permute(2,0,1).repeat(1,3,1,1)
193
+ mask_dt = compute_distance_transform(masks)
194
+ # flows = torch.zeros(1,2, self.img_size, self.img_size)
195
+ flows = torch.zeros(1)
196
+ bboxs = torch.FloatTensor([0, 0, 0, self.img_size, self.img_size, 1, 1, 0]).unsqueeze(0) # frame_id, crop_x0, crop_y0, crop_w, crop_h, resize_sx, resize_sy, sharpness
197
+ bg_image = images[0]
198
+ seq_idx = torch.LongTensor([index])
199
+ frame_idx = torch.LongTensor([0])
200
+ return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx
201
+
202
+
203
+ def compute_distance_transform(mask):
204
+ mask_dt = []
205
+ for m in mask:
206
+ dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
207
+ inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
208
+ mask_dt += [torch.stack([dt, inv_dt], 0)]
209
+ return torch.stack(mask_dt, 0) # Bx2xHxW
210
+
211
+
212
+ def resize_img(img, scale_factor):
213
+ new_size = (np.round(np.array(img.shape[:2]) * scale_factor)).astype(int)
214
+ new_img = cv2.resize(img, (new_size[1], new_size[0]))
215
+ # This is scale factor of [height, width] i.e. [y, x]
216
+ actual_factor = [new_size[0] / float(img.shape[0]),
217
+ new_size[1] / float(img.shape[1])]
218
+ return new_img, actual_factor
219
+
220
+
221
+ def peturb_bbox(bbox, pf=0, jf=0):
222
+ '''
223
+ Jitters and pads the input bbox.
224
+ Args:
225
+ bbox: Zero-indexed tight bbox.
226
+ pf: padding fraction.
227
+ jf: jittering fraction.
228
+ Returns:
229
+ pet_bbox: Jittered and padded box. Might have -ve or out-of-image coordinates
230
+ '''
231
+ pet_bbox = [coord for coord in bbox]
232
+ bwidth = bbox[2] - bbox[0] + 1
233
+ bheight = bbox[3] - bbox[1] + 1
234
+
235
+ pet_bbox[0] -= (pf*bwidth) + (1-2*np.random.random())*jf*bwidth
236
+ pet_bbox[1] -= (pf*bheight) + (1-2*np.random.random())*jf*bheight
237
+ pet_bbox[2] += (pf*bwidth) + (1-2*np.random.random())*jf*bwidth
238
+ pet_bbox[3] += (pf*bheight) + (1-2*np.random.random())*jf*bheight
239
+
240
+ return pet_bbox
241
+
242
+
243
+ def square_bbox(bbox):
244
+ '''
245
+ Converts a bbox to have a square shape by increasing size along non-max dimension.
246
+ '''
247
+ sq_bbox = [int(round(coord)) for coord in bbox]
248
+ bwidth = sq_bbox[2] - sq_bbox[0] + 1
249
+ bheight = sq_bbox[3] - sq_bbox[1] + 1
250
+ maxdim = float(max(bwidth, bheight))
251
+
252
+ dw_b_2 = int(round((maxdim-bwidth)/2.0))
253
+ dh_b_2 = int(round((maxdim-bheight)/2.0))
254
+
255
+ sq_bbox[0] -= dw_b_2
256
+ sq_bbox[1] -= dh_b_2
257
+ sq_bbox[2] = sq_bbox[0] + maxdim - 1
258
+ sq_bbox[3] = sq_bbox[1] + maxdim - 1
259
+
260
+ return sq_bbox
261
+
262
+
263
+ def crop(img, bbox, bgval=0):
264
+ '''
265
+ Crops a region from the image corresponding to the bbox.
266
+ If some regions specified go outside the image boundaries, the pixel values are set to bgval.
267
+ Args:
268
+ img: image to crop
269
+ bbox: bounding box to crop
270
+ bgval: default background for regions outside image
271
+ '''
272
+ bbox = [int(round(c)) for c in bbox]
273
+ bwidth = bbox[2] - bbox[0] + 1
274
+ bheight = bbox[3] - bbox[1] + 1
275
+
276
+ im_shape = np.shape(img)
277
+ im_h, im_w = im_shape[0], im_shape[1]
278
+
279
+ nc = 1 if len(im_shape) < 3 else im_shape[2]
280
+
281
+ img_out = np.ones((bheight, bwidth, nc))*bgval
282
+ x_min_src = max(0, bbox[0])
283
+ x_max_src = min(im_w, bbox[2]+1)
284
+ y_min_src = max(0, bbox[1])
285
+ y_max_src = min(im_h, bbox[3]+1)
286
+
287
+ x_min_trg = x_min_src - bbox[0]
288
+ x_max_trg = x_max_src - x_min_src + x_min_trg
289
+ y_min_trg = y_min_src - bbox[1]
290
+ y_max_trg = y_max_src - y_min_src + y_min_trg
291
+
292
+ img_out[y_min_trg:y_max_trg, x_min_trg:x_max_trg, :] = img[y_min_src:y_max_src, x_min_src:x_max_src, :]
293
+ return img_out
294
+
295
+
296
+ # https://github.com/akanazawa/cmr/blob/master/utils/transformations.py
297
+ import math
298
+ import numpy
299
+ _EPS = numpy.finfo(float).eps * 4.0
300
+
301
+ def quaternion_matrix(quaternion):
302
+ """Return homogeneous rotation matrix from quaternion.
303
+ >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
304
+ >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
305
+ True
306
+ >>> M = quaternion_matrix([1, 0, 0, 0])
307
+ >>> numpy.allclose(M, numpy.identity(4))
308
+ True
309
+ >>> M = quaternion_matrix([0, 1, 0, 0])
310
+ >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
311
+ True
312
+ """
313
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
314
+ n = numpy.dot(q, q)
315
+ if n < _EPS:
316
+ return numpy.identity(4)
317
+ q *= math.sqrt(2.0 / n)
318
+ q = numpy.outer(q, q)
319
+ return numpy.array([
320
+ [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0],
321
+ [ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0],
322
+ [ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0],
323
+ [ 0.0, 0.0, 0.0, 1.0]])
324
+
325
+ def quaternion_from_matrix(matrix, isprecise=False):
326
+ """Return quaternion from rotation matrix.
327
+ If isprecise is True, the input matrix is assumed to be a precise rotation
328
+ matrix and a faster algorithm is used.
329
+ >>> q = quaternion_from_matrix(numpy.identity(4), True)
330
+ >>> numpy.allclose(q, [1, 0, 0, 0])
331
+ True
332
+ >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
333
+ >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
334
+ True
335
+ >>> R = rotation_matrix(0.123, (1, 2, 3))
336
+ >>> q = quaternion_from_matrix(R, True)
337
+ >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
338
+ True
339
+ >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
340
+ ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
341
+ >>> q = quaternion_from_matrix(R)
342
+ >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
343
+ True
344
+ >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
345
+ ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
346
+ >>> q = quaternion_from_matrix(R)
347
+ >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
348
+ True
349
+ >>> R = random_rotation_matrix()
350
+ >>> q = quaternion_from_matrix(R)
351
+ >>> is_same_transform(R, quaternion_matrix(q))
352
+ True
353
+ >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
354
+ ... quaternion_from_matrix(R, isprecise=True))
355
+ True
356
+ >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
357
+ >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
358
+ ... quaternion_from_matrix(R, isprecise=True))
359
+ True
360
+ """
361
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
362
+ if isprecise:
363
+ q = numpy.empty((4, ))
364
+ t = numpy.trace(M)
365
+ if t > M[3, 3]:
366
+ q[0] = t
367
+ q[3] = M[1, 0] - M[0, 1]
368
+ q[2] = M[0, 2] - M[2, 0]
369
+ q[1] = M[2, 1] - M[1, 2]
370
+ else:
371
+ i, j, k = 0, 1, 2
372
+ if M[1, 1] > M[0, 0]:
373
+ i, j, k = 1, 2, 0
374
+ if M[2, 2] > M[i, i]:
375
+ i, j, k = 2, 0, 1
376
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
377
+ q[i] = t
378
+ q[j] = M[i, j] + M[j, i]
379
+ q[k] = M[k, i] + M[i, k]
380
+ q[3] = M[k, j] - M[j, k]
381
+ q = q[[3, 0, 1, 2]]
382
+ q *= 0.5 / math.sqrt(t * M[3, 3])
383
+ else:
384
+ m00 = M[0, 0]
385
+ m01 = M[0, 1]
386
+ m02 = M[0, 2]
387
+ m10 = M[1, 0]
388
+ m11 = M[1, 1]
389
+ m12 = M[1, 2]
390
+ m20 = M[2, 0]
391
+ m21 = M[2, 1]
392
+ m22 = M[2, 2]
393
+ # symmetric matrix K
394
+ K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0],
395
+ [m01+m10, m11-m00-m22, 0.0, 0.0],
396
+ [m02+m20, m12+m21, m22-m00-m11, 0.0],
397
+ [m21-m12, m02-m20, m10-m01, m00+m11+m22]])
398
+ K /= 3.0
399
+ # quaternion is eigenvector of K that corresponds to largest eigenvalue
400
+ w, V = numpy.linalg.eigh(K)
401
+ q = V[[3, 0, 1, 2], numpy.argmax(w)]
402
+ if q[0] < 0.0:
403
+ numpy.negative(q, q)
404
+ return q
video3d/cub_dataloaders_ddp.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import scipy.io as sio
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+ from types import SimpleNamespace
9
+
10
+
11
+ def get_cub_loader(data_dir, split='test', is_validation=False, batch_size=256, num_workers=4, image_size=256):
12
+ opts = SimpleNamespace()
13
+ opts.data_dir = data_dir
14
+ opts.padding_frac = 0.05
15
+ opts.jitter_frac = 0.05
16
+ opts.input_size = image_size
17
+ opts.split = split
18
+
19
+ dataset = CUBDataset(opts)
20
+ loader = torch.utils.data.DataLoader(
21
+ dataset,
22
+ batch_size=batch_size,
23
+ shuffle=not is_validation,
24
+ num_workers=num_workers,
25
+ pin_memory=True
26
+ )
27
+ return loader
28
+
29
+
30
+ def get_cub_loader_ddp(data_dir, world_size, rank, split='test', is_validation=False, batch_size=256, num_workers=4, image_size=256):
31
+ opts = SimpleNamespace()
32
+ opts.data_dir = data_dir
33
+ opts.padding_frac = 0.05
34
+ opts.jitter_frac = 0.05
35
+ opts.input_size = image_size
36
+ opts.split = split
37
+
38
+ dataset = CUBDataset(opts)
39
+
40
+ sampler = torch.utils.data.distributed.DistributedSampler(
41
+ dataset,
42
+ num_replicas=world_size,
43
+ rank=rank,
44
+ )
45
+
46
+ loader = torch.utils.data.DataLoader(
47
+ dataset,
48
+ sampler=sampler,
49
+ batch_size=batch_size,
50
+ shuffle=not is_validation,
51
+ drop_last=True,
52
+ num_workers=num_workers,
53
+ pin_memory=True
54
+ )
55
+ return loader
56
+
57
+
58
+ class CUBDataset(Dataset):
59
+ def __init__(self, opts):
60
+ super().__init__()
61
+
62
+ self.opts = opts
63
+ self.img_size = opts.input_size
64
+ self.jitter_frac = opts.jitter_frac
65
+ self.padding_frac = opts.padding_frac
66
+ self.split = opts.split
67
+ self.data_dir = opts.data_dir
68
+ self.data_cache_dir = osp.join(self.data_dir, 'cachedir/cub')
69
+ self.img_dir = osp.join(self.data_dir, 'images')
70
+
71
+ self.anno_path = osp.join(self.data_cache_dir, 'data', '%s_cub_cleaned.mat' % self.split)
72
+ self.anno_sfm_path = osp.join(self.data_cache_dir, 'sfm', 'anno_%s.mat' % self.split)
73
+
74
+ if not osp.exists(self.anno_path):
75
+ print('%s doesnt exist!' % self.anno_path)
76
+ import pdb; pdb.set_trace()
77
+
78
+ # Load the annotation file.
79
+ print('loading %s' % self.anno_path)
80
+ self.anno = sio.loadmat(
81
+ self.anno_path, struct_as_record=False, squeeze_me=True)['images']
82
+ self.anno_sfm = sio.loadmat(
83
+ self.anno_sfm_path, struct_as_record=False, squeeze_me=True)['sfm_anno']
84
+
85
+ self.kp_perm = np.array([1, 2, 3, 4, 5, 6, 11, 12, 13, 10, 7, 8, 9, 14, 15]) - 1;
86
+
87
+ self.num_imgs = len(self.anno)
88
+ print('%d images' % self.num_imgs)
89
+
90
+ def forward_img(self, index):
91
+ data = self.anno[index]
92
+ data_sfm = self.anno_sfm[0]
93
+
94
+ # sfm_pose = (sfm_c, sfm_t, sfm_r)
95
+ sfm_pose = [np.copy(data_sfm.scale), np.copy(data_sfm.trans), np.copy(data_sfm.rot)]
96
+
97
+ sfm_rot = np.pad(sfm_pose[2], (0,1), 'constant')
98
+ sfm_rot[3, 3] = 1
99
+ sfm_pose[2] = quaternion_from_matrix(sfm_rot, isprecise=True)
100
+
101
+ img_path = osp.join(self.img_dir, str(data.rel_path))
102
+ #img_path = img_path.replace("JPEG", "jpg")
103
+ img = np.array(Image.open(img_path))
104
+
105
+ # Some are grayscale:
106
+ if len(img.shape) == 2:
107
+ img = np.repeat(np.expand_dims(img, 2), 3, axis=2)
108
+ mask = data.mask
109
+ mask = np.expand_dims(mask, 2)
110
+ h,w,_ = mask.shape
111
+
112
+ # Adjust to 0 indexing
113
+ bbox = np.array(
114
+ [data.bbox.x1, data.bbox.y1, data.bbox.x2, data.bbox.y2],
115
+ float) - 1
116
+
117
+ parts = data.parts.T.astype(float)
118
+ kp = np.copy(parts)
119
+ vis = kp[:, 2] > 0
120
+ kp[vis, :2] -= 1
121
+
122
+ # Peturb bbox
123
+ if self.split == 'train':
124
+ bbox = peturb_bbox(
125
+ bbox, pf=self.padding_frac, jf=self.jitter_frac)
126
+ else:
127
+ bbox = peturb_bbox(
128
+ bbox, pf=self.padding_frac, jf=0)
129
+ bbox = square_bbox(bbox)
130
+
131
+ # crop image around bbox, translate kps
132
+ img, mask, kp, sfm_pose = self.crop_image(img, mask, bbox, kp, vis, sfm_pose)
133
+
134
+ # scale image, and mask. And scale kps.
135
+ img, mask, kp, sfm_pose = self.scale_image(img, mask, kp, vis, sfm_pose)
136
+
137
+ # Mirror image on random.
138
+ if self.split == 'train':
139
+ img, mask, kp, sfm_pose = self.mirror_image(img, mask, kp, sfm_pose)
140
+
141
+ # Normalize kp to be [-1, 1]
142
+ img_h, img_w = img.shape[:2]
143
+ kp_norm, sfm_pose = self.normalize_kp(kp, sfm_pose, img_h, img_w)
144
+
145
+ # img = Image.fromarray(np.asarray(img, np.uint8))
146
+ mask = np.asarray(mask, np.float32)
147
+ return img, kp_norm, mask, sfm_pose, img_path
148
+
149
+ def normalize_kp(self, kp, sfm_pose, img_h, img_w):
150
+ vis = kp[:, 2, None] > 0
151
+ new_kp = np.stack([2 * (kp[:, 0] / img_w) - 1,
152
+ 2 * (kp[:, 1] / img_h) - 1,
153
+ kp[:, 2]]).T
154
+ sfm_pose[0] *= (1.0/img_w + 1.0/img_h)
155
+ sfm_pose[1][0] = 2.0 * (sfm_pose[1][0] / img_w) - 1
156
+ sfm_pose[1][1] = 2.0 * (sfm_pose[1][1] / img_h) - 1
157
+ new_kp = vis * new_kp
158
+
159
+ return new_kp, sfm_pose
160
+
161
+ def crop_image(self, img, mask, bbox, kp, vis, sfm_pose):
162
+ # crop image and mask and translate kps
163
+ img = crop(img, bbox, bgval=1)
164
+ mask = crop(mask, bbox, bgval=0)
165
+ kp[vis, 0] -= bbox[0]
166
+ kp[vis, 1] -= bbox[1]
167
+ sfm_pose[1][0] -= bbox[0]
168
+ sfm_pose[1][1] -= bbox[1]
169
+ return img, mask, kp, sfm_pose
170
+
171
+ def scale_image(self, img, mask, kp, vis, sfm_pose):
172
+ # Scale image so largest bbox size is img_size
173
+ bwidth = np.shape(img)[0]
174
+ bheight = np.shape(img)[1]
175
+ scale = self.img_size / float(max(bwidth, bheight))
176
+ img_scale, _ = resize_img(img, scale)
177
+ # if img_scale.shape[0] != self.img_size:
178
+ # print('bad!')
179
+ # import ipdb; ipdb.set_trace()
180
+ # mask_scale, _ = resize_img(mask, scale)
181
+ # mask_scale, _ = resize_img(mask, scale, interpolation=cv2.INTER_NEAREST)
182
+ mask_scale, _ = resize_img(mask, scale)
183
+ kp[vis, :2] *= scale
184
+ sfm_pose[0] *= scale
185
+ sfm_pose[1] *= scale
186
+
187
+ return img_scale, mask_scale, kp, sfm_pose
188
+
189
+ def mirror_image(self, img, mask, kp, sfm_pose):
190
+ kp_perm = self.kp_perm
191
+ if np.random.rand(1) > 0.5:
192
+ # Need copy bc torch collate doesnt like neg strides
193
+ img_flip = img[:, ::-1, :].copy()
194
+ mask_flip = mask[:, ::-1].copy()
195
+
196
+ # Flip kps.
197
+ new_x = img.shape[1] - kp[:, 0] - 1
198
+ kp_flip = np.hstack((new_x[:, None], kp[:, 1:]))
199
+ kp_flip = kp_flip[kp_perm, :]
200
+ # Flip sfm_pose Rot.
201
+ R = quaternion_matrix(sfm_pose[2])
202
+ flip_R = np.diag([-1, 1, 1, 1]).dot(R.dot(np.diag([-1, 1, 1, 1])))
203
+ sfm_pose[2] = quaternion_from_matrix(flip_R, isprecise=True)
204
+ # Flip tx
205
+ tx = img.shape[1] - sfm_pose[1][0] - 1
206
+ sfm_pose[1][0] = tx
207
+ return img_flip, mask_flip, kp_flip, sfm_pose
208
+ else:
209
+ return img, mask, kp, sfm_pose
210
+
211
+ def __len__(self):
212
+ return self.num_imgs
213
+
214
+ def __getitem__(self, index):
215
+ img, kp, mask, sfm_pose, img_path = self.forward_img(index)
216
+ sfm_pose[0].shape = 1
217
+ mask = np.expand_dims(mask, 2)
218
+
219
+ images = torch.FloatTensor(img /255.).permute(2,0,1).unsqueeze(0)
220
+ masks = torch.FloatTensor(mask).permute(2,0,1).repeat(1,3,1,1)
221
+ mask_dt = compute_distance_transform(masks)
222
+ # flows = torch.zeros(1,2, self.img_size, self.img_size)
223
+ flows = torch.zeros(1)
224
+ bboxs = torch.FloatTensor([0, 0, 0, self.img_size, self.img_size, 1, 1, 0]).unsqueeze(0) # frame_id, crop_x0, crop_y0, crop_w, crop_h, resize_sx, resize_sy, sharpness
225
+ bg_image = images[0]
226
+ seq_idx = torch.LongTensor([index])
227
+ frame_idx = torch.LongTensor([0])
228
+ return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx
229
+
230
+
231
+ def compute_distance_transform(mask):
232
+ mask_dt = []
233
+ for m in mask:
234
+ dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
235
+ inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
236
+ mask_dt += [torch.stack([dt, inv_dt], 0)]
237
+ return torch.stack(mask_dt, 0) # Bx2xHxW
238
+
239
+
240
+ def resize_img(img, scale_factor):
241
+ new_size = (np.round(np.array(img.shape[:2]) * scale_factor)).astype(int)
242
+ new_img = cv2.resize(img, (new_size[1], new_size[0]))
243
+ # This is scale factor of [height, width] i.e. [y, x]
244
+ actual_factor = [new_size[0] / float(img.shape[0]),
245
+ new_size[1] / float(img.shape[1])]
246
+ return new_img, actual_factor
247
+
248
+
249
+ def peturb_bbox(bbox, pf=0, jf=0):
250
+ '''
251
+ Jitters and pads the input bbox.
252
+ Args:
253
+ bbox: Zero-indexed tight bbox.
254
+ pf: padding fraction.
255
+ jf: jittering fraction.
256
+ Returns:
257
+ pet_bbox: Jittered and padded box. Might have -ve or out-of-image coordinates
258
+ '''
259
+ pet_bbox = [coord for coord in bbox]
260
+ bwidth = bbox[2] - bbox[0] + 1
261
+ bheight = bbox[3] - bbox[1] + 1
262
+
263
+ pet_bbox[0] -= (pf*bwidth) + (1-2*np.random.random())*jf*bwidth
264
+ pet_bbox[1] -= (pf*bheight) + (1-2*np.random.random())*jf*bheight
265
+ pet_bbox[2] += (pf*bwidth) + (1-2*np.random.random())*jf*bwidth
266
+ pet_bbox[3] += (pf*bheight) + (1-2*np.random.random())*jf*bheight
267
+
268
+ return pet_bbox
269
+
270
+
271
+ def square_bbox(bbox):
272
+ '''
273
+ Converts a bbox to have a square shape by increasing size along non-max dimension.
274
+ '''
275
+ sq_bbox = [int(round(coord)) for coord in bbox]
276
+ bwidth = sq_bbox[2] - sq_bbox[0] + 1
277
+ bheight = sq_bbox[3] - sq_bbox[1] + 1
278
+ maxdim = float(max(bwidth, bheight))
279
+
280
+ dw_b_2 = int(round((maxdim-bwidth)/2.0))
281
+ dh_b_2 = int(round((maxdim-bheight)/2.0))
282
+
283
+ sq_bbox[0] -= dw_b_2
284
+ sq_bbox[1] -= dh_b_2
285
+ sq_bbox[2] = sq_bbox[0] + maxdim - 1
286
+ sq_bbox[3] = sq_bbox[1] + maxdim - 1
287
+
288
+ return sq_bbox
289
+
290
+
291
+ def crop(img, bbox, bgval=0):
292
+ '''
293
+ Crops a region from the image corresponding to the bbox.
294
+ If some regions specified go outside the image boundaries, the pixel values are set to bgval.
295
+ Args:
296
+ img: image to crop
297
+ bbox: bounding box to crop
298
+ bgval: default background for regions outside image
299
+ '''
300
+ bbox = [int(round(c)) for c in bbox]
301
+ bwidth = bbox[2] - bbox[0] + 1
302
+ bheight = bbox[3] - bbox[1] + 1
303
+
304
+ im_shape = np.shape(img)
305
+ im_h, im_w = im_shape[0], im_shape[1]
306
+
307
+ nc = 1 if len(im_shape) < 3 else im_shape[2]
308
+
309
+ img_out = np.ones((bheight, bwidth, nc))*bgval
310
+ x_min_src = max(0, bbox[0])
311
+ x_max_src = min(im_w, bbox[2]+1)
312
+ y_min_src = max(0, bbox[1])
313
+ y_max_src = min(im_h, bbox[3]+1)
314
+
315
+ x_min_trg = x_min_src - bbox[0]
316
+ x_max_trg = x_max_src - x_min_src + x_min_trg
317
+ y_min_trg = y_min_src - bbox[1]
318
+ y_max_trg = y_max_src - y_min_src + y_min_trg
319
+
320
+ img_out[y_min_trg:y_max_trg, x_min_trg:x_max_trg, :] = img[y_min_src:y_max_src, x_min_src:x_max_src, :]
321
+ return img_out
322
+
323
+
324
+ # https://github.com/akanazawa/cmr/blob/master/utils/transformations.py
325
+ import math
326
+ import numpy
327
+ _EPS = numpy.finfo(float).eps * 4.0
328
+
329
+
330
+ def quaternion_matrix(quaternion):
331
+ """Return homogeneous rotation matrix from quaternion.
332
+ >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
333
+ >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
334
+ True
335
+ >>> M = quaternion_matrix([1, 0, 0, 0])
336
+ >>> numpy.allclose(M, numpy.identity(4))
337
+ True
338
+ >>> M = quaternion_matrix([0, 1, 0, 0])
339
+ >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
340
+ True
341
+ """
342
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
343
+ n = numpy.dot(q, q)
344
+ if n < _EPS:
345
+ return numpy.identity(4)
346
+ q *= math.sqrt(2.0 / n)
347
+ q = numpy.outer(q, q)
348
+ return numpy.array([
349
+ [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0],
350
+ [ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0],
351
+ [ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0],
352
+ [ 0.0, 0.0, 0.0, 1.0]])
353
+
354
+
355
+ def quaternion_from_matrix(matrix, isprecise=False):
356
+ """Return quaternion from rotation matrix.
357
+ If isprecise is True, the input matrix is assumed to be a precise rotation
358
+ matrix and a faster algorithm is used.
359
+ >>> q = quaternion_from_matrix(numpy.identity(4), True)
360
+ >>> numpy.allclose(q, [1, 0, 0, 0])
361
+ True
362
+ >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
363
+ >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
364
+ True
365
+ >>> R = rotation_matrix(0.123, (1, 2, 3))
366
+ >>> q = quaternion_from_matrix(R, True)
367
+ >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
368
+ True
369
+ >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
370
+ ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
371
+ >>> q = quaternion_from_matrix(R)
372
+ >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
373
+ True
374
+ >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
375
+ ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
376
+ >>> q = quaternion_from_matrix(R)
377
+ >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
378
+ True
379
+ >>> R = random_rotation_matrix()
380
+ >>> q = quaternion_from_matrix(R)
381
+ >>> is_same_transform(R, quaternion_matrix(q))
382
+ True
383
+ >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
384
+ ... quaternion_from_matrix(R, isprecise=True))
385
+ True
386
+ >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
387
+ >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
388
+ ... quaternion_from_matrix(R, isprecise=True))
389
+ True
390
+ """
391
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
392
+ if isprecise:
393
+ q = numpy.empty((4, ))
394
+ t = numpy.trace(M)
395
+ if t > M[3, 3]:
396
+ q[0] = t
397
+ q[3] = M[1, 0] - M[0, 1]
398
+ q[2] = M[0, 2] - M[2, 0]
399
+ q[1] = M[2, 1] - M[1, 2]
400
+ else:
401
+ i, j, k = 0, 1, 2
402
+ if M[1, 1] > M[0, 0]:
403
+ i, j, k = 1, 2, 0
404
+ if M[2, 2] > M[i, i]:
405
+ i, j, k = 2, 0, 1
406
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
407
+ q[i] = t
408
+ q[j] = M[i, j] + M[j, i]
409
+ q[k] = M[k, i] + M[i, k]
410
+ q[3] = M[k, j] - M[j, k]
411
+ q = q[[3, 0, 1, 2]]
412
+ q *= 0.5 / math.sqrt(t * M[3, 3])
413
+ else:
414
+ m00 = M[0, 0]
415
+ m01 = M[0, 1]
416
+ m02 = M[0, 2]
417
+ m10 = M[1, 0]
418
+ m11 = M[1, 1]
419
+ m12 = M[1, 2]
420
+ m20 = M[2, 0]
421
+ m21 = M[2, 1]
422
+ m22 = M[2, 2]
423
+ # symmetric matrix K
424
+ K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0],
425
+ [m01+m10, m11-m00-m22, 0.0, 0.0],
426
+ [m02+m20, m12+m21, m22-m00-m11, 0.0],
427
+ [m21-m12, m02-m20, m10-m01, m00+m11+m22]])
428
+ K /= 3.0
429
+ # quaternion is eigenvector of K that corresponds to largest eigenvalue
430
+ w, V = numpy.linalg.eigh(K)
431
+ q = V[[3, 0, 1, 2], numpy.argmax(w)]
432
+ if q[0] < 0.0:
433
+ numpy.negative(q, q)
434
+ return q
video3d/dataloaders.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ import torchvision.datasets.folder
10
+ import torchvision.transforms as transforms
11
+ from einops import rearrange
12
+
13
+
14
+ def compute_distance_transform(mask):
15
+ mask_dt = []
16
+ for m in mask:
17
+ dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
18
+ inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
19
+ mask_dt += [torch.stack([dt, inv_dt], 0)]
20
+ return torch.stack(mask_dt, 0) # Bx2xHxW
21
+
22
+
23
+ def crop_image(image, boxs, size):
24
+ crops = []
25
+ for box in boxs:
26
+ crop_x0, crop_y0, crop_w, crop_h = box
27
+ crop = transforms.functional.resized_crop(image, crop_y0, crop_x0, crop_h, crop_w, size)
28
+ crop = transforms.functional.to_tensor(crop)
29
+ crops += [crop]
30
+ return torch.stack(crops, 0)
31
+
32
+
33
+ def box_loader(fpath):
34
+ box = np.loadtxt(fpath, 'str')
35
+ box[0] = box[0].split('_')[0]
36
+ return box.astype(np.float32)
37
+
38
+
39
+ def read_feat_from_img(path, n_channels):
40
+ feat = np.array(Image.open(path))
41
+ return dencode_feat_from_img(feat, n_channels)
42
+
43
+
44
+ def dencode_feat_from_img(img, n_channels):
45
+ n_addon_channels = int(np.ceil(n_channels / 3) * 3) - n_channels
46
+ n_tiles = int((n_channels + n_addon_channels) / 3)
47
+ feat = rearrange(img, 'h (t w) c -> h w (t c)', t=n_tiles, c=3)
48
+ feat = feat[:, :, :-n_addon_channels]
49
+ feat = feat.astype('float32') / 255
50
+ return feat.transpose(2, 0, 1)
51
+
52
+
53
+ def dino_loader(fpath, n_channels):
54
+ dino_map = read_feat_from_img(fpath, n_channels)
55
+ return dino_map
56
+
57
+
58
+ def get_valid_mask(boxs, image_size):
59
+ valid_masks = []
60
+ for box in boxs:
61
+ crop_x0, crop_y0, crop_w, crop_h, full_w, full_h = box[1:7].int().numpy()
62
+ # Discard a small margin near the boundary.
63
+ margin_w = int(crop_w * 0.02)
64
+ margin_h = int(crop_h * 0.02)
65
+ mask_full = torch.ones(full_h-margin_h*2, full_w-margin_w*2)
66
+ mask_full_pad = torch.nn.functional.pad(mask_full, (crop_w+margin_w, crop_w+margin_w, crop_h+margin_h, crop_h+margin_h), mode='constant', value=0.0)
67
+ mask_full_crop = mask_full_pad[crop_y0+crop_h:crop_y0+crop_h*2, crop_x0+crop_w:crop_x0+crop_w*2]
68
+ mask_crop = torch.nn.functional.interpolate(mask_full_crop[None, None, :, :], image_size, mode='nearest')[0,0]
69
+ valid_masks += [mask_crop]
70
+ return torch.stack(valid_masks, 0) # NxHxW
71
+
72
+
73
+ def horizontal_flip_box(box):
74
+ frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = box.unbind(1)
75
+ box[:,1] = full_w - crop_x0 - crop_w # x0
76
+ return box
77
+
78
+
79
+ def horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features=None, dino_clusters=None):
80
+ images = images.flip(3) # NxCxHxW
81
+ masks = masks.flip(3) # NxCxHxW
82
+ mask_dt = mask_dt.flip(3) # NxCxHxW
83
+ mask_valid = mask_valid.flip(2) # NxHxW
84
+ if flows.dim() > 1:
85
+ flows = flows.flip(3) # (N-1)x(x,y)xHxW
86
+ flows[:,0] *= -1 # invert delta x
87
+ bboxs = horizontal_flip_box(bboxs) # NxK
88
+ bg_images = bg_images.flip(3) # NxCxHxW
89
+ if dino_features.dim() > 1:
90
+ dino_features = dino_features.flip(3)
91
+ if dino_clusters.dim() > 1:
92
+ dino_clusters = dino_clusters.flip(3)
93
+ return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters
94
+
95
+
96
+ class BaseSequenceDataset(Dataset):
97
+ def __init__(self, root, skip_beginning=4, skip_end=4, min_seq_len=10, debug_seq=False):
98
+ super().__init__()
99
+
100
+ self.skip_beginning = skip_beginning
101
+ self.skip_end = skip_end
102
+ self.min_seq_len = min_seq_len
103
+ # self.pattern = "{:07d}_{}"
104
+ self.sequences = self._make_sequences(root)
105
+
106
+ if debug_seq:
107
+ # self.sequences = [self.sequences[0][20:160]] * 100
108
+ seq_len = 0
109
+ while seq_len < min_seq_len:
110
+ i = np.random.randint(len(self.sequences))
111
+ rand_seq = self.sequences[i]
112
+ seq_len = len(rand_seq)
113
+ self.sequences = [rand_seq]
114
+
115
+ self.samples = []
116
+
117
+ def _make_sequences(self, path):
118
+ result = []
119
+ for d in sorted(os.scandir(path), key=lambda e: e.name):
120
+ if d.is_dir():
121
+ files = self._parse_folder(d)
122
+ if len(files) >= self.min_seq_len:
123
+ result.append(files)
124
+ return result
125
+
126
+ def _parse_folder(self, path):
127
+ result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0])))
128
+ result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
129
+
130
+ if len(result) <= self.skip_beginning + self.skip_end:
131
+ return []
132
+ if self.skip_end == 0:
133
+ return result[self.skip_beginning:]
134
+ return result[self.skip_beginning:-self.skip_end]
135
+
136
+ def _load_ids(self, path_patterns, loaders, transform=None):
137
+ result = []
138
+ for loader in loaders:
139
+ for p in path_patterns:
140
+ x = loader[1](p.format(loader[0]), *loader[2:])
141
+ if transform:
142
+ x = transform(x)
143
+ result.append(x)
144
+ return tuple(result)
145
+
146
+ def __len__(self):
147
+ return len(self.samples)
148
+
149
+ def __getitem__(self, index):
150
+ raise NotImplemented("This is a base class and should not be used directly")
151
+
152
+
153
+ class NFrameSequenceDataset(BaseSequenceDataset):
154
+ def __init__(self, root, cat_name=None, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, in_image_size=256, out_image_size=256, debug_seq=False, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, **kwargs):
155
+ self.cat_name = cat_name
156
+ self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
157
+ self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
158
+ self.bbox_loaders = [("box.txt", box_loader)]
159
+ super().__init__(root, skip_beginning, skip_end, min_seq_len, debug_seq)
160
+ if num_sample_frames > 1:
161
+ self.flow_loaders = [("flow.png", cv2.imread, cv2.IMREAD_UNCHANGED)]
162
+ else:
163
+ self.flow_loaders = None
164
+
165
+ self.num_sample_frames = num_sample_frames
166
+ self.random_sample = random_sample
167
+ if self.random_sample:
168
+ if shuffle:
169
+ random.shuffle(self.sequences)
170
+ self.samples = self.sequences
171
+ else:
172
+ for i, s in enumerate(self.sequences):
173
+ stride = 1 if dense_sample else self.num_sample_frames
174
+ self.samples += [(i, k) for k in range(0, len(s), stride)]
175
+ if shuffle:
176
+ random.shuffle(self.samples)
177
+
178
+ self.in_image_size = in_image_size
179
+ self.out_image_size = out_image_size
180
+ self.load_background = load_background
181
+ self.color_jitter = color_jitter
182
+ self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
183
+ self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
184
+ if self.flow_loaders is not None:
185
+ self.flow_transform = lambda x: (torch.FloatTensor(x.astype(np.float32)).flip(2)[:,:,:2] / 65535. ) *2 -1
186
+ self.random_flip = random_flip
187
+ self.load_dino_feature = load_dino_feature
188
+ if load_dino_feature:
189
+ self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
190
+ self.load_dino_cluster = load_dino_cluster
191
+ if load_dino_cluster:
192
+ self.dino_cluster_loaders = [("clusters.png", torchvision.datasets.folder.default_loader)]
193
+
194
+ def __getitem__(self, index):
195
+ if self.random_sample:
196
+ seq_idx = index % len(self.sequences)
197
+ seq = self.sequences[seq_idx]
198
+ if len(seq) < self.num_sample_frames:
199
+ start_frame_idx = 0
200
+ else:
201
+ start_frame_idx = np.random.randint(len(seq)-self.num_sample_frames+1)
202
+ paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames]
203
+ else:
204
+ seq_idx, start_frame_idx = self.samples[index % len(self.samples)]
205
+ seq = self.sequences[seq_idx]
206
+ # Handle edge case: when only last frame is left, sample last two frames, except if the sequence only has one frame
207
+ if len(seq) <= start_frame_idx +1:
208
+ start_frame_idx = max(0, start_frame_idx-1)
209
+ paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames]
210
+
211
+ masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
212
+ mask_dt = compute_distance_transform(masks)
213
+ jitter = False
214
+ if self.color_jitter is not None:
215
+ prob, b, h = self.color_jitter
216
+ if np.random.rand() < prob:
217
+ jitter = True
218
+ color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
219
+ image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
220
+ color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
221
+ image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
222
+ if jitter:
223
+ images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
224
+ images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
225
+ images = images_fg * masks + images_bg * (1-masks)
226
+ else:
227
+ images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
228
+ if len(paths) > 1:
229
+ flows = torch.stack(self._load_ids(paths[:-1], self.flow_loaders, transform=self.flow_transform), 0).permute(0,3,1,2) # load flow for first image, (N-1)x(x,y)xHxW, -1~1
230
+ flows = torch.nn.functional.interpolate(flows, size=self.out_image_size, mode="bilinear")
231
+ else:
232
+ flows = torch.zeros(1)
233
+ bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
234
+ mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
235
+ if self.load_background:
236
+ bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
237
+ if jitter:
238
+ bg_image = color_jitter_tsf_bg(bg_image)
239
+ bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
240
+ else:
241
+ bg_images = torch.zeros_like(images)
242
+ if self.load_dino_feature:
243
+ dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
244
+ else:
245
+ dino_features = torch.zeros(1)
246
+ if self.load_dino_cluster:
247
+ dino_clusters = torch.stack(self._load_ids(paths, self.dino_cluster_loaders, transform=transforms.ToTensor()), 0) # BxFx3x55x55
248
+ else:
249
+ dino_clusters = torch.zeros(1)
250
+ seq_idx = torch.LongTensor([seq_idx])
251
+ frame_idx = torch.arange(start_frame_idx, start_frame_idx+len(paths)).long()
252
+
253
+ if self.random_flip and np.random.rand() < 0.5:
254
+ images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
255
+
256
+ ## pad shorter sequence
257
+ if len(paths) < self.num_sample_frames:
258
+ num_pad = self.num_sample_frames - len(paths)
259
+ images = torch.cat([images[:1]] *num_pad + [images], 0)
260
+ masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
261
+ mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
262
+ mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
263
+ if flows.dim() > 1:
264
+ flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
265
+ bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
266
+ bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
267
+ if dino_features.dim() > 1:
268
+ dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
269
+ if dino_clusters.dim() > 1:
270
+ dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
271
+ frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
272
+
273
+ return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name
274
+
275
+
276
+ def get_sequence_loader(data_dir, **kwargs):
277
+ if isinstance(data_dir, dict):
278
+ loaders = []
279
+ for k, v in data_dir.items():
280
+ dataset= NFrameSequenceDataset(v, cat_name=k, **kwargs)
281
+ loader = torch.utils.data.DataLoader(dataset, batch_size=kwargs['batch_size'], shuffle=kwargs['shuffle'], num_workers=kwargs['num_workers'], pin_memory=True)
282
+ loaders += [loader]
283
+ return loaders
284
+ else:
285
+ return [get_sequence_loader_single(data_dir, **kwargs)]
286
+
287
+
288
+ def get_sequence_loader_single(data_dir, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64):
289
+ if mode == 'n_frame':
290
+ dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim)
291
+ else:
292
+ raise NotImplementedError
293
+ loader = torch.utils.data.DataLoader(
294
+ dataset,
295
+ batch_size=batch_size,
296
+ shuffle=not is_validation,
297
+ num_workers=num_workers,
298
+ pin_memory=True
299
+ )
300
+ return loader
301
+
302
+
303
+ class ImageDataset(Dataset):
304
+ def __init__(self, root, is_validation=False, image_size=256, color_jitter=None):
305
+ super().__init__()
306
+ self.image_loader = ("rgb.jpg", torchvision.datasets.folder.default_loader)
307
+ self.mask_loader = ("mask.png", torchvision.datasets.folder.default_loader)
308
+ self.bbox_loader = ("box.txt", np.loadtxt, 'str')
309
+ self.samples = self._parse_folder(root)
310
+ self.image_size = image_size
311
+ self.color_jitter = color_jitter
312
+ self.image_transform = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor()])
313
+ self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
314
+
315
+ def _parse_folder(self, path):
316
+ result = sorted(glob(os.path.join(path, '**/*'+self.image_loader[0]), recursive=True))
317
+ result = [p.replace(self.image_loader[0], '{}') for p in result]
318
+ return result
319
+
320
+ def _load_ids(self, path, loader, transform=None):
321
+ x = loader[1](path.format(loader[0]), *loader[2:])
322
+ if transform:
323
+ x = transform(x)
324
+ return x
325
+
326
+ def __len__(self):
327
+ return len(self.samples)
328
+
329
+ def __getitem__(self, index):
330
+ path = self.samples[index % len(self.samples)]
331
+ masks = self._load_ids(path, self.mask_loader, transform=self.mask_transform).unsqueeze(0)
332
+ mask_dt = compute_distance_transform(masks)
333
+ jitter = False
334
+ if self.color_jitter is not None:
335
+ prob, b, h = self.color_jitter
336
+ if np.random.rand() < prob:
337
+ jitter = True
338
+ color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
339
+ image_transform_fg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_fg, transforms.ToTensor()])
340
+ color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
341
+ image_transform_bg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_bg, transforms.ToTensor()])
342
+ if jitter:
343
+ images_fg = self._load_ids(path, self.image_loader, transform=image_transform_fg).unsqueeze(0)
344
+ images_bg = self._load_ids(path, self.image_loader, transform=image_transform_bg).unsqueeze(0)
345
+ images = images_fg * masks + images_bg * (1-masks)
346
+ else:
347
+ images = self._load_ids(path, self.image_loader, transform=self.image_transform).unsqueeze(0)
348
+ flows = torch.zeros(1)
349
+ bboxs = self._load_ids(path, self.bbox_loader, transform=None)
350
+ bboxs[0] = '0'
351
+ bboxs = torch.FloatTensor(bboxs.astype('float')).unsqueeze(0)
352
+ bg_fpath = os.path.join(os.path.dirname(path), 'background_frame.jpg')
353
+ if os.path.isfile(bg_fpath):
354
+ bg_image = torchvision.datasets.folder.default_loader(bg_fpath)
355
+ if jitter:
356
+ bg_image = color_jitter_tsf_bg(bg_image)
357
+ bg_image = transforms.ToTensor()(bg_image)
358
+ else:
359
+ bg_image = images[0]
360
+ seq_idx = torch.LongTensor([index])
361
+ frame_idx = torch.LongTensor([0])
362
+ return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx
363
+
364
+
365
+ def get_image_loader(data_dir, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None):
366
+ dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter)
367
+
368
+ loader = torch.utils.data.DataLoader(
369
+ dataset,
370
+ batch_size=batch_size,
371
+ shuffle=False,
372
+ num_workers=num_workers,
373
+ pin_memory=True
374
+ )
375
+ return loader
video3d/dataloaders_ddp.py ADDED
@@ -0,0 +1,1210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import itertools
8
+ import torch
9
+ import copy
10
+ from torch.utils.data import Dataset
11
+ import torchvision.datasets.folder
12
+ import torchvision.transforms as transforms
13
+ from einops import rearrange
14
+
15
+
16
+ def compute_distance_transform(mask):
17
+ mask_dt = []
18
+ for m in mask:
19
+ dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
20
+ inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
21
+ mask_dt += [torch.stack([dt, inv_dt], 0)]
22
+ return torch.stack(mask_dt, 0) # Bx2xHxW
23
+
24
+
25
+ def crop_image(image, boxs, size):
26
+ crops = []
27
+ for box in boxs:
28
+ crop_x0, crop_y0, crop_w, crop_h = box
29
+ crop = transforms.functional.resized_crop(image, crop_y0, crop_x0, crop_h, crop_w, size)
30
+ crop = transforms.functional.to_tensor(crop)
31
+ crops += [crop]
32
+ return torch.stack(crops, 0)
33
+
34
+
35
+ def box_loader(fpath):
36
+ box = np.loadtxt(fpath, 'str')
37
+ box[0] = box[0].split('_')[0]
38
+ return box.astype(np.float32)
39
+
40
+
41
+ def read_feat_from_img(path, n_channels):
42
+ feat = np.array(Image.open(path))
43
+ return dencode_feat_from_img(feat, n_channels)
44
+
45
+
46
+ def dencode_feat_from_img(img, n_channels):
47
+ n_addon_channels = int(np.ceil(n_channels / 3) * 3) - n_channels
48
+ n_tiles = int((n_channels + n_addon_channels) / 3)
49
+ feat = rearrange(img, 'h (t w) c -> h w (t c)', t=n_tiles, c=3)
50
+ if n_addon_channels != 0:
51
+ feat = feat[:, :, :-n_addon_channels]
52
+ feat = feat.astype('float32') / 255
53
+ return feat.transpose(2, 0, 1)
54
+
55
+
56
+ def dino_loader(fpath, n_channels):
57
+ dino_map = read_feat_from_img(fpath, n_channels)
58
+ return dino_map
59
+
60
+
61
+ def get_valid_mask(boxs, image_size):
62
+ valid_masks = []
63
+ for box in boxs:
64
+ crop_x0, crop_y0, crop_w, crop_h, full_w, full_h = box[1:7].int().numpy()
65
+ margin_w = int(crop_w * 0.02)
66
+ margin_h = int(crop_h * 0.02)
67
+ mask_full = torch.ones(full_h-margin_h*2, full_w-margin_w*2)
68
+ mask_full_pad = torch.nn.functional.pad(mask_full, (crop_w+margin_w, crop_w+margin_w, crop_h+margin_h, crop_h+margin_h), mode='constant', value=0.0)
69
+ mask_full_crop = mask_full_pad[(crop_y0+crop_h):crop_y0+(crop_h*2), (crop_x0+crop_w):crop_x0+(crop_w*2)]
70
+ mask_crop = torch.nn.functional.interpolate(mask_full_crop[None, None, :, :], image_size, mode='nearest')[0,0]
71
+ valid_masks += [mask_crop]
72
+ return torch.stack(valid_masks, 0) # NxHxW
73
+
74
+
75
+ def horizontal_flip_box(box):
76
+ frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = box.unbind(1)
77
+ box[:,1] = full_w - crop_x0 - crop_w # x0
78
+ return box
79
+
80
+
81
+ def horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features=None, dino_clusters=None):
82
+ images = images.flip(3) # NxCxHxW
83
+ masks = masks.flip(3) # NxCxHxW
84
+ mask_dt = mask_dt.flip(3) # NxCxHxW
85
+ mask_valid = mask_valid.flip(2) # NxHxW
86
+ if flows.dim() > 1:
87
+ flows = flows.flip(3) # (N-1)x(x,y)xHxW
88
+ flows[:,0] *= -1 # invert delta x
89
+ bboxs = horizontal_flip_box(bboxs) # NxK
90
+ bg_images = bg_images.flip(3) # NxCxHxW
91
+ if dino_features.dim() > 1:
92
+ dino_features = dino_features.flip(3)
93
+ if dino_clusters.dim() > 1:
94
+ dino_clusters = dino_clusters.flip(3)
95
+ return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters
96
+
97
+
98
+ def none_to_nan(x):
99
+ return torch.FloatTensor([float('nan')]) if x is None else x
100
+
101
+
102
+ class BaseSequenceDataset(Dataset):
103
+ def __init__(self, root, skip_beginning=4, skip_end=4, min_seq_len=10, debug_seq=False):
104
+ super().__init__()
105
+
106
+ self.skip_beginning = skip_beginning
107
+ self.skip_end = skip_end
108
+ self.min_seq_len = min_seq_len
109
+ # self.pattern = "{:07d}_{}"
110
+ self.sequences = self._make_sequences(root)
111
+
112
+ if debug_seq:
113
+ # self.sequences = [self.sequences[0][20:160]] * 100
114
+ seq_len = 0
115
+ while seq_len < min_seq_len:
116
+ i = np.random.randint(len(self.sequences))
117
+ rand_seq = self.sequences[i]
118
+ seq_len = len(rand_seq)
119
+ self.sequences = [rand_seq]
120
+
121
+ self.samples = []
122
+
123
+ def _make_sequences(self, path):
124
+ result = []
125
+ for d in sorted(os.scandir(path), key=lambda e: e.name):
126
+ if d.is_dir():
127
+ files = self._parse_folder(d)
128
+ if len(files) >= self.min_seq_len:
129
+ result.append(files)
130
+ return result
131
+
132
+ def _parse_folder(self, path):
133
+ result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0])))
134
+ result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
135
+
136
+ if len(result) <= self.skip_beginning + self.skip_end:
137
+ return []
138
+ if self.skip_end == 0:
139
+ return result[self.skip_beginning:]
140
+ return result[self.skip_beginning:-self.skip_end]
141
+
142
+ def _load_ids(self, path_patterns, loaders, transform=None):
143
+ result = []
144
+ for loader in loaders:
145
+ for p in path_patterns:
146
+ x = loader[1](p.format(loader[0]), *loader[2:])
147
+ if transform:
148
+ x = transform(x)
149
+ result.append(x)
150
+ return tuple(result)
151
+
152
+ def __len__(self):
153
+ return len(self.samples)
154
+
155
+ def __getitem__(self, index):
156
+ raise NotImplemented("This is a base class and should not be used directly")
157
+
158
+
159
+ class NFrameSequenceDataset(BaseSequenceDataset):
160
+ def __init__(self, root, cat_name=None, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, in_image_size=256, out_image_size=256, debug_seq=False, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, flow_bool=False, **kwargs):
161
+ self.cat_name = cat_name
162
+ self.flow_bool=flow_bool
163
+
164
+ self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
165
+ self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
166
+ self.bbox_loaders = [("box.txt", box_loader)]
167
+ super().__init__(root, skip_beginning, skip_end, min_seq_len, debug_seq)
168
+ # from IPython import embed; embed()
169
+ if flow_bool and num_sample_frames > 1:
170
+ self.flow_loaders = [("flow.png", cv2.imread, cv2.IMREAD_UNCHANGED)]
171
+ else:
172
+ self.flow_loaders = None
173
+
174
+ self.num_sample_frames = num_sample_frames
175
+ self.random_sample = random_sample
176
+ if self.random_sample:
177
+ if shuffle:
178
+ random.shuffle(self.sequences)
179
+ self.samples = self.sequences
180
+ else:
181
+
182
+ for i, s in enumerate(self.sequences):
183
+ stride = 1 if dense_sample else self.num_sample_frames
184
+ self.samples += [(i, k) for k in range(0, len(s), stride)]
185
+ if shuffle:
186
+ random.shuffle(self.samples)
187
+
188
+ self.in_image_size = in_image_size
189
+ self.out_image_size = out_image_size
190
+ self.load_background = load_background
191
+ self.color_jitter = color_jitter
192
+ self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
193
+ self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
194
+ if self.flow_loaders is not None:
195
+ self.flow_transform = lambda x: (torch.FloatTensor(x.astype(np.float32)).flip(2)[:,:,:2] / 65535. ) *2 -1
196
+ self.random_flip = random_flip
197
+ self.load_dino_feature = load_dino_feature
198
+ if load_dino_feature:
199
+ self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
200
+ self.load_dino_cluster = load_dino_cluster
201
+ if load_dino_cluster:
202
+ self.dino_cluster_loaders = [("clusters.png", torchvision.datasets.folder.default_loader)]
203
+
204
+ def __getitem__(self, index):
205
+ if self.random_sample:
206
+ seq_idx = index % len(self.sequences)
207
+ seq = self.sequences[seq_idx]
208
+ if len(seq) < self.num_sample_frames:
209
+ start_frame_idx = 0
210
+ else:
211
+ start_frame_idx = np.random.randint(len(seq)-self.num_sample_frames+1)
212
+ paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames]
213
+ else:
214
+ seq_idx, start_frame_idx = self.samples[index % len(self.samples)]
215
+ seq = self.sequences[seq_idx]
216
+ # Handle edge case: when only last frame is left, sample last two frames, except if the sequence only has one frame
217
+ if len(seq) <= start_frame_idx +1:
218
+ start_frame_idx = max(0, start_frame_idx-1)
219
+ paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames]
220
+
221
+ masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
222
+ mask_dt = compute_distance_transform(masks)
223
+ jitter = False
224
+ if self.color_jitter is not None:
225
+ prob, b, h = self.color_jitter
226
+ if np.random.rand() < prob:
227
+ jitter = True
228
+ color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
229
+ image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
230
+ color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
231
+ image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
232
+ if jitter:
233
+ images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
234
+ images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
235
+ images = images_fg * masks + images_bg * (1-masks)
236
+ else:
237
+ images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
238
+ if self.flow_bool==True and len(paths) > 1:
239
+ flows = torch.stack(self._load_ids(paths[:-1], self.flow_loaders, transform=self.flow_transform), 0).permute(0,3,1,2) # load flow for first image, (N-1)x(x,y)xHxW, -1~1
240
+ flows = torch.nn.functional.interpolate(flows, size=self.out_image_size, mode="bilinear")
241
+ else:
242
+ flows = torch.zeros(1)
243
+ bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
244
+ mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
245
+ if self.load_background:
246
+ bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
247
+ if jitter:
248
+ bg_image = color_jitter_tsf_bg(bg_image)
249
+ bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
250
+ else:
251
+ bg_images = torch.zeros_like(images)
252
+ if self.load_dino_feature:
253
+ dino_paths = [
254
+ x.replace(
255
+ "/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new",
256
+ "/viscam/projects/articulated/zzli/data_dino_5000/7_cat"
257
+ )
258
+ for x in paths
259
+ ]
260
+ dino_features = torch.stack(self._load_ids(dino_paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0)
261
+ # dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
262
+ else:
263
+ dino_features = torch.zeros(1)
264
+ if self.load_dino_cluster:
265
+ dino_clusters = torch.stack(self._load_ids(paths, self.dino_cluster_loaders, transform=transforms.ToTensor()), 0) # BxFx3x55x55
266
+ else:
267
+ dino_clusters = torch.zeros(1)
268
+ seq_idx = torch.LongTensor([seq_idx])
269
+ frame_idx = torch.arange(start_frame_idx, start_frame_idx+len(paths)).long()
270
+
271
+ if self.random_flip and np.random.rand() < 0.5:
272
+ images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
273
+
274
+ ## pad shorter sequence
275
+ if len(paths) < self.num_sample_frames:
276
+ num_pad = self.num_sample_frames - len(paths)
277
+ images = torch.cat([images[:1]] *num_pad + [images], 0)
278
+ masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
279
+ mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
280
+ mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
281
+ if flows.dim() > 1:
282
+ flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
283
+ bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
284
+ bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
285
+ if dino_features.dim() > 1:
286
+ dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
287
+ if dino_clusters.dim() > 1:
288
+ dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
289
+ frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
290
+
291
+ out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name)), )
292
+ return out
293
+ # return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name
294
+
295
+
296
+ def few_shot_box_loader(fpath):
297
+ box = np.loadtxt(fpath, 'str')
298
+ # box[0] = box[0].split('_')[0]
299
+ return box.astype(np.float32)
300
+
301
+
302
+ class FewShotImageDataset(Dataset):
303
+ def __init__(self, root, cat_name=None, cat_num=0, num_sample_frames=2, in_image_size=256, out_image_size=256, shuffle=False, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, flow_bool=False, **kwargs):
304
+ super().__init__()
305
+ self.cat_name = cat_name
306
+ self.cat_num = cat_num # this is actually useless
307
+ self.flow_bool=flow_bool
308
+
309
+ self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
310
+ self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
311
+ self.bbox_loaders = [("box.txt", few_shot_box_loader)]
312
+ self.flow_loaders = None
313
+
314
+ # get all the valid paths, since it's just image-wise, in get_item, we will make it like a len=1 sequence
315
+ result = sorted(glob(os.path.join(root, '*'+self.image_loaders[0][0])))
316
+ result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
317
+ self.sequences = result
318
+
319
+ self.num_sample_frames = num_sample_frames
320
+ if shuffle:
321
+ random.shuffle(self.sequences)
322
+ self.samples = self.sequences
323
+
324
+ self.in_image_size = in_image_size
325
+ self.out_image_size = out_image_size
326
+ self.load_background = load_background
327
+ self.color_jitter = color_jitter
328
+ self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
329
+ self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
330
+ self.random_flip = random_flip
331
+ self.load_dino_feature = load_dino_feature
332
+ if load_dino_feature:
333
+ self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
334
+
335
+ def _load_ids(self, path_patterns, loaders, transform=None):
336
+ result = []
337
+ for loader in loaders:
338
+ for p in path_patterns:
339
+ x = loader[1](p.format(loader[0]), *loader[2:])
340
+ if transform:
341
+ x = transform(x)
342
+ result.append(x)
343
+ return tuple(result)
344
+
345
+ def __len__(self):
346
+ return len(self.samples)
347
+
348
+ def __getitem__(self, index):
349
+ paths = [self.samples[index]] # len 1 sequence
350
+
351
+ masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
352
+ mask_dt = compute_distance_transform(masks)
353
+ jitter = False
354
+ if self.color_jitter is not None:
355
+ prob, b, h = self.color_jitter
356
+ if np.random.rand() < prob:
357
+ jitter = True
358
+ color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
359
+ image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
360
+ color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
361
+ image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
362
+ if jitter:
363
+ images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
364
+ images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
365
+ images = images_fg * masks + images_bg * (1-masks)
366
+ else:
367
+ images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
368
+
369
+ flows = torch.zeros(1)
370
+ bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
371
+ bboxs=torch.cat([bboxs, torch.Tensor([[self.cat_num]]).float()],dim=-1) # pad a label number
372
+
373
+ mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
374
+ if self.load_background:
375
+ bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
376
+ if jitter:
377
+ bg_image = color_jitter_tsf_bg(bg_image)
378
+ bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
379
+ else:
380
+ bg_images = torch.zeros_like(images)
381
+ if self.load_dino_feature:
382
+ dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
383
+ else:
384
+ dino_features = torch.zeros(1)
385
+
386
+ dino_clusters = torch.zeros(1)
387
+
388
+ # These are actually no use
389
+ seq_idx = 0
390
+ seq_idx = torch.LongTensor([seq_idx])
391
+ frame_idx = torch.arange(0, 1).long()
392
+
393
+ if self.random_flip and np.random.rand() < 0.5:
394
+ images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
395
+
396
+ ## pad shorter sequence
397
+ if len(paths) < self.num_sample_frames:
398
+ num_pad = self.num_sample_frames - len(paths)
399
+ images = torch.cat([images[:1]] *num_pad + [images], 0)
400
+ masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
401
+ mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
402
+ mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
403
+ if flows.dim() > 1:
404
+ flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
405
+ bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
406
+ bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
407
+ if dino_features.dim() > 1:
408
+ dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
409
+ if dino_clusters.dim() > 1:
410
+ dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
411
+ frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
412
+
413
+ out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name)), )
414
+ return out
415
+ # return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name
416
+
417
+
418
+ class Quadrupeds_Image_Dataset(Dataset):
419
+ def __init__(self, original_data_dirs, few_shot_data_dirs, original_num=7, few_shot_num=93, num_sample_frames=2,
420
+ in_image_size=256, out_image_size=256, is_validation=False, val_image_num=5, shuffle=False, color_jitter=None,
421
+ load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64,
422
+ flow_bool=False, disable_fewshot=False, dataset_split_num=-1, **kwargs):
423
+ self.original_data_dirs = original_data_dirs
424
+ self.few_shot_data_dirs = few_shot_data_dirs
425
+ self.original_num = original_num
426
+ self.few_shot_num = few_shot_num
427
+
428
+ self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
429
+ self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
430
+ self.original_bbox_loaders = [("box.txt", box_loader)]
431
+ self.few_shot_bbox_loaders = [("box.txt", few_shot_box_loader)]
432
+
433
+ assert len(self.original_data_dirs.keys()) == self.original_num
434
+ assert len(self.few_shot_data_dirs.keys()) == self.few_shot_num
435
+ self.num_sample_frames = num_sample_frames
436
+
437
+ self.batch_size = kwargs['batch_size'] # a hack way here
438
+
439
+ # for debug, just use some categories
440
+ if "override_categories" in kwargs:
441
+ self.override_categories = kwargs["override_categories"]
442
+ else:
443
+ self.override_categories = None
444
+
445
+ # original dataset
446
+ original_data_paths = {}
447
+ for k,v in self.original_data_dirs.items():
448
+
449
+ # categories override
450
+ if self.override_categories is not None:
451
+ if k not in self.override_categories:
452
+ continue
453
+
454
+ sequences = self._make_sequences(v)
455
+ samples = []
456
+ for seq in sequences:
457
+ samples += seq
458
+ if shuffle:
459
+ random.shuffle(samples)
460
+ original_data_paths.update({k: samples})
461
+
462
+ # few-shot dataset
463
+ enhance_back_view = kwargs['enhance_back_view']
464
+ if enhance_back_view:
465
+ enhance_back_view_path = kwargs['enhance_back_view_path']
466
+
467
+ few_shot_data_paths = {}
468
+ for k,v in self.few_shot_data_dirs.items():
469
+
470
+ # categories override
471
+ if self.override_categories is not None:
472
+ if k not in self.override_categories:
473
+ continue
474
+ if k.startswith('_'):
475
+ # a boundary here for dealing with when in new data, we have same categories as in 7-cat
476
+ v = v.replace(k, k[1:])
477
+
478
+ if isinstance(v, str):
479
+ result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0])))
480
+ elif isinstance(v, list):
481
+ result = []
482
+ for _v in v:
483
+ result = result + sorted(glob(os.path.join(_v, '*'+self.image_loaders[0][0])))
484
+ else:
485
+ raise NotImplementedError
486
+
487
+ # result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0])))
488
+ result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
489
+ sequences = result
490
+
491
+ # the original 7 categories are using pre-defined paths to separate train and test
492
+ # here the few-shot we use is_validation to decide if this dataset is train or test
493
+ # if use enhanced back view, we first pad the multiplied back view image paths at the front of seq
494
+ # i.e., we don't use back view images for validation
495
+ if enhance_back_view:
496
+ back_view_dir = os.path.join(enhance_back_view_path, k, 'train')
497
+ back_view_result = sorted(glob(os.path.join(back_view_dir, '*'+self.image_loaders[0][0])))
498
+ back_view_result = [p.replace(self.image_loaders[0][0], '{}') for p in back_view_result]
499
+ mul_bv_sequences = self._more_back_views(back_view_result, result)
500
+ sequences = mul_bv_sequences + sequences
501
+
502
+ if is_validation:
503
+ # sequences = sequences[-2:]
504
+ sequences = sequences[-val_image_num:]
505
+ else:
506
+ # sequences = sequences[:-2]
507
+ sequences = sequences[:-val_image_num]
508
+
509
+ if shuffle:
510
+ random.shuffle(sequences)
511
+ few_shot_data_paths.update({k: sequences})
512
+
513
+ # for visualization purpose
514
+ self.pure_ori_data_path = original_data_paths
515
+ self.pure_fs_data_path = few_shot_data_paths
516
+
517
+ self.few_shot_data_length = self._get_data_length(few_shot_data_paths) # get the original length of each few-shot category
518
+
519
+ if disable_fewshot:
520
+ few_shot_data_paths = {}
521
+
522
+ self.dataset_split_num = dataset_split_num # if -1 then pad to longest, otherwise follow this number to pad and split
523
+ if is_validation:
524
+ self.dataset_split_num = -1 # validation we don't split dataset
525
+
526
+ if self.dataset_split_num == -1:
527
+ self.all_data_paths, self.one_category_num = self._pad_paths(original_data_paths, few_shot_data_paths)
528
+ self.all_category_num = len(self.all_data_paths.keys())
529
+ self.all_category_names = list(self.all_data_paths.keys())
530
+ self.original_category_names = list(self.original_data_dirs.keys())
531
+ elif self.dataset_split_num > 0:
532
+ self.all_data_paths, self.one_category_num, self.original_category_names = self._pad_paths_withnum(original_data_paths, few_shot_data_paths, self.dataset_split_num)
533
+ self.all_category_num = len(self.all_data_paths.keys())
534
+ self.all_category_names = list(self.all_data_paths.keys())
535
+ else:
536
+ raise NotImplementedError
537
+
538
+ self.in_image_size = in_image_size
539
+ self.out_image_size = out_image_size
540
+ self.load_background = load_background
541
+ self.color_jitter = color_jitter
542
+ self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
543
+ self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
544
+ self.random_flip = random_flip
545
+ self.load_dino_feature = load_dino_feature
546
+ if load_dino_feature:
547
+ self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
548
+
549
+ def _more_back_views(self, back_view_seq, seq):
550
+ if len(back_view_seq) == 0:
551
+ # for category without back views
552
+ return []
553
+ factor = 5
554
+ # length = (len(seq) // factor) * factor
555
+ length = (len(seq) // factor) * (factor - 1)
556
+ mul_f = length // len(back_view_seq)
557
+ pad_f = length % len(back_view_seq)
558
+ new_seq = mul_f * back_view_seq + back_view_seq[:pad_f]
559
+ return new_seq
560
+
561
+ def _get_data_length(self, paths):
562
+ data_length = {}
563
+ for k,v in paths.items():
564
+ length = len(v)
565
+ data_length.update({k: length})
566
+ return data_length
567
+
568
+ def _make_sequences(self, path):
569
+ result = []
570
+ for d in sorted(os.scandir(path), key=lambda e: e.name):
571
+ if d.is_dir():
572
+ files = self._parse_folder(d)
573
+ if len(files) >= 1:
574
+ result.append(files)
575
+ return result
576
+
577
+ def _parse_folder(self, path):
578
+ result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0])))
579
+ result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
580
+
581
+ if len(result) <= 0:
582
+ return []
583
+ return result
584
+
585
+ def _pad_paths(self, ori_paths, fs_paths):
586
+ img_nums = []
587
+ all_paths = copy.deepcopy(ori_paths)
588
+ all_paths.update(fs_paths)
589
+ for _, v in all_paths.items():
590
+ img_nums.append(len(v))
591
+
592
+ img_num = max(img_nums)
593
+ img_num = (img_num // self.batch_size) * self.batch_size
594
+
595
+ for k,v in all_paths.items():
596
+ if len(v) < img_num:
597
+ mul_time = img_num // len(v)
598
+ pad_time = img_num % len(v)
599
+ # for each v, shuffle it
600
+ shuffle_v = copy.deepcopy(v)
601
+ new_v = []
602
+ for i in range(mul_time):
603
+ new_v = new_v + shuffle_v
604
+ random.shuffle(shuffle_v)
605
+ del shuffle_v
606
+ new_v = new_v + v[0:pad_time]
607
+ # new_v = mul_time * v + v[0:pad_time]
608
+ all_paths[k] = new_v
609
+ elif len(v) > img_num:
610
+ all_paths[k] = v[:img_num]
611
+ else:
612
+ continue
613
+
614
+ return all_paths, img_num
615
+
616
+ def _pad_paths_withnum(self, ori_paths, fs_paths, split_num=1000):
617
+ img_num = (split_num // self.batch_size) * self.batch_size
618
+ all_paths = {}
619
+ orig_cat_names = []
620
+
621
+ for k, v in ori_paths.items():
622
+ total_num = ((len(v) // img_num) + 1) * img_num
623
+ pad_num = total_num - len(v)
624
+ split_num = total_num // img_num
625
+
626
+ new_v = copy.deepcopy(v)
627
+ random.shuffle(new_v)
628
+ all_v = v + new_v[:pad_num]
629
+ del new_v
630
+
631
+ for sn in range(split_num):
632
+ split_cat_name = f'{k}_' + '%03d' % sn
633
+ all_paths.update({
634
+ split_cat_name: all_v[sn*img_num: (sn+1)*img_num]
635
+ })
636
+ orig_cat_names.append(split_cat_name)
637
+
638
+ for k, v in fs_paths.items():
639
+ if len(v) < img_num:
640
+ mul_time = img_num // len(v)
641
+ pad_time = img_num % len(v)
642
+ # for each v, shuffle it
643
+ shuffle_v = copy.deepcopy(v)
644
+ new_v = []
645
+ for i in range(mul_time):
646
+ new_v = new_v + shuffle_v
647
+ random.shuffle(shuffle_v)
648
+ del shuffle_v
649
+ new_v = new_v + v[0:pad_time]
650
+ # new_v = mul_time * v + v[0:pad_time]
651
+ all_paths.update({
652
+ k: new_v
653
+ })
654
+ elif len(v) > img_num:
655
+ all_paths.update({
656
+ k: v[:img_num]
657
+ })
658
+ else:
659
+ continue
660
+
661
+ return all_paths, img_num, orig_cat_names
662
+
663
+
664
+ def _load_ids(self, path_patterns, loaders, transform=None):
665
+ result = []
666
+ for loader in loaders:
667
+ for p in path_patterns:
668
+ x = loader[1](p.format(loader[0]), *loader[2:])
669
+ if transform:
670
+ x = transform(x)
671
+ result.append(x)
672
+ return tuple(result)
673
+
674
+ def _shuffle_all(self):
675
+ for k,v in self.all_data_paths.items():
676
+ new_v = copy.deepcopy(v)
677
+ random.shuffle(new_v)
678
+ self.all_data_paths[k] = new_v
679
+ return None
680
+
681
+ def __len__(self):
682
+ return self.all_category_num * self.one_category_num
683
+
684
+ def __getitem__(self, index):
685
+ '''
686
+ This dataset must have non-shuffled index!!
687
+ '''
688
+ category_idx = (index % (self.batch_size * self.all_category_num)) // self.batch_size
689
+ path_idx = (index // (self.batch_size * self.all_category_num)) * self.batch_size + (index % (self.batch_size * self.all_category_num)) - category_idx * self.batch_size
690
+ category_name = self.all_category_names[category_idx]
691
+ paths = [self.all_data_paths[category_name][path_idx]] # len 1 sequence
692
+
693
+ if category_name in self.original_category_names:
694
+ bbox_loaders = self.original_bbox_loaders
695
+ use_original_bbox = True
696
+ else:
697
+ bbox_loaders = self.few_shot_bbox_loaders
698
+ use_original_bbox = False
699
+
700
+ masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
701
+ mask_dt = compute_distance_transform(masks)
702
+ jitter = False
703
+ if self.color_jitter is not None:
704
+ prob, b, h = self.color_jitter
705
+ if np.random.rand() < prob:
706
+ jitter = True
707
+ color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
708
+ image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
709
+ color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
710
+ image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
711
+ if jitter:
712
+ images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
713
+ images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
714
+ images = images_fg * masks + images_bg * (1-masks)
715
+ else:
716
+ images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
717
+
718
+ flows = torch.zeros(1)
719
+ bboxs = torch.stack(self._load_ids(paths, bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
720
+ if not use_original_bbox:
721
+ bboxs=torch.cat([bboxs, torch.Tensor([[category_idx]]).float()],dim=-1) # pad a label number
722
+
723
+ mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
724
+ if self.load_background:
725
+ bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
726
+ if jitter:
727
+ bg_image = color_jitter_tsf_bg(bg_image)
728
+ bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
729
+ else:
730
+ bg_images = torch.zeros_like(images)
731
+ if self.load_dino_feature:
732
+ # print(paths)
733
+ new_dino_data_name = "data_dino_5000"
734
+ new_dino_data_path = os.path.join("/viscam/projects/articulated/dor/combine_all_data_for_ablation_magicpony", new_dino_data_name)
735
+
736
+ # TODO: use another version of DINO here by changing the path
737
+ if paths[0].startswith("/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new"):
738
+ # 7 cat data
739
+ new_dino_path = paths[0].replace(
740
+ "/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new",
741
+ "/viscam/projects/articulated/zzli/data_dino_5000/7_cat"
742
+ )
743
+ dino_paths = [new_dino_path]
744
+ elif paths[0].startswith("/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all"):
745
+ # 100 cat
746
+ dino_path = paths[0].replace(
747
+ "/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all",
748
+ os.path.join(new_dino_data_path, "100_cat")
749
+ )
750
+ dino_path_list = dino_path.split("/")
751
+ new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/"
752
+ new_dino_path = '/'.join(new_dino_path)
753
+ dino_paths = [new_dino_path]
754
+
755
+ elif paths[0].startswith("/viscam/projects/articulated/zzli/fs_data/data_resize_update/few_shot_data_all"):
756
+ # 100 cat
757
+ dino_path = paths[0].replace(
758
+ "/viscam/projects/articulated/zzli/fs_data/data_resize_update/few_shot_data_all",
759
+ os.path.join(new_dino_data_path, "100_cat")
760
+ )
761
+ dino_path_list = dino_path.split("/")
762
+ new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/"
763
+ new_dino_path = '/'.join(new_dino_path)
764
+ dino_paths = [new_dino_path]
765
+
766
+ elif paths[0].startswith("/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data"):
767
+ # back 100 cat
768
+ dino_path = paths[0].replace(
769
+ "/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data",
770
+ os.path.join(new_dino_data_path, "back_100_cat")
771
+ )
772
+ dino_path_list = dino_path.split("/")
773
+ new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/"
774
+ new_dino_path = '/'.join(new_dino_path)
775
+ dino_paths = [new_dino_path]
776
+
777
+ elif paths[0].startswith("/viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered"):
778
+ # animal3d
779
+ dino_path = paths[0].replace(
780
+ "/viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered",
781
+ os.path.join(new_dino_data_path, "animal3D")
782
+ )
783
+ dino_path_list = dino_path.split("/")
784
+ new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/"
785
+ new_dino_path = '/'.join(new_dino_path)
786
+ dino_paths = [new_dino_path]
787
+ else:
788
+ raise NotImplementedError
789
+ dino_features = torch.stack(self._load_ids(dino_paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0)
790
+ # dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
791
+ else:
792
+ dino_features = torch.zeros(1)
793
+
794
+ dino_clusters = torch.zeros(1)
795
+
796
+ # These are actually no use
797
+ seq_idx = 0
798
+ seq_idx = torch.LongTensor([seq_idx])
799
+ frame_idx = torch.arange(0, 1).long()
800
+
801
+ if self.random_flip and np.random.rand() < 0.5:
802
+ images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
803
+
804
+ ## pad shorter sequence
805
+ if len(paths) < self.num_sample_frames:
806
+ num_pad = self.num_sample_frames - len(paths)
807
+ images = torch.cat([images[:1]] *num_pad + [images], 0)
808
+ masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
809
+ mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
810
+ mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
811
+ if flows.dim() > 1:
812
+ flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
813
+ bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
814
+ bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
815
+ if dino_features.dim() > 1:
816
+ dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
817
+ if dino_clusters.dim() > 1:
818
+ dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
819
+ frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
820
+
821
+ out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name)), )
822
+ return out
823
+ # return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name
824
+
825
+ def get_sequence_loader_quadrupeds(original_data_dirs, few_shot_data_dirs, original_num, few_shot_num, rank, world_size, **kwargs):
826
+ dataset = Quadrupeds_Image_Dataset(original_data_dirs, few_shot_data_dirs, original_num, few_shot_num, **kwargs)
827
+ sampler = torch.utils.data.distributed.DistributedSampler(
828
+ dataset,
829
+ num_replicas=world_size,
830
+ rank=rank,
831
+ shuffle=False
832
+ )
833
+ loaders = []
834
+ loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)]
835
+
836
+ return loaders
837
+
838
+
839
+ class Quadrupeds_Image_Test_Dataset(Dataset):
840
+ def __init__(self, test_data_dirs, num_sample_frames=2, in_image_size=256, out_image_size=256, shuffle=False, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, flow_bool=False, **kwargs):
841
+ self.few_shot_data_dirs = test_data_dirs
842
+
843
+ self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
844
+ self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
845
+ self.original_bbox_loaders = [("box.txt", box_loader)]
846
+ self.few_shot_bbox_loaders = [("box.txt", few_shot_box_loader)]
847
+
848
+ self.num_sample_frames = num_sample_frames
849
+
850
+ self.batch_size = kwargs['batch_size'] # a hack way here
851
+
852
+ few_shot_data_paths = {}
853
+ for k,v in self.few_shot_data_dirs.items():
854
+
855
+ if k.startswith('_'):
856
+ # a boundary here for dealing with when in new data, we have same categories as in 7-cat
857
+ v = v.replace(k, k[1:])
858
+
859
+ if isinstance(v, str):
860
+ result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0])))
861
+ elif isinstance(v, list):
862
+ result = []
863
+ for _v in v:
864
+ result = result + sorted(glob(os.path.join(_v, '*'+self.image_loaders[0][0])))
865
+ else:
866
+ raise NotImplementedError
867
+
868
+ # result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0])))
869
+ result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
870
+ sequences = result
871
+
872
+ if shuffle:
873
+ random.shuffle(sequences)
874
+ few_shot_data_paths.update({k: sequences})
875
+
876
+ # for visualization purpose
877
+ self.pure_fs_data_path = few_shot_data_paths
878
+
879
+ self.all_data_paths, self.one_category_num = self._pad_paths(few_shot_data_paths)
880
+ self.all_category_num = len(self.all_data_paths.keys())
881
+ self.all_category_names = list(self.all_data_paths.keys())
882
+
883
+ self.in_image_size = in_image_size
884
+ self.out_image_size = out_image_size
885
+ self.load_background = load_background
886
+ self.color_jitter = color_jitter
887
+ self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
888
+ self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
889
+ self.random_flip = random_flip
890
+ self.load_dino_feature = load_dino_feature
891
+ if load_dino_feature:
892
+ self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
893
+
894
+ def _pad_paths(self, fs_paths):
895
+ img_nums = []
896
+ all_paths = copy.deepcopy(fs_paths)
897
+ for _, v in all_paths.items():
898
+ img_nums.append(len(v))
899
+
900
+ img_num = max(img_nums)
901
+ img_num = (img_num // self.batch_size) * self.batch_size
902
+
903
+ for k,v in all_paths.items():
904
+ if len(v) < img_num:
905
+ mul_time = img_num // len(v)
906
+ pad_time = img_num % len(v)
907
+ # for each v, shuffle it
908
+ shuffle_v = copy.deepcopy(v)
909
+ new_v = []
910
+ for i in range(mul_time):
911
+ new_v = new_v + shuffle_v
912
+ random.shuffle(shuffle_v)
913
+ del shuffle_v
914
+ new_v = new_v + v[0:pad_time]
915
+ # new_v = mul_time * v + v[0:pad_time]
916
+ all_paths[k] = new_v
917
+ elif len(v) > img_num:
918
+ all_paths[k] = v[:img_num]
919
+ else:
920
+ continue
921
+
922
+ return all_paths, img_num
923
+
924
+ def _load_ids(self, path_patterns, loaders, transform=None):
925
+ result = []
926
+ for loader in loaders:
927
+ for p in path_patterns:
928
+ x = loader[1](p.format(loader[0]), *loader[2:])
929
+ if transform:
930
+ x = transform(x)
931
+ result.append(x)
932
+ return tuple(result)
933
+
934
+ def _shuffle_all(self):
935
+ for k,v in self.all_data_paths.items():
936
+ new_v = copy.deepcopy(v)
937
+ random.shuffle(new_v)
938
+ self.all_data_paths[k] = new_v
939
+ return None
940
+
941
+ def __len__(self):
942
+ return self.all_category_num * self.one_category_num
943
+
944
+ def __getitem__(self, index):
945
+ '''
946
+ This dataset must have non-shuffled index!!
947
+ '''
948
+ category_idx = (index % (self.batch_size * self.all_category_num)) // self.batch_size
949
+ path_idx = (index // (self.batch_size * self.all_category_num)) * self.batch_size + (index % (self.batch_size * self.all_category_num)) - category_idx * self.batch_size
950
+ category_name = self.all_category_names[category_idx]
951
+ paths = [self.all_data_paths[category_name][path_idx]] # len 1 sequence
952
+
953
+ # if category_name in self.original_category_names:
954
+ # bbox_loaders = self.original_bbox_loaders
955
+ # use_original_bbox = True
956
+ # else:
957
+ bbox_loaders = self.few_shot_bbox_loaders
958
+ use_original_bbox = False
959
+
960
+ masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
961
+ mask_dt = compute_distance_transform(masks)
962
+ jitter = False
963
+ if self.color_jitter is not None:
964
+ prob, b, h = self.color_jitter
965
+ if np.random.rand() < prob:
966
+ jitter = True
967
+ color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
968
+ image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
969
+ color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
970
+ image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
971
+ if jitter:
972
+ images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
973
+ images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
974
+ images = images_fg * masks + images_bg * (1-masks)
975
+ else:
976
+ images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
977
+
978
+ flows = torch.zeros(1)
979
+ bboxs = torch.stack(self._load_ids(paths, bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
980
+ if not use_original_bbox:
981
+ bboxs=torch.cat([bboxs, torch.Tensor([[category_idx]]).float()],dim=-1) # pad a label number
982
+
983
+ mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
984
+ if self.load_background:
985
+ bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
986
+ if jitter:
987
+ bg_image = color_jitter_tsf_bg(bg_image)
988
+ bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
989
+ else:
990
+ bg_images = torch.zeros_like(images)
991
+ if self.load_dino_feature:
992
+ dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
993
+ else:
994
+ dino_features = torch.zeros(1)
995
+
996
+ dino_clusters = torch.zeros(1)
997
+
998
+ # These are actually no use
999
+ seq_idx = 0
1000
+ seq_idx = torch.LongTensor([seq_idx])
1001
+ frame_idx = torch.arange(0, 1).long()
1002
+
1003
+ if self.random_flip and np.random.rand() < 0.5:
1004
+ images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
1005
+
1006
+ ## pad shorter sequence
1007
+ if len(paths) < self.num_sample_frames:
1008
+ num_pad = self.num_sample_frames - len(paths)
1009
+ images = torch.cat([images[:1]] *num_pad + [images], 0)
1010
+ masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
1011
+ mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
1012
+ mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
1013
+ if flows.dim() > 1:
1014
+ flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
1015
+ bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
1016
+ bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
1017
+ if dino_features.dim() > 1:
1018
+ dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
1019
+ if dino_clusters.dim() > 1:
1020
+ dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
1021
+ frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
1022
+
1023
+ out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name)), )
1024
+ return out
1025
+ # return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name
1026
+
1027
+
1028
+
1029
+ def get_test_loader_quadrupeds(test_data_dirs, rank, world_size, **kwargs):
1030
+ dataset = Quadrupeds_Image_Test_Dataset(test_data_dirs, **kwargs)
1031
+ sampler = torch.utils.data.distributed.DistributedSampler(
1032
+ dataset,
1033
+ num_replicas=world_size,
1034
+ rank=rank,
1035
+ shuffle=False
1036
+ )
1037
+ loaders = []
1038
+ loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)]
1039
+
1040
+ return loaders
1041
+
1042
+ def get_sequence_loader(data_dir, **kwargs):
1043
+ if isinstance(data_dir, dict):
1044
+ loaders = []
1045
+ for k, v in data_dir.items():
1046
+ dataset= NFrameSequenceDataset(v, cat_name=k, **kwargs)
1047
+ loader = torch.utils.data.DataLoader(dataset, batch_size=kwargs['batch_size'], shuffle=kwargs['shuffle'], num_workers=kwargs['num_workers'], pin_memory=True)
1048
+ loaders += [loader]
1049
+ return loaders
1050
+ else:
1051
+ return [get_sequence_loader_single(data_dir, **kwargs)]
1052
+
1053
+
1054
+ def get_sequence_loader_single(data_dir, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64):
1055
+ if mode == 'n_frame':
1056
+ dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim)
1057
+ else:
1058
+ raise NotImplementedError
1059
+ loader = torch.utils.data.DataLoader(
1060
+ dataset,
1061
+ batch_size=batch_size,
1062
+ shuffle=not is_validation,
1063
+ num_workers=num_workers,
1064
+ pin_memory=True
1065
+ )
1066
+ return loader
1067
+
1068
+
1069
+ def get_sequence_loader_ddp(data_dir, world_size, rank, use_few_shot=False, **kwargs):
1070
+ original_classes_num = 0
1071
+ use_few_shot = use_few_shot
1072
+ if isinstance(data_dir, list) and len(data_dir) == 2 and isinstance(data_dir[-1], dict):
1073
+ # a hack way for few shot experiment
1074
+ original_classes_num = data_dir[0]
1075
+ data_dir = data_dir[-1]
1076
+ if isinstance(data_dir, dict):
1077
+ loaders = []
1078
+ cnt = original_classes_num
1079
+ for k, v in data_dir.items():
1080
+ if use_few_shot:
1081
+ dataset = FewShotImageDataset(v, cat_name=k, cat_num=cnt, **kwargs)
1082
+ cnt += 1
1083
+ else:
1084
+ dataset = NFrameSequenceDataset(v, cat_name=k, **kwargs)
1085
+ sampler = torch.utils.data.distributed.DistributedSampler(
1086
+ dataset,
1087
+ num_replicas=world_size,
1088
+ rank=rank,
1089
+ )
1090
+ loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)]
1091
+ return loaders
1092
+ else:
1093
+ return [get_sequence_loader_single_ddp(data_dir, world_size, rank, **kwargs)]
1094
+
1095
+
1096
+ def get_sequence_loader_single_ddp(data_dir, world_size, rank, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, flow_bool=False):
1097
+ if mode == 'n_frame':
1098
+ dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim, flow_bool=flow_bool)
1099
+ else:
1100
+ raise NotImplementedError
1101
+ sampler = torch.utils.data.distributed.DistributedSampler(
1102
+ dataset,
1103
+ num_replicas=world_size,
1104
+ rank=rank,
1105
+ )
1106
+ loader = torch.utils.data.DataLoader(
1107
+ dataset,
1108
+ sampler=sampler,
1109
+ batch_size=batch_size,
1110
+ shuffle=False,
1111
+ drop_last=True,
1112
+ num_workers=num_workers,
1113
+ pin_memory=True
1114
+ )
1115
+ return loader
1116
+
1117
+
1118
+ class ImageDataset(Dataset):
1119
+ def __init__(self, root, is_validation=False, image_size=256, color_jitter=None):
1120
+ super().__init__()
1121
+ self.image_loader = ("rgb.jpg", torchvision.datasets.folder.default_loader)
1122
+ self.mask_loader = ("mask.png", torchvision.datasets.folder.default_loader)
1123
+ self.bbox_loader = ("box.txt", np.loadtxt, 'str')
1124
+ self.samples = self._parse_folder(root)
1125
+ self.image_size = image_size
1126
+ self.color_jitter = color_jitter
1127
+ self.image_transform = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor()])
1128
+ self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
1129
+
1130
+ def _parse_folder(self, path):
1131
+ result = sorted(glob(os.path.join(path, '**/*'+self.image_loader[0]), recursive=True))
1132
+ result = [p.replace(self.image_loader[0], '{}') for p in result]
1133
+ return result
1134
+
1135
+ def _load_ids(self, path, loader, transform=None):
1136
+ x = loader[1](path.format(loader[0]), *loader[2:])
1137
+ if transform:
1138
+ x = transform(x)
1139
+ return x
1140
+
1141
+ def __len__(self):
1142
+ return len(self.samples)
1143
+
1144
+ def __getitem__(self, index):
1145
+ path = self.samples[index % len(self.samples)]
1146
+ masks = self._load_ids(path, self.mask_loader, transform=self.mask_transform).unsqueeze(0)
1147
+ mask_dt = compute_distance_transform(masks)
1148
+ jitter = False
1149
+ if self.color_jitter is not None:
1150
+ prob, b, h = self.color_jitter
1151
+ if np.random.rand() < prob:
1152
+ jitter = True
1153
+ color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
1154
+ image_transform_fg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_fg, transforms.ToTensor()])
1155
+ color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
1156
+ image_transform_bg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_bg, transforms.ToTensor()])
1157
+ if jitter:
1158
+ images_fg = self._load_ids(path, self.image_loader, transform=image_transform_fg).unsqueeze(0)
1159
+ images_bg = self._load_ids(path, self.image_loader, transform=image_transform_bg).unsqueeze(0)
1160
+ images = images_fg * masks + images_bg * (1-masks)
1161
+ else:
1162
+ images = self._load_ids(path, self.image_loader, transform=self.image_transform).unsqueeze(0)
1163
+ flows = torch.zeros(1)
1164
+ bboxs = self._load_ids(path, self.bbox_loader, transform=None)
1165
+ bboxs[0] = '0'
1166
+ bboxs = torch.FloatTensor(bboxs.astype('float')).unsqueeze(0)
1167
+ bg_fpath = os.path.join(os.path.dirname(path), 'background_frame.jpg')
1168
+ if os.path.isfile(bg_fpath):
1169
+ bg_image = torchvision.datasets.folder.default_loader(bg_fpath)
1170
+ if jitter:
1171
+ bg_image = color_jitter_tsf_bg(bg_image)
1172
+ bg_image = transforms.ToTensor()(bg_image)
1173
+ else:
1174
+ bg_image = images[0]
1175
+ seq_idx = torch.LongTensor([index])
1176
+ frame_idx = torch.LongTensor([0])
1177
+ return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx
1178
+
1179
+
1180
+ def get_image_loader(data_dir, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None):
1181
+ dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter)
1182
+
1183
+ loader = torch.utils.data.DataLoader(
1184
+ dataset,
1185
+ batch_size=batch_size,
1186
+ shuffle=False,
1187
+ num_workers=num_workers,
1188
+ pin_memory=True
1189
+ )
1190
+ return loader
1191
+
1192
+
1193
+ def get_image_loader_ddp(data_dir, world_size, rank, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None):
1194
+ dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter)
1195
+
1196
+ sampler = torch.utils.data.distributed.DistributedSampler(
1197
+ dataset,
1198
+ num_replicas=world_size,
1199
+ rank=rank,
1200
+ )
1201
+ loader = torch.utils.data.DataLoader(
1202
+ dataset,
1203
+ sampler=sampler,
1204
+ batch_size=batch_size,
1205
+ shuffle=False,
1206
+ drop_last=True,
1207
+ num_workers=num_workers,
1208
+ pin_memory=True
1209
+ )
1210
+ return loader
video3d/diffusion/sd.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ['HUGGINGFACE_HUB_CACHE'] = '/work/tomj/cache/huggingface_hub'
3
+ # os.environ['HF_HOME'] = '/work/tomj/cache/huggingface_hub'
4
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli'
5
+ os.environ['HF_HOME'] = '/viscam/u/zzli'
6
+
7
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
8
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler
9
+
10
+ # Suppress partial model loading warning
11
+ logging.set_verbosity_error()
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from torch.cuda.amp import custom_bwd, custom_fwd
18
+
19
+ class SpecifyGradient(torch.autograd.Function):
20
+ @staticmethod
21
+ @custom_fwd
22
+ def forward(ctx, input_tensor, gt_grad):
23
+ ctx.save_for_backward(gt_grad)
24
+ return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype) # Dummy loss value
25
+
26
+ @staticmethod
27
+ @custom_bwd
28
+ def backward(ctx, grad):
29
+ gt_grad, = ctx.saved_tensors
30
+ batch_size = len(gt_grad)
31
+ return gt_grad / batch_size, None
32
+
33
+ def seed_everything(seed):
34
+ torch.manual_seed(seed)
35
+ torch.cuda.manual_seed(seed)
36
+
37
+
38
+ class StableDiffusion(nn.Module):
39
+ def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32):
40
+ super().__init__()
41
+
42
+ self.device = device
43
+ self.sd_version = sd_version
44
+ self.torch_dtype = torch_dtype
45
+
46
+ print(f'[INFO] loading stable diffusion...')
47
+
48
+ if hf_key is not None:
49
+ print(f'[INFO] using hugging face custom model key: {hf_key}')
50
+ model_key = hf_key
51
+ elif self.sd_version == '2.1':
52
+ model_key = "stabilityai/stable-diffusion-2-1-base"
53
+ elif self.sd_version == '2.0':
54
+ model_key = "stabilityai/stable-diffusion-2-base"
55
+ elif self.sd_version == '1.5':
56
+ model_key = "runwayml/stable-diffusion-v1-5"
57
+ else:
58
+ raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
59
+
60
+ # Create model
61
+ self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device)
62
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
63
+ self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
64
+ self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
65
+
66
+ self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
67
+ # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler")
68
+
69
+ self.num_train_timesteps = self.scheduler.config.num_train_timesteps
70
+ self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
71
+
72
+ print(f'[INFO] loaded stable diffusion!')
73
+
74
+ def get_text_embeds(self, prompt, negative_prompt):
75
+ # prompt, negative_prompt: [str]
76
+
77
+ # Tokenize text and get embeddings
78
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
79
+
80
+ with torch.no_grad():
81
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
82
+
83
+ # Do the same for unconditional embeddings
84
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
85
+
86
+ with torch.no_grad():
87
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
88
+
89
+ # Cat for final embeddings
90
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
91
+ return text_embeddings
92
+
93
+ def train_step(self, text_embeddings, pred_rgb,
94
+ guidance_scale=100, loss_weight=1.0, min_step_pct=0.02, max_step_pct=0.98, return_aux=False):
95
+ pred_rgb = pred_rgb.to(self.torch_dtype)
96
+ text_embeddings = text_embeddings.to(self.torch_dtype)
97
+ b = pred_rgb.shape[0]
98
+
99
+ # interp to 512x512 to be fed into vae.
100
+
101
+ # _t = time.time()
102
+ pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
103
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
104
+
105
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
106
+ min_step = int(self.num_train_timesteps * min_step_pct)
107
+ max_step = int(self.num_train_timesteps * max_step_pct)
108
+ t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device)
109
+
110
+ # encode image into latents with vae, requires grad!
111
+ # _t = time.time()
112
+ latents = self.encode_imgs(pred_rgb_512)
113
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
114
+
115
+ # predict the noise residual with unet, NO grad!
116
+ # _t = time.time()
117
+ with torch.no_grad():
118
+ # add noise
119
+ noise = torch.randn_like(latents)
120
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
121
+ # pred noise
122
+ latent_model_input = torch.cat([latents_noisy] * 2)
123
+ t_input = torch.cat([t, t])
124
+ noise_pred = self.unet(latent_model_input, t_input, encoder_hidden_states=text_embeddings).sample
125
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s')
126
+
127
+ # perform guidance (high scale from paper!)
128
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
129
+ # noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
130
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
131
+
132
+ # w(t), sigma_t^2
133
+ w = (1 - self.alphas[t])
134
+ # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
135
+ grad = loss_weight * w[:, None, None, None] * (noise_pred - noise)
136
+
137
+ # clip grad for stable training?
138
+ # grad = grad.clamp(-10, 10)
139
+ grad = torch.nan_to_num(grad)
140
+
141
+ # since we omitted an item in grad, we need to use the custom function to specify the gradient
142
+ # _t = time.time()
143
+ # loss = SpecifyGradient.apply(latents, grad)
144
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s')
145
+
146
+ targets = (latents - grad).detach()
147
+ loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
148
+
149
+ if return_aux:
150
+ aux = {'grad': grad, 't': t, 'w': w}
151
+ return loss, aux
152
+ else:
153
+ return loss
154
+
155
+
156
+ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
157
+
158
+ if latents is None:
159
+ latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.config.in_channels, height // 8, width // 8), device=self.device)
160
+
161
+ self.scheduler.set_timesteps(num_inference_steps)
162
+
163
+ with torch.autocast('cuda'):
164
+ for i, t in enumerate(self.scheduler.timesteps):
165
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
166
+ latent_model_input = torch.cat([latents] * 2)
167
+
168
+ # predict the noise residual
169
+ with torch.no_grad():
170
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
171
+
172
+ # perform guidance
173
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
174
+ noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
175
+
176
+ # compute the previous noisy sample x_t -> x_t-1
177
+ latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
178
+
179
+ return latents
180
+
181
+ def decode_latents(self, latents):
182
+
183
+ latents = 1 / self.vae.config.scaling_factor * latents
184
+
185
+ with torch.no_grad():
186
+ imgs = self.vae.decode(latents).sample
187
+
188
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
189
+
190
+ return imgs
191
+
192
+ def encode_imgs(self, imgs):
193
+ # imgs: [B, 3, H, W]
194
+
195
+ imgs = 2 * imgs - 1
196
+
197
+ posterior = self.vae.encode(imgs).latent_dist
198
+ latents = posterior.sample() * self.vae.config.scaling_factor
199
+
200
+ return latents
201
+
202
+ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
203
+
204
+ if isinstance(prompts, str):
205
+ prompts = [prompts]
206
+
207
+ if isinstance(negative_prompts, str):
208
+ negative_prompts = [negative_prompts]
209
+
210
+ # Prompts -> text embeds
211
+ text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
212
+
213
+ # Text embeds -> img latents
214
+ latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
215
+
216
+ # Img latents -> imgs
217
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
218
+
219
+ # Img to Numpy
220
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
221
+ imgs = (imgs * 255).round().astype('uint8')
222
+
223
+ return imgs
224
+
225
+
226
+ if __name__ == '__main__':
227
+ import argparse
228
+ import matplotlib.pyplot as plt
229
+
230
+ parser = argparse.ArgumentParser()
231
+ parser.add_argument('prompt', type=str)
232
+ parser.add_argument('--negative', default='', type=str)
233
+ parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
234
+ parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
235
+ parser.add_argument('-H', type=int, default=512)
236
+ parser.add_argument('-W', type=int, default=512)
237
+ parser.add_argument('--seed', type=int, default=0)
238
+ parser.add_argument('--steps', type=int, default=50)
239
+ opt = parser.parse_args()
240
+
241
+ seed_everything(opt.seed)
242
+
243
+ device = torch.device('cuda')
244
+
245
+ sd = StableDiffusion(device, opt.sd_version, opt.hf_key)
246
+
247
+ imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
248
+
249
+ # visualize image
250
+ plt.imshow(imgs[0])
251
+ plt.show()
252
+ plt.savefig(f'{opt.prompt}.png')
video3d/diffusion/sd_utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import torch.nn.functional as F
5
+
6
+ from ..render.light import DirectionalLight
7
+
8
+ def safe_normalize(x, eps=1e-20):
9
+ return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
10
+
11
+ def get_view_direction(thetas, phis, overhead, front, phi_offset=0):
12
+ # phis [B,]; thetas: [B,]
13
+ # front = 0 [360 - front / 2, front / 2)
14
+ # side (left) = 1 [front / 2, 180 - front / 2)
15
+ # back = 2 [180 - front / 2, 180 + front / 2)
16
+ # side (right) = 3 [180 + front / 2, 360 - front / 2)
17
+ # top = 4 [0, overhead]
18
+ # bottom = 5 [180-overhead, 180]
19
+ res = torch.zeros(thetas.shape[0], dtype=torch.long)
20
+
21
+ # first determine by phis
22
+ phi_offset = np.deg2rad(phi_offset)
23
+ phis = phis + phi_offset
24
+ phis = phis % (2 * np.pi)
25
+ half_front = front / 2
26
+
27
+ res[(phis >= (2*np.pi - half_front)) | (phis < half_front)] = 0
28
+ res[(phis >= half_front) & (phis < (np.pi - half_front))] = 1
29
+ res[(phis >= (np.pi - half_front)) & (phis < (np.pi + half_front))] = 2
30
+ res[(phis >= (np.pi + half_front)) & (phis < (2*np.pi - half_front))] = 3
31
+
32
+ # override by thetas
33
+ res[thetas <= overhead] = 4
34
+ res[thetas >= (np.pi - overhead)] = 5
35
+ return res
36
+
37
+
38
+ def view_direction_id_to_text(view_direction_id):
39
+ dir_texts = ['front', 'side', 'back', 'side', 'overhead', 'bottom']
40
+ return [dir_texts[i] for i in view_direction_id]
41
+
42
+
43
+ def append_text_direction(prompts, dir_texts):
44
+ return [f'{prompt}, {dir_text} view' for prompt, dir_text in zip(prompts, dir_texts)]
45
+
46
+
47
+ def rand_lights(camera_dir, fixed_ambient, fixed_diffuse):
48
+ size = camera_dir.shape[0]
49
+ device = camera_dir.device
50
+ random_fixed_dir = F.normalize(torch.randn_like(camera_dir) + camera_dir, dim=-1) # Centered around camera_dir
51
+ random_fixed_intensity = torch.tensor([fixed_ambient, fixed_diffuse], device=device)[None, :].repeat(size, 1) # ambient, diffuse
52
+ return DirectionalLight(mlp_in=1, mlp_layers=1, mlp_hidden_size=1, # Dummy values
53
+ intensity_min_max=[0.5, 1],fixed_dir=random_fixed_dir, fixed_intensity=random_fixed_intensity).to(device)
54
+
55
+ def rand_poses(size, device, radius_range=[1, 1], theta_range=[0, 120], phi_range=[0, 360], cam_z_offset=10, return_dirs=False, angle_overhead=30, angle_front=60, phi_offset=0, jitter=False, uniform_sphere_rate=0.5):
56
+ ''' generate random poses from an orbit camera
57
+ Args:
58
+ size: batch size of generated poses.
59
+ device: where to allocate the output.
60
+ radius_range: [min, max]
61
+ theta_range: [min, max], should be in [0, pi]
62
+ phi_range: [min, max], should be in [0, 2 * pi]
63
+ Return:
64
+ poses: [size, 4, 4]
65
+ '''
66
+
67
+ theta_range = np.deg2rad(theta_range)
68
+ phi_range = np.deg2rad(phi_range)
69
+ angle_overhead = np.deg2rad(angle_overhead)
70
+ angle_front = np.deg2rad(angle_front)
71
+
72
+ radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
73
+
74
+ phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
75
+ if random.random() < uniform_sphere_rate:
76
+ # based on http://corysimon.github.io/articles/uniformdistn-on-sphere/
77
+ # acos takes in [-1, 1], first convert theta range to fit in [-1, 1]
78
+ theta_range = torch.from_numpy(np.array(theta_range)).to(device)
79
+ theta_amplitude_range = torch.cos(theta_range)
80
+ # sample uniformly in amplitude space range
81
+ thetas_amplitude = torch.rand(size, device=device) * (theta_amplitude_range[1] - theta_amplitude_range[0]) + theta_amplitude_range[0]
82
+ # convert back
83
+ thetas = torch.acos(thetas_amplitude)
84
+ else:
85
+ thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
86
+
87
+ centers = -torch.stack([
88
+ radius * torch.sin(thetas) * torch.sin(phis),
89
+ radius * torch.cos(thetas),
90
+ radius * torch.sin(thetas) * torch.cos(phis),
91
+ ], dim=-1) # [B, 3]
92
+
93
+ targets = 0
94
+
95
+ # jitters
96
+ if jitter:
97
+ centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
98
+ targets = targets + torch.randn_like(centers) * 0.2
99
+
100
+ # lookat
101
+ forward_vector = safe_normalize(targets - centers)
102
+ up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
103
+ right_vector = safe_normalize(torch.cross(up_vector, forward_vector, dim=-1))
104
+
105
+ if jitter:
106
+ up_noise = torch.randn_like(up_vector) * 0.02
107
+ else:
108
+ up_noise = 0
109
+
110
+ up_vector = safe_normalize(torch.cross(forward_vector, right_vector, dim=-1) + up_noise)
111
+
112
+ poses = torch.stack([right_vector, up_vector, forward_vector], dim=-1)
113
+ radius = radius[..., None] - cam_z_offset
114
+ translations = torch.cat([torch.zeros_like(radius), torch.zeros_like(radius), radius], dim=-1)
115
+ poses = torch.cat([poses.view(-1, 9), translations], dim=-1)
116
+
117
+ if return_dirs:
118
+ dirs = get_view_direction(thetas, phis, angle_overhead, angle_front, phi_offset=phi_offset)
119
+ dirs = view_direction_id_to_text(dirs)
120
+ else:
121
+ dirs = None
122
+
123
+ return poses, dirs
video3d/diffusion/vsd.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli'
3
+ os.environ['HF_HOME'] = '/viscam/u/zzli'
4
+
5
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
6
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler
7
+
8
+ from diffusers.loaders import AttnProcsLayers
9
+ from diffusers.models.attention_processor import LoRAAttnProcessor
10
+ from diffusers.models.embeddings import TimestepEmbedding
11
+ from diffusers.utils.import_utils import is_xformers_available
12
+
13
+ # Suppress partial model loading warning
14
+ logging.set_verbosity_error()
15
+
16
+ import gc
17
+ import random
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import tinycudann as tcnn
22
+ from video3d.diffusion.sd import StableDiffusion
23
+ from torch.cuda.amp import custom_bwd, custom_fwd
24
+
25
+
26
+ def seed_everything(seed):
27
+ torch.manual_seed(seed)
28
+ torch.cuda.manual_seed(seed)
29
+
30
+ def cleanup():
31
+ gc.collect()
32
+ torch.cuda.empty_cache()
33
+ tcnn.free_temporary_memory()
34
+
35
+ class StableDiffusion_VSD(StableDiffusion):
36
+ def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32, lora_n_timestamp_samples=1):
37
+ super().__init__(device, sd_version=sd_version, hf_key=hf_key, torch_dtype=torch_dtype)
38
+
39
+ # self.device = device
40
+ # self.sd_version = sd_version
41
+ # self.torch_dtype = torch_dtype
42
+
43
+ if hf_key is not None:
44
+ print(f'[INFO] using hugging face custom model key: {hf_key}')
45
+ model_key = hf_key
46
+ elif self.sd_version == '2.1':
47
+ model_key = "stabilityai/stable-diffusion-2-1-base"
48
+ elif self.sd_version == '2.0':
49
+ model_key = "stabilityai/stable-diffusion-2-base"
50
+ elif self.sd_version == '1.5':
51
+ model_key = "runwayml/stable-diffusion-v1-5"
52
+ else:
53
+ raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
54
+
55
+ # # Create model
56
+ # self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device)
57
+ # self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
58
+ # self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
59
+ # self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
60
+
61
+ # self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
62
+ # # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler")
63
+
64
+ # self.num_train_timesteps = self.scheduler.config.num_train_timesteps
65
+ # self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
66
+
67
+ print(f'[INFO] loading stable diffusion VSD modules...')
68
+
69
+ self.unet_lora = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
70
+ cleanup()
71
+
72
+ for p in self.vae.parameters():
73
+ p.requires_grad_(False)
74
+ for p in self.text_encoder.parameters():
75
+ p.requires_grad_(False)
76
+ for p in self.unet.parameters():
77
+ p.requires_grad_(False)
78
+ for p in self.unet_lora.parameters():
79
+ p.requires_grad_(False)
80
+
81
+ # set up LoRA layers
82
+ lora_attn_procs = {}
83
+ for name in self.unet_lora.attn_processors.keys():
84
+ cross_attention_dim = (
85
+ None
86
+ if name.endswith("attn1.processor")
87
+ else self.unet_lora.config.cross_attention_dim
88
+ )
89
+ if name.startswith("mid_block"):
90
+ hidden_size = self.unet_lora.config.block_out_channels[-1]
91
+ elif name.startswith("up_blocks"):
92
+ block_id = int(name[len("up_blocks.")])
93
+ hidden_size = list(reversed(self.unet_lora.config.block_out_channels))[
94
+ block_id
95
+ ]
96
+ elif name.startswith("down_blocks"):
97
+ block_id = int(name[len("down_blocks.")])
98
+ hidden_size = self.unet_lora.config.block_out_channels[block_id]
99
+
100
+ lora_attn_procs[name] = LoRAAttnProcessor(
101
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
102
+ )
103
+
104
+ self.unet_lora.set_attn_processor(lora_attn_procs)
105
+
106
+ self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to(
107
+ self.device
108
+ )
109
+ self.lora_layers._load_state_dict_pre_hooks.clear()
110
+ self.lora_layers._state_dict_hooks.clear()
111
+ self.lora_n_timestamp_samples = lora_n_timestamp_samples
112
+ self.scheduler_lora = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
113
+
114
+ print(f'[INFO] loaded stable diffusion VSD modules!')
115
+
116
+ def train_lora(
117
+ self,
118
+ latents,
119
+ text_embeddings,
120
+ camera_condition
121
+ ):
122
+ B = latents.shape[0]
123
+ lora_n_timestamp_samples = self.lora_n_timestamp_samples
124
+ latents = latents.detach().repeat(lora_n_timestamp_samples, 1, 1, 1)
125
+
126
+ t = torch.randint(
127
+ int(self.num_train_timesteps * 0.0),
128
+ int(self.num_train_timesteps * 1.0),
129
+ [B * lora_n_timestamp_samples],
130
+ dtype=torch.long,
131
+ device=self.device,
132
+ )
133
+
134
+ noise = torch.randn_like(latents)
135
+ noisy_latents = self.scheduler_lora.add_noise(latents, noise, t)
136
+ if self.scheduler_lora.config.prediction_type == "epsilon":
137
+ target = noise
138
+ elif self.scheduler_lora.config.prediction_type == "v_prediction":
139
+ target = self.scheduler_lora.get_velocity(latents, noise, t)
140
+ else:
141
+ raise ValueError(
142
+ f"Unknown prediction type {self.scheduler_lora.config.prediction_type}"
143
+ )
144
+
145
+ # use view-independent text embeddings in LoRA
146
+ _, text_embeddings_cond = text_embeddings.chunk(2)
147
+
148
+ if random.random() < 0.1:
149
+ camera_condition = torch.zeros_like(camera_condition)
150
+
151
+ noise_pred = self.unet_lora(
152
+ noisy_latents,
153
+ t,
154
+ encoder_hidden_states=text_embeddings_cond.repeat(
155
+ lora_n_timestamp_samples, 1, 1
156
+ ),
157
+ class_labels=camera_condition.reshape(B, -1).repeat(
158
+ lora_n_timestamp_samples, 1
159
+ ),
160
+ cross_attention_kwargs={"scale": 1.0}
161
+ ).sample
162
+
163
+ loss_lora = 0.5 * F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
164
+ return loss_lora
165
+
166
+
167
+ def train_step(
168
+ self,
169
+ text_embeddings,
170
+ text_embeddings_vd,
171
+ pred_rgb,
172
+ camera_condition,
173
+ im_features,
174
+ guidance_scale=7.5,
175
+ guidance_scale_lora=7.5,
176
+ loss_weight=1.0,
177
+ min_step_pct=0.02,
178
+ max_step_pct=0.98,
179
+ return_aux=False
180
+ ):
181
+ pred_rgb = pred_rgb.to(self.torch_dtype)
182
+ text_embeddings = text_embeddings.to(self.torch_dtype)
183
+ text_embeddings_vd = text_embeddings_vd.to(self.torch_dtype)
184
+ camera_condition = camera_condition.to(self.torch_dtype)
185
+ im_features = im_features.to(self.torch_dtype)
186
+
187
+ # condition_label = camera_condition
188
+ condition_label = im_features
189
+
190
+ b = pred_rgb.shape[0]
191
+
192
+ # interp to 512x512 to be fed into vae.
193
+ # _t = time.time()
194
+ pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
195
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
196
+
197
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
198
+ min_step = int(self.num_train_timesteps * min_step_pct)
199
+ max_step = int(self.num_train_timesteps * max_step_pct)
200
+ t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device)
201
+
202
+ # encode image into latents with vae, requires grad!
203
+ # _t = time.time()
204
+ latents = self.encode_imgs(pred_rgb_512)
205
+ # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
206
+
207
+ # predict the noise residual with unet, NO grad!
208
+ # _t = time.time()
209
+ with torch.no_grad():
210
+ # add noise
211
+ noise = torch.randn_like(latents)
212
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
213
+ # pred noise
214
+ latent_model_input = torch.cat([latents_noisy] * 2)
215
+
216
+ # disable unet class embedding here
217
+ cls_embedding = self.unet.class_embedding
218
+ self.unet.class_embedding = None
219
+
220
+ cross_attention_kwargs = None
221
+ noise_pred_pretrain = self.unet(
222
+ latent_model_input,
223
+ torch.cat([t, t]),
224
+ encoder_hidden_states=text_embeddings_vd,
225
+ class_labels=None,
226
+ cross_attention_kwargs=cross_attention_kwargs
227
+ ).sample
228
+
229
+ self.unet.class_embedding = cls_embedding
230
+
231
+ # use view-independent text embeddings in LoRA
232
+ _, text_embeddings_cond = text_embeddings.chunk(2)
233
+
234
+ noise_pred_est = self.unet_lora(
235
+ latent_model_input,
236
+ torch.cat([t, t]),
237
+ encoder_hidden_states=torch.cat([text_embeddings_cond] * 2),
238
+ class_labels=torch.cat(
239
+ [
240
+ condition_label.reshape(b, -1),
241
+ torch.zeros_like(condition_label.reshape(b, -1)),
242
+ ],
243
+ dim=0,
244
+ ),
245
+ cross_attention_kwargs={"scale": 1.0},
246
+ ).sample
247
+
248
+ noise_pred_pretrain_uncond, noise_pred_pretrain_text = noise_pred_pretrain.chunk(2)
249
+
250
+ noise_pred_pretrain = noise_pred_pretrain_uncond + guidance_scale * (
251
+ noise_pred_pretrain_text - noise_pred_pretrain_uncond
252
+ )
253
+
254
+ assert self.scheduler.config.prediction_type == "epsilon"
255
+ if self.scheduler_lora.config.prediction_type == "v_prediction":
256
+ alphas_cumprod = self.scheduler_lora.alphas_cumprod.to(
257
+ device=latents_noisy.device, dtype=latents_noisy.dtype
258
+ )
259
+ alpha_t = alphas_cumprod[t] ** 0.5
260
+ sigma_t = (1 - alphas_cumprod[t]) ** 0.5
261
+
262
+ noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).reshape(
263
+ -1, 1, 1, 1
264
+ ) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).reshape(-1, 1, 1, 1)
265
+
266
+ noise_pred_est_uncond, noise_pred_est_camera = noise_pred_est.chunk(2)
267
+
268
+ noise_pred_est = noise_pred_est_uncond + guidance_scale_lora * (
269
+ noise_pred_est_camera - noise_pred_est_uncond
270
+ )
271
+
272
+ # w(t), sigma_t^2
273
+ w = (1 - self.alphas[t])
274
+ # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
275
+ grad = loss_weight * w[:, None, None, None] * (noise_pred_pretrain - noise_pred_est)
276
+
277
+ grad = torch.nan_to_num(grad)
278
+
279
+ targets = (latents - grad).detach()
280
+ loss_vsd = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
281
+
282
+ loss_lora = self.train_lora(latents, text_embeddings, condition_label)
283
+
284
+ loss = {
285
+ 'loss_vsd': loss_vsd,
286
+ 'loss_lora': loss_lora
287
+ }
288
+
289
+ if return_aux:
290
+ aux = {'grad': grad, 't': t, 'w': w}
291
+ return loss, aux
292
+ else:
293
+ return loss
294
+
295
+
296
+
297
+ if __name__ == '__main__':
298
+ import argparse
299
+ import matplotlib.pyplot as plt
300
+
301
+ parser = argparse.ArgumentParser()
302
+ parser.add_argument('prompt', type=str)
303
+ parser.add_argument('--negative', default='', type=str)
304
+ parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
305
+ parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
306
+ parser.add_argument('-H', type=int, default=512)
307
+ parser.add_argument('-W', type=int, default=512)
308
+ parser.add_argument('--seed', type=int, default=0)
309
+ parser.add_argument('--steps', type=int, default=50)
310
+ opt = parser.parse_args()
311
+
312
+ seed_everything(opt.seed)
313
+
314
+ device = torch.device('cuda')
315
+
316
+ sd = StableDiffusion_VSD(device, opt.sd_version, opt.hf_key)
317
+
318
+ imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
319
+
320
+ # visualize image
321
+ plt.imshow(imgs[0])
322
+ plt.show()
323
+ plt.savefig(f'{opt.prompt}.png')
video3d/discriminator_architecture.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from math import log2
4
+ import torch.nn.functional as F
5
+ from torch import autograd
6
+
7
+
8
+ class DCDiscriminator(nn.Module):
9
+ ''' DC Discriminator class.
10
+
11
+ Args:
12
+ in_dim (int): input dimension
13
+ n_feat (int): features of final hidden layer
14
+ img_size (int): input image size
15
+ '''
16
+ def __init__(self, in_dim=1, out_dim=1, n_feat=512, img_size=256, last_bias=False):
17
+ super().__init__()
18
+
19
+ self.in_dim = in_dim
20
+ self.out_dim = out_dim
21
+ n_layers = int(log2(img_size) - 2)
22
+ self.blocks = nn.ModuleList(
23
+ [nn.Conv2d(
24
+ in_dim,
25
+ int(n_feat / (2 ** (n_layers - 1))),
26
+ 4, 2, 1, bias=False)] + [nn.Conv2d(
27
+ int(n_feat / (2 ** (n_layers - i))),
28
+ int(n_feat / (2 ** (n_layers - 1 - i))),
29
+ 4, 2, 1, bias=False) for i in range(1, n_layers)])
30
+
31
+ self.conv_out = nn.Conv2d(n_feat, out_dim, 4, 1, 0, bias=last_bias)
32
+ self.actvn = nn.LeakyReLU(0.2, inplace=True)
33
+
34
+ def forward(self, x):
35
+ batch_size = x.shape[0]
36
+ if x.shape[1] != self.in_dim:
37
+ import ipdb; ipdb.set_trace()
38
+ x = x[:, :self.in_dim]
39
+ for layer in self.blocks:
40
+ x = self.actvn(layer(x))
41
+
42
+ out = self.conv_out(x)
43
+ out = out.reshape(batch_size, self.out_dim)
44
+ return out
45
+
46
+
47
+ # class ADADiscriminator(DCDiscriminator):
48
+ # def __init__(self, aug, aug_p, **kwargs):
49
+ # super().__init__(**kwargs)
50
+ # self.aug = build_from_config(aug)
51
+ # self.aug.p.copy_(torch.tensor(aug_p, dtype=torch.float32))
52
+ # self.resolution = kwargs['img_size']
53
+
54
+ # def get_resolution(self):
55
+ # return self.resolution
56
+
57
+ # def forward(self, x, **kwargs):
58
+ # x = self.aug(x)
59
+ # return super().forward(x, **kwargs)
60
+
61
+
62
+ # class ADADiscriminatorView(ADADiscriminator):
63
+ # def __init__(self, out_dim_position, out_dim_latent, **kwargs):
64
+ # self.out_dim_position = out_dim_position
65
+ # self.out_dim_latent = out_dim_latent
66
+
67
+ # super().__init__(**kwargs)
68
+
69
+ def bce_loss_target(d_out, target):
70
+ targets = d_out.new_full(size=d_out.size(), fill_value=target)
71
+ loss = F.binary_cross_entropy_with_logits(d_out, targets)
72
+ return loss.mean()
73
+
74
+ def compute_grad2(d_out, x_in):
75
+ batch_size = x_in.size(0)
76
+ grad_dout = autograd.grad(
77
+ outputs=d_out.sum(), inputs=x_in,
78
+ create_graph=True, retain_graph=True, only_inputs=True
79
+ )[0]
80
+ grad_dout2 = grad_dout.pow(2)
81
+ assert(grad_dout2.size() == x_in.size())
82
+ reg = grad_dout2.reshape(batch_size, -1).sum(1)
83
+ return reg.mean()
video3d/flow/__init__.py ADDED
File without changes
video3d/flow/flow.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy.lib.npyio import load
2
+ from torch._C import device
3
+ import sys
4
+ sys.path.append('/scratch/shared/beegfs/szwu/projects/video3d/RAFT')
5
+ from core.raft import RAFT
6
+
7
+ from .utils import InputPadder
8
+ import torch
9
+
10
+
11
+ class AttrDict(dict):
12
+ def __init__(self, *args, **kwargs):
13
+ super(AttrDict, self).__init__(*args, **kwargs)
14
+ self.__dict__ = self
15
+
16
+
17
+
18
+ class FlowModel():
19
+ def __init__(self, model, device):
20
+ args = AttrDict({'model': model, 'small': False, 'mixed_precision': False, 'alternate_corr': False})
21
+ self.model = self.load_model(args, device)
22
+ self.device = device
23
+
24
+
25
+ @staticmethod
26
+ def load_model(args, device):
27
+ model = torch.nn.DataParallel(RAFT(args))
28
+ model.load_state_dict(torch.load(args.model))
29
+
30
+ model = model.module
31
+ model.to(device)
32
+ model.eval()
33
+ return model
34
+
35
+
36
+ def preprocess_image(self, image):
37
+ # image = image[:, :, ::-1].copy()
38
+ image = torch.from_numpy(image).permute(2, 0, 1).float()
39
+ image = image.to(self.device)
40
+ image = image[None]
41
+ # size = [540, 960]
42
+ # image = torch.nn.functional.interpolate(image, size=size, mode='bilinear', align_corners=False)
43
+ padder = InputPadder(image.shape)
44
+ return padder.pad(image)[0], padder
45
+
46
+
47
+ def compute_flow(self, frame, next_frame, iters=20):
48
+ frame, padder = self.preprocess_image(frame)
49
+ next_frame, padder = self.preprocess_image(next_frame)
50
+ _, flow = self.model(frame, next_frame, iters=iters, test_mode=True)
51
+ return padder.unpad(flow)[0].permute(1, 2, 0).cpu().numpy()
video3d/flow/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from RAFT
2
+
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class InputPadder:
7
+ """ Pads images such that dimensions are divisible by 8 """
8
+ def __init__(self, dims, mode='sintel'):
9
+ self.ht, self.wd = dims[-2:]
10
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
11
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
12
+ if mode == 'sintel':
13
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
14
+ else:
15
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
16
+
17
+ def pad(self, *inputs):
18
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
19
+
20
+ def unpad(self,x):
21
+ ht, wd = x.shape[-2:]
22
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
23
+ return x[..., c[0]:c[1], c[2]:c[3]]
video3d/geometry/dlmesh.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import torch
11
+
12
+ from ..render import mesh
13
+ from ..render import render
14
+ from ..render import regularizer
15
+
16
+ ###############################################################################
17
+ # Geometry interface
18
+ ###############################################################################
19
+
20
+ class DLMesh(torch.nn.Module):
21
+ def __init__(self, initial_guess, FLAGS):
22
+ super(DLMesh, self).__init__()
23
+
24
+ self.FLAGS = FLAGS
25
+
26
+ self.initial_guess = initial_guess
27
+ self.mesh = initial_guess.clone()
28
+ print("Base mesh has %d triangles and %d vertices." % (self.mesh.t_pos_idx.shape[0], self.mesh.v_pos.shape[0]))
29
+
30
+ self.mesh.v_pos = torch.nn.Parameter(self.mesh.v_pos, requires_grad=True)
31
+ self.register_parameter('vertex_pos', self.mesh.v_pos)
32
+
33
+ @torch.no_grad()
34
+ def getAABB(self):
35
+ return mesh.aabb(self.mesh)
36
+
37
+ def getMesh(self, material):
38
+ self.mesh.material = material
39
+
40
+ imesh = mesh.Mesh(base=self.mesh)
41
+ # Compute normals and tangent space
42
+ imesh = mesh.auto_normals(imesh)
43
+ imesh = mesh.compute_tangents(imesh)
44
+ return imesh
45
+
46
+ def render(self, glctx, target, lgt, opt_material, bsdf=None):
47
+ opt_mesh = self.getMesh(opt_material)
48
+ return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
49
+ num_layers=self.FLAGS.layers, msaa=True, background=target['background'], bsdf=bsdf)
50
+
51
+ def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration):
52
+
53
+ # ==============================================================================================
54
+ # Render optimizable object with identical conditions
55
+ # ==============================================================================================
56
+ buffers = self.render(glctx, target, lgt, opt_material)
57
+
58
+ # ==============================================================================================
59
+ # Compute loss
60
+ # ==============================================================================================
61
+ t_iter = iteration / self.FLAGS.iter
62
+
63
+ # Image-space loss, split into a coverage component and a color component
64
+ color_ref = target['img']
65
+ img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
66
+ img_loss += loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
67
+
68
+ reg_loss = torch.tensor([0], dtype=torch.float32, device="cuda")
69
+
70
+ # Compute regularizer.
71
+ if self.FLAGS.laplace == "absolute":
72
+ reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter)
73
+ elif self.FLAGS.laplace == "relative":
74
+ reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos - self.initial_guess.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter)
75
+
76
+ # Albedo (k_d) smoothnesss regularizer
77
+ reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
78
+
79
+ # Visibility regularizer
80
+ reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500)
81
+
82
+ # Light white balance regularizer
83
+ reg_loss = reg_loss + lgt.regularizer() * 0.005
84
+
85
+ return img_loss, reg_loss
video3d/geometry/dmtet.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ from multiprocessing.spawn import get_preparation_data
11
+ import numpy as np
12
+ import torch
13
+
14
+ from ..render import mesh
15
+ from ..render import render
16
+ from ..networks import MLPWithPositionalEncoding, MLPWithPositionalEncoding_Style
17
+
18
+ ###############################################################################
19
+ # Marching tetrahedrons implementation (differentiable), adapted from
20
+ # https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
21
+ #
22
+ # Note this only supports batch size = 1.
23
+ ###############################################################################
24
+
25
+ class DMTet:
26
+ def __init__(self):
27
+ self.triangle_table = torch.tensor([
28
+ [-1, -1, -1, -1, -1, -1],
29
+ [ 1, 0, 2, -1, -1, -1],
30
+ [ 4, 0, 3, -1, -1, -1],
31
+ [ 1, 4, 2, 1, 3, 4],
32
+ [ 3, 1, 5, -1, -1, -1],
33
+ [ 2, 3, 0, 2, 5, 3],
34
+ [ 1, 4, 0, 1, 5, 4],
35
+ [ 4, 2, 5, -1, -1, -1],
36
+ [ 4, 5, 2, -1, -1, -1],
37
+ [ 4, 1, 0, 4, 5, 1],
38
+ [ 3, 2, 0, 3, 5, 2],
39
+ [ 1, 3, 5, -1, -1, -1],
40
+ [ 4, 1, 2, 4, 3, 1],
41
+ [ 3, 0, 4, -1, -1, -1],
42
+ [ 2, 0, 1, -1, -1, -1],
43
+ [-1, -1, -1, -1, -1, -1]
44
+ ], dtype=torch.long, device='cuda')
45
+
46
+ self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
47
+ self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')
48
+
49
+ ###############################################################################
50
+ # Utility functions
51
+ ###############################################################################
52
+
53
+ def sort_edges(self, edges_ex2):
54
+ with torch.no_grad():
55
+ order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
56
+ order = order.unsqueeze(dim=1)
57
+
58
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
59
+ b = torch.gather(input=edges_ex2, index=1-order, dim=1)
60
+
61
+ return torch.stack([a, b],-1)
62
+
63
+ def map_uv(self, faces, face_gidx, max_idx):
64
+ N = int(np.ceil(np.sqrt((max_idx+1)//2)))
65
+ tex_y, tex_x = torch.meshgrid(
66
+ torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
67
+ torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
68
+ indexing='ij'
69
+ )
70
+
71
+ pad = 0.9 / N
72
+
73
+ uvs = torch.stack([
74
+ tex_x , tex_y,
75
+ tex_x + pad, tex_y,
76
+ tex_x + pad, tex_y + pad,
77
+ tex_x , tex_y + pad
78
+ ], dim=-1).view(-1, 2)
79
+
80
+ def _idx(tet_idx, N):
81
+ x = tet_idx % N
82
+ y = torch.div(tet_idx, N, rounding_mode='trunc')
83
+ return y * N + x
84
+
85
+ tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
86
+ tri_idx = face_gidx % 2
87
+
88
+ uv_idx = torch.stack((
89
+ tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
90
+ ), dim = -1). view(-1, 3)
91
+
92
+ return uvs, uv_idx
93
+
94
+ ###############################################################################
95
+ # Marching tets implementation
96
+ ###############################################################################
97
+
98
+ def __call__(self, pos_nx3, sdf_n, tet_fx4):
99
+ with torch.no_grad():
100
+ occ_n = sdf_n > 0
101
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
102
+ occ_sum = torch.sum(occ_fx4, -1)
103
+ valid_tets = (occ_sum>0) & (occ_sum<4)
104
+ occ_sum = occ_sum[valid_tets]
105
+
106
+ # find all vertices
107
+ all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
108
+ all_edges = self.sort_edges(all_edges)
109
+ unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
110
+
111
+ unique_edges = unique_edges.long()
112
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
113
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
114
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
115
+ idx_map = mapping[idx_map] # map edges to verts
116
+
117
+ interp_v = unique_edges[mask_edges]
118
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
119
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
120
+ edges_to_interp_sdf[:,-1] *= -1
121
+
122
+ denominator = edges_to_interp_sdf.sum(1,keepdim = True)
123
+
124
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
125
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
126
+
127
+ idx_map = idx_map.reshape(-1,6)
128
+
129
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
130
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
131
+ num_triangles = self.num_triangles_table[tetindex]
132
+
133
+ # Generate triangle indices
134
+ faces = torch.cat((
135
+ torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),
136
+ torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),
137
+ ), dim=0)
138
+
139
+ # Get global face index (static, does not depend on topology)
140
+ num_tets = tet_fx4.shape[0]
141
+ tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
142
+ face_gidx = torch.cat((
143
+ tet_gidx[num_triangles == 1]*2,
144
+ torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
145
+ ), dim=0)
146
+
147
+ uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
148
+
149
+ return verts, faces, uvs, uv_idx
150
+
151
+ ###############################################################################
152
+ # Regularizer
153
+ ###############################################################################
154
+
155
+ def sdf_bce_reg_loss(sdf, all_edges):
156
+ sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
157
+ mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
158
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
159
+ sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
160
+ torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
161
+ if torch.isnan(sdf_diff).any():
162
+ import ipdb; ipdb.set_trace()
163
+ return sdf_diff
164
+
165
+ ###############################################################################
166
+ # Geometry interface
167
+ ###############################################################################
168
+
169
+ class DMTetGeometry(torch.nn.Module):
170
+ def __init__(self, grid_res, scale, sdf_mode, num_layers=None, hidden_size=None, embedder_freq=None, embed_concat_pts=True, init_sdf=None, jitter_grid=0., perturb_sdf_iter=10000, sym_prior_shape=False, dim_of_classes=0, condition_choice='concat'):
171
+ super(DMTetGeometry, self).__init__()
172
+
173
+ self.sdf_mode = sdf_mode
174
+ self.grid_res = grid_res
175
+ self.marching_tets = DMTet()
176
+ self.grid_scale = scale
177
+ self.init_sdf = init_sdf
178
+ self.jitter_grid = jitter_grid
179
+ self.perturb_sdf_iter = perturb_sdf_iter
180
+ self.sym_prior_shape = sym_prior_shape
181
+ self.load_tets(self.grid_res, self.grid_scale)
182
+
183
+ if sdf_mode == "param":
184
+ sdf = torch.rand_like(self.verts[:,0]) - 0.1 # Random init.
185
+ self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
186
+ self.register_parameter('sdf', self.sdf)
187
+ self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)
188
+ self.register_parameter('deform', self.deform)
189
+ else:
190
+ embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9
191
+
192
+ if dim_of_classes == 0 or (dim_of_classes != 0 and condition_choice == 'concat'):
193
+ self.mlp = MLPWithPositionalEncoding(
194
+ 3,
195
+ 1,
196
+ num_layers,
197
+ nf=hidden_size,
198
+ extra_dim=dim_of_classes,
199
+ dropout=0,
200
+ activation=None,
201
+ n_harmonic_functions=embedder_freq,
202
+ omega0=embedder_scaler,
203
+ embed_concat_pts=embed_concat_pts)
204
+
205
+ elif condition_choice == 'film' or condition_choice == 'mod':
206
+ self.mlp = MLPWithPositionalEncoding_Style(
207
+ 3,
208
+ 1,
209
+ num_layers,
210
+ nf=hidden_size,
211
+ extra_dim=dim_of_classes,
212
+ dropout=0,
213
+ activation=None,
214
+ n_harmonic_functions=embedder_freq,
215
+ omega0=embedder_scaler,
216
+ embed_concat_pts=embed_concat_pts,
217
+ style_choice=condition_choice)
218
+
219
+ else:
220
+ raise NotImplementedError
221
+
222
+ def load_tets(self, grid_res=None, scale=None):
223
+ if grid_res is None:
224
+ grid_res = self.grid_res
225
+ else:
226
+ self.grid_res = grid_res
227
+ if scale is None:
228
+ scale = self.grid_scale
229
+ else:
230
+ self.grid_scale = scale
231
+ tets = np.load('./data/tets/{}_tets.npz'.format(grid_res))
232
+ self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale # verts original scale (-0.5, 0.5)
233
+ self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
234
+ self.generate_edges()
235
+
236
+ def get_sdf(self, pts=None, perturb_sdf=False, total_iter=0, class_vector=None):
237
+ if self.sdf_mode == 'param':
238
+ sdf = self.sdf
239
+ else:
240
+ if pts is None:
241
+ pts = self.verts
242
+ if self.sym_prior_shape:
243
+ xs, ys, zs = pts.unbind(-1)
244
+ pts = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
245
+ feat = None
246
+ if class_vector is not None:
247
+ feat = class_vector.unsqueeze(0).repeat(pts.shape[0], 1)
248
+ sdf = self.mlp(pts, feat=feat)
249
+
250
+ if self.init_sdf is None:
251
+ pass
252
+ elif type(self.init_sdf) in [float, int]:
253
+ sdf = sdf + self.init_sdf
254
+ elif self.init_sdf == 'sphere':
255
+ init_radius = self.grid_scale * 0.25
256
+ init_sdf = init_radius - pts.norm(dim=-1, keepdim=True) # init sdf is a sphere centered at origin
257
+ sdf = sdf + init_sdf
258
+ elif self.init_sdf == 'ellipsoid':
259
+ rxy = self.grid_scale * 0.15
260
+ xs, ys, zs = pts.unbind(-1)[:3]
261
+ init_sdf = rxy - torch.stack([xs, ys, zs/2], -1).norm(dim=-1, keepdim=True) # init sdf is approximately an ellipsoid centered at origin
262
+ sdf = sdf + init_sdf
263
+ else:
264
+ raise NotImplementedError
265
+
266
+ if perturb_sdf:
267
+ sdf = sdf + torch.randn_like(sdf) * 0.1 * max(0, 1-total_iter/self.perturb_sdf_iter)
268
+ return sdf
269
+
270
+ def get_sdf_gradient(self, class_vector=None):
271
+ assert self.sdf_mode == 'mlp', "Only MLP supports gradient computation."
272
+ num_samples = 5000
273
+ sample_points = (torch.rand(num_samples, 3, device=self.verts.device) - 0.5) * self.grid_scale
274
+ mesh_verts = self.mesh_verts.detach() + (torch.rand_like(self.mesh_verts) -0.5) * 0.1 * self.grid_scale
275
+ rand_idx = torch.randperm(len(mesh_verts), device=mesh_verts.device)[:5000]
276
+ mesh_verts = mesh_verts[rand_idx]
277
+ sample_points = torch.cat([sample_points, mesh_verts], 0)
278
+ sample_points.requires_grad = True
279
+ y = self.get_sdf(pts=sample_points, perturb_sdf=False, class_vector=class_vector)
280
+ d_output = torch.ones_like(y, requires_grad=False, device=y.device)
281
+ try:
282
+ gradients = torch.autograd.grad(
283
+ outputs=[y],
284
+ inputs=sample_points,
285
+ grad_outputs=d_output,
286
+ create_graph=True,
287
+ retain_graph=True,
288
+ only_inputs=True)[0]
289
+ except RuntimeError: # For validation, we have disabled gradient calculation.
290
+ return torch.zeros_like(sample_points)
291
+ return gradients
292
+
293
+ def get_sdf_reg_loss(self, class_vector=None):
294
+ reg_loss = {"sdf_bce_reg_loss": sdf_bce_reg_loss(self.current_sdf, self.all_edges).mean()}
295
+ if self.sdf_mode == 'mlp':
296
+ reg_loss["sdf_gradient_reg_loss"] = ((self.get_sdf_gradient(class_vector=class_vector).norm(dim=-1) - 1) ** 2).mean()
297
+ reg_loss['sdf_inflate_reg_loss'] = -self.current_sdf.mean()
298
+ return reg_loss
299
+
300
+ def generate_edges(self):
301
+ with torch.no_grad():
302
+ edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
303
+ all_edges = self.indices[:,edges].reshape(-1,2)
304
+ all_edges_sorted = torch.sort(all_edges, dim=1)[0]
305
+ self.all_edges = torch.unique(all_edges_sorted, dim=0)
306
+
307
+ @torch.no_grad()
308
+ def getAABB(self):
309
+ return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
310
+
311
+ def getMesh(self, material=None, perturb_sdf=False, total_iter=0, jitter_grid=True, class_vector=None):
312
+ # Run DM tet to get a base mesh
313
+ v_deformed = self.verts
314
+
315
+ # if self.FLAGS.deform_grid:
316
+ # v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)
317
+ # else:
318
+ # v_deformed = self.verts
319
+ if jitter_grid and self.jitter_grid > 0:
320
+ jitter = (torch.rand(1, device=v_deformed.device)*2-1) * self.jitter_grid * self.grid_scale
321
+ v_deformed = v_deformed + jitter
322
+
323
+ self.current_sdf = self.get_sdf(v_deformed, perturb_sdf=perturb_sdf, total_iter=total_iter, class_vector=class_vector)
324
+ verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.current_sdf, self.indices)
325
+ self.mesh_verts = verts
326
+ return mesh.make_mesh(verts[None], faces[None], uvs[None], uv_idx[None], material)
327
+
328
+ def render(self, glctx, target, lgt, opt_material, bsdf=None):
329
+ opt_mesh = self.getMesh(opt_material)
330
+ return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf)
331
+
332
+ def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration):
333
+ # ==============================================================================================
334
+ # Render optimizable object with identical conditions
335
+ # ==============================================================================================
336
+ buffers = self.render(glctx, target, lgt, opt_material)
337
+
338
+ # ==============================================================================================
339
+ # Compute loss
340
+ # ==============================================================================================
341
+ t_iter = iteration / 20000
342
+
343
+ # Image-space loss, split into a coverage component and a color component
344
+ color_ref = target['img']
345
+ img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
346
+ img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
347
+
348
+ # SDF regularizer
349
+ # sdf_weight = self.sdf_regularizer - (self.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter) # Dropoff to 0.01
350
+ reg_loss = sum(self.get_sdf_reg_loss().values)
351
+
352
+ # Albedo (k_d) smoothnesss regularizer
353
+ reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
354
+
355
+ # Visibility regularizer
356
+ reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500)
357
+
358
+ # Light white balance regularizer
359
+ reg_loss = reg_loss + lgt.regularizer() * 0.005
360
+
361
+ return img_loss, reg_loss
video3d/model.py ADDED
@@ -0,0 +1,1526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.spawn import prepare
2
+ from turtle import forward
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.models as models
7
+ import nvdiffrast.torch as dr
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import os
11
+ import os.path as osp
12
+
13
+ from video3d.render.regularizer import get_edge_length, normal_consistency
14
+ from . import networks
15
+ from .renderer import *
16
+ from .utils import misc, meters, flow_viz, arap, custom_loss
17
+ from .dataloaders import get_sequence_loader, get_image_loader
18
+ from .cub_dataloaders import get_cub_loader
19
+ from .utils.skinning_v4 import estimate_bones, skinning
20
+ import lpips
21
+ from einops import rearrange
22
+
23
+ from .geometry.dmtet import DMTetGeometry
24
+ from .geometry.dlmesh import DLMesh
25
+
26
+ from .render import renderutils as ru
27
+ from .render import material
28
+ from .render import mlptexture
29
+ from .render import util
30
+ from .render import mesh
31
+ from .render import light
32
+ from .render import render
33
+
34
+ EPS = 1e-7
35
+
36
+
37
+ def get_optimizer(model, lr=0.0001, betas=(0.9, 0.999), weight_decay=0):
38
+ return torch.optim.Adam(
39
+ filter(lambda p: p.requires_grad, model.parameters()),
40
+ lr=lr, betas=betas, weight_decay=weight_decay)
41
+
42
+
43
+ def set_requires_grad(model, requires_grad):
44
+ if model is not None:
45
+ for param in model.parameters():
46
+ param.requires_grad = requires_grad
47
+
48
+
49
+ def forward_to_matrix(vec_forward, up=[0,1,0]):
50
+ up = torch.FloatTensor(up).to(vec_forward.device)
51
+ # vec_forward = nn.functional.normalize(vec_forward, p=2, dim=-1) # x right, y up, z forward
52
+ vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1)
53
+ vec_right = nn.functional.normalize(vec_right, p=2, dim=-1)
54
+ vec_up = vec_forward.cross(vec_right, dim=-1)
55
+ vec_up = nn.functional.normalize(vec_up, p=2, dim=-1)
56
+ rot_mat = torch.stack([vec_right, vec_up, vec_forward], -2)
57
+ return rot_mat
58
+
59
+
60
+ def sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, pose_xflip_recon=False, input_image_xflip_flag=None, rot_temp_scalar=1., num_hypos=4, naive_probs_iter=2000, best_pose_start_iter=6000, random_sample=True):
61
+ rots_pred = poses_raw[..., :num_hypos*4].view(-1, num_hypos, 4)
62
+ rots_logits = rots_pred[..., 0] # Nx4
63
+ temp = 1 / np.clip(total_iter / 1000 / rot_temp_scalar, 1., 100.)
64
+
65
+ rots_probs = torch.nn.functional.softmax(-rots_logits / temp, dim=1) # N x K
66
+ # naive_probs = torch.FloatTensor([10] + [1] * (num_hypos - 1)).to(rots_logits.device)
67
+ naive_probs = torch.ones(num_hypos).to(rots_logits.device)
68
+ naive_probs = naive_probs / naive_probs.sum()
69
+ naive_probs_weight = np.clip(1 - (total_iter - naive_probs_iter) / 2000, 0, 1)
70
+ rots_probs = naive_probs.view(1, num_hypos) * naive_probs_weight + rots_probs * (1 - naive_probs_weight)
71
+
72
+ rots_pred = rots_pred[..., 1:4]
73
+ trans_pred = poses_raw[..., -3:]
74
+ best_rot_idx = torch.argmax(rots_probs, dim=1) # N
75
+ if random_sample:
76
+ # rand_rot_idx = torch.randint(0, 4, (batch_size * num_frames,), device=poses_raw.device) # N
77
+ rand_rot_idx = torch.randperm(batch_size * num_frames, device=poses_raw.device) % num_hypos # N
78
+ # rand_rot_idx = torch.randperm(batch_size, device=poses_raw.device)[:,None].repeat(1, num_frames).view(-1) % 4 # N
79
+ best_flag = (torch.randperm(batch_size * num_frames, device=poses_raw.device) / (batch_size * num_frames) < np.clip((total_iter - best_pose_start_iter)/2000, 0, 0.8)).long()
80
+ rand_flag = 1 - best_flag
81
+ # best_flag = torch.zeros_like(best_rot_idx)
82
+ rot_idx = best_rot_idx * best_flag + rand_rot_idx * (1 - best_flag)
83
+ else:
84
+ rand_flag = torch.zeros_like(best_rot_idx)
85
+ rot_idx = best_rot_idx
86
+ rot_pred = torch.gather(rots_pred, 1, rot_idx[:, None, None].expand(-1, 1, 3))[:, 0] # Nx3
87
+ pose_raw = torch.cat([rot_pred, trans_pred], -1)
88
+ rot_prob = torch.gather(rots_probs, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N
89
+ rot_logit = torch.gather(rots_logits, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N
90
+
91
+ if pose_xflip_recon:
92
+ raise NotImplementedError
93
+ rot_mat = forward_to_matrix(pose_raw[:, :3], up=[0, 1, 0])
94
+ pose = torch.cat([rot_mat.view(batch_size * num_frames, -1), pose_raw[:, 3:]], -1)
95
+ return pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_flag
96
+
97
+
98
+ class PriorPredictor(nn.Module):
99
+ def __init__(self, cfgs):
100
+ super().__init__()
101
+ dmtet_grid = cfgs.get('dmtet_grid', 64)
102
+ grid_scale = cfgs.get('grid_scale', 5)
103
+ prior_sdf_mode = cfgs.get('prior_sdf_mode', 'mlp')
104
+ num_layers_shape = cfgs.get('num_layers_shape', 5)
105
+ hidden_size = cfgs.get('hidden_size', 64)
106
+ embedder_freq_shape = cfgs.get('embedder_freq_shape', 8)
107
+ embed_concat_pts = cfgs.get('embed_concat_pts', True)
108
+ init_sdf = cfgs.get('init_sdf', None)
109
+ jitter_grid = cfgs.get('jitter_grid', 0.)
110
+ perturb_sdf_iter = cfgs.get('perturb_sdf_iter', 10000)
111
+ sym_prior_shape = cfgs.get('sym_prior_shape', False)
112
+ self.netShape = DMTetGeometry(dmtet_grid, grid_scale, prior_sdf_mode, num_layers=num_layers_shape, hidden_size=hidden_size, embedder_freq=embedder_freq_shape, embed_concat_pts=embed_concat_pts, init_sdf=init_sdf, jitter_grid=jitter_grid, perturb_sdf_iter=perturb_sdf_iter, sym_prior_shape=sym_prior_shape)
113
+
114
+ mlp_hidden_size = cfgs.get('hidden_size', 64)
115
+ tet_bbox = self.netShape.getAABB()
116
+ self.render_dino_mode = cfgs.get('render_dino_mode', None)
117
+ num_layers_dino = cfgs.get("num_layers_dino", 5)
118
+ dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64)
119
+ sym_dino = cfgs.get("sym_dino", False)
120
+ dino_min = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_min', 0.)
121
+ dino_max = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_max', 1.)
122
+ min_max = torch.stack((dino_min, dino_max), dim=0)
123
+ if self.render_dino_mode is None:
124
+ pass
125
+ elif self.render_dino_mode == 'feature_mlpnv':
126
+ self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_feature_recon_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=min_max, bsdf=None, perturb_normal=False, symmetrize=sym_dino)
127
+ elif self.render_dino_mode == 'feature_mlp':
128
+ embedder_scaler = 2 * np.pi / grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9
129
+ embed_concat_pts = cfgs.get('embed_concat_pts', True)
130
+ self.netDINO = networks.MLPTextureSimple(
131
+ 3, # x, y, z coordinates
132
+ dino_feature_recon_dim,
133
+ num_layers_dino,
134
+ nf=mlp_hidden_size,
135
+ dropout=0,
136
+ activation="sigmoid",
137
+ min_max=min_max,
138
+ n_harmonic_functions=cfgs.get('embedder_freq_dino', 8),
139
+ omega0=embedder_scaler,
140
+ extra_dim=0,
141
+ embed_concat_pts=embed_concat_pts,
142
+ perturb_normal=False,
143
+ symmetrize=sym_dino
144
+ )
145
+ elif self.render_dino_mode == 'cluster':
146
+ num_layers_dino = cfgs.get("num_layers_dino", 5)
147
+ dino_cluster_dim = cfgs.get('dino_cluster_dim', 64)
148
+ self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_cluster_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=None, bsdf=None, perturb_normal=False, symmetrize=sym_dino)
149
+ else:
150
+ raise NotImplementedError
151
+
152
+ def forward(self, perturb_sdf=False, total_iter=None, is_training=True):
153
+ prior_shape = self.netShape.getMesh(perturb_sdf=perturb_sdf, total_iter=total_iter, jitter_grid=is_training)
154
+ return prior_shape, self.netDINO
155
+
156
+
157
+ class InstancePredictor(nn.Module):
158
+ def __init__(self, cfgs, tet_bbox=None):
159
+ super().__init__()
160
+ self.cfgs = cfgs
161
+ self.grid_scale = cfgs.get('grid_scale', 5)
162
+
163
+ self.enable_encoder = cfgs.get('enable_encoder', False)
164
+ if self.enable_encoder:
165
+ encoder_latent_dim = cfgs.get('latent_dim', 256)
166
+ encoder_pretrained = cfgs.get('encoder_pretrained', False)
167
+ encoder_frozen = cfgs.get('encoder_frozen', False)
168
+ encoder_arch = cfgs.get('encoder_arch', 'simple')
169
+ in_image_size = cfgs.get('in_image_size', 256)
170
+ self.dino_feature_input = cfgs.get('dino_feature_input', False)
171
+ dino_feature_dim = cfgs.get('dino_feature_dim', 64)
172
+ if encoder_arch == 'simple':
173
+ if self.dino_feature_input:
174
+ self.netEncoder = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None)
175
+ else:
176
+ self.netEncoder = networks.Encoder(cin=3, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None)
177
+ elif encoder_arch == 'vgg':
178
+ self.netEncoder = networks.VGGEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained)
179
+ elif encoder_arch == 'resnet':
180
+ self.netEncoder = networks.ResnetEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained)
181
+ elif encoder_arch == 'vit':
182
+ which_vit = cfgs.get('which_vit', 'dino_vits8')
183
+ vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv')
184
+ self.netEncoder = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type)
185
+ else:
186
+ raise NotImplementedError
187
+ else:
188
+ encoder_latent_dim = 0
189
+
190
+ mlp_hidden_size = cfgs.get('hidden_size', 64)
191
+
192
+ bsdf = cfgs.get("bsdf", 'diffuse')
193
+ num_layers_tex = cfgs.get("num_layers_tex", 5)
194
+ feat_dim = cfgs.get("latent_dim", 64) if self.enable_encoder else 0
195
+ perturb_normal = cfgs.get("perturb_normal", False)
196
+ sym_texture = cfgs.get("sym_texture", False)
197
+ kd_min = torch.FloatTensor(cfgs.get('kd_min', [0., 0., 0., 0.]))
198
+ kd_max = torch.FloatTensor(cfgs.get('kd_max', [1., 1., 1., 1.]))
199
+ ks_min = torch.FloatTensor(cfgs.get('ks_min', [0., 0., 0.]))
200
+ ks_max = torch.FloatTensor(cfgs.get('ks_max', [0., 0., 0.]))
201
+ nrm_min = torch.FloatTensor(cfgs.get('nrm_min', [-1., -1., 0.]))
202
+ nrm_max = torch.FloatTensor(cfgs.get('nrm_max', [1., 1., 1.]))
203
+ mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
204
+ mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
205
+ min_max = torch.stack((mlp_min, mlp_max), dim=0)
206
+ out_chn = 9
207
+ # TODO: if the tet verts are deforming, we need to recompute tet_bbox
208
+ texture_mode = cfgs.get("texture_mode", 'mlp')
209
+ if texture_mode == 'mlpnv':
210
+ self.netTexture = mlptexture.MLPTexture3D(tet_bbox, channels=out_chn, internal_dims=mlp_hidden_size, hidden=num_layers_tex-1, feat_dim=feat_dim, min_max=min_max, bsdf=bsdf, perturb_normal=perturb_normal, symmetrize=sym_texture)
211
+ elif texture_mode == 'mlp':
212
+ embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9
213
+ embed_concat_pts = cfgs.get('embed_concat_pts', True)
214
+ self.netTexture = networks.MLPTextureSimple(
215
+ 3, # x, y, z coordinates
216
+ out_chn,
217
+ num_layers_tex,
218
+ nf=mlp_hidden_size,
219
+ dropout=0,
220
+ activation="sigmoid",
221
+ min_max=min_max,
222
+ n_harmonic_functions=cfgs.get('embedder_freq_tex', 10),
223
+ omega0=embedder_scaler,
224
+ extra_dim=feat_dim,
225
+ embed_concat_pts=embed_concat_pts,
226
+ perturb_normal=perturb_normal,
227
+ symmetrize=sym_texture
228
+ )
229
+
230
+ self.rot_rep = cfgs.get('rot_rep', 'euler_angle')
231
+ self.enable_pose = cfgs.get('enable_pose', False)
232
+ if self.enable_pose:
233
+ cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.)
234
+ fov = cfgs.get('crop_fov_approx', 25)
235
+ half_range = np.tan(fov /2 /180 * np.pi) * cam_pos_z_offset # 2.22
236
+ self.max_trans_xy_range = half_range * cfgs.get('max_trans_xy_range_ratio', 1.)
237
+ self.max_trans_z_range = half_range * cfgs.get('max_trans_z_range_ratio', 1.)
238
+ self.lookat_init = cfgs.get('lookat_init', None)
239
+ self.lookat_zeroy = cfgs.get('lookat_zeroy', False)
240
+ self.rot_temp_scalar = cfgs.get('rot_temp_scalar', 1.)
241
+ self.naive_probs_iter = cfgs.get('naive_probs_iter', 2000)
242
+ self.best_pose_start_iter = cfgs.get('best_pose_start_iter', 6000)
243
+
244
+ if self.rot_rep == 'euler_angle':
245
+ pose_cout = 6
246
+ elif self.rot_rep == 'quaternion':
247
+ pose_cout = 7
248
+ elif self.rot_rep == 'lookat':
249
+ pose_cout = 6
250
+ elif self.rot_rep == 'quadlookat':
251
+ self.num_pose_hypos = 4
252
+ pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 4 quadrants, 4 quadrant classification logits, 3 for translation
253
+ self.orthant_signs = torch.FloatTensor([[1,1,1], [-1,1,1], [-1,1,-1], [1,1,-1]])
254
+ elif self.rot_rep == 'octlookat':
255
+ self.num_pose_hypos = 8
256
+ pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 8 octants, 8 octant classification logits, 3 for translation
257
+ self.orthant_signs = torch.stack(torch.meshgrid([torch.arange(1, -2, -2)] *3), -1).view(-1, 3) # 8x3
258
+ else:
259
+ raise NotImplementedError
260
+
261
+ self.pose_arch = cfgs.get('pose_arch', 'mlp')
262
+ if self.pose_arch == 'mlp':
263
+ num_layers_pose = cfgs.get('num_layers_pose', 5)
264
+ self.netPose = networks.MLP(
265
+ encoder_latent_dim,
266
+ pose_cout,
267
+ num_layers_pose,
268
+ nf=mlp_hidden_size,
269
+ dropout=0,
270
+ activation=None
271
+ )
272
+ elif self.pose_arch == 'encoder':
273
+ if self.dino_feature_input:
274
+ dino_feature_dim = cfgs.get('dino_feature_dim', 64)
275
+ self.netPose = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None)
276
+ else:
277
+ self.netPose = networks.Encoder(cin=3, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None)
278
+ elif self.pose_arch in ['encoder_dino_patch_out', 'encoder_dino_patch_key']:
279
+ if which_vit == 'dino_vits8':
280
+ dino_feat_dim = 384
281
+ elif which_vit == 'dinov2_vits14':
282
+ dino_feat_dim = 384
283
+ elif which_vit == 'dino_vitb8':
284
+ dino_feat_dim = 768
285
+ self.netPose = networks.Encoder32(cin=dino_feat_dim, cout=pose_cout, nf=256, activation=None)
286
+ elif self.pose_arch == 'vit':
287
+ encoder_pretrained = cfgs.get('encoder_pretrained', False)
288
+ encoder_frozen = cfgs.get('encoder_frozen', False)
289
+ which_vit = cfgs.get('which_vit', 'dino_vits8')
290
+ vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv')
291
+ self.netPose = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type)
292
+ else:
293
+ raise NotImplementedError
294
+
295
+ self.enable_deform = cfgs.get('enable_deform', False)
296
+ if self.enable_deform:
297
+ embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9
298
+ embed_concat_pts = cfgs.get('embed_concat_pts', True)
299
+ num_layers_deform = cfgs.get('num_layers_deform', 5)
300
+ self.deform_epochs = np.arange(*cfgs.get('deform_epochs', [0, 0]))
301
+ sym_deform = cfgs.get("sym_deform", False)
302
+ self.netDeform = networks.MLPWithPositionalEncoding(
303
+ 3, # x, y, z coordinates
304
+ 3, # dx, dy, dz deformation
305
+ num_layers_deform,
306
+ nf=mlp_hidden_size,
307
+ dropout=0,
308
+ activation=None,
309
+ n_harmonic_functions=cfgs.get('embedder_freq_deform', 10),
310
+ omega0=embedder_scaler,
311
+ extra_dim=encoder_latent_dim,
312
+ embed_concat_pts=embed_concat_pts,
313
+ symmetrize=sym_deform
314
+ )
315
+
316
+ self.enable_articulation = cfgs.get('enable_articulation', False)
317
+ if self.enable_articulation:
318
+ self.num_body_bones = cfgs.get('num_body_bones', 4)
319
+ self.articulation_multiplier = cfgs.get('articulation_multiplier', 1)
320
+ self.static_root_bones = cfgs.get('static_root_bones', False)
321
+ self.skinning_temperature = cfgs.get('skinning_temperature', 1)
322
+ self.articulation_epochs = np.arange(*cfgs.get('articulation_epochs', [0, 0]))
323
+ self.num_legs = cfgs.get('num_legs', 0)
324
+ self.num_leg_bones = cfgs.get('num_leg_bones', 0)
325
+ self.body_bones_type = cfgs.get('body_bones_type', 'z_minmax')
326
+ self.perturb_articulation_epochs = np.arange(*cfgs.get('perturb_articulation_epochs', [0, 0]))
327
+ self.num_bones = self.num_body_bones + self.num_legs * self.num_leg_bones
328
+ self.constrain_legs = cfgs.get('constrain_legs', False)
329
+ self.attach_legs_to_body_epochs = np.arange(*cfgs.get('attach_legs_to_body_epochs', [0, 0]))
330
+ self.max_arti_angle = cfgs.get('max_arti_angle', 60)
331
+
332
+ num_layers_arti = cfgs.get('num_layers_arti', 5)
333
+ which_vit = cfgs.get('which_vit', 'dino_vits8')
334
+ if which_vit == 'dino_vits8':
335
+ dino_feat_dim = 384
336
+ elif which_vit == 'dino_vitb8':
337
+ dino_feat_dim = 768
338
+ self.articulation_arch = cfgs.get('articulation_arch', 'mlp')
339
+ self.articulation_feature_mode = cfgs.get('articulation_feature_mode', 'sample')
340
+ embedder_freq_arti = cfgs.get('embedder_freq_arti', 8)
341
+ if self.articulation_feature_mode == 'global':
342
+ feat_dim = encoder_latent_dim
343
+ elif self.articulation_feature_mode == 'sample':
344
+ feat_dim = dino_feat_dim
345
+ elif self.articulation_feature_mode == 'sample+global':
346
+ feat_dim = encoder_latent_dim + dino_feat_dim
347
+ if self.articulation_feature_mode == 'attention':
348
+ arti_feat_attn_zdim = cfgs.get('arti_feat_attn_zdim', 128)
349
+ pos_dim = 1 + 2 + 3*2
350
+ self.netFeatureAttn = networks.FeatureAttention(which_vit, pos_dim, embedder_freq_arti, arti_feat_attn_zdim, img_size=in_image_size)
351
+ embedder_scaler = np.pi * 0.9 # originally (-1, 1) rescale to (-pi, pi) * 0.9
352
+ self.netArticulation = networks.ArticulationNetwork(self.articulation_arch, feat_dim, 1+2+3*2, num_layers_arti, mlp_hidden_size, n_harmonic_functions=embedder_freq_arti, omega0=embedder_scaler)
353
+ self.kinematic_tree_epoch = -1
354
+
355
+ self.enable_lighting = cfgs.get('enable_lighting', False)
356
+ if self.enable_lighting:
357
+ num_layers_light = cfgs.get('num_layers_light', 5)
358
+ amb_diff_min = torch.FloatTensor(cfgs.get('amb_diff_min', [0., 0.]))
359
+ amb_diff_max = torch.FloatTensor(cfgs.get('amb_diff_max', [1., 1.]))
360
+ intensity_min_max = torch.stack((amb_diff_min, amb_diff_max), dim=0)
361
+ self.netLight = light.DirectionalLight(encoder_latent_dim, num_layers_light, mlp_hidden_size, intensity_min_max=intensity_min_max)
362
+
363
+ self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.)
364
+ self.crop_fov_approx = cfgs.get("crop_fov_approx", 25)
365
+
366
+ def forward_encoder(self, images, dino_features=None):
367
+ images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1)
368
+ patch_out = patch_key = None
369
+ if self.dino_feature_input and self.cfgs.get('encoder_arch', 'simple') != 'vit':
370
+ dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1)
371
+ feat_out = self.netEncoder(images_in, dino_features_in) # Shape: (B, latent_dim)
372
+ elif self.cfgs.get('encoder_arch', 'simple') == 'vit':
373
+ feat_out, feat_key, patch_out, patch_key = self.netEncoder(images_in, return_patches=True)
374
+ else:
375
+ feat_out = self.netEncoder(images_in) # Shape: (B, latent_dim)
376
+ return feat_out, feat_key, patch_out, patch_key
377
+
378
+ def forward_pose(self, images, feat, patch_out, patch_key, dino_features):
379
+ if self.pose_arch == 'mlp':
380
+ pose = self.netPose(feat)
381
+ elif self.pose_arch == 'encoder':
382
+ images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1)
383
+ if self.dino_feature_input:
384
+ dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1)
385
+ pose = self.netPose(images_in, dino_features_in) # Shape: (B, latent_dim)
386
+ else:
387
+ pose = self.netPose(images_in) # Shape: (B, latent_dim)
388
+ elif self.pose_arch == 'vit':
389
+ images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1)
390
+ pose = self.netPose(images_in)
391
+ elif self.pose_arch == 'encoder_dino_patch_out':
392
+ pose = self.netPose(patch_out) # Shape: (B, latent_dim)
393
+ elif self.pose_arch == 'encoder_dino_patch_key':
394
+ pose = self.netPose(patch_key) # Shape: (B, latent_dim)
395
+ else:
396
+ raise NotImplementedError
397
+ trans_pred = pose[...,-3:].tanh() * torch.FloatTensor([self.max_trans_xy_range, self.max_trans_xy_range, self.max_trans_z_range]).to(pose.device)
398
+ if self.rot_rep == 'euler_angle':
399
+ multiplier = 1.
400
+ if self.gradually_expand_yaw:
401
+ # multiplier += (min(iteration, 20000) // 500) * 0.25
402
+ multiplier *= 1.2 ** (min(iteration, 20000) // 500) # 1.125^40 = 111.200
403
+ rot_pred = torch.cat([pose[...,:1], pose[...,1:2]*multiplier, pose[...,2:3]], -1).tanh()
404
+ rot_pred = rot_pred * torch.FloatTensor([self.max_rot_x_range, self.max_rot_y_range, self.max_rot_z_range]).to(pose.device) /180 * np.pi
405
+
406
+ elif self.rot_rep == 'quaternion':
407
+ quat_init = torch.FloatTensor([0.01,0,0,0]).to(pose.device)
408
+ rot_pred = pose[...,:4] + quat_init
409
+ rot_pred = nn.functional.normalize(rot_pred, p=2, dim=-1)
410
+ # rot_pred = torch.cat([rot_pred[...,:1].abs(), rot_pred[...,1:]], -1) # make real part non-negative
411
+ rot_pred = rot_pred * rot_pred[...,:1].sign() # make real part non-negative
412
+
413
+ elif self.rot_rep == 'lookat':
414
+ vec_forward_raw = pose[...,:3]
415
+ if self.lookat_init is not None:
416
+ vec_forward_raw = vec_forward_raw + torch.FloatTensor(self.lookat_init).to(pose.device)
417
+ if self.lookat_zeroy:
418
+ vec_forward_raw = vec_forward_raw * torch.FloatTensor([1,0,1]).to(pose.device)
419
+ vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward
420
+ rot_pred = vec_forward_raw
421
+
422
+ elif self.rot_rep in ['quadlookat', 'octlookat']:
423
+ rots_pred = pose[..., :self.num_pose_hypos*4].view(-1, self.num_pose_hypos, 4) # (B, T, K, 4)
424
+ rots_logits = rots_pred[..., :1]
425
+ vec_forward_raw = rots_pred[..., 1:4]
426
+ xs, ys, zs = vec_forward_raw.unbind(-1)
427
+ margin = 0.
428
+ xs = nn.functional.softplus(xs, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5
429
+ if self.rot_rep == 'octlookat':
430
+ ys = nn.functional.softplus(ys, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5
431
+ if self.lookat_zeroy:
432
+ ys = ys * 0
433
+ zs = nn.functional.softplus(zs, beta=2*np.log(2)) # initialize to 0.5
434
+ vec_forward_raw = torch.stack([xs, ys, zs], -1)
435
+ vec_forward_raw = vec_forward_raw * self.orthant_signs.to(pose.device)
436
+ vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward
437
+ rot_pred = torch.cat([rots_logits, vec_forward_raw], -1).view(-1, self.num_pose_hypos*4)
438
+
439
+ else:
440
+ raise NotImplementedError
441
+
442
+ pose = torch.cat([rot_pred, trans_pred], -1)
443
+ return pose
444
+
445
+ def forward_deformation(self, shape, feat=None):
446
+ original_verts = shape.v_pos
447
+ num_verts = original_verts.shape[1]
448
+ if feat is not None:
449
+ deform_feat = feat[:, None, :].repeat(1, num_verts, 1) # Shape: (B, num_verts, latent_dim)
450
+ original_verts = original_verts.repeat(len(feat),1,1)
451
+ deformation = self.netDeform(original_verts, deform_feat) * 0.1 # Shape: (B, num_verts, 3)
452
+ shape = shape.deform(deformation)
453
+ return shape, deformation
454
+
455
+ def forward_articulation(self, shape, feat, patch_feat, mvp, w2c, batch_size, num_frames, epoch):
456
+ """
457
+ Forward propagation of articulation. For each bone, the network takes: 1) the 3D location of the bone; 2) the feature of the patch which
458
+ the bone is projected to; and 3) an encoding of the bone's index to predict the bone's rotation (represented by an Euler angle).
459
+
460
+ Args:
461
+ shape: a Mesh object, whose v_pos has batch size BxF or 1.
462
+ feat: the feature of the patches. Shape: (BxF, feat_dim, num_patches_per_axis, num_patches_per_axis)
463
+ mvp: the model-view-projection matrix. Shape: (BxF, 4, 4)
464
+
465
+ Returns:
466
+ shape: a Mesh object, whose v_pos has batch size BxF (collapsed).
467
+ articulation_angles: the predicted bone rotations. Shape: (B, F, num_bones, 3)
468
+ aux: a dictionary containing auxiliary information.
469
+ """
470
+ verts = shape.v_pos
471
+ if len(verts) == 1:
472
+ verts = verts[None]
473
+ else:
474
+ verts = verts.view(batch_size, num_frames, *verts.shape[1:])
475
+
476
+ if self.kinematic_tree_epoch != epoch:
477
+ # if (epoch == self.articulation_epochs[0]) and (self.kinematic_tree_epoch != epoch):
478
+ # if (epoch in [self.articulation_epochs[0], self.articulation_epochs[0]+2, self.articulation_epochs[0]+4]) and (self.kinematic_tree_epoch != epoch):
479
+ attach_legs_to_body = epoch in self.attach_legs_to_body_epochs
480
+ bones, self.kinematic_tree, self.bone_aux = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=True, attach_legs_to_body=attach_legs_to_body)
481
+ self.kinematic_tree_epoch = epoch
482
+ else:
483
+ bones = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=False, aux=self.bone_aux)
484
+
485
+ bones_pos = bones # Shape: (B, F, K, 2, 3)
486
+ if batch_size > bones_pos.shape[0] or num_frames > bones_pos.shape[1]:
487
+ assert bones_pos.shape[0] == 1 and bones_pos.shape[1] == 1, "If there is a mismatch, then there must be only one canonical mesh."
488
+ bones_pos = bones_pos.repeat(batch_size, num_frames, 1, 1, 1)
489
+ num_bones = bones_pos.shape[2]
490
+ bones_pos = bones_pos.view(batch_size*num_frames, num_bones, 2, 3) # NxKx2x3
491
+ bones_mid_pos = bones_pos.mean(2) # NxKx3
492
+ bones_idx = torch.arange(num_bones).to(bones_pos.device)
493
+
494
+ bones_mid_pos_world4 = torch.cat([bones_mid_pos, torch.ones_like(bones_mid_pos[..., :1])], -1) # NxKx4
495
+ bones_mid_pos_clip4 = bones_mid_pos_world4 @ mvp.transpose(-1, -2)
496
+ bones_mid_pos_uv = bones_mid_pos_clip4[..., :2] / bones_mid_pos_clip4[..., 3:4]
497
+ bones_mid_pos_uv = bones_mid_pos_uv.detach()
498
+
499
+ bones_pos_world4 = torch.cat([bones_pos, torch.ones_like(bones_pos[..., :1])], -1) # NxKx2x4
500
+ bones_pos_cam4 = bones_pos_world4 @ w2c[:,None].transpose(-1, -2)
501
+ bones_pos_cam3 = bones_pos_cam4[..., :3] / bones_pos_cam4[..., 3:4]
502
+ bones_pos_cam3 = bones_pos_cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(bones_pos_cam3.device).view(1, 1, 1, 3)
503
+ bones_pos_in = bones_pos_cam3.view(batch_size*num_frames, num_bones, 2*3) / self.grid_scale * 2 # (-1, 1), NxKx(2*3)
504
+
505
+ bones_idx_in = ((bones_idx[None, :, None] + 0.5) / num_bones * 2 - 1).repeat(batch_size * num_frames, 1, 1) # (-1, 1)
506
+ bones_pos_in = torch.cat([bones_mid_pos_uv, bones_pos_in, bones_idx_in], -1).detach()
507
+
508
+ if self.articulation_feature_mode == 'global':
509
+ bones_patch_features = feat[:, None].repeat(1, num_bones, 1) # (BxF, K, feat_dim)
510
+ elif self.articulation_feature_mode == 'sample':
511
+ bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim)
512
+ elif self.articulation_feature_mode == 'sample+global':
513
+ bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim)
514
+ bones_patch_features = torch.cat([feat[:, None].repeat(1, num_bones, 1), bones_patch_features], -1)
515
+ elif self.articulation_feature_mode == 'attention':
516
+ bones_patch_features = self.netFeatureAttn(bones_pos_in, patch_feat)
517
+ else:
518
+ raise NotImplementedError
519
+
520
+ articulation_angles = self.netArticulation(bones_patch_features, bones_pos_in).view(batch_size, num_frames, num_bones, 3) * self.articulation_multiplier
521
+
522
+ if self.static_root_bones:
523
+ root_bones = [self.num_body_bones // 2 - 1, self.num_body_bones - 1]
524
+ tmp_mask = torch.ones_like(articulation_angles)
525
+ tmp_mask[:, :, root_bones] = 0
526
+ articulation_angles = articulation_angles * tmp_mask
527
+
528
+ articulation_angles = articulation_angles.tanh()
529
+
530
+ if self.constrain_legs:
531
+ leg_bones_posx = [self.num_body_bones + i for i in range(self.num_leg_bones * self.num_legs // 2)]
532
+ leg_bones_negx = [self.num_body_bones + self.num_leg_bones * self.num_legs // 2 + i for i in range(self.num_leg_bones * self.num_legs // 2)]
533
+
534
+ tmp_mask = torch.zeros_like(articulation_angles)
535
+ tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 2] = 1
536
+ articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # no twist
537
+
538
+ tmp_mask = torch.zeros_like(articulation_angles)
539
+ tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 1] = 1
540
+ articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # (-0.4, 0.4), limit side bending
541
+
542
+ if epoch in self.perturb_articulation_epochs:
543
+ articulation_angles = articulation_angles + torch.randn_like(articulation_angles) * 0.1
544
+ articulation_angles = articulation_angles * self.max_arti_angle / 180 * np.pi
545
+
546
+ verts_articulated, aux = skinning(verts, bones, self.kinematic_tree, articulation_angles,
547
+ output_posed_bones=True, temperature=self.skinning_temperature)
548
+ verts_articulated = verts_articulated.view(batch_size*num_frames, *verts_articulated.shape[2:])
549
+ v_tex = shape.v_tex
550
+ if len(v_tex) != len(verts_articulated):
551
+ v_tex = v_tex.repeat(len(verts_articulated), 1, 1)
552
+ shape = mesh.make_mesh(
553
+ verts_articulated,
554
+ shape.t_pos_idx,
555
+ v_tex,
556
+ shape.t_tex_idx,
557
+ shape.material)
558
+ return shape, articulation_angles, aux
559
+
560
+ def get_camera_extrinsics_from_pose(self, pose, znear=0.1, zfar=1000.):
561
+ N = len(pose)
562
+ cam_pos_offset = torch.FloatTensor([0, 0, -self.cam_pos_z_offset]).to(pose.device)
563
+ pose_R = pose[:, :9].view(N, 3, 3).transpose(2, 1)
564
+ pose_T = pose[:, -3:] + cam_pos_offset[None, None, :]
565
+ pose_T = pose_T.view(N, 3, 1)
566
+ pose_RT = torch.cat([pose_R, pose_T], axis=2) # Nx3x4
567
+ w2c = torch.cat([pose_RT, torch.FloatTensor([0, 0, 0, 1]).repeat(N, 1, 1).to(pose.device)], axis=1) # Nx4x4
568
+ # We assume the images are perfect square.
569
+ proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, znear, zfar)[None].to(pose.device)
570
+ mvp = torch.matmul(proj, w2c)
571
+ campos = -torch.matmul(pose_R.transpose(2, 1), pose_T).view(N, 3)
572
+ return mvp, w2c, campos
573
+
574
+ def forward(self, images=None, prior_shape=None, epoch=None, dino_features=None, dino_clusters=None, total_iter=None, is_training=True):
575
+ batch_size, num_frames = images.shape[:2]
576
+ if self.enable_encoder:
577
+ feat_out, feat_key, patch_out, patch_key = self.forward_encoder(images, dino_features)
578
+ else:
579
+ feat_out = feat_key = patch_out = patch_key = None
580
+ shape = prior_shape
581
+ texture = self.netTexture
582
+
583
+ multi_hypothesis_aux = {}
584
+ if self.enable_pose:
585
+ poses_raw = self.forward_pose(images, feat_out, patch_out, patch_key, dino_features)
586
+ pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_pose_flag = sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, rot_temp_scalar=self.rot_temp_scalar, num_hypos=self.num_pose_hypos, naive_probs_iter=self.naive_probs_iter, best_pose_start_iter=self.best_pose_start_iter, random_sample=is_training)
587
+ multi_hypothesis_aux['rot_idx'] = rot_idx
588
+ multi_hypothesis_aux['rot_prob'] = rot_prob
589
+ multi_hypothesis_aux['rot_logit'] = rot_logit
590
+ multi_hypothesis_aux['rots_probs'] = rots_probs
591
+ multi_hypothesis_aux['rand_pose_flag'] = rand_pose_flag
592
+ else:
593
+ raise NotImplementedError
594
+ mvp, w2c, campos = self.get_camera_extrinsics_from_pose(pose)
595
+
596
+ deformation = None
597
+ if self.enable_deform and epoch in self.deform_epochs:
598
+ shape, deformation = self.forward_deformation(shape, feat_key)
599
+
600
+ arti_params, articulation_aux = None, {}
601
+ if self.enable_articulation and epoch in self.articulation_epochs:
602
+ shape, arti_params, articulation_aux = self.forward_articulation(shape, feat_key, patch_key, mvp, w2c, batch_size, num_frames, epoch)
603
+
604
+ if self.enable_lighting:
605
+ light = self.netLight
606
+ else:
607
+ light = None
608
+
609
+ aux = articulation_aux
610
+ aux.update(multi_hypothesis_aux)
611
+
612
+ return shape, pose_raw, pose, mvp, w2c, campos, texture, feat_out, deformation, arti_params, light, aux
613
+
614
+
615
+ class Unsup3D:
616
+ def __init__(self, cfgs):
617
+ self.cfgs = cfgs
618
+ self.device = cfgs.get('device', 'cpu')
619
+ self.in_image_size = cfgs.get('in_image_size', 128)
620
+ self.out_image_size = cfgs.get('out_image_size', 128)
621
+
622
+ self.num_epochs = cfgs.get('num_epochs', 10)
623
+ self.lr = cfgs.get('lr', 1e-4)
624
+ self.use_scheduler = cfgs.get('use_scheduler', False)
625
+ if self.use_scheduler:
626
+ scheduler_milestone = cfgs.get('scheduler_milestone', [1,2,3,4,5])
627
+ scheduler_gamma = cfgs.get('scheduler_gamma', 0.5)
628
+ self.make_scheduler = lambda optim: torch.optim.lr_scheduler.MultiStepLR(optim, milestones=scheduler_milestone, gamma=scheduler_gamma)
629
+
630
+ self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.)
631
+ self.full_size_h = cfgs.get('full_size_h', 1080)
632
+ self.full_size_w = cfgs.get('full_size_w', 1920)
633
+ # self.fov_w = cfgs.get('fov_w', 60)
634
+ # self.fov_h = np.arctan(np.tan(self.fov_w /2 /180*np.pi) / self.full_size_w * self.full_size_h) *2 /np.pi*180 # 36
635
+ self.crop_fov_approx = cfgs.get("crop_fov_approx", 25)
636
+ self.mesh_regularization_mode = cfgs.get('mesh_regularization_mode', 'seq')
637
+
638
+ self.enable_prior = cfgs.get('enable_prior', False)
639
+ if self.enable_prior:
640
+ self.netPrior = PriorPredictor(self.cfgs)
641
+ self.prior_lr = cfgs.get('prior_lr', self.lr)
642
+ self.prior_weight_decay = cfgs.get('prior_weight_decay', 0.)
643
+ self.prior_only_epochs = cfgs.get('prior_only_epochs', 0)
644
+ self.netInstance = InstancePredictor(self.cfgs, tet_bbox=self.netPrior.netShape.getAABB())
645
+ self.perturb_sdf = cfgs.get('perturb_sdf', False)
646
+ self.blur_mask = cfgs.get('blur_mask', False)
647
+ self.blur_mask_iter = cfgs.get('blur_mask_iter', 1)
648
+
649
+ self.seqshape_epochs = np.arange(*cfgs.get('seqshape_epochs', [0, self.num_epochs]))
650
+ self.avg_texture_epochs = np.arange(*cfgs.get('avg_texture_epochs', [0, 0]))
651
+ self.swap_texture_epochs = np.arange(*cfgs.get('swap_texture_epochs', [0, 0]))
652
+ self.swap_priorshape_epochs = np.arange(*cfgs.get('swap_priorshape_epochs', [0, 0]))
653
+ self.avg_seqshape_epochs = np.arange(*cfgs.get('avg_seqshape_epochs', [0, 0]))
654
+ self.swap_seqshape_epochs = np.arange(*cfgs.get('swap_seqshape_epochs', [0, 0]))
655
+ self.pose_epochs = np.arange(*cfgs.get('pose_epochs', [0, 0]))
656
+ self.pose_iters = cfgs.get('pose_iters', 0)
657
+ self.deform_type = cfgs.get('deform_type', None)
658
+ self.mesh_reg_decay_epoch = cfgs.get('mesh_reg_decay_epoch', 0)
659
+ self.sdf_reg_decay_start_iter = cfgs.get('sdf_reg_decay_start_iter', 0)
660
+ self.mesh_reg_decay_rate = cfgs.get('mesh_reg_decay_rate', 1)
661
+ self.texture_epochs = np.arange(*cfgs.get('texture_epochs', [0, self.num_epochs]))
662
+ self.zflip_epochs = np.arange(*cfgs.get('zflip_epochs', [0, self.num_epochs]))
663
+ self.lookat_zflip_loss_epochs = np.arange(*cfgs.get('lookat_zflip_loss_epochs', [0, self.num_epochs]))
664
+ self.lookat_zflip_no_other_losses = cfgs.get('lookat_zflip_no_other_losses', False)
665
+ self.flow_loss_epochs = np.arange(*cfgs.get('flow_loss_epochs', [0, self.num_epochs]))
666
+ self.sdf_inflate_reg_loss_epochs = np.arange(*cfgs.get('sdf_inflate_reg_loss_epochs', [0, self.num_epochs]))
667
+ self.arti_reg_loss_epochs = np.arange(*cfgs.get('arti_reg_loss_epochs', [0, self.num_epochs]))
668
+ self.background_mode = cfgs.get('background_mode', 'background')
669
+ self.shape_prior_type = cfgs.get('shape_prior_type', 'deform')
670
+ self.backward_prior = cfgs.get('backward_prior', True)
671
+ self.resume_prior_optim = cfgs.get('resume_prior_optim', True)
672
+ self.dmtet_grid_smaller_epoch = cfgs.get('dmtet_grid_smaller_epoch', 0)
673
+ self.dmtet_grid_smaller = cfgs.get('dmtet_grid_smaller', 128)
674
+ self.dmtet_grid = cfgs.get('dmtet_grid', 256)
675
+ self.pose_xflip_recon_epochs = np.arange(*cfgs.get('pose_xflip_recon_epochs', [0, 0]))
676
+ self.rot_rand_quad_epochs = np.arange(*cfgs.get('rot_rand_quad_epochs', [0, 0]))
677
+ self.rot_all_quad_epochs = np.arange(*cfgs.get('rot_all_quad_epochs', [0, 0]))
678
+
679
+ ## perceptual loss
680
+ if cfgs.get('perceptual_loss_weight', 0.) > 0:
681
+ self.perceptual_loss_use_lin = cfgs.get('perceptual_loss_use_lin', True)
682
+ self.perceptual_loss = lpips.LPIPS(net='vgg', lpips=self.perceptual_loss_use_lin)
683
+
684
+ self.glctx = dr.RasterizeGLContext()
685
+ self.render_flow = self.cfgs.get('flow_loss_weight', 0.) > 0.
686
+ self.extra_renders = cfgs.get('extra_renders', [])
687
+ self.renderer_spp = cfgs.get('renderer_spp', 1)
688
+ self.dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64)
689
+
690
+ self.total_loss = 0.
691
+ self.all_scores = torch.Tensor()
692
+ self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results')
693
+
694
+ @staticmethod
695
+ def get_data_loaders(cfgs, dataset, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None):
696
+ train_loader = val_loader = test_loader = None
697
+ color_jitter_train = cfgs.get('color_jitter_train', None)
698
+ color_jitter_val = cfgs.get('color_jitter_val', None)
699
+ random_flip_train = cfgs.get('random_flip_train', False)
700
+
701
+ ## video dataset
702
+ if dataset == 'video':
703
+ data_loader_mode = cfgs.get('data_loader_mode', 'n_frame')
704
+ skip_beginning = cfgs.get('skip_beginning', 4)
705
+ skip_end = cfgs.get('skip_end', 4)
706
+ num_sample_frames = cfgs.get('num_sample_frames', 2)
707
+ min_seq_len = cfgs.get('min_seq_len', 10)
708
+ max_seq_len = cfgs.get('max_seq_len', 10)
709
+ debug_seq = cfgs.get('debug_seq', False)
710
+ random_sample_train_frames = cfgs.get('random_sample_train_frames', False)
711
+ shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False)
712
+ random_sample_val_frames = cfgs.get('random_sample_val_frames', False)
713
+ load_background = cfgs.get('background_mode', 'none') == 'background'
714
+ rgb_suffix = cfgs.get('rgb_suffix', '.png')
715
+ load_dino_feature = cfgs.get('load_dino_feature', False)
716
+ load_dino_cluster = cfgs.get('load_dino_cluster', False)
717
+ dino_feature_dim = cfgs.get('dino_feature_dim', 64)
718
+ get_loader = lambda **kwargs: get_sequence_loader(
719
+ mode=data_loader_mode,
720
+ batch_size=batch_size,
721
+ num_workers=num_workers,
722
+ in_image_size=in_image_size,
723
+ out_image_size=out_image_size,
724
+ debug_seq=debug_seq,
725
+ skip_beginning=skip_beginning,
726
+ skip_end=skip_end,
727
+ num_sample_frames=num_sample_frames,
728
+ min_seq_len=min_seq_len,
729
+ max_seq_len=max_seq_len,
730
+ load_background=load_background,
731
+ rgb_suffix=rgb_suffix,
732
+ load_dino_feature=load_dino_feature,
733
+ load_dino_cluster=load_dino_cluster,
734
+ dino_feature_dim=dino_feature_dim,
735
+ **kwargs)
736
+
737
+ if run_train:
738
+ assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}"
739
+ print(f"Loading training data from {train_data_dir}")
740
+ train_loader = get_loader(data_dir=train_data_dir, is_validation=False, random_sample=random_sample_train_frames, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train)
741
+
742
+ if val_data_dir is not None:
743
+ assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}"
744
+ print(f"Loading validation data from {val_data_dir}")
745
+ val_loader = get_loader(data_dir=val_data_dir, is_validation=True, random_sample=random_sample_val_frames, shuffle=False, dense_sample=False, color_jitter=color_jitter_val, random_flip=False)
746
+ if run_test:
747
+ assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}"
748
+ print(f"Loading testing data from {test_data_dir}")
749
+ test_loader = get_loader(data_dir=test_data_dir, is_validation=True, dense_sample=False, color_jitter=None, random_flip=False)
750
+
751
+ ## CUB dataset
752
+ elif dataset == 'cub':
753
+ get_loader = lambda **kwargs: get_cub_loader(
754
+ batch_size=batch_size,
755
+ num_workers=num_workers,
756
+ image_size=in_image_size,
757
+ **kwargs)
758
+
759
+ if run_train:
760
+ assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}"
761
+ print(f"Loading training data from {train_data_dir}")
762
+ train_loader = get_loader(data_dir=train_data_dir, split='train', is_validation=False)
763
+ val_loader = get_loader(data_dir=val_data_dir, split='val', is_validation=True)
764
+
765
+ if run_test:
766
+ assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}"
767
+ print(f"Loading testing data from {test_data_dir}")
768
+ test_loader = get_loader(data_dir=test_data_dir, split='test', is_validation=True)
769
+
770
+ ## other datasets
771
+ else:
772
+ get_loader = lambda **kwargs: get_image_loader(
773
+ batch_size=batch_size,
774
+ num_workers=num_workers,
775
+ image_size=in_image_size,
776
+ **kwargs)
777
+
778
+ if run_train:
779
+ assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}"
780
+ print(f"Loading training data from {train_data_dir}")
781
+ train_loader = get_loader(data_dir=train_data_dir, is_validation=False, color_jitter=color_jitter_train)
782
+
783
+ if val_data_dir is not None:
784
+ assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}"
785
+ print(f"Loading validation data from {val_data_dir}")
786
+ val_loader = get_loader(data_dir=val_data_dir, is_validation=True, color_jitter=color_jitter_val)
787
+
788
+ if run_test:
789
+ assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}"
790
+ print(f"Loading testing data from {test_data_dir}")
791
+ test_loader = get_loader(data_dir=test_data_dir, is_validation=True, color_jitter=None)
792
+
793
+ return train_loader, val_loader, test_loader
794
+
795
+ def load_model_state(self, cp):
796
+ self.netInstance.load_state_dict(cp["netInstance"])
797
+ if self.enable_prior:
798
+ self.netPrior.load_state_dict(cp["netPrior"])
799
+
800
+ def load_optimizer_state(self, cp):
801
+ self.optimizerInstance.load_state_dict(cp["optimizerInstance"])
802
+ if self.use_scheduler:
803
+ if 'schedulerInstance' in cp:
804
+ self.schedulerInstance.load_state_dict(cp["schedulerInstance"])
805
+ if self.enable_prior and self.resume_prior_optim:
806
+ self.optimizerPrior.load_state_dict(cp["optimizerPrior"])
807
+ if self.use_scheduler:
808
+ if 'schedulerPrior' in cp:
809
+ self.schedulerPrior.load_state_dict(cp["schedulerPrior"])
810
+
811
+ def get_model_state(self):
812
+ state = {"netInstance": self.netInstance.state_dict()}
813
+ if self.enable_prior:
814
+ state["netPrior"] = self.netPrior.state_dict()
815
+ return state
816
+
817
+ def get_optimizer_state(self):
818
+ state = {"optimizerInstance": self.optimizerInstance.state_dict()}
819
+ if self.use_scheduler:
820
+ state["schedulerInstance"] = self.schedulerInstance.state_dict()
821
+ if self.enable_prior:
822
+ state["optimizerPrior"] = self.optimizerPrior.state_dict()
823
+ if self.use_scheduler:
824
+ state["schedulerPrior"] = self.schedulerPrior.state_dict()
825
+ return state
826
+
827
+ def to(self, device):
828
+ self.device = device
829
+ self.netInstance.to(device)
830
+ if self.enable_prior:
831
+ self.netPrior.to(device)
832
+ if hasattr(self, 'perceptual_loss'):
833
+ self.perceptual_loss.to(device)
834
+
835
+ def set_train(self):
836
+ self.netInstance.train()
837
+ if self.enable_prior:
838
+ self.netPrior.train()
839
+
840
+ def set_eval(self):
841
+ self.netInstance.eval()
842
+ if self.enable_prior:
843
+ self.netPrior.eval()
844
+
845
+ def reset_optimizers(self):
846
+ print("Resetting optimizers...")
847
+ self.optimizerInstance = get_optimizer(self.netInstance, self.lr)
848
+ if self.use_scheduler:
849
+ self.schedulerInstance = self.make_scheduler(self.optimizerInstance)
850
+ if self.enable_prior:
851
+ self.optimizerPrior = get_optimizer(self.netPrior, lr=self.prior_lr, weight_decay=self.prior_weight_decay)
852
+ if self.use_scheduler:
853
+ self.schedulerPrior = self.make_scheduler(self.optimizerPrior)
854
+
855
+ def backward(self):
856
+ self.optimizerInstance.zero_grad()
857
+ if self.backward_prior:
858
+ self.optimizerPrior.zero_grad()
859
+ self.total_loss.backward()
860
+ self.optimizerInstance.step()
861
+ if self.backward_prior:
862
+ self.optimizerPrior.step()
863
+ self.total_loss = 0.
864
+
865
+ def scheduler_step(self):
866
+ if self.use_scheduler:
867
+ self.schedulerInstance.step()
868
+ if self.enable_prior:
869
+ self.schedulerPrior.step()
870
+
871
+ def zflip_pose(self, pose):
872
+ if self.rot_rep == 'lookat':
873
+ vec_forward = pose[:,:,6:9]
874
+ vec_forward = vec_forward * torch.FloatTensor([1,1,-1]).view(1,1,3).to(vec_forward.device)
875
+ up = torch.FloatTensor([0,1,0]).to(pose.device).view(1,1,3)
876
+ vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1)
877
+ vec_right = nn.functional.normalize(vec_right, p=2, dim=-1)
878
+ vec_up = vec_forward.cross(vec_right, dim=-1)
879
+ vec_up = nn.functional.normalize(vec_up, p=2, dim=-1)
880
+ rot_mat = torch.stack([vec_right, vec_up, vec_forward], 2)
881
+ rot_pred = rot_mat.reshape(*pose.shape[:-1], -1)
882
+ pose_zflip = torch.cat([rot_pred, pose[:,:,9:]], -1)
883
+ else:
884
+ raise NotImplementedError
885
+ return pose_zflip
886
+
887
+ def render(self, shape, texture, mvp, w2c, campos, resolution, background='none', im_features=None, light=None, prior_shape=None, render_flow=True, dino_pred=None, render_mode='diffuse', two_sided_shading=True, num_frames=None, spp=1):
888
+ h, w = resolution
889
+ N = len(mvp)
890
+ if background in ['none', 'black']:
891
+ bg_image = torch.zeros((N, h, w, 3), device=mvp.device)
892
+ elif background == 'white':
893
+ bg_image = torch.ones((N, h, w, 3), device=mvp.device)
894
+ elif background == 'checkerboard':
895
+ bg_image = torch.FloatTensor(util.checkerboard((h, w), 8), device=self.device).repeat(N, 1, 1, 1) # NxHxWxC
896
+ else:
897
+ raise NotImplementedError
898
+
899
+ frame_rendered = render.render_mesh(
900
+ self.glctx,
901
+ shape,
902
+ mtx_in=mvp,
903
+ w2c=w2c,
904
+ view_pos=campos,
905
+ material=texture,
906
+ lgt=light,
907
+ resolution=resolution,
908
+ spp=spp,
909
+ msaa=True,
910
+ background=bg_image,
911
+ bsdf=render_mode,
912
+ feat=im_features,
913
+ prior_mesh=prior_shape,
914
+ two_sided_shading=two_sided_shading,
915
+ render_flow=render_flow,
916
+ dino_pred=dino_pred,
917
+ num_frames=num_frames)
918
+ shaded = frame_rendered['shaded'].permute(0, 3, 1, 2)
919
+ image_pred = shaded[:, :3, :, :]
920
+ mask_pred = shaded[:, 3, :, :]
921
+ albedo = frame_rendered['kd'].permute(0, 3, 1, 2)[:, :3, :, :]
922
+ if 'shading' in frame_rendered:
923
+ shading = frame_rendered['shading'].permute(0, 3, 1, 2)[:, :1, :, :]
924
+ else:
925
+ shading = None
926
+ if render_flow:
927
+ flow_pred = frame_rendered['flow']
928
+ flow_pred = flow_pred.permute(0, 3, 1, 2)[:, :2, :, :]
929
+ else:
930
+ flow_pred = None
931
+ if dino_pred is not None:
932
+ dino_feat_im_pred = frame_rendered['dino_feat_im_pred']
933
+ dino_feat_im_pred = dino_feat_im_pred.permute(0, 3, 1, 2)[:, :-1]
934
+ else:
935
+ dino_feat_im_pred = None
936
+
937
+ return image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading
938
+
939
+ def compute_reconstruction_losses(self, image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode='none', reduce=False):
940
+ losses = {}
941
+ batch_size, num_frames, _, h, w = image_pred.shape # BxFxCxHxW
942
+
943
+ # image_loss = (image_pred - image_gt) ** 2
944
+ image_loss = (image_pred - image_gt).abs()
945
+
946
+ ## silhouette loss
947
+ mask_pred_valid = mask_pred * mask_valid
948
+ # mask_pred_valid = mask_pred
949
+ # losses["silhouette_loss"] = ((mask_pred - mask_gt) ** 2).mean()
950
+ # mask_loss_mask = (image_loss.mean(2).detach() > 0.05).float()
951
+ mask_loss = (mask_pred_valid - mask_gt) ** 2
952
+ # mask_loss = nn.functional.mse_loss(mask_pred, mask_gt)
953
+ # num_mask_pixels = mask_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1)
954
+ # losses["silhouette_loss"] = (mask_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean()
955
+ losses['silhouette_loss'] = mask_loss.view(batch_size, num_frames, -1).mean(2)
956
+ losses['silhouette_dt_loss'] = (mask_pred * mask_dt[:,:,1]).view(batch_size, num_frames, -1).mean(2)
957
+ losses['silhouette_inv_dt_loss'] = ((1-mask_pred) * mask_dt[:,:,0]).view(batch_size, num_frames, -1).mean(2)
958
+
959
+ mask_pred_binary = (mask_pred_valid > 0.).float().detach()
960
+ mask_both_binary = (mask_pred_binary * mask_gt).view(batch_size*num_frames, 1, *mask_pred.shape[2:])
961
+ mask_both_binary = (nn.functional.avg_pool2d(mask_both_binary, 3, stride=1, padding=1).view(batch_size, num_frames, *mask_pred.shape[2:]) > 0.99).float().detach() # erode by 1 pixel
962
+
963
+ ## reconstruction loss
964
+ # image_loss_mask = (mask_pred*mask_gt).unsqueeze(2).expand_as(image_gt)
965
+ # image_loss = image_loss * image_loss_mask
966
+ # num_mask_pixels = image_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1)
967
+ # losses["rgb_loss"] = (image_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean()
968
+ if background_mode in ['background', 'input']:
969
+ pass
970
+ else:
971
+ image_loss = image_loss * mask_both_binary.unsqueeze(2)
972
+ losses['rgb_loss'] = image_loss.reshape(batch_size, num_frames, -1).mean(2)
973
+
974
+ if self.cfgs.get('perceptual_loss_weight', 0.) > 0:
975
+ if background_mode in ['background', 'input']:
976
+ perc_image_pred = image_pred
977
+ perc_image_gt = image_gt
978
+ else:
979
+ perc_image_pred = image_pred * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2))
980
+ perc_image_gt = image_gt * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2))
981
+ losses['perceptual_loss'] = self.perceptual_loss(perc_image_pred.view(-1, *image_pred.shape[2:]) *2-1, perc_image_gt.view(-1, *image_gt.shape[2:]) *2-1).view(batch_size, num_frames)
982
+
983
+ ## flow loss - between first and second frame
984
+ if flow_pred is not None:
985
+ flow_loss = (flow_pred - flow_gt).abs()
986
+ flow_loss_mask = mask_both_binary[:,:-1].unsqueeze(2).expand_as(flow_gt).detach()
987
+
988
+ ## ignore frames where GT flow is too large (likely inaccurate)
989
+ large_flow = (flow_gt.abs() > 0.5).float() * flow_loss_mask
990
+ large_flow = (large_flow.view(batch_size, num_frames-1, -1).sum(2) > 0).float()
991
+ self.large_flow = large_flow
992
+
993
+ flow_loss = flow_loss * flow_loss_mask * (1 - large_flow[:,:,None,None,None])
994
+ num_mask_pixels = flow_loss_mask.reshape(batch_size, num_frames-1, -1).sum(2).clamp(min=1)
995
+ losses['flow_loss'] = (flow_loss.reshape(batch_size, num_frames-1, -1).sum(2) / num_mask_pixels)
996
+ # losses["flow_loss"] = flow_loss.mean()
997
+
998
+ if dino_feat_im_pred is not None:
999
+ dino_feat_loss = (dino_feat_im_pred - dino_feat_im_gt) ** 2
1000
+ dino_feat_loss = dino_feat_loss * mask_both_binary.unsqueeze(2)
1001
+ losses['dino_feat_im_loss'] = dino_feat_loss.reshape(batch_size, num_frames, -1).mean(2)
1002
+
1003
+ if reduce:
1004
+ for k, v in losses.item():
1005
+ losses[k] = v.mean()
1006
+ return losses
1007
+
1008
+ def compute_pose_xflip_reg_loss(self, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None):
1009
+ image_xflip = input_image.flip(4)
1010
+ if dino_feat_im is not None:
1011
+ dino_feat_im_xflip = dino_feat_im.flip(4)
1012
+ else:
1013
+ dino_feat_im_xflip = None
1014
+ feat_xflip, _ = self.netInstance.forward_encoder(image_xflip, dino_feat_im_xflip)
1015
+ batch_size, num_frames = input_image.shape[:2]
1016
+ pose_xflip_raw = self.netInstance.forward_pose(image_xflip, feat_xflip, dino_feat_im_xflip)
1017
+
1018
+ if input_image_xflip_flag is not None:
1019
+ pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x
1020
+ pose_xflip_raw = pose_xflip_raw * (1 - input_image_xflip_flag.view(batch_size * num_frames, 1)) + pose_xflip_raw_xflip * input_image_xflip_flag.view(batch_size * num_frames, 1)
1021
+
1022
+ rot_rep = self.netInstance.rot_rep
1023
+ if rot_rep == 'euler_angle' or rot_rep == 'soft_calss':
1024
+ pose_xflip_xflip = pose_xflip * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x
1025
+ pose_xflip_reg_loss = ((pose_xflip_xflip - pose) ** 2.).mean()
1026
+ elif rot_rep == 'quaternion':
1027
+ rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose[...,:4]), convention='XYZ')
1028
+ pose_euler = torch.cat([rot_euler, pose[...,4:]], -1)
1029
+ rot_xflip_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip[...,:4]), convention='XYZ')
1030
+ pose_xflip_euler = torch.cat([rot_xflip_euler, pose_xflip[...,4:]], -1)
1031
+ pose_xflip_euler_xflip = pose_xflip_euler * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x
1032
+ pose_xflip_reg_loss = ((pose_xflip_euler_xflip - pose_euler) ** 2.).mean()
1033
+ elif rot_rep == 'lookat':
1034
+ pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x
1035
+ pose_xflip_reg_loss = ((pose_xflip_raw_xflip - pose_raw)[...,0] ** 2.) # compute x only
1036
+ # if epoch >= self.nolookat_zflip_loss_epochs and self.lookat_zflip_no_other_losses:
1037
+ # pose_xflip_reg_loss = pose_xflip_reg_loss.mean(1) * is_pose_1_better
1038
+ pose_xflip_reg_loss = pose_xflip_reg_loss.mean()
1039
+ return pose_xflip_reg_loss, pose_xflip_raw
1040
+
1041
+ def compute_edge_length_reg_loss(self, mesh, prior_mesh):
1042
+ prior_edge_lengths = get_edge_length(prior_mesh.v_pos, prior_mesh.t_pos_idx)
1043
+ max_length = prior_edge_lengths.max().detach() *1.1
1044
+ edge_lengths = get_edge_length(mesh.v_pos, mesh.t_pos_idx)
1045
+ mesh_edge_length_loss = ((edge_lengths - max_length).clamp(min=0)**2).mean()
1046
+ return mesh_edge_length_loss, edge_lengths
1047
+
1048
+ def compute_regularizers(self, mesh, prior_mesh, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None, arti_params=None, deformation=None):
1049
+ losses = {}
1050
+ aux = {}
1051
+
1052
+ if self.enable_prior:
1053
+ losses.update(self.netPrior.netShape.get_sdf_reg_loss())
1054
+
1055
+ if self.cfgs.get('pose_xflip_reg_loss_weight', 0.) > 0:
1056
+ losses["pose_xflip_reg_loss"], aux['pose_xflip_raw'] = self.compute_pose_xflip_reg_loss(input_image, dino_feat_im, pose_raw, input_image_xflip_flag)
1057
+
1058
+ b, f = input_image.shape[:2]
1059
+ if b >= 2:
1060
+ vec_forward = pose_raw[..., :3]
1061
+ losses['pose_entropy_loss'] = (vec_forward[:b//2] * vec_forward[b//2:(b//2)*2]).sum(-1).mean()
1062
+ else:
1063
+ losses['pose_entropy_loss'] = 0.
1064
+
1065
+ losses['mesh_normal_consistency_loss'] = normal_consistency(mesh.v_pos, mesh.t_pos_idx)
1066
+ losses['mesh_edge_length_loss'], aux['edge_lengths'] = self.compute_edge_length_reg_loss(mesh, prior_mesh)
1067
+ if arti_params is not None:
1068
+ losses['arti_reg_loss'] = (arti_params ** 2).mean()
1069
+
1070
+ if deformation is not None:
1071
+ losses['deformation_reg_loss'] = (deformation ** 2).mean()
1072
+ # losses['deformation_reg_loss'] = deformation.abs().mean()
1073
+
1074
+ return losses, aux
1075
+
1076
+ def forward(self, batch, epoch, iter, is_train=True, viz_logger=None, total_iter=None, save_results=False, save_dir=None, which_data='', logger_prefix='', is_training=True):
1077
+ batch = [x.to(self.device) if x is not None else None for x in batch]
1078
+ input_image, mask_gt, mask_dt, mask_valid, flow_gt, bbox, bg_image, dino_feat_im, dino_cluster_im, seq_idx, frame_idx = batch
1079
+ batch_size, num_frames, _, h0, w0 = input_image.shape # BxFxCxHxW
1080
+ h = w = self.out_image_size
1081
+
1082
+ def collapseF(x):
1083
+ return None if x is None else x.view(batch_size * num_frames, *x.shape[2:])
1084
+ def expandF(x):
1085
+ return None if x is None else x.view(batch_size, num_frames, *x.shape[1:])
1086
+
1087
+ if flow_gt.dim() == 2: # dummy tensor for not loading flow
1088
+ flow_gt = None
1089
+ if dino_feat_im.dim() == 2: # dummy tensor for not loading dino features
1090
+ dino_feat_im = None
1091
+ dino_feat_im_gt = None
1092
+ else:
1093
+ dino_feat_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_feat_im), size=[h, w], mode="bilinear"))[:, :, :self.dino_feature_recon_dim]
1094
+ if dino_cluster_im.dim() == 2: # dummy tensor for not loading dino clusters
1095
+ dino_cluster_im = None
1096
+ dino_cluster_im_gt = None
1097
+ else:
1098
+ dino_cluster_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_cluster_im), size=[h, w], mode="nearest"))
1099
+
1100
+ seq_idx = seq_idx.squeeze(1)
1101
+ # seq_idx = seq_idx * 0 # single sequnce model
1102
+ frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness = bbox.unbind(2) # BxFx7
1103
+ bbox = torch.stack([crop_x0, crop_y0, crop_w, crop_h], 2)
1104
+ mask_gt = (mask_gt[:, :, 0, :, :] > 0.9).float() # BxFxHxW
1105
+ mask_dt = mask_dt / self.in_image_size
1106
+
1107
+ if which_data != 'video':
1108
+ flow_gt = None
1109
+
1110
+ aux_viz = {}
1111
+
1112
+ ## GT
1113
+ image_gt = input_image
1114
+ if self.out_image_size != self.in_image_size:
1115
+ image_gt = expandF(torch.nn.functional.interpolate(collapseF(image_gt), size=[h, w], mode='bilinear'))
1116
+ if flow_gt is not None:
1117
+ flow_gt = torch.nn.functional.interpolate(flow_gt.view(batch_size*(num_frames-1), 2, h0, w0), size=[h, w], mode="bilinear").view(batch_size, num_frames-1, 2, h, w)
1118
+
1119
+ self.train_pose_only = False
1120
+ if epoch in self.pose_epochs:
1121
+ if (total_iter // self.pose_iters) % 2 == 0:
1122
+ self.train_pose_only = True
1123
+
1124
+ ## flip input and pose
1125
+ if epoch in self.pose_xflip_recon_epochs:
1126
+ input_image_xflip = input_image.flip(-1)
1127
+ input_image_xflip_flag = torch.randint(0, 2, (batch_size, num_frames), device=input_image.device)
1128
+ input_image = input_image * (1 - input_image_xflip_flag[:,:,None,None,None]) + input_image_xflip * input_image_xflip_flag[:,:,None,None,None]
1129
+ else:
1130
+ input_image_xflip_flag = None
1131
+
1132
+ ## 1st pose hypothesis with original predictions
1133
+
1134
+ # ==============================================================================================
1135
+ # Predict prior mesh.
1136
+ # ==============================================================================================
1137
+ if self.enable_prior:
1138
+ if epoch < self.dmtet_grid_smaller_epoch:
1139
+ if self.netPrior.netShape.grid_res != self.dmtet_grid_smaller:
1140
+ self.netPrior.netShape.load_tets(self.dmtet_grid_smaller)
1141
+ else:
1142
+ if self.netPrior.netShape.grid_res != self.dmtet_grid:
1143
+ self.netPrior.netShape.load_tets(self.dmtet_grid)
1144
+
1145
+ perturb_sdf = self.perturb_sdf if is_train else False
1146
+ prior_shape, dino_pred = self.netPrior(perturb_sdf=perturb_sdf, total_iter=total_iter, is_training=is_training)
1147
+ else:
1148
+ prior_shape = None
1149
+ raise NotImplementedError
1150
+
1151
+ shape, pose_raw, pose, mvp, w2c, campos, texture, im_features, deformation, arti_params, light, forward_aux = self.netInstance(input_image, prior_shape, epoch, dino_feat_im, dino_cluster_im, total_iter, is_training=is_training) # frame dim collapsed N=(B*F)
1152
+ rot_logit = forward_aux['rot_logit']
1153
+ rot_idx = forward_aux['rot_idx']
1154
+ rot_prob = forward_aux['rot_prob']
1155
+ aux_viz.update(forward_aux)
1156
+
1157
+ if self.train_pose_only:
1158
+ safe_detach = lambda x: x.detach() if x is not None else None
1159
+ prior_shape = safe_detach(prior_shape)
1160
+ shape = safe_detach(shape)
1161
+ im_features = safe_detach(im_features)
1162
+ arti_params = safe_detach(arti_params)
1163
+ deformation = safe_detach(deformation)
1164
+ set_requires_grad(texture, False)
1165
+ set_requires_grad(light, False)
1166
+ set_requires_grad(dino_pred, False)
1167
+ else:
1168
+ set_requires_grad(texture, True)
1169
+ set_requires_grad(light, True)
1170
+ set_requires_grad(dino_pred, True)
1171
+
1172
+ render_flow = self.render_flow and num_frames > 1
1173
+ image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading = self.render(shape, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=im_features, light=light, prior_shape=prior_shape, render_flow=render_flow, dino_pred=dino_pred, num_frames=num_frames, spp=self.renderer_spp)
1174
+ image_pred, mask_pred, flow_pred, dino_feat_im_pred = map(expandF, (image_pred, mask_pred, flow_pred, dino_feat_im_pred))
1175
+ if flow_pred is not None:
1176
+ flow_pred = flow_pred[:, :-1] # Bx(F-1)x2xHxW
1177
+
1178
+ if self.blur_mask:
1179
+ sigma = max(0.5, 3 * (1 - total_iter / self.blur_mask_iter))
1180
+ if sigma > 0.5:
1181
+ mask_gt = util.blur_image(mask_gt, kernel_size=9, sigma=sigma, mode='gaussian')
1182
+ # mask_pred = util.blur_image(mask_pred, kernel_size=7, mode='average')
1183
+
1184
+ losses = self.compute_reconstruction_losses(image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode=self.background_mode, reduce=False)
1185
+
1186
+ ## TODO: assume flow loss is not used
1187
+ logit_loss_target = torch.zeros_like(expandF(rot_logit))
1188
+ final_losses = {}
1189
+ for name, loss in losses.items():
1190
+ loss_weight_logit = self.cfgs.get(f"{name}_weight", 0.)
1191
+ # if (name in ['flow_loss'] and epoch not in self.flow_loss_epochs) or (name in ['rgb_loss', 'perceptual_loss'] and epoch not in self.texture_epochs):
1192
+ # if name in ['flow_loss', 'rgb_loss', 'perceptual_loss']:
1193
+ # loss_weight_logit = 0.
1194
+ if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']:
1195
+ if total_iter >= self.sdf_reg_decay_start_iter:
1196
+ decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000)
1197
+ loss_weight_logit = max(loss_weight_logit * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.))
1198
+ if name in ['dino_feat_im_loss']:
1199
+ loss_weight_logit = loss_weight_logit * self.cfgs.get("logit_loss_dino_feat_im_loss_multiplier", 1.)
1200
+ if loss_weight_logit > 0:
1201
+ logit_loss_target += loss * loss_weight_logit
1202
+
1203
+ if self.netInstance.rot_rep in ['quadlookat', 'octlookat']:
1204
+ loss = loss * rot_prob.detach().view(batch_size, num_frames)[:, :loss.shape[1]] *self.netInstance.num_pose_hypos
1205
+ if name == 'flow_loss' and num_frames > 1:
1206
+ ri = rot_idx.view(batch_size, num_frames)
1207
+ same_rot_idx = (ri[:, 1:] == ri[:, :-1]).float()
1208
+ loss = loss * same_rot_idx
1209
+ final_losses[name] = loss.mean()
1210
+ final_losses['logit_loss'] = ((expandF(rot_logit) - logit_loss_target.detach())**2.).mean()
1211
+
1212
+ ## regularizers
1213
+ regularizers, aux = self.compute_regularizers(shape, prior_shape, input_image, dino_feat_im, pose_raw, input_image_xflip_flag, arti_params, deformation)
1214
+ final_losses.update(regularizers)
1215
+ aux_viz.update(aux)
1216
+
1217
+ total_loss = 0
1218
+ for name, loss in final_losses.items():
1219
+ loss_weight = self.cfgs.get(f"{name}_weight", 0.)
1220
+ if loss_weight <= 0:
1221
+ continue
1222
+
1223
+ if self.train_pose_only:
1224
+ if name not in ['silhouette_loss', 'silhouette_dt_loss', 'silhouette_inv_dt_loss', 'flow_loss', 'pose_xflip_reg_loss', 'lookat_zflip_loss', 'dino_feat_im_loss']:
1225
+ continue
1226
+ if epoch not in self.flow_loss_epochs:
1227
+ if name in ['flow_loss']:
1228
+ continue
1229
+ if epoch not in self.texture_epochs:
1230
+ if name in ['rgb_loss', 'perceptual_loss']:
1231
+ continue
1232
+ if epoch not in self.lookat_zflip_loss_epochs:
1233
+ if name in ['lookat_zflip_loss']:
1234
+ continue
1235
+ if name in ['mesh_laplacian_smoothing_loss', 'mesh_normal_consistency_loss']:
1236
+ if total_iter < self.cfgs.get('mesh_reg_start_iter', 0):
1237
+ continue
1238
+ if epoch >= self.mesh_reg_decay_epoch:
1239
+ decay_rate = self.mesh_reg_decay_rate ** (epoch - self.mesh_reg_decay_epoch)
1240
+ loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.))
1241
+ if epoch not in self.sdf_inflate_reg_loss_epochs:
1242
+ if name in ['sdf_inflate_reg_loss']:
1243
+ continue
1244
+ if epoch not in self.arti_reg_loss_epochs:
1245
+ if name in ['arti_reg_loss']:
1246
+ continue
1247
+ if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']:
1248
+ if total_iter >= self.sdf_reg_decay_start_iter:
1249
+ decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000)
1250
+ loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.))
1251
+
1252
+ total_loss += loss * loss_weight
1253
+
1254
+ self.total_loss += total_loss # reset to 0 in backward step
1255
+
1256
+ if torch.isnan(self.total_loss):
1257
+ print("NaN in loss...")
1258
+ import ipdb; ipdb.set_trace()
1259
+
1260
+ final_losses['logit_loss_target'] = logit_loss_target.mean()
1261
+
1262
+ metrics = {'loss': total_loss, **final_losses}
1263
+
1264
+ ## log visuals
1265
+ if viz_logger is not None:
1266
+ b0 = max(min(batch_size, 16//num_frames), 1)
1267
+ viz_logger.add_image(logger_prefix+'image/image_gt', misc.image_grid(image_gt.detach().cpu()[:b0,:].reshape(-1,*input_image.shape[2:]).clamp(0,1)), total_iter)
1268
+ viz_logger.add_image(logger_prefix+'image/image_pred', misc.image_grid(image_pred.detach().cpu()[:b0,:].reshape(-1,*image_pred.shape[2:]).clamp(0,1)), total_iter)
1269
+ # viz_logger.add_image(logger_prefix+'image/flow_loss_mask', misc.image_grid(flow_loss_mask[:b0,:,:1].reshape(-1,1,*flow_loss_mask.shape[3:]).repeat(1,3,1,1).clamp(0,1)), total_iter)
1270
+ viz_logger.add_image(logger_prefix+'image/mask_gt', misc.image_grid(mask_gt.detach().cpu()[:b0,:].reshape(-1,*mask_gt.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter)
1271
+ viz_logger.add_image(logger_prefix+'image/mask_pred', misc.image_grid(mask_pred.detach().cpu()[:b0,:].reshape(-1,*mask_pred.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter)
1272
+
1273
+ if self.render_flow and flow_gt is not None:
1274
+ flow_gt = flow_gt.detach().cpu()
1275
+ flow_gt_viz = torch.cat([flow_gt[:b0], torch.zeros_like(flow_gt[:b0,:,:1])], 2) + 0.5 # -0.5~1.5
1276
+ flow_gt_viz = torch.nn.functional.pad(flow_gt_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1])
1277
+
1278
+ ## draw marker on large flow frames
1279
+ large_flow_marker_mask = torch.zeros_like(flow_gt_viz)
1280
+ large_flow_marker_mask[:,:,:,:8,:8] = 1.
1281
+ large_flow = torch.cat([self.large_flow, self.large_flow[:,:1] *0.], 1).detach().cpu()[:b0]
1282
+ large_flow_marker_mask = large_flow_marker_mask * large_flow[:,:,None,None,None]
1283
+ red = torch.FloatTensor([1,0,0])[None,None,:,None,None]
1284
+ flow_gt_viz = large_flow_marker_mask * red + (1-large_flow_marker_mask) * flow_gt_viz
1285
+
1286
+ viz_logger.add_image(logger_prefix+'image/flow_gt', misc.image_grid(flow_gt_viz.reshape(-1,*flow_gt_viz.shape[2:])), total_iter)
1287
+
1288
+ if self.render_flow and flow_pred is not None:
1289
+ flow_pred = flow_pred.detach().cpu()
1290
+ flow_pred_viz = torch.cat([flow_pred[:b0], torch.zeros_like(flow_pred[:b0,:,:1])], 2) + 0.5 # -0.5~1.5
1291
+ flow_pred_viz = torch.nn.functional.pad(flow_pred_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1])
1292
+ viz_logger.add_image(logger_prefix+'image/flow_pred', misc.image_grid(flow_pred_viz.reshape(-1,*flow_pred_viz.shape[2:])), total_iter)
1293
+
1294
+ if light is not None:
1295
+ param_names = ['dir_x', 'dir_y', 'dir_z', 'int_ambient', 'int_diffuse']
1296
+ for name, param in zip(param_names, light.light_params.unbind(-1)):
1297
+ viz_logger.add_histogram(logger_prefix+'light/'+name, param, total_iter)
1298
+ viz_logger.add_image(
1299
+ logger_prefix + f'image/albedo',
1300
+ misc.image_grid(expandF(albedo)[:b0, ...].view(-1, *albedo.shape[1:])),
1301
+ total_iter)
1302
+ viz_logger.add_image(
1303
+ logger_prefix + f'image/shading',
1304
+ misc.image_grid(expandF(shading)[:b0, ...].view(-1, *shading.shape[1:]).repeat(1, 3, 1, 1) /2.),
1305
+ total_iter)
1306
+
1307
+ viz_logger.add_histogram(logger_prefix+'sdf', self.netPrior.netShape.get_sdf(perturb_sdf=False), total_iter)
1308
+ viz_logger.add_histogram(logger_prefix+'coordinates', shape.v_pos, total_iter)
1309
+ if arti_params is not None:
1310
+ viz_logger.add_histogram(logger_prefix+'arti_params', arti_params, total_iter)
1311
+ viz_logger.add_histogram(logger_prefix+'edge_lengths', aux_viz['edge_lengths'], total_iter)
1312
+
1313
+ if deformation is not None:
1314
+ viz_logger.add_histogram(logger_prefix+'deformation', deformation, total_iter)
1315
+
1316
+ rot_rep = self.netInstance.rot_rep
1317
+ if rot_rep == 'euler_angle' or rot_rep == 'soft_calss':
1318
+ for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']):
1319
+ viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter)
1320
+ elif rot_rep == 'quaternion':
1321
+ for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']):
1322
+ viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter)
1323
+ rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose.detach().cpu()[...,:4]), convention='XYZ')
1324
+ for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']):
1325
+ viz_logger.add_histogram(logger_prefix+'pose/'+name, rot_euler[...,i], total_iter)
1326
+ elif rot_rep in ['lookat', 'quadlookat', 'octlookat']:
1327
+ for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']):
1328
+ viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,i], total_iter)
1329
+ for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']):
1330
+ viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,-3+i], total_iter)
1331
+
1332
+ if rot_rep in ['quadlookat', 'octlookat']:
1333
+ for i, rp in enumerate(forward_aux['rots_probs'].unbind(-1)):
1334
+ viz_logger.add_histogram(logger_prefix+'pose/rot_prob_%d'%i, rp, total_iter)
1335
+
1336
+ if 'pose_xflip_raw' in aux_viz:
1337
+ pose_xflip_raw = aux_viz['pose_xflip_raw']
1338
+ if rot_rep == 'euler_angle' or rot_rep == 'soft_calss':
1339
+ for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']):
1340
+ viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter)
1341
+ elif rot_rep == 'quaternion':
1342
+ for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']):
1343
+ viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter)
1344
+ rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip.detach().cpu()[...,:4]), convention='XYZ')
1345
+ for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']):
1346
+ viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, rot_euler[...,i], total_iter)
1347
+ elif rot_rep in ['lookat', 'quadlookat', 'octlookat']:
1348
+ for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']):
1349
+ viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,i], total_iter)
1350
+ for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']):
1351
+ viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,-3+i], total_iter)
1352
+
1353
+ if dino_feat_im_gt is not None:
1354
+ dino_feat_im_gt_first3 = dino_feat_im_gt[:,:,:3]
1355
+ viz_logger.add_image(logger_prefix+'image/dino_feat_im_gt', misc.image_grid(dino_feat_im_gt_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_gt_first3.shape[2:]).clamp(0,1)), total_iter)
1356
+
1357
+ if dino_cluster_im_gt is not None:
1358
+ viz_logger.add_image(logger_prefix+'image/dino_cluster_im_gt', misc.image_grid(dino_cluster_im_gt.detach().cpu()[:b0,:].reshape(-1,*dino_cluster_im_gt.shape[2:]).clamp(0,1)), total_iter)
1359
+
1360
+ if dino_feat_im_pred is not None:
1361
+ dino_feat_im_pred_first3 = dino_feat_im_pred[:,:,:3]
1362
+ viz_logger.add_image(logger_prefix+'image/dino_feat_im_pred', misc.image_grid(dino_feat_im_pred_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_pred_first3.shape[2:]).clamp(0,1)), total_iter)
1363
+
1364
+ for which_shape, modes in self.extra_renders.items():
1365
+ # This is wrong
1366
+ # if which_shape == "prior":
1367
+ # shape_to_render = prior_shape.extend(im_features.shape[0])
1368
+ # needed_im_features = None
1369
+ if which_shape == "instance":
1370
+ shape_to_render = shape
1371
+ needed_im_features = im_features
1372
+ else:
1373
+ raise NotImplementedError
1374
+
1375
+ for mode in modes:
1376
+ rendered, _, _, _, _, _ = self.render(shape_to_render, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, render_mode=mode, render_flow=False, dino_pred=None)
1377
+ if 'kd' in mode:
1378
+ rendered = util.rgb_to_srgb(rendered)
1379
+ rendered = rendered.detach().cpu()
1380
+
1381
+ if 'posed_bones' in aux_viz:
1382
+ rendered_bone_image = self.render_bones(mvp, aux_viz['posed_bones'], (h, w))
1383
+ rendered_bone_image_mask = (rendered_bone_image < 1).any(1, keepdim=True).float()
1384
+ # viz_logger.add_image(logger_prefix+'image/articulation_bones', misc.image_grid(self.render_bones(mvp, aux_viz['posed_bones'])), total_iter)
1385
+ rendered = rendered_bone_image_mask*0.8 * rendered_bone_image + (1-rendered_bone_image_mask*0.8) * rendered
1386
+
1387
+ if rot_rep in ['quadlookat', 'octlookat']:
1388
+ rand_pose_flag = forward_aux['rand_pose_flag'].detach().cpu()
1389
+ rand_pose_marker_mask = torch.zeros_like(rendered)
1390
+ rand_pose_marker_mask[:,:,:16,:16] = 1.
1391
+ rand_pose_marker_mask = rand_pose_marker_mask * rand_pose_flag[:,None,None,None]
1392
+ red = torch.FloatTensor([1,0,0])[None,:,None,None]
1393
+ rendered = rand_pose_marker_mask * red + (1-rand_pose_marker_mask) * rendered
1394
+
1395
+ viz_logger.add_image(
1396
+ logger_prefix + f'image/{which_shape}_{mode}',
1397
+ misc.image_grid(expandF(rendered)[:b0, ...].view(-1, *rendered.shape[1:])),
1398
+ total_iter)
1399
+
1400
+ viz_logger.add_video(
1401
+ logger_prefix + f'animation/{which_shape}_{mode}',
1402
+ self.render_rotation_frames(shape_to_render, texture, light, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, num_frames=15, render_mode=mode, b=1).detach().cpu().unsqueeze(0),
1403
+ total_iter,
1404
+ fps=2)
1405
+
1406
+ viz_logger.add_video(
1407
+ logger_prefix+'animation/prior_image_rotation',
1408
+ self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, b=1).detach().cpu().unsqueeze(0).clamp(0,1),
1409
+ total_iter,
1410
+ fps=2)
1411
+
1412
+ viz_logger.add_video(
1413
+ logger_prefix+'animation/prior_normal_rotation',
1414
+ self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, render_mode='geo_normal', b=1).detach().cpu().unsqueeze(0),
1415
+ total_iter,
1416
+ fps=2)
1417
+
1418
+ if save_results:
1419
+ b0 = self.cfgs.get('num_saved_from_each_batch', batch_size*num_frames)
1420
+ fnames = [f'{total_iter:07d}_{fid:10d}' for fid in collapseF(frame_id.int())][:b0]
1421
+
1422
+ misc.save_images(save_dir, collapseF(image_gt)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_gt', fnames=fnames)
1423
+ misc.save_images(save_dir, collapseF(image_pred)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_pred', fnames=fnames)
1424
+ misc.save_images(save_dir, collapseF(mask_gt)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_gt', fnames=fnames)
1425
+ misc.save_images(save_dir, collapseF(mask_pred)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_pred', fnames=fnames)
1426
+ # tmp_shape = shape.first_n(b0).clone()
1427
+ # tmp_shape.material = texture
1428
+ # feat = im_features[:b0] if im_features is not None else None
1429
+ # misc.save_obj(save_dir, tmp_shape, save_material=False, feat=feat, suffix="mesh", fnames=fnames) # Save the first mesh.
1430
+ # if self.render_flow and flow_gt is not None:
1431
+ # flow_gt_viz = torch.cat([flow_gt, torch.zeros_like(flow_gt[:,:,:1])], 2) + 0.5 # -0.5~1.5
1432
+ # flow_gt_viz = flow_gt_viz.view(-1, *flow_gt_viz.shape[2:])
1433
+ # misc.save_images(save_dir, flow_gt_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_gt', fnames=fnames)
1434
+ # if flow_pred is not None:
1435
+ # flow_pred_viz = torch.cat([flow_pred, torch.zeros_like(flow_pred[:,:,:1])], 2) + 0.5 # -0.5~1.5
1436
+ # flow_pred_viz = flow_pred_viz.view(-1, *flow_pred_viz.shape[2:])
1437
+ # misc.save_images(save_dir, flow_pred_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_pred', fnames=fnames)
1438
+
1439
+ misc.save_txt(save_dir, pose[:b0].detach().cpu().numpy(), suffix='pose', fnames=fnames)
1440
+
1441
+ return metrics
1442
+
1443
+ def save_scores(self, path):
1444
+ header = 'mask_mse, \
1445
+ mask_iou, \
1446
+ image_mse, \
1447
+ flow_mse'
1448
+ mean = self.all_scores.mean(0)
1449
+ std = self.all_scores.std(0)
1450
+ header = header + '\nMean: ' + ',\t'.join(['%.8f'%x for x in mean])
1451
+ header = header + '\nStd: ' + ',\t'.join(['%.8f'%x for x in std])
1452
+ misc.save_scores(path, self.all_scores, header=header)
1453
+ print(header)
1454
+
1455
+ def render_rotation_frames(self, mesh, texture, light, resolution, background='none', im_features=None, prior_shape=None, num_frames=36, render_mode='diffuse', b=None):
1456
+ frames = []
1457
+ if b is None:
1458
+ b = len(mesh)
1459
+ else:
1460
+ mesh = mesh.first_n(b)
1461
+ feat = im_features[:b] if im_features is not None else None
1462
+
1463
+ delta_angle = np.pi / num_frames * 2
1464
+ delta_rot_matrix = torch.FloatTensor([
1465
+ [np.cos(delta_angle), 0, np.sin(delta_angle), 0],
1466
+ [0, 1, 0, 0],
1467
+ [-np.sin(delta_angle), 0, np.cos(delta_angle), 0],
1468
+ [0, 0, 0, 1],
1469
+ ]).to(self.device).repeat(b, 1, 1)
1470
+
1471
+ w2c = torch.FloatTensor(np.diag([1., 1., 1., 1]))
1472
+ w2c[:3, 3] = torch.FloatTensor([0, 0, -self.cam_pos_z_offset *1.1])
1473
+ w2c = w2c.repeat(b, 1, 1).to(self.device)
1474
+ proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device)
1475
+ mvp = torch.bmm(proj, w2c)
1476
+ campos = -w2c[:, :3, 3]
1477
+
1478
+ def rotate_pose(mvp, campos):
1479
+ mvp = torch.matmul(mvp, delta_rot_matrix)
1480
+ campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), campos[:,:,None])[:,:,0]
1481
+ return mvp, campos
1482
+
1483
+ for _ in range(num_frames):
1484
+ image_pred, _, _, _, _, _ = self.render(mesh, texture, mvp, w2c, campos, resolution, background=background, im_features=feat, light=light, prior_shape=prior_shape, render_flow=False, dino_pred=None, render_mode=render_mode, two_sided_shading=False)
1485
+ frames += [misc.image_grid(image_pred)]
1486
+ mvp, campos = rotate_pose(mvp, campos)
1487
+ return torch.stack(frames, dim=0) # Shape: (T, C, H, W)
1488
+
1489
+ def render_bones(self, mvp, bones_pred, size=(256, 256)):
1490
+ bone_world4 = torch.concat([bones_pred, torch.ones_like(bones_pred[..., :1]).to(bones_pred.device)], dim=-1)
1491
+ b, f, num_bones = bone_world4.shape[:3]
1492
+ bones_clip4 = (bone_world4.view(b, f, num_bones*2, 1, 4) @ mvp.transpose(-1, -2).reshape(b, f, 1, 4, 4)).view(b, f, num_bones, 2, 4)
1493
+ bones_uv = bones_clip4[..., :2] / bones_clip4[..., 3:4] # b, f, num_bones, 2, 2
1494
+ dpi = 32
1495
+ fx, fy = size[1] // dpi, size[0] // dpi
1496
+
1497
+ rendered = []
1498
+ for b_idx in range(b):
1499
+ for f_idx in range(f):
1500
+ frame_bones_uv = bones_uv[b_idx, f_idx].cpu().numpy()
1501
+ fig = plt.figure(figsize=(fx, fy), dpi=dpi, frameon=False)
1502
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
1503
+ ax.set_axis_off()
1504
+ for bone in frame_bones_uv:
1505
+ ax.plot(bone[:, 0], bone[:, 1], marker='o', linewidth=8, markersize=20)
1506
+ ax.set_xlim(-1, 1)
1507
+ ax.set_ylim(-1, 1)
1508
+ ax.invert_yaxis()
1509
+ # Convert to image
1510
+ fig.add_axes(ax)
1511
+ fig.canvas.draw_idle()
1512
+ image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
1513
+ w, h = fig.canvas.get_width_height()
1514
+ image.resize(h, w, 3)
1515
+ rendered += [image / 255.]
1516
+ return torch.from_numpy(np.stack(rendered, 0).transpose(0, 3, 1, 2))
1517
+
1518
+ def render_deformation_frames(self, mesh, texture, batch_size, num_frames, resolution, background='none', im_features=None, render_mode='diffuse', b=None):
1519
+ # frames = []
1520
+ # if b is None:
1521
+ # b = batch_size
1522
+ # im_features = im_features[]
1523
+ # mesh = mesh.first_n(num_frames * b)
1524
+ # for i in range(b):
1525
+ # tmp_mesh = mesh.get_m_to_n(i*num_frames:(i+1)*num_frames)
1526
+ pass
video3d/model_ddp.py ADDED
The diff for this file is too large to render. See raw diff
 
video3d/networks.py ADDED
@@ -0,0 +1,1724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision
5
+ import torchvision.models as models
6
+ from typing import Union, List, Tuple
7
+ import os
8
+ import video3d.utils.misc as misc
9
+ import torch.nn.functional as F
10
+ from siren_pytorch import SirenNet
11
+ from video3d.triplane_texture.lift_architecture import Lift_Encoder
12
+ from video3d.triplane_texture.triplane_transformer import Triplane_Transformer
13
+
14
+
15
+ EPS = 1e-7
16
+
17
+
18
+ def get_activation(name, inplace=True, lrelu_param=0.2):
19
+ if name == 'tanh':
20
+ return nn.Tanh()
21
+ elif name == 'sigmoid':
22
+ return nn.Sigmoid()
23
+ elif name == 'relu':
24
+ return nn.ReLU(inplace=inplace)
25
+ elif name == 'lrelu':
26
+ return nn.LeakyReLU(lrelu_param, inplace=inplace)
27
+ else:
28
+ raise NotImplementedError
29
+
30
+
31
+ class MLPWithPositionalEncoding(nn.Module):
32
+ def __init__(self,
33
+ cin,
34
+ cout,
35
+ num_layers,
36
+ nf=256,
37
+ dropout=0,
38
+ activation=None,
39
+ n_harmonic_functions=10,
40
+ omega0=1,
41
+ extra_dim=0,
42
+ embed_concat_pts=True,
43
+ symmetrize=False):
44
+ super().__init__()
45
+ self.extra_dim = extra_dim
46
+
47
+ if n_harmonic_functions > 0:
48
+ self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
49
+ dim_in = cin * 2 * n_harmonic_functions
50
+ self.embed_concat_pts = embed_concat_pts
51
+ if embed_concat_pts:
52
+ dim_in += cin
53
+ else:
54
+ self.embedder = None
55
+ dim_in = cin
56
+
57
+ self.in_layer = nn.Linear(dim_in, nf)
58
+ self.relu = nn.ReLU(inplace=True)
59
+ self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation)
60
+ self.symmetrize = symmetrize
61
+
62
+ def forward(self, x, feat=None):
63
+ assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim
64
+ if self.symmetrize:
65
+ xs, ys, zs = x.unbind(-1)
66
+ x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
67
+
68
+ if self.embedder is not None:
69
+ x_in = self.embedder(x)
70
+ if self.embed_concat_pts:
71
+ x_in = torch.cat([x, x_in], -1)
72
+ else:
73
+ x_in = x
74
+
75
+ x_in = self.relu(self.in_layer(x_in))
76
+
77
+ if feat is not None:
78
+ # if len(feat.shape) == 1:
79
+ # for _ in range(len(x_in.shape) - 1):
80
+ # feat = feat.unsqueeze(0)
81
+ # feat = feat.repeat(*x_in.shape[:-1], 1)
82
+ x_in = torch.concat([x_in, feat], dim=-1)
83
+
84
+ return self.mlp(x_in)
85
+
86
+
87
+ class MLPWithPositionalEncoding_Style(nn.Module):
88
+ def __init__(self,
89
+ cin,
90
+ cout,
91
+ num_layers,
92
+ nf=256,
93
+ dropout=0,
94
+ activation=None,
95
+ n_harmonic_functions=10,
96
+ omega0=1,
97
+ extra_dim=0,
98
+ embed_concat_pts=True,
99
+ symmetrize=False,
100
+ style_choice='film'):
101
+ super().__init__()
102
+ self.extra_dim = extra_dim
103
+
104
+ if n_harmonic_functions > 0:
105
+ self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
106
+ dim_in = cin * 2 * n_harmonic_functions
107
+ self.embed_concat_pts = embed_concat_pts
108
+ if embed_concat_pts:
109
+ dim_in += cin
110
+ else:
111
+ self.embedder = None
112
+ dim_in = cin
113
+
114
+ self.in_layer = nn.Linear(dim_in, nf)
115
+ self.relu = nn.ReLU(inplace=True)
116
+
117
+ if extra_dim == 0:
118
+ self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation)
119
+
120
+ else:
121
+ if style_choice == 'film':
122
+ self.mlp = MLP_FiLM(nf, cout, num_layers, nf, dropout, activation)
123
+ self.style_mlp = MLP(extra_dim, nf*2, 2, nf, dropout, None)
124
+
125
+ elif style_choice == 'mod':
126
+ self.mlp = MLP_Mod(nf, cout, num_layers, nf, dropout, activation)
127
+ self.style_mlp = MLP(extra_dim, nf, 2, nf, dropout, None)
128
+
129
+ else:
130
+ raise NotImplementedError
131
+
132
+ self.style_choice = style_choice
133
+
134
+ self.symmetrize = symmetrize
135
+
136
+ def forward(self, x, feat=None):
137
+ assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim
138
+ if self.symmetrize:
139
+ xs, ys, zs = x.unbind(-1)
140
+ x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
141
+
142
+ if self.embedder is not None:
143
+ x_in = self.embedder(x)
144
+ if self.embed_concat_pts:
145
+ x_in = torch.cat([x, x_in], -1)
146
+ else:
147
+ x_in = x
148
+
149
+ x_in = self.relu(self.in_layer(x_in))
150
+
151
+ if feat is not None:
152
+ style = self.style_mlp(feat)
153
+
154
+ if self.style_choice == 'film':
155
+ style = style.reshape(style.shape[:-1] + (-1, 2))
156
+
157
+ out = self.mlp(x_in, style)
158
+
159
+ else:
160
+ out = self.mlp(x_in)
161
+
162
+ return out
163
+
164
+
165
+ class MLP_FiLM(nn.Module):
166
+ def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None):
167
+ # default no dropout
168
+ super().__init__()
169
+ assert num_layers >= 1
170
+ self.num_layers = num_layers
171
+ if num_layers == 1:
172
+ self.network = Linear_FiLM(cin, cout, bias=False)
173
+ else:
174
+ self.relu = nn.ReLU(inplace=True)
175
+ for i in range(num_layers):
176
+ if i == 0:
177
+ setattr(self, f'linear_{i}', Linear_FiLM(cin, nf, bias=False))
178
+ elif i == (num_layers-1):
179
+ setattr(self, f'linear_{i}', Linear_FiLM(nf, cout, bias=False))
180
+ else:
181
+ setattr(self, f'linear_{i}', Linear_FiLM(nf, nf, bias=False))
182
+
183
+ def forward(self, input, style):
184
+ if self.num_layers == 1:
185
+ out = self.network(input, style)
186
+ else:
187
+ x = input
188
+ for i in range(self.num_layers):
189
+ linear_layer = getattr(self, f'linear_{i}')
190
+ if i == (self.num_layers - 1):
191
+ x = linear_layer(x, style)
192
+ else:
193
+ x = linear_layer(x, style)
194
+ x = self.relu(x)
195
+
196
+ out = x
197
+ return out
198
+
199
+
200
+ class MLP_Mod(nn.Module):
201
+ def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None):
202
+ # default no dropout
203
+ super().__init__()
204
+ assert num_layers >= 1
205
+ self.num_layers = num_layers
206
+ if num_layers == 1:
207
+ self.network = Linear_Mod(cin, cout, bias=False)
208
+ else:
209
+ self.relu = nn.ReLU(inplace=True)
210
+ for i in range(num_layers):
211
+ if i == 0:
212
+ setattr(self, f'linear_{i}', Linear_Mod(cin, nf, bias=False))
213
+ elif i == (num_layers-1):
214
+ setattr(self, f'linear_{i}', Linear_Mod(nf, cout, bias=False))
215
+ else:
216
+ setattr(self, f'linear_{i}', Linear_Mod(nf, nf, bias=False))
217
+
218
+ def forward(self, input, style):
219
+ if self.num_layers == 1:
220
+ out = self.network(input, style)
221
+ else:
222
+ x = input
223
+ for i in range(self.num_layers):
224
+ linear_layer = getattr(self, f'linear_{i}')
225
+ if i == (self.num_layers - 1):
226
+ x = linear_layer(x, style)
227
+ else:
228
+ x = linear_layer(x, style)
229
+ x = self.relu(x)
230
+
231
+ out = x
232
+ return out
233
+
234
+
235
+ import math
236
+
237
+ class Linear_FiLM(nn.Module):
238
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
239
+ device=None, dtype=None) -> None:
240
+ factory_kwargs = {'device': device, 'dtype': dtype}
241
+ super().__init__()
242
+ self.in_features = in_features
243
+ self.out_features = out_features
244
+ self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
245
+ if bias:
246
+ self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
247
+ else:
248
+ self.register_parameter('bias', None)
249
+ self.reset_parameters()
250
+
251
+ def reset_parameters(self) -> None:
252
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
253
+ if self.bias is not None:
254
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
255
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
256
+ nn.init.uniform_(self.bias, -bound, bound)
257
+
258
+ def forward(self, input, style):
259
+ # if input is [..., D], style should be [..., D, 2]
260
+ x = input * style[..., 0] + style[..., 1]
261
+ return torch.nn.functional.linear(x, self.weight, self.bias)
262
+
263
+ def extra_repr(self) -> str:
264
+ return 'in_features={}, out_features={}, bias={}'.format(
265
+ self.in_features, self.out_features, self.bias is not None
266
+ )
267
+
268
+
269
+ class Linear_Mod(nn.Module):
270
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
271
+ device=None, dtype=None) -> None:
272
+ factory_kwargs = {'device': device, 'dtype': dtype}
273
+ super().__init__()
274
+ self.in_features = in_features
275
+ self.out_features = out_features
276
+ self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
277
+ if bias:
278
+ self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
279
+ else:
280
+ self.register_parameter('bias', None)
281
+ self.reset_parameters()
282
+
283
+ def reset_parameters(self) -> None:
284
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
285
+ if self.bias is not None:
286
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
287
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
288
+ nn.init.uniform_(self.bias, -bound, bound)
289
+
290
+ def forward(self, input, style):
291
+ # weight: [out_features, in_features]
292
+ # style: [..., in_features]
293
+ if len(style.shape) > 1:
294
+ style = style.reshape(-1, style.shape[-1])
295
+ style = style[0]
296
+
297
+ weight = self.weight * style.unsqueeze(0)
298
+ decoefs = ((weight * weight).sum(dim=-1, keepdim=True) + 1e-5).sqrt()
299
+ weight = weight / decoefs
300
+
301
+ return torch.nn.functional.linear(input, weight, self.bias)
302
+
303
+ def extra_repr(self) -> str:
304
+ return 'in_features={}, out_features={}, bias={}'.format(
305
+ self.in_features, self.out_features, self.bias is not None
306
+ )
307
+
308
+
309
+ class MLPTextureSimple(nn.Module):
310
+ def __init__(self,
311
+ cin,
312
+ cout,
313
+ num_layers,
314
+ nf=256,
315
+ dropout=0,
316
+ activation=None,
317
+ min_max=None,
318
+ n_harmonic_functions=10,
319
+ omega0=1,
320
+ extra_dim=0,
321
+ embed_concat_pts=True,
322
+ perturb_normal=False,
323
+ symmetrize=False,
324
+ texture_act='relu',
325
+ linear_bias=False):
326
+ super().__init__()
327
+ self.extra_dim = extra_dim
328
+
329
+ if n_harmonic_functions > 0:
330
+ self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
331
+ dim_in = cin * 2 * n_harmonic_functions
332
+ self.embed_concat_pts = embed_concat_pts
333
+ if embed_concat_pts:
334
+ dim_in += cin
335
+ else:
336
+ self.embedder = None
337
+ dim_in = cin
338
+
339
+ self.in_layer = nn.Linear(dim_in, nf)
340
+ self.relu = nn.ReLU(inplace=True)
341
+
342
+ if texture_act == 'sin':
343
+ print('using siren network for texture mlp here')
344
+ self.mlp = SirenNet(
345
+ dim_in=(nf + extra_dim),
346
+ dim_hidden=nf,
347
+ dim_out=cout,
348
+ num_layers=num_layers,
349
+ final_activation=get_activation(activation),
350
+ w0_initial=30,
351
+ use_bias=linear_bias,
352
+ dropout=dropout
353
+ )
354
+ else:
355
+ self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias)
356
+ self.perturb_normal = perturb_normal
357
+ self.symmetrize = symmetrize
358
+ if min_max is not None:
359
+ self.register_buffer('min_max', min_max)
360
+ else:
361
+ self.min_max = None
362
+ self.bsdf = None
363
+
364
+ def sample(self, x, feat=None):
365
+ assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim)
366
+ b, h, w, c = x.shape
367
+
368
+ if self.symmetrize:
369
+ xs, ys, zs = x.unbind(-1)
370
+ x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
371
+
372
+ x = x.view(-1, c)
373
+ if self.embedder is not None:
374
+ x_in = self.embedder(x)
375
+ if self.embed_concat_pts:
376
+ x_in = torch.cat([x, x_in], -1)
377
+ else:
378
+ x_in = x
379
+
380
+ x_in = self.in_layer(x_in)
381
+ if feat is not None:
382
+ feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
383
+ x_in = torch.concat([x_in, feat], dim=-1)
384
+ out = self.mlp(self.relu(x_in))
385
+ if self.min_max is not None:
386
+ out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
387
+ return out.view(b, h, w, -1)
388
+
389
+
390
+ class MLPTextureTriplane(nn.Module):
391
+ def __init__(self,
392
+ cin,
393
+ cout,
394
+ num_layers,
395
+ nf=256,
396
+ dropout=0,
397
+ activation=None,
398
+ min_max=None,
399
+ n_harmonic_functions=10,
400
+ omega0=1,
401
+ extra_dim=0,
402
+ embed_concat_pts=True,
403
+ perturb_normal=False,
404
+ symmetrize=False,
405
+ texture_act='relu',
406
+ linear_bias=False,
407
+ cam_pos_z_offset=10.,
408
+ grid_scale=7,):
409
+ super().__init__()
410
+ self.extra_dim = extra_dim
411
+
412
+ if n_harmonic_functions > 0:
413
+ self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
414
+ dim_in = cin * 2 * n_harmonic_functions
415
+ self.embed_concat_pts = embed_concat_pts
416
+ if embed_concat_pts:
417
+ dim_in += cin
418
+ else:
419
+ self.embedder = None
420
+ dim_in = cin
421
+
422
+ self.in_layer = nn.Linear(dim_in, nf)
423
+ self.relu = nn.ReLU(inplace=True)
424
+
425
+ self.feat_net = Triplane_Transformer(
426
+ emb_dim=256,
427
+ num_layers=8,
428
+ triplane_dim=80,
429
+ triplane_scale=grid_scale
430
+ )
431
+ self.extra_dim -= extra_dim
432
+ self.extra_dim += (self.feat_net.triplane_dim * 3)
433
+
434
+ if texture_act == 'sin':
435
+ print('using siren network for texture mlp here')
436
+ self.mlp = SirenNet(
437
+ dim_in=(nf + self.extra_dim),
438
+ dim_hidden=nf,
439
+ dim_out=cout,
440
+ num_layers=num_layers,
441
+ final_activation=get_activation(activation),
442
+ w0_initial=30,
443
+ use_bias=linear_bias,
444
+ dropout=dropout
445
+ )
446
+ else:
447
+ self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias)
448
+ self.perturb_normal = perturb_normal
449
+ self.symmetrize = symmetrize
450
+ if min_max is not None:
451
+ self.register_buffer('min_max', min_max)
452
+ else:
453
+ self.min_max = None
454
+ self.bsdf = None
455
+
456
+ def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None):
457
+ # assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim)
458
+ b, h, w, c = x.shape
459
+
460
+ if self.symmetrize:
461
+ xs, ys, zs = x.unbind(-1)
462
+ x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
463
+
464
+ if isinstance(feat_map, dict):
465
+ feat_map = feat_map["im_features_map"]
466
+
467
+ feat_map = feat_map.permute(0, 2, 3, 1)
468
+ _, ph, pw, _ = feat_map.shape
469
+ feat_map = feat_map.reshape(feat_map.shape[0], ph*pw, feat_map.shape[-1])
470
+ pts_feat = self.feat_net(feat_map, x.reshape(b, -1, 3))
471
+ pts_c = pts_feat.shape[-1]
472
+ pts_feat = pts_feat.reshape(-1, pts_c)
473
+
474
+ x = x.view(-1, c)
475
+ if self.embedder is not None:
476
+ x_in = self.embedder(x)
477
+ if self.embed_concat_pts:
478
+ x_in = torch.cat([x, x_in], -1)
479
+ else:
480
+ x_in = x
481
+
482
+ x_in = self.in_layer(x_in)
483
+
484
+ x_in = torch.concat([x_in, pts_feat], dim=-1)
485
+
486
+ out = self.mlp(self.relu(x_in))
487
+ if self.min_max is not None:
488
+ out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
489
+ return out.view(b, h, w, -1)
490
+
491
+
492
+ class LocalFeatureBlock(nn.Module):
493
+ def __init__(self, local_feat_dim, input_dim=384, output_dim=384, upscale_num=3):
494
+ super().__init__()
495
+ self.local_feat_dim = local_feat_dim
496
+ self.conv_list = nn.ModuleList([])
497
+ self.upscale_list = nn.ModuleList([])
498
+
499
+ for i in range(upscale_num):
500
+ if i == 0:
501
+ self.conv_list.append(nn.Conv2d(input_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1))
502
+ else:
503
+ self.conv_list.append(nn.Conv2d(local_feat_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1))
504
+ self.upscale_list.append(nn.PixelShuffle(2))
505
+
506
+ self.conv_head = nn.Conv2d(local_feat_dim, output_dim, 3, stride=1, padding=1, dilation=1)
507
+
508
+ def forward(self, x):
509
+ for idx, conv in enumerate(self.conv_list):
510
+ x = conv(x)
511
+ x = self.upscale_list[idx](x)
512
+
513
+ out = self.conv_head(x)
514
+ return out
515
+
516
+
517
+ class MLPTextureLocal(nn.Module):
518
+ def __init__(self,
519
+ cin,
520
+ cout,
521
+ num_layers,
522
+ nf=256,
523
+ dropout=0,
524
+ activation=None,
525
+ min_max=None,
526
+ n_harmonic_functions=10,
527
+ omega0=1,
528
+ extra_dim=0,
529
+ embed_concat_pts=True,
530
+ perturb_normal=False,
531
+ symmetrize=False,
532
+ texture_way=None,
533
+ larger_tex_dim=False,
534
+ cam_pos_z_offset=10.,
535
+ grid_scale=7.):
536
+ super().__init__()
537
+ self.extra_dim = extra_dim
538
+ self.cam_pos_z_offset = cam_pos_z_offset
539
+ self.grid_scale = grid_scale
540
+
541
+ local_feat_dim = 64
542
+
543
+ assert texture_way is not None
544
+ self.texture_way = texture_way
545
+ if 'local' in texture_way and 'global' in texture_way:
546
+ # self.extra_dim = extra_dim + local_feat_dim
547
+ self.extra_dim = extra_dim
548
+ elif 'local' in texture_way and 'global' not in texture_way:
549
+ # self.extra_dim = local_feat_dim
550
+ self.extra_dim = extra_dim
551
+ elif 'local' not in texture_way and 'global' in texture_way:
552
+ self.extra_dim = extra_dim
553
+
554
+ if n_harmonic_functions > 0:
555
+ self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
556
+ dim_in = cin * 2 * n_harmonic_functions
557
+ self.embed_concat_pts = embed_concat_pts
558
+ if embed_concat_pts:
559
+ dim_in += cin
560
+ else:
561
+ self.embedder = None
562
+ dim_in = cin
563
+
564
+ # self.local_feature_block = LocalFeatureBlock(local_feat_dim=local_feat_dim, input_dim=384, output_dim=256)
565
+ self.local_feature_block = nn.Linear(384, nf, bias=False)
566
+
567
+ self.in_layer = nn.Linear(dim_in, nf)
568
+ self.relu = nn.ReLU(inplace=True)
569
+ self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation)
570
+ self.perturb_normal = perturb_normal
571
+ self.symmetrize = symmetrize
572
+ if min_max is not None:
573
+ self.register_buffer('min_max', min_max)
574
+ else:
575
+ self.min_max = None
576
+ self.bsdf = None
577
+
578
+ def get_uv_depth(self, xyz, mvp):
579
+ # xyz: [b, k, 3]
580
+ # mvp: [b, 4, 4]
581
+ cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2))
582
+ cam3 = cam4[..., :3] / cam4[..., 3:4]
583
+ cam_uv = cam3[..., :2]
584
+ # cam_uv = cam_uv.detach()
585
+ cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3)
586
+ cam_depth = cam_depth / self.grid_scale * 2
587
+ cam_depth = cam_depth[..., 2:3]
588
+ # cam_depth = cam_depth.detach()
589
+ return cam_uv, cam_depth
590
+
591
+ def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w):
592
+ # here the xyz is deformed points
593
+ # and we don't cast any symmtery here
594
+ b, k, c = xyz.shape
595
+ THRESHOLD = 1e-4
596
+ if isinstance(feat_map, torch.Tensor):
597
+ coordinates = xyz
598
+ # use pre-symmetry points to get feature and record depth
599
+ cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp)
600
+ cam_uv = cam_uv.detach()
601
+ cam_depth = cam_depth.detach()
602
+
603
+ # get local feature
604
+ feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
605
+
606
+ self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1]
607
+ self.input_pts = coordinates.detach()
608
+
609
+ elif isinstance(feat_map, dict):
610
+ original_mvp = feat_map['original_mvp']
611
+ local_feat_map = feat_map['im_features_map']
612
+ original_depth = self.input_depth[0:b]
613
+
614
+ coordinates = xyz
615
+ cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp)
616
+ cam_uv = cam_uv.detach()
617
+ cam_depth = cam_depth.detach()
618
+
619
+ project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
620
+ project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1]
621
+
622
+ use_mask = cam_depth <= project_depth + THRESHOLD
623
+ feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1])
624
+
625
+ ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value
626
+ return ret_feature
627
+
628
+ def proj_sample(self, xyz, feat_map, mvp, w2c, img_h, img_w, xyz_before_sym=None):
629
+ # the new one with no input feature map upsampling
630
+ # feat_map: [B, C, H, W]
631
+ b, k, c = xyz.shape
632
+ if isinstance(feat_map, torch.Tensor):
633
+ if xyz_before_sym is None:
634
+ coordinates = xyz
635
+ else:
636
+ coordinates = xyz_before_sym
637
+ # use pre-symmetry points to get feature and record depth
638
+ cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp)
639
+ cam_uv = cam_uv.detach()
640
+ cam_depth = cam_depth.detach()
641
+
642
+ # get local feature
643
+ feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
644
+
645
+ self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1]
646
+ self.input_pts = coordinates.detach()
647
+
648
+ elif isinstance(feat_map, dict):
649
+ original_mvp = feat_map['original_mvp']
650
+ local_feat_map = feat_map['im_features_map']
651
+ THRESHOLD = 1e-4
652
+ original_depth = self.input_depth[0:b]
653
+ # if b == 1:
654
+ # from pdb import set_trace; set_trace()
655
+ # tmp_mask = xyz[0].reshape(256, 256, 3).sum(dim=-1) != 0
656
+ # tmp_mask = tmp_mask.cpu().numpy()
657
+ # tmp_mask = tmp_mask * 255
658
+ # src_dp = self.input_depth[0,:,:,0].cpu().numpy()
659
+ # input_pts = self.input_pts[0].cpu().numpy()
660
+ # input_mask = self.input_pts[0].reshape(256, 256, 3).sum(dim=-1) != 0
661
+ # input_mask = input_mask.int().cpu().numpy()
662
+ # input_mask = input_mask * 255
663
+ # np.save('./tmp_save/src_dp.npy', src_dp)
664
+ # np.save('./tmp_save/input_pts.npy', input_pts)
665
+ # import cv2
666
+ # cv2.imwrite('./tmp_save/input_mask.png', input_mask)
667
+ # cv2.imwrite('./tmp_save/mask.png', tmp_mask)
668
+ # test_pts_pos = xyz[0].cpu().numpy()
669
+ # np.save('./tmp_save/test_pts_pos.npy', test_pts_pos)
670
+ # test_pts_raw = xyz_before_sym[0].cpu().numpy()
671
+ # np.save('./tmp_save/test_pts_raw.npy', test_pts_raw)
672
+ # mvp_now = mvp[0].detach().cpu().numpy()
673
+ # mvp_original = original_mvp[0].detach().cpu().numpy()
674
+ # np.save('./tmp_save/mvp_now.npy', mvp_now)
675
+ # np.save('./tmp_save/mvp_original.npy', mvp_original)
676
+ if xyz_before_sym is None:
677
+ # just check the project depth of xyz
678
+ coordinates = xyz
679
+ cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp)
680
+ cam_uv = cam_uv.detach()
681
+ cam_depth = cam_depth.detach()
682
+
683
+ project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
684
+ project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1]
685
+
686
+ use_mask = cam_depth <= project_depth + THRESHOLD
687
+ feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1])
688
+ else:
689
+ # need to double check, but now we are still use symmetry! Even if the two points are all visible in input view
690
+ coords_inp = xyz
691
+ x_check, y_check, z_check = xyz.unbind(-1)
692
+ xyz_check = torch.stack([-1 * x_check, y_check, z_check], -1)
693
+ coords_rev = xyz_check # we directly use neg-x to get the points of another side
694
+
695
+ uv_inp, dp_inp = self.get_uv_depth(coords_inp, original_mvp)
696
+ uv_rev, dp_rev = self.get_uv_depth(coords_rev, original_mvp)
697
+ uv_inp = uv_inp.detach()
698
+ uv_rev = uv_rev.detach()
699
+ dp_inp = dp_inp.detach()
700
+ dp_rev = dp_rev.detach()
701
+
702
+ proj_feat_inp = F.grid_sample(local_feat_map, uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
703
+ proj_feat_rev = F.grid_sample(local_feat_map, uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
704
+
705
+ proj_dp_inp = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1]
706
+ proj_dp_rev = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1]
707
+
708
+ use_mask_inp = dp_inp <= proj_dp_inp + THRESHOLD
709
+ use_mask_rev = dp_rev <= proj_dp_rev + THRESHOLD
710
+
711
+ # for those points we can see in two sides, we use average
712
+ use_mask_inp = use_mask_inp.int()
713
+ use_mask_rev = use_mask_rev.int()
714
+ both_vis = (use_mask_inp == 1) & (use_mask_rev == 1)
715
+ use_mask_inp[both_vis] = 0.5
716
+ use_mask_rev[both_vis] = 0.5
717
+
718
+ feature = proj_feat_inp * use_mask_inp.repeat(1, 1, proj_feat_inp.shape[-1]) + proj_feat_rev * use_mask_rev.repeat(1, 1, proj_feat_rev.shape[-1])
719
+ else:
720
+ raise NotImplementedError
721
+
722
+ ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value
723
+ return ret_feature
724
+
725
+ def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None):
726
+ # assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim)
727
+ b, h, w, c = x.shape
728
+
729
+ xyz_before_sym = None
730
+ if self.symmetrize:
731
+ xyz_before_sym = x.reshape(b, -1, c)
732
+ xs, ys, zs = x.unbind(-1)
733
+ x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
734
+
735
+ mvp = mvp.detach() # [b, 4, 4]
736
+ w2c = w2c.detach() # [b, 4, 4]
737
+
738
+ pts_xyz = x.reshape(b, -1, c)
739
+ deform_xyz = deform_xyz.reshape(b, -1, c)
740
+
741
+ if 'global' in self.texture_way and 'local' in self.texture_way:
742
+ global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
743
+ # local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym)
744
+ local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w)
745
+ # feature_rep = torch.concat([global_feat, local_feat], dim=-1)
746
+ feature_rep = global_feat + local_feat
747
+ elif 'global' not in self.texture_way and 'local' in self.texture_way:
748
+ # local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym)
749
+ local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w)
750
+ feature_rep = local_feat
751
+ elif 'global' in self.texture_way and 'local' not in self.texture_way:
752
+ global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
753
+ feature_rep = global_feat
754
+ else:
755
+ global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
756
+ feature_rep = global_feat
757
+
758
+ x = x.view(-1, c)
759
+
760
+ if self.embedder is not None:
761
+ x_in = self.embedder(x)
762
+ if self.embed_concat_pts:
763
+ x_in = torch.cat([x, x_in], -1)
764
+ else:
765
+ x_in = x
766
+
767
+ x_in = self.in_layer(x_in)
768
+
769
+ # if feat is not None:
770
+ # feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
771
+ # x_in = torch.concat([x_in, feat], dim=-1)
772
+
773
+ x_in = torch.concat([x_in, feature_rep], dim=-1)
774
+
775
+ out = self.mlp(self.relu(x_in))
776
+ if self.min_max is not None:
777
+ out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
778
+ return out.view(b, h, w, -1)
779
+
780
+
781
+ class LiftTexture(nn.Module):
782
+ def __init__(self,
783
+ cin,
784
+ cout,
785
+ num_layers,
786
+ nf=256,
787
+ dropout=0,
788
+ activation=None,
789
+ min_max=None,
790
+ n_harmonic_functions=10,
791
+ omega0=1,
792
+ extra_dim=0,
793
+ embed_concat_pts=True,
794
+ perturb_normal=False,
795
+ symmetrize=False,
796
+ texture_way=None,
797
+ cam_pos_z_offset=10.,
798
+ grid_scale=7.,
799
+ local_feat_dim=128,
800
+ grid_size=32,
801
+ optim_latent=False):
802
+ super().__init__()
803
+ self.extra_dim = extra_dim
804
+ self.cam_pos_z_offset = cam_pos_z_offset
805
+ self.grid_scale = grid_scale
806
+
807
+ assert texture_way is not None
808
+ self.extra_dim = local_feat_dim + extra_dim
809
+
810
+ if n_harmonic_functions > 0:
811
+ self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
812
+ dim_in = cin * 2 * n_harmonic_functions
813
+ self.embed_concat_pts = embed_concat_pts
814
+ if embed_concat_pts:
815
+ dim_in += cin
816
+ else:
817
+ self.embedder = None
818
+ dim_in = cin
819
+
820
+ self.encoder = Lift_Encoder(
821
+ cin=384,
822
+ feat_dim=local_feat_dim,
823
+ grid_scale=grid_scale / 2, # the dmtet is initialized in (-0.5, 0.5)
824
+ grid_size=grid_size,
825
+ optim_latent=optim_latent,
826
+ with_z_feature=True,
827
+ cam_pos_z_offset=cam_pos_z_offset
828
+ )
829
+
830
+
831
+ self.in_layer = nn.Linear(dim_in, nf)
832
+ self.relu = nn.ReLU(inplace=True)
833
+ self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation)
834
+ self.perturb_normal = perturb_normal
835
+ self.symmetrize = symmetrize
836
+ if min_max is not None:
837
+ self.register_buffer('min_max', min_max)
838
+ else:
839
+ self.min_max = None
840
+ self.bsdf = None
841
+
842
+ def get_uv_depth(self, xyz, mvp):
843
+ # xyz: [b, k, 3]
844
+ # mvp: [b, 4, 4]
845
+ cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2))
846
+ cam3 = cam4[..., :3] / cam4[..., 3:4]
847
+ cam_uv = cam3[..., :2]
848
+ # cam_uv = cam_uv.detach()
849
+ cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3)
850
+ cam_depth = cam_depth / self.grid_scale * 2
851
+ cam_depth = cam_depth[..., 2:3]
852
+ # cam_depth = cam_depth.detach()
853
+ return cam_uv, cam_depth
854
+
855
+ def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w):
856
+ # here the xyz is deformed points
857
+ # and we don't cast any symmtery here
858
+ if isinstance(feat_map, torch.Tensor):
859
+ feature = self.encoder(feat_map, mvp, xyz, inference="unproject")
860
+
861
+ elif isinstance(feat_map, dict):
862
+ feature = self.encoder(feat_map['im_features_map'], mvp, xyz, inference="sample")
863
+ C = feature.shape[-1]
864
+ feature = feature.reshape(-1, C)
865
+ return feature
866
+
867
+ def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None):
868
+ # assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim)
869
+ b, h, w, c = x.shape
870
+
871
+ xyz_before_sym = None
872
+ if self.symmetrize:
873
+ xyz_before_sym = x.reshape(b, -1, c)
874
+ xs, ys, zs = x.unbind(-1)
875
+ x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
876
+
877
+ mvp = mvp.detach() # [b, 4, 4]
878
+ w2c = w2c.detach() # [b, 4, 4]
879
+
880
+ pts_xyz = x.reshape(b, -1, c)
881
+ deform_xyz = deform_xyz.reshape(b, -1, c)
882
+
883
+ global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
884
+ local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w)
885
+ feature_rep = torch.concat([global_feat, local_feat], dim=-1)
886
+ x = x.view(-1, c)
887
+
888
+ if self.embedder is not None:
889
+ x_in = self.embedder(x)
890
+ if self.embed_concat_pts:
891
+ x_in = torch.cat([x, x_in], -1)
892
+ else:
893
+ x_in = x
894
+
895
+ x_in = self.in_layer(x_in)
896
+
897
+ # if feat is not None:
898
+ # feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
899
+ # x_in = torch.concat([x_in, feat], dim=-1)
900
+
901
+ x_in = torch.concat([x_in, feature_rep], dim=-1)
902
+
903
+ out = self.mlp(self.relu(x_in))
904
+ if self.min_max is not None:
905
+ out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
906
+ return out.view(b, h, w, -1)
907
+
908
+
909
+ class HarmonicEmbedding(nn.Module):
910
+ def __init__(self, n_harmonic_functions=10, omega0=1):
911
+ """
912
+ Positional Embedding implementation (adapted from Pytorch3D).
913
+ Given an input tensor `x` of shape [minibatch, ... , dim],
914
+ the harmonic embedding layer converts each feature
915
+ in `x` into a series of harmonic features `embedding`
916
+ as follows:
917
+ embedding[..., i*dim:(i+1)*dim] = [
918
+ sin(x[..., i]),
919
+ sin(2*x[..., i]),
920
+ sin(4*x[..., i]),
921
+ ...
922
+ sin(2**self.n_harmonic_functions * x[..., i]),
923
+ cos(x[..., i]),
924
+ cos(2*x[..., i]),
925
+ cos(4*x[..., i]),
926
+ ...
927
+ cos(2**self.n_harmonic_functions * x[..., i])
928
+ ]
929
+ Note that `x` is also premultiplied by `omega0` before
930
+ evaluting the harmonic functions.
931
+ """
932
+ super().__init__()
933
+ self.frequencies = omega0 * (2.0 ** torch.arange(n_harmonic_functions))
934
+
935
+ def forward(self, x):
936
+ """
937
+ Args:
938
+ x: tensor of shape [..., dim]
939
+ Returns:
940
+ embedding: a harmonic embedding of `x`
941
+ of shape [..., n_harmonic_functions * dim * 2]
942
+ """
943
+ embed = (x[..., None] * self.frequencies.to(x.device)).view(*x.shape[:-1], -1)
944
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
945
+
946
+
947
+ class VGGEncoder(nn.Module):
948
+ def __init__(self, cout, pretrained=False):
949
+ super().__init__()
950
+ if pretrained:
951
+ raise NotImplementedError
952
+ vgg = models.vgg16()
953
+ self.vgg_encoder = nn.Sequential(vgg.features, vgg.avgpool)
954
+ self.linear1 = nn.Linear(25088, 4096)
955
+ self.linear2 = nn.Linear(4096, cout)
956
+ self.relu = nn.ReLU(inplace=True)
957
+
958
+ def forward(self, x):
959
+ batch_size, _, _, _ = x.shape
960
+ out = self.relu(self.linear1(self.vgg_encoder(x).view(batch_size, -1)))
961
+ return self.linear2(out)
962
+
963
+
964
+ class ResnetEncoder(nn.Module):
965
+ def __init__(self, cout, pretrained=False):
966
+ super().__init__()
967
+ self.resnet = nn.Sequential(list(models.resnet18(weights="DEFAULT" if pretrained else None).modules())[:-1])
968
+ self.final_linear = nn.Linear(512, cout)
969
+
970
+ def forward(self, x):
971
+ return self.final_linear(self.resnet(x))
972
+
973
+
974
+ class Encoder(nn.Module):
975
+ def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None):
976
+ super().__init__()
977
+ network = [
978
+ nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64
979
+ nn.GroupNorm(16, nf),
980
+ # nn.ReLU(inplace=True),
981
+ nn.LeakyReLU(0.2, inplace=True),
982
+ nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
983
+ nn.GroupNorm(16*2, nf*2),
984
+ # nn.ReLU(inplace=True),
985
+ nn.LeakyReLU(0.2, inplace=True),
986
+ nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
987
+ nn.GroupNorm(16*4, nf*4),
988
+ # nn.ReLU(inplace=True),
989
+ nn.LeakyReLU(0.2, inplace=True),
990
+ nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
991
+ # nn.GroupNorm(16*8, nf*8),
992
+ # nn.ReLU(inplace=True),
993
+ nn.LeakyReLU(0.2, inplace=True),
994
+ ]
995
+
996
+ add_downsample = int(np.log2(in_size//128))
997
+ if add_downsample > 0:
998
+ for _ in range(add_downsample):
999
+ network += [
1000
+ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
1001
+ # nn.GroupNorm(16*8, nf*8),
1002
+ # nn.ReLU(inplace=True),
1003
+ nn.LeakyReLU(0.2, inplace=True),
1004
+ ]
1005
+
1006
+ network += [
1007
+ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
1008
+ nn.LeakyReLU(0.2, inplace=True),
1009
+ ]
1010
+
1011
+ if zdim is None:
1012
+ network += [
1013
+ nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
1014
+ ]
1015
+ else:
1016
+ network += [
1017
+ nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
1018
+ # nn.ReLU(inplace=True),
1019
+ nn.LeakyReLU(0.2, inplace=True),
1020
+ nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False),
1021
+ ]
1022
+
1023
+ if activation is not None:
1024
+ network += [get_activation(activation)]
1025
+ self.network = nn.Sequential(*network)
1026
+
1027
+ def forward(self, input):
1028
+ return self.network(input).reshape(input.size(0), -1)
1029
+
1030
+
1031
+ class EncoderWithDINO(nn.Module):
1032
+ def __init__(self, cin_rgb, cin_dino, cout, in_size=128, zdim=None, nf=64, activation=None):
1033
+ super().__init__()
1034
+ network_rgb_in = [
1035
+ nn.Conv2d(cin_rgb, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64
1036
+ nn.GroupNorm(16, nf),
1037
+ # nn.ReLU(inplace=True),
1038
+ nn.LeakyReLU(0.2, inplace=True),
1039
+ nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
1040
+ nn.GroupNorm(16*2, nf*2),
1041
+ # nn.ReLU(inplace=True),
1042
+ nn.LeakyReLU(0.2, inplace=True),
1043
+ nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
1044
+ nn.GroupNorm(16*4, nf*4),
1045
+ # nn.ReLU(inplace=True),
1046
+ nn.LeakyReLU(0.2, inplace=True),
1047
+ ]
1048
+ self.network_rgb_in = nn.Sequential(*network_rgb_in)
1049
+ network_dino_in = [
1050
+ nn.Conv2d(cin_dino, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64
1051
+ nn.GroupNorm(16, nf),
1052
+ # nn.ReLU(inplace=True),
1053
+ nn.LeakyReLU(0.2, inplace=True),
1054
+ nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
1055
+ nn.GroupNorm(16*2, nf*2),
1056
+ # nn.ReLU(inplace=True),
1057
+ nn.LeakyReLU(0.2, inplace=True),
1058
+ nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
1059
+ nn.GroupNorm(16*4, nf*4),
1060
+ # nn.ReLU(inplace=True),
1061
+ nn.LeakyReLU(0.2, inplace=True),
1062
+ ]
1063
+ self.network_dino_in = nn.Sequential(*network_dino_in)
1064
+
1065
+ network_fusion = [
1066
+ nn.Conv2d(nf*4*2, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
1067
+ # nn.GroupNorm(16*8, nf*8),
1068
+ # nn.ReLU(inplace=True),
1069
+ nn.LeakyReLU(0.2, inplace=True),
1070
+ ]
1071
+
1072
+ add_downsample = int(np.log2(in_size//128))
1073
+ if add_downsample > 0:
1074
+ for _ in range(add_downsample):
1075
+ network_fusion += [
1076
+ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
1077
+ # nn.GroupNorm(16*8, nf*8),
1078
+ # nn.ReLU(inplace=True),
1079
+ nn.LeakyReLU(0.2, inplace=True),
1080
+ ]
1081
+
1082
+ network_fusion += [
1083
+ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
1084
+ nn.LeakyReLU(0.2, inplace=True),
1085
+ ]
1086
+
1087
+ if zdim is None:
1088
+ network_fusion += [
1089
+ nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
1090
+ ]
1091
+ else:
1092
+ network_fusion += [
1093
+ nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
1094
+ # nn.ReLU(inplace=True),
1095
+ nn.LeakyReLU(0.2, inplace=True),
1096
+ nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False),
1097
+ ]
1098
+
1099
+ if activation is not None:
1100
+ network_fusion += [get_activation(activation)]
1101
+ self.network_fusion = nn.Sequential(*network_fusion)
1102
+
1103
+ def forward(self, rgb_image, dino_image):
1104
+ rgb_feat = self.network_rgb_in(rgb_image)
1105
+ dino_feat = self.network_dino_in(dino_image)
1106
+ out = self.network_fusion(torch.cat([rgb_feat, dino_feat], dim=1))
1107
+ return out.reshape(rgb_image.size(0), -1)
1108
+
1109
+
1110
+ class Encoder32(nn.Module):
1111
+ def __init__(self, cin, cout, nf=256, activation=None):
1112
+ super().__init__()
1113
+ network = [
1114
+ nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
1115
+ nn.GroupNorm(nf//4, nf),
1116
+ nn.LeakyReLU(0.2, inplace=True),
1117
+ nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
1118
+ nn.GroupNorm(nf//4, nf),
1119
+ nn.LeakyReLU(0.2, inplace=True),
1120
+ nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
1121
+ nn.GroupNorm(nf//4, nf),
1122
+ nn.LeakyReLU(0.2, inplace=True),
1123
+ nn.Conv2d(nf, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
1124
+ ]
1125
+ if activation is not None:
1126
+ network += [get_activation(activation)]
1127
+ self.network = nn.Sequential(*network)
1128
+
1129
+ def forward(self, input):
1130
+ return self.network(input).reshape(input.size(0), -1)
1131
+
1132
+
1133
+ class MLP(nn.Module):
1134
+ def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, inner_act='relu', linear_bias=False):
1135
+ super().__init__()
1136
+ assert num_layers >= 1
1137
+ layer_act = get_activation(inner_act)
1138
+ if num_layers == 1:
1139
+ network = [nn.Linear(cin, cout, bias=linear_bias)]
1140
+ else:
1141
+ # network = [nn.Linear(cin, nf, bias=False)]
1142
+ # for _ in range(num_layers-2):
1143
+ # network += [
1144
+ # nn.ReLU(inplace=True),
1145
+ # nn.Linear(nf, nf, bias=False)]
1146
+ # if dropout:
1147
+ # network += [nn.Dropout(dropout)]
1148
+ # network += [
1149
+ # nn.ReLU(inplace=True),
1150
+ # nn.Linear(nf, cout, bias=False)]
1151
+ network = [nn.Linear(cin, nf, bias=linear_bias)]
1152
+ for _ in range(num_layers-2):
1153
+ network += [
1154
+ layer_act,
1155
+ nn.Linear(nf, nf, bias=linear_bias)]
1156
+ if dropout:
1157
+ network += [nn.Dropout(dropout)]
1158
+ network += [
1159
+ layer_act,
1160
+ nn.Linear(nf, cout, bias=linear_bias)]
1161
+ if activation is not None:
1162
+ network += [get_activation(activation)]
1163
+ self.network = nn.Sequential(*network)
1164
+
1165
+ def forward(self, input):
1166
+ return self.network(input)
1167
+
1168
+
1169
+ class Embedding(nn.Module):
1170
+ def __init__(self, cin, cout, zdim=128, nf=64, activation=None):
1171
+ super().__init__()
1172
+ network = [
1173
+ nn.Linear(cin, nf, bias=False),
1174
+ nn.ReLU(inplace=True),
1175
+ nn.Linear(nf, zdim, bias=False),
1176
+ nn.ReLU(inplace=True),
1177
+ nn.Linear(zdim, cout, bias=False)]
1178
+ if activation is not None:
1179
+ network += [get_activation(activation)]
1180
+ self.network = nn.Sequential(*network)
1181
+
1182
+ def forward(self, input):
1183
+ return self.network(input.reshape(input.size(0), -1)).reshape(input.size(0), -1)
1184
+
1185
+
1186
+ class PerceptualLoss(nn.Module):
1187
+ def __init__(self, requires_grad=False):
1188
+ super(PerceptualLoss, self).__init__()
1189
+ mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406])
1190
+ std_rgb = torch.FloatTensor([0.229, 0.224, 0.225])
1191
+ self.register_buffer('mean_rgb', mean_rgb)
1192
+ self.register_buffer('std_rgb', std_rgb)
1193
+
1194
+ vgg_pretrained_features = torchvision.models.vgg16(pretrained=True).features
1195
+ self.slice1 = nn.Sequential()
1196
+ self.slice2 = nn.Sequential()
1197
+ self.slice3 = nn.Sequential()
1198
+ self.slice4 = nn.Sequential()
1199
+ for x in range(4):
1200
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
1201
+ for x in range(4, 9):
1202
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
1203
+ for x in range(9, 16):
1204
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
1205
+ for x in range(16, 23):
1206
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
1207
+ if not requires_grad:
1208
+ for param in self.parameters():
1209
+ param.requires_grad = False
1210
+
1211
+ def normalize(self, x):
1212
+ out = x/2 + 0.5
1213
+ out = (out - self.mean_rgb.view(1,3,1,1)) / self.std_rgb.view(1,3,1,1)
1214
+ return out
1215
+
1216
+ def __call__(self, im1, im2, mask=None, conf_sigma=None):
1217
+ im = torch.cat([im1,im2], 0)
1218
+ im = self.normalize(im) # normalize input
1219
+
1220
+ ## compute features
1221
+ feats = []
1222
+ f = self.slice1(im)
1223
+ feats += [torch.chunk(f, 2, dim=0)]
1224
+ f = self.slice2(f)
1225
+ feats += [torch.chunk(f, 2, dim=0)]
1226
+ f = self.slice3(f)
1227
+ feats += [torch.chunk(f, 2, dim=0)]
1228
+ f = self.slice4(f)
1229
+ feats += [torch.chunk(f, 2, dim=0)]
1230
+
1231
+ losses = []
1232
+ for f1, f2 in feats[2:3]: # use relu3_3 features only
1233
+ loss = (f1-f2)**2
1234
+ if conf_sigma is not None:
1235
+ loss = loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log()
1236
+ if mask is not None:
1237
+ b, c, h, w = loss.shape
1238
+ _, _, hm, wm = mask.shape
1239
+ sh, sw = hm//h, wm//w
1240
+ mask0 = nn.functional.avg_pool2d(mask, kernel_size=(sh,sw), stride=(sh,sw)).expand_as(loss)
1241
+ loss = (loss * mask0).sum() / mask0.sum()
1242
+ else:
1243
+ loss = loss.mean()
1244
+ losses += [loss]
1245
+ return sum(losses)
1246
+
1247
+
1248
+ ## from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
1249
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
1250
+ """3x3 convolution with padding"""
1251
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
1252
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
1253
+
1254
+
1255
+ def conv1x1(in_planes, out_planes, stride=1):
1256
+ """1x1 convolution"""
1257
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
1258
+
1259
+
1260
+ class BasicBlock(nn.Module):
1261
+ expansion = 1
1262
+
1263
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
1264
+ base_width=64, dilation=1, norm_layer=None):
1265
+ super(BasicBlock, self).__init__()
1266
+ if groups != 1 or base_width != 64:
1267
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
1268
+ if dilation > 1:
1269
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
1270
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
1271
+ self.conv1 = conv3x3(inplanes, planes, stride)
1272
+ self.relu = nn.ReLU(inplace=True)
1273
+ self.conv2 = conv3x3(planes, planes)
1274
+
1275
+ self.norm_layer = norm_layer
1276
+ if norm_layer is not None:
1277
+ self.bn1 = norm_layer(planes)
1278
+ self.bn2 = norm_layer(planes)
1279
+
1280
+ if inplanes != planes:
1281
+ self.downsample = nn.Sequential(
1282
+ conv1x1(inplanes, planes, stride),
1283
+ norm_layer(planes),
1284
+ )
1285
+ else:
1286
+ self.downsample = None
1287
+ self.stride = stride
1288
+
1289
+ def forward(self, x):
1290
+ identity = x
1291
+
1292
+ out = self.conv1(x)
1293
+ if self.norm_layer is not None:
1294
+ out = self.bn1(out)
1295
+ out = self.relu(out)
1296
+
1297
+ out = self.conv2(out)
1298
+ if self.norm_layer is not None:
1299
+ out = self.bn2(out)
1300
+
1301
+ if self.downsample is not None:
1302
+ identity = self.downsample(x)
1303
+
1304
+ out += identity
1305
+ out = self.relu(out)
1306
+
1307
+ return out
1308
+
1309
+
1310
+ class ResEncoder(nn.Module):
1311
+ def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None):
1312
+ super().__init__()
1313
+ network = [
1314
+ nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64
1315
+ # nn.GroupNorm(16, nf),
1316
+ # nn.ReLU(inplace=True),
1317
+ nn.LeakyReLU(0.2, inplace=True),
1318
+ nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
1319
+ # nn.GroupNorm(16*2, nf*2),
1320
+ # nn.ReLU(inplace=True),
1321
+ nn.LeakyReLU(0.2, inplace=True),
1322
+ BasicBlock(nf*2, nf*2, norm_layer=None),
1323
+ BasicBlock(nf*2, nf*2, norm_layer=None),
1324
+ nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
1325
+ # nn.GroupNorm(16*4, nf*4),
1326
+ # nn.ReLU(inplace=True),
1327
+ nn.LeakyReLU(0.2, inplace=True),
1328
+ BasicBlock(nf*4, nf*4, norm_layer=None),
1329
+ BasicBlock(nf*4, nf*4, norm_layer=None),
1330
+ nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
1331
+ # nn.ReLU(inplace=True),
1332
+ nn.LeakyReLU(0.2, inplace=True),
1333
+ BasicBlock(nf*8, nf*8, norm_layer=None),
1334
+ BasicBlock(nf*8, nf*8, norm_layer=None),
1335
+ ]
1336
+
1337
+ add_downsample = int(np.log2(in_size//64))
1338
+ if add_downsample > 0:
1339
+ for _ in range(add_downsample):
1340
+ network += [
1341
+ nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
1342
+ # nn.ReLU(inplace=True),
1343
+ nn.LeakyReLU(0.2, inplace=True),
1344
+ BasicBlock(nf*8, nf*8, norm_layer=None),
1345
+ BasicBlock(nf*8, nf*8, norm_layer=None),
1346
+ ]
1347
+
1348
+ if zdim is None:
1349
+ network += [
1350
+ nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
1351
+ ]
1352
+ else:
1353
+ network += [
1354
+ nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
1355
+ # nn.ReLU(inplace=True),
1356
+ nn.LeakyReLU(0.2, inplace=True),
1357
+ nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False),
1358
+ ]
1359
+
1360
+ if activation is not None:
1361
+ network += [get_activation(activation)]
1362
+ self.network = nn.Sequential(*network)
1363
+
1364
+ def forward(self, input):
1365
+ return self.network(input).reshape(input.size(0), -1)
1366
+
1367
+
1368
+ class Attention(nn.Module):
1369
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
1370
+ super().__init__()
1371
+ self.num_heads = num_heads
1372
+ head_dim = dim // num_heads
1373
+ self.scale = qk_scale or head_dim ** -0.5
1374
+
1375
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
1376
+ self.attn_drop = nn.Dropout(attn_drop)
1377
+ self.proj = nn.Linear(dim, dim)
1378
+ self.proj_drop = nn.Dropout(proj_drop)
1379
+
1380
+ def forward(self, x):
1381
+ B, N, C = x.shape
1382
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
1383
+ q, k, v = qkv[0], qkv[1], qkv[2]
1384
+
1385
+ attn = (q @ k.transpose(-2, -1)) * self.scale
1386
+ attn = attn.softmax(dim=-1)
1387
+ attn = self.attn_drop(attn)
1388
+
1389
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
1390
+ x = self.proj(x)
1391
+ x = self.proj_drop(x)
1392
+ return x, attn
1393
+
1394
+
1395
+ class ViTEncoder(nn.Module):
1396
+ def __init__(self, cout, which_vit='dino_vits8', pretrained=False, frozen=False, in_size=256, final_layer_type='none', root='/root'):
1397
+ super().__init__()
1398
+ if misc.is_main_process():
1399
+ force_reload = not os.path.exists(os.path.join(root, ".cache/torch/hub/checkpoints/"))
1400
+ else:
1401
+ force_reload = False
1402
+ if "dinov2" in which_vit:
1403
+ self.ViT = torch.hub.load('facebookresearch/dinov2:main', which_vit, pretrained=pretrained, force_reload=force_reload)
1404
+ else:
1405
+ self.ViT = torch.hub.load('facebookresearch/dino:main', which_vit, pretrained=pretrained, force_reload=force_reload)
1406
+
1407
+ if frozen:
1408
+ for p in self.ViT.parameters():
1409
+ p.requires_grad = False
1410
+ if which_vit == 'dino_vits8':
1411
+ self.vit_feat_dim = 384
1412
+ self.patch_size = 8
1413
+ elif which_vit == 'dinov2_vits14':
1414
+ self.vit_feat_dim = 384
1415
+ self.patch_size = 14
1416
+ elif which_vit == 'dino_vitb8':
1417
+ self.vit_feat_dim = 768
1418
+ self.patch_size = 8
1419
+
1420
+ self._feats = []
1421
+ self.hook_handlers = []
1422
+
1423
+ if final_layer_type == 'none':
1424
+ pass
1425
+ elif final_layer_type == 'conv':
1426
+ self.final_layer_patch_out = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None)
1427
+ self.final_layer_patch_key = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None)
1428
+ elif final_layer_type == 'attention':
1429
+ raise NotImplementedError
1430
+ self.final_layer = Attention(
1431
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
1432
+ self.fc = nn.Linear(self.vit_feat_dim, cout)
1433
+ else:
1434
+ raise NotImplementedError
1435
+ self.final_layer_type = final_layer_type
1436
+
1437
+ def _get_hook(self, facet: str):
1438
+ """
1439
+ generate a hook method for a specific block and facet.
1440
+ """
1441
+ if facet in ['attn', 'token']:
1442
+ def _hook(model, input, output):
1443
+ self._feats.append(output)
1444
+ return _hook
1445
+
1446
+ if facet == 'query':
1447
+ facet_idx = 0
1448
+ elif facet == 'key':
1449
+ facet_idx = 1
1450
+ elif facet == 'value':
1451
+ facet_idx = 2
1452
+ else:
1453
+ raise TypeError(f"{facet} is not a supported facet.")
1454
+
1455
+ def _inner_hook(module, input, output):
1456
+ input = input[0]
1457
+ B, N, C = input.shape
1458
+ qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
1459
+ self._feats.append(qkv[facet_idx]) #Bxhxtxd
1460
+ return _inner_hook
1461
+
1462
+ def _register_hooks(self, layers: List[int], facet: str) -> None:
1463
+ """
1464
+ register hook to extract features.
1465
+ :param layers: layers from which to extract features.
1466
+ :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
1467
+ """
1468
+ for block_idx, block in enumerate(self.ViT.blocks):
1469
+ if block_idx in layers:
1470
+ if facet == 'token':
1471
+ self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet)))
1472
+ elif facet == 'attn':
1473
+ self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet)))
1474
+ elif facet in ['key', 'query', 'value']:
1475
+ self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet)))
1476
+ else:
1477
+ raise TypeError(f"{facet} is not a supported facet.")
1478
+
1479
+ def _unregister_hooks(self) -> None:
1480
+ """
1481
+ unregisters the hooks. should be called after feature extraction.
1482
+ """
1483
+ for handle in self.hook_handlers:
1484
+ handle.remove()
1485
+ self.hook_handlers = []
1486
+
1487
+ def forward(self, x, return_patches=False):
1488
+ b, c, h, w = x.shape
1489
+ self._feats = []
1490
+ self._register_hooks([11], 'key')
1491
+ #self._register_hooks([11], 'token')
1492
+ x = self.ViT.prepare_tokens(x)
1493
+ #x = self.ViT.prepare_tokens_with_masks(x)
1494
+
1495
+ for blk in self.ViT.blocks:
1496
+ x = blk(x)
1497
+ out = self.ViT.norm(x)
1498
+ self._unregister_hooks()
1499
+
1500
+ ph, pw = h // self.patch_size, w // self.patch_size
1501
+ patch_out = out[:, 1:] # first is class token
1502
+ patch_out = patch_out.reshape(b, ph, pw, self.vit_feat_dim).permute(0, 3, 1, 2)
1503
+
1504
+ patch_key = self._feats[0][:,:,1:] # B, num_heads, num_patches, dim
1505
+ patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, self.vit_feat_dim, ph, pw)
1506
+
1507
+ if self.final_layer_type == 'none':
1508
+ global_feat_out = out[:, 0].reshape(b, -1) # first is class token
1509
+ global_feat_key = self._feats[0][:, :, 0].reshape(b, -1) # first is class token
1510
+ elif self.final_layer_type == 'conv':
1511
+ global_feat_out = self.final_layer_patch_out(patch_out).view(b, -1)
1512
+ global_feat_key = self.final_layer_patch_key(patch_key).view(b, -1)
1513
+ elif self.final_layer_type == 'attention':
1514
+ raise NotImplementedError
1515
+ else:
1516
+ raise NotImplementedError
1517
+ if not return_patches:
1518
+ patch_out = patch_key = None
1519
+ return global_feat_out, global_feat_key, patch_out, patch_key
1520
+
1521
+
1522
+ class ArticulationNetwork(nn.Module):
1523
+ def __init__(self, net_type, feat_dim, pos_dim, num_layers, nf, n_harmonic_functions=0, omega0=1, activation=None, enable_articulation_idadd=False):
1524
+ super().__init__()
1525
+ if n_harmonic_functions > 0:
1526
+ self.posenc = HarmonicEmbedding(n_harmonic_functions=n_harmonic_functions, omega0=omega0)
1527
+ pos_dim = pos_dim * (n_harmonic_functions * 2 + 1)
1528
+ else:
1529
+ self.posenc = None
1530
+ pos_dim = 4
1531
+ cout = 3
1532
+
1533
+ if net_type == 'mlp':
1534
+ self.network = MLP(
1535
+ feat_dim + pos_dim, # + bone xyz pos and index
1536
+ cout, # We represent the rotation of each bone by its Euler angles ψ, θ, and φ
1537
+ num_layers,
1538
+ nf=nf,
1539
+ dropout=0,
1540
+ activation=activation
1541
+ )
1542
+ elif net_type == 'attention':
1543
+ self.in_layer = nn.Sequential(
1544
+ nn.Linear(feat_dim + pos_dim, nf),
1545
+ nn.GELU(),
1546
+ nn.LayerNorm(nf),
1547
+ )
1548
+ self.blocks = nn.ModuleList([
1549
+ Block(
1550
+ dim=nf, num_heads=8, mlp_ratio=2., qkv_bias=False, qk_scale=None,
1551
+ drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm)
1552
+ for i in range(num_layers)])
1553
+ out_layer = [nn.Linear(nf, cout)]
1554
+ if activation:
1555
+ out_layer += [get_activation(activation)]
1556
+ self.out_layer = nn.Sequential(*out_layer)
1557
+ else:
1558
+ raise NotImplementedError
1559
+ self.net_type = net_type
1560
+ self.enable_articulation_idadd = enable_articulation_idadd
1561
+
1562
+ def forward(self, x, pos):
1563
+ pos_inp = pos
1564
+ if self.posenc is not None:
1565
+ pos = torch.cat([pos, self.posenc(pos)], dim=-1)
1566
+ x = torch.cat([x, pos], dim=-1)
1567
+ if self.enable_articulation_idadd:
1568
+ articulation_id = pos_inp[..., -1:]
1569
+ x = x + articulation_id
1570
+ if self.net_type == 'mlp':
1571
+ out = self.network(x)
1572
+ elif self.net_type == 'attention':
1573
+ x = self.in_layer(x)
1574
+ for blk in self.blocks:
1575
+ x = blk(x)
1576
+ out = self.out_layer(x)
1577
+ else:
1578
+ raise NotImplementedError
1579
+ return out
1580
+
1581
+
1582
+ ## Attention block from ViT (https://github.com/facebookresearch/dino/blob/main/vision_transformer.py)
1583
+ class Attention(nn.Module):
1584
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
1585
+ super().__init__()
1586
+ self.num_heads = num_heads
1587
+ head_dim = dim // num_heads
1588
+ self.scale = qk_scale or head_dim ** -0.5
1589
+
1590
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
1591
+ self.attn_drop = nn.Dropout(attn_drop)
1592
+ self.proj = nn.Linear(dim, dim)
1593
+ self.proj_drop = nn.Dropout(proj_drop)
1594
+
1595
+ def forward(self, x):
1596
+ B, N, C = x.shape
1597
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
1598
+ q, k, v = qkv[0], qkv[1], qkv[2]
1599
+
1600
+ attn = (q @ k.transpose(-2, -1)) * self.scale
1601
+ attn = attn.softmax(dim=-1)
1602
+ attn = self.attn_drop(attn)
1603
+
1604
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
1605
+ x = self.proj(x)
1606
+ x = self.proj_drop(x)
1607
+ return x, attn
1608
+
1609
+
1610
+ class Mlp(nn.Module):
1611
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
1612
+ super().__init__()
1613
+ out_features = out_features or in_features
1614
+ hidden_features = hidden_features or in_features
1615
+ self.fc1 = nn.Linear(in_features, hidden_features)
1616
+ self.act = act_layer()
1617
+ self.fc2 = nn.Linear(hidden_features, out_features)
1618
+ self.drop = nn.Dropout(drop)
1619
+
1620
+ def forward(self, x):
1621
+ x = self.fc1(x)
1622
+ x = self.act(x)
1623
+ x = self.drop(x)
1624
+ x = self.fc2(x)
1625
+ x = self.drop(x)
1626
+ return x
1627
+
1628
+
1629
+ class DropPath(nn.Module):
1630
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
1631
+ """
1632
+ def __init__(self, drop_prob=None):
1633
+ super(DropPath, self).__init__()
1634
+ self.drop_prob = drop_prob
1635
+
1636
+ def forward(self, x):
1637
+ return drop_path(x, self.drop_prob, self.training)
1638
+
1639
+
1640
+ class Block(nn.Module):
1641
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
1642
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
1643
+ super().__init__()
1644
+ self.norm1 = norm_layer(dim)
1645
+ self.attn = Attention(
1646
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
1647
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
1648
+ self.norm2 = norm_layer(dim)
1649
+ mlp_hidden_dim = int(dim * mlp_ratio)
1650
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
1651
+
1652
+ def forward(self, x, return_attention=False):
1653
+ y, attn = self.attn(self.norm1(x))
1654
+ if return_attention:
1655
+ return attn
1656
+ x = x + self.drop_path(y)
1657
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1658
+ return x
1659
+
1660
+
1661
+ class FeatureAttention(nn.Module):
1662
+ def __init__(self, vit_type, pos_dim, embedder_freq=0, zdim=128, img_size=256, activation=None):
1663
+ super().__init__()
1664
+ self.zdim = zdim
1665
+ if embedder_freq > 0:
1666
+ self.posenc = HarmonicEmbedding(n_harmonic_functions=embedder_freq, omega0=1)
1667
+ pos_dim = pos_dim * (embedder_freq * 2 + 1)
1668
+ else:
1669
+ self.posenc = None
1670
+ self.pos_dim = pos_dim
1671
+
1672
+ if vit_type == 'dino_vits8':
1673
+ self.vit_feat_dim = 384
1674
+ patch_size = 8
1675
+ elif which_vit == 'dinov2_vits14':
1676
+ self.vit_feat_dim = 384
1677
+ self.patch_size = 14
1678
+ elif vit_type == 'dino_vitb8':
1679
+ self.vit_feat_dim = 768
1680
+ patch_size = 8
1681
+ else:
1682
+ raise NotImplementedError
1683
+ self.num_patches_per_dim = img_size // patch_size
1684
+
1685
+ self.kv = nn.Sequential(
1686
+ nn.Linear(self.vit_feat_dim, zdim),
1687
+ nn.ReLU(inplace=True),
1688
+ nn.LayerNorm(zdim),
1689
+ nn.Linear(zdim, zdim*2),
1690
+ )
1691
+
1692
+ self.q = nn.Sequential(
1693
+ nn.Linear(pos_dim, zdim),
1694
+ nn.ReLU(inplace=True),
1695
+ nn.LayerNorm(zdim),
1696
+ nn.Linear(zdim, zdim),
1697
+ )
1698
+
1699
+ final_mlp = [
1700
+ nn.Linear(zdim, zdim),
1701
+ nn.ReLU(inplace=True),
1702
+ nn.LayerNorm(zdim),
1703
+ nn.Linear(zdim, self.vit_feat_dim)
1704
+ ]
1705
+ if activation is not None:
1706
+ final_mlp += [get_activation(activation)]
1707
+ self.final_ln = nn.Sequential(*final_mlp)
1708
+
1709
+ def forward(self, x, feat):
1710
+ _, vit_feat_dim, ph, pw = feat.shape
1711
+ assert ph == pw and ph == self.num_patches_per_dim and vit_feat_dim == self.vit_feat_dim
1712
+
1713
+ if self.posenc is not None:
1714
+ x = torch.cat([x, self.posenc(x)], dim=-1)
1715
+ bxf, k, c = x.shape
1716
+ assert c == self.pos_dim
1717
+
1718
+ query = self.q(x)
1719
+ feat_in = feat.view(bxf, vit_feat_dim, ph*pw).permute(0, 2, 1) # N, K, C
1720
+ k, v = self.kv(feat_in).chunk(2, dim=-1)
1721
+ attn = torch.einsum('bnd,bpd->bnp', query, k).softmax(dim=-1)
1722
+ out = torch.einsum('bnp,bpd->bnd', attn, v)
1723
+ out = self.final_ln(out)
1724
+ return out
video3d/render/light.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import os
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import nvdiffrast.torch as dr
15
+
16
+ from . import util
17
+ from . import renderutils as ru
18
+ from ..networks import MLP
19
+
20
+ ######################################################################################
21
+ # Utility functions
22
+ ######################################################################################
23
+
24
+ class cubemap_mip(torch.autograd.Function):
25
+ @staticmethod
26
+ def forward(ctx, cubemap):
27
+ return util.avg_pool_nhwc(cubemap, (2,2))
28
+
29
+ @staticmethod
30
+ def backward(ctx, dout):
31
+ res = dout.shape[1] * 2
32
+ out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda")
33
+ for s in range(6):
34
+ gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
35
+ torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
36
+ indexing='ij')
37
+ v = util.safe_normalize(util.cube_to_dir(s, gx, gy))
38
+ out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')
39
+ return out
40
+
41
+ ######################################################################################
42
+ # Split-sum environment map light source with automatic mipmap generation
43
+ ######################################################################################
44
+
45
+ class EnvironmentLight(torch.nn.Module):
46
+ LIGHT_MIN_RES = 16
47
+
48
+ MIN_ROUGHNESS = 0.08
49
+ MAX_ROUGHNESS = 0.5
50
+
51
+ def __init__(self, base):
52
+ super(EnvironmentLight, self).__init__()
53
+ self.mtx = None
54
+ self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=True)
55
+ self.register_parameter('env_base', self.base)
56
+
57
+ def xfm(self, mtx):
58
+ self.mtx = mtx
59
+
60
+ def clone(self):
61
+ return EnvironmentLight(self.base.clone().detach())
62
+
63
+ def clamp_(self, min=None, max=None):
64
+ self.base.clamp_(min, max)
65
+
66
+ def get_mip(self, roughness):
67
+ return torch.where(roughness < self.MAX_ROUGHNESS
68
+ , (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2)
69
+ , (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2)
70
+
71
+ def build_mips(self, cutoff=0.99):
72
+ self.specular = [self.base]
73
+ while self.specular[-1].shape[1] > self.LIGHT_MIN_RES:
74
+ self.specular += [cubemap_mip.apply(self.specular[-1])]
75
+
76
+ self.diffuse = ru.diffuse_cubemap(self.specular[-1])
77
+
78
+ for idx in range(len(self.specular) - 1):
79
+ roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS
80
+ self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff)
81
+ self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff)
82
+
83
+ def regularizer(self):
84
+ white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0
85
+ return torch.mean(torch.abs(self.base - white))
86
+
87
+ def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True):
88
+ wo = util.safe_normalize(view_pos - gb_pos)
89
+
90
+ if specular:
91
+ roughness = ks[..., 1:2] # y component
92
+ metallic = ks[..., 2:3] # z component
93
+ spec_col = (1.0 - metallic)*0.04 + kd * metallic
94
+ diff_col = kd * (1.0 - metallic)
95
+ else:
96
+ diff_col = kd
97
+
98
+ reflvec = util.safe_normalize(util.reflect(wo, gb_normal))
99
+ nrmvec = gb_normal
100
+ if self.mtx is not None: # Rotate lookup
101
+ mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda')
102
+ reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
103
+ nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
104
+
105
+ # Diffuse lookup
106
+ diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube')
107
+ shaded_col = diffuse * diff_col
108
+
109
+ if specular:
110
+ # Lookup FG term from lookup texture
111
+ NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4)
112
+ fg_uv = torch.cat((NdotV, roughness), dim=-1)
113
+ if not hasattr(self, '_FG_LUT'):
114
+ self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda')
115
+ fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp')
116
+
117
+ # Roughness adjusted specular env lookup
118
+ miplevel = self.get_mip(roughness)
119
+ spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube')
120
+
121
+ # Compute aggregate lighting
122
+ reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2]
123
+ shaded_col += spec * reflectance
124
+
125
+ return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility
126
+
127
+ ######################################################################################
128
+ # Load and store
129
+ ######################################################################################
130
+
131
+ # Load from latlong .HDR file
132
+ def _load_env_hdr(fn, scale=1.0):
133
+ latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale
134
+ cubemap = util.latlong_to_cubemap(latlong_img, [512, 512])
135
+
136
+ l = EnvironmentLight(cubemap)
137
+ l.build_mips()
138
+
139
+ return l
140
+
141
+ def load_env(fn, scale=1.0):
142
+ if os.path.splitext(fn)[1].lower() == ".hdr":
143
+ return _load_env_hdr(fn, scale)
144
+ else:
145
+ assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1]
146
+
147
+ def save_env_map(fn, light):
148
+ assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently"
149
+ if isinstance(light, EnvironmentLight):
150
+ color = util.cubemap_to_latlong(light.base, [512, 1024])
151
+ util.save_image_raw(fn, color.detach().cpu().numpy())
152
+
153
+ ######################################################################################
154
+ # Create trainable env map with random initialization
155
+ ######################################################################################
156
+
157
+ def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25):
158
+ base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias
159
+ return EnvironmentLight(base)
160
+
161
+
162
+ ######################################################################################
163
+ # Directional light source
164
+ ######################################################################################
165
+
166
+ class DirectionalLight(torch.nn.Module):
167
+ def __init__(self, mlp_in, mlp_layers, mlp_hidden_size, intensity_min_max=None):
168
+ super(DirectionalLight, self).__init__()
169
+ self.mlp = MLP(mlp_in, 4, mlp_layers, nf=mlp_hidden_size, activation='sigmoid')
170
+ if intensity_min_max is not None:
171
+ self.register_buffer('intensity_min_max', intensity_min_max)
172
+ else:
173
+ self.intensity_min_max = None
174
+
175
+ def forward(self, feat):
176
+ # print('----------------- forward light !!! -----------------')
177
+ out = self.mlp(feat)
178
+ light_dir = F.normalize(torch.cat([out[..., 0:1] *2-1, torch.ones_like(out[..., :1]) * 0.5, out[..., 1:2] *2-1], dim=-1), dim=-1) # upper hemisphere
179
+ if self.intensity_min_max is not None:
180
+ int = out[..., 2:] * (self.intensity_min_max[1][None, :] - self.intensity_min_max[0][None, :]) + self.intensity_min_max[0][None, :]
181
+ self.light_params = torch.cat([light_dir, int], -1)
182
+ return self.light_params
183
+
184
+ def shade(self, feat, kd, normal):
185
+ light_params = self.forward(feat)
186
+ light_dir = light_params[..., :3][:, None, None, :]
187
+ int_amb = light_params[..., 3:4][:, None, None, :]
188
+ int_diff = light_params[..., 4:5][:, None, None, :]
189
+ shading = (int_amb + int_diff * torch.clamp(util.dot(light_dir, normal), min=0.0))
190
+ shaded = shading * kd
191
+ return shaded, shading
video3d/render/material.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import os
11
+ import numpy as np
12
+ import torch
13
+ import nvdiffrast.torch as dr
14
+ import cv2
15
+
16
+ from video3d.render.render import render_uv
17
+
18
+ from . import util
19
+ from . import texture
20
+ from . import mlptexture
21
+ from ..utils import misc
22
+
23
+ ######################################################################################
24
+ # Wrapper to make materials behave like a python dict, but register textures as
25
+ # torch.nn.Module parameters.
26
+ ######################################################################################
27
+ class Material(torch.nn.Module):
28
+ def __init__(self, mat_dict):
29
+ super(Material, self).__init__()
30
+ self.mat_keys = set()
31
+ for key in mat_dict.keys():
32
+ self.mat_keys.add(key)
33
+ self[key] = mat_dict[key]
34
+
35
+ def __contains__(self, key):
36
+ return hasattr(self, key)
37
+
38
+ def __getitem__(self, key):
39
+ return getattr(self, key)
40
+
41
+ def __setitem__(self, key, val):
42
+ self.mat_keys.add(key)
43
+ setattr(self, key, val)
44
+
45
+ def __delitem__(self, key):
46
+ self.mat_keys.remove(key)
47
+ delattr(self, key)
48
+
49
+ def keys(self):
50
+ return self.mat_keys
51
+
52
+ ######################################################################################
53
+ # .mtl material format loading / storing
54
+ ######################################################################################
55
+ @torch.no_grad()
56
+ def load_mtl(fn, clear_ks=True):
57
+ import re
58
+ mtl_path = os.path.dirname(fn)
59
+
60
+ # Read file
61
+ with open(fn, 'r') as f:
62
+ lines = f.readlines()
63
+
64
+ # Parse materials
65
+ materials = []
66
+ for line in lines:
67
+ split_line = re.split(' +|\t+|\n+', line.strip())
68
+ prefix = split_line[0].lower()
69
+ data = split_line[1:]
70
+ if 'newmtl' in prefix:
71
+ material = Material({'name' : data[0]})
72
+ materials += [material]
73
+ elif materials:
74
+ if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix:
75
+ material[prefix] = data[0]
76
+ else:
77
+ material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
78
+
79
+ # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps
80
+ for mat in materials:
81
+ if not 'bsdf' in mat:
82
+ mat['bsdf'] = 'pbr'
83
+
84
+ if 'map_kd' in mat:
85
+ mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']))
86
+ else:
87
+ mat['kd'] = texture.Texture2D(mat['kd'])
88
+
89
+ if 'map_ks' in mat:
90
+ mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3)
91
+ else:
92
+ mat['ks'] = texture.Texture2D(mat['ks'])
93
+
94
+ if 'bump' in mat:
95
+ mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3)
96
+
97
+ # Convert Kd from sRGB to linear RGB
98
+ mat['kd'] = texture.srgb_to_rgb(mat['kd'])
99
+
100
+ if clear_ks:
101
+ # Override ORM occlusion (red) channel by zeros. We hijack this channel
102
+ for mip in mat['ks'].getMips():
103
+ mip[..., 0] = 0.0
104
+
105
+ return materials
106
+
107
+ @torch.no_grad()
108
+ def save_mtl(fn, material, mesh=None, feat=None, resolution=[256, 256], prior_shape=None):
109
+ folder = os.path.dirname(fn)
110
+ file = os.path.basename(fn)
111
+ prefix = '_'.join(file.split('_')[:-1]) + '_'
112
+ with open(fn, "w") as f:
113
+ f.write('newmtl defaultMat\n')
114
+ if material is not None:
115
+ f.write('bsdf %s\n' % material['bsdf'])
116
+ if 'kd_ks_normal' in material.keys():
117
+ assert mesh is not None
118
+ glctx = dr.RasterizeGLContext()
119
+ mask, kd, ks, normal = render_uv(glctx, mesh, resolution, material['kd_ks_normal'], feat=feat, prior_shape=prior_shape)
120
+
121
+ hole_mask = 1. - mask
122
+ hole_mask = hole_mask.int()[0]
123
+ def uv_padding(image):
124
+ uv_padding_size = 4
125
+ inpaint_image = (
126
+ cv2.inpaint(
127
+ (image.detach().cpu().numpy() * 255).astype(np.uint8),
128
+ (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
129
+ uv_padding_size,
130
+ cv2.INPAINT_TELEA,
131
+ )
132
+ / 255.0
133
+ )
134
+ return torch.from_numpy(inpaint_image).to(image)
135
+
136
+ kd = uv_padding(kd[0])[None]
137
+
138
+ batch_size = kd.shape[0]
139
+ f.write(f'map_Kd {prefix}texture_kd.png\n')
140
+ misc.save_images(folder, kd.permute(0,3,1,2).detach().cpu().numpy(), fnames=[prefix + "texture_kd"] * batch_size)
141
+ f.write(f'map_Ks {prefix}texture_ks.png\n')
142
+ misc.save_images(folder, ks.permute(0,3,1,2).detach().cpu().numpy(), fnames=[prefix + "texture_ks"] * batch_size)
143
+ # disable normal
144
+ # f.write(f'bump {prefix}texture_n.png\n')
145
+ # misc.save_images(folder, normal.permute(0,3,1,2).detach().cpu().numpy(), fnames=[prefix + "texture_n"] * batch_size)
146
+ if 'kd' in material.keys():
147
+ f.write('map_Kd texture_kd.png\n')
148
+ texture.save_texture2D(os.path.join(folder, 'texture_Kd.png'), texture.rgb_to_srgb(material['kd']))
149
+ if 'ks' in material.keys():
150
+ f.write('map_Ks texture_ks.png\n')
151
+ texture.save_texture2D(os.path.join(folder, 'texture_Ks.png'), material['ks'])
152
+ if 'normal' in material.keys():
153
+ f.write('bump texture_n.png\n')
154
+ texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5)
155
+ else:
156
+ f.write('Kd 1 1 1\n')
157
+ f.write('Ks 0 0 0\n')
158
+ f.write('Ka 0 0 0\n')
159
+ f.write('Tf 1 1 1\n')
160
+ f.write('Ni 1\n')
161
+ f.write('Ns 0\n')
162
+
163
+ ######################################################################################
164
+ # Merge multiple materials into a single uber-material
165
+ ######################################################################################
166
+
167
+ def _upscale_replicate(x, full_res):
168
+ x = x.permute(0, 3, 1, 2)
169
+ x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate')
170
+ return x.permute(0, 2, 3, 1).contiguous()
171
+
172
+ def merge_materials(materials, texcoords, tfaces, mfaces):
173
+ assert len(materials) > 0
174
+ for mat in materials:
175
+ assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)"
176
+ assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled"
177
+
178
+ uber_material = Material({
179
+ 'name' : 'uber_material',
180
+ 'bsdf' : materials[0]['bsdf'],
181
+ })
182
+
183
+ textures = ['kd', 'ks', 'normal']
184
+
185
+ # Find maximum texture resolution across all materials and textures
186
+ max_res = None
187
+ for mat in materials:
188
+ for tex in textures:
189
+ tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1])
190
+ max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res
191
+
192
+ # Compute size of compund texture and round up to nearest PoT
193
+ full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int)
194
+
195
+ # Normalize texture resolution across all materials & combine into a single large texture
196
+ for tex in textures:
197
+ if tex in materials[0]:
198
+ tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x
199
+ tex_data = _upscale_replicate(tex_data, full_res)
200
+ uber_material[tex] = texture.Texture2D(tex_data)
201
+
202
+ # Compute scaling values for used / unused texture area
203
+ s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]]
204
+
205
+ # Recompute texture coordinates to cooincide with new composite texture
206
+ new_tverts = {}
207
+ new_tverts_data = []
208
+ for fi in range(len(tfaces)):
209
+ matIdx = mfaces[fi]
210
+ for vi in range(3):
211
+ ti = tfaces[fi][vi]
212
+ if not (ti in new_tverts):
213
+ new_tverts[ti] = {}
214
+ if not (matIdx in new_tverts[ti]): # create new vertex
215
+ new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
216
+ new_tverts[ti][matIdx] = len(new_tverts_data) - 1
217
+ tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex
218
+
219
+ return uber_material, new_tverts_data, tfaces
220
+
221
+ ######################################################################################
222
+ # Utility functions for material
223
+ ######################################################################################
224
+
225
+ def initial_guess_material(cfgs, mlp=False, init_mat=None, tet_bbox=None):
226
+ kd_min = torch.tensor(cfgs.get('kd_min', [0., 0., 0., 0.]), dtype=torch.float32)
227
+ kd_max = torch.tensor(cfgs.get('kd_max', [1., 1., 1., 1.]), dtype=torch.float32)
228
+ ks_min = torch.tensor(cfgs.get('ks_min', [0., 0., 0.]), dtype=torch.float32)
229
+ ks_max = torch.tensor(cfgs.get('ks_max', [0., 0., 0.]), dtype=torch.float32)
230
+ nrm_min = torch.tensor(cfgs.get('nrm_min', [-1., -1., 0.]), dtype=torch.float32)
231
+ nrm_max = torch.tensor(cfgs.get('nrm_max', [1., 1., 1.]), dtype=torch.float32)
232
+ if mlp:
233
+ num_layers = cfgs.get("num_layers_tex", 5)
234
+ nf = cfgs.get("hidden_size", 128)
235
+ enable_encoder = cfgs.get("enable_encoder", False)
236
+ feat_dim = cfgs.get("latent_dim", 64) if enable_encoder else 0
237
+
238
+ mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
239
+ mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
240
+ min_max = torch.stack((mlp_min, mlp_max), dim=0)
241
+ out_chn = 9
242
+ mlp_map_opt = mlptexture.MLPTexture3D(tet_bbox, channels=out_chn, internal_dims=nf, hidden=num_layers-1, feat_dim=feat_dim, min_max=min_max)
243
+ mat = Material({'kd_ks_normal' : mlp_map_opt})
244
+ else:
245
+ # Setup Kd (albedo) and Ks (x, roughness, metalness) textures
246
+ if cfgs.random_textures or init_mat is None:
247
+ num_channels = 4 if cfgs.layers > 1 else 3
248
+ kd_init = torch.rand(size=cfgs.texture_res + [num_channels]) * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
249
+ kd_map_opt = texture.create_trainable(kd_init , cfgs.texture_res, not cfgs.custom_mip, [kd_min, kd_max])
250
+
251
+ ksR = np.random.uniform(size=cfgs.texture_res + [1], low=0.0, high=0.01)
252
+ ksG = np.random.uniform(size=cfgs.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
253
+ ksB = np.random.uniform(size=cfgs.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
254
+
255
+ ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), cfgs.texture_res, not cfgs.custom_mip, [ks_min, ks_max])
256
+ else:
257
+ kd_map_opt = texture.create_trainable(init_mat['kd'], cfgs.texture_res, not cfgs.custom_mip, [kd_min, kd_max])
258
+ ks_map_opt = texture.create_trainable(init_mat['ks'], cfgs.texture_res, not cfgs.custom_mip, [ks_min, ks_max])
259
+
260
+ # Setup normal map
261
+ if cfgs.random_textures or init_mat is None or 'normal' not in init_mat:
262
+ normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), cfgs.texture_res, not cfgs.custom_mip, [nrm_min, nrm_max])
263
+ else:
264
+ normal_map_opt = texture.create_trainable(init_mat['normal'], cfgs.texture_res, not cfgs.custom_mip, [nrm_min, nrm_max])
265
+
266
+ mat = Material({
267
+ 'kd' : kd_map_opt,
268
+ 'ks' : ks_map_opt,
269
+ 'normal' : normal_map_opt
270
+ })
271
+
272
+ if init_mat is not None:
273
+ mat['bsdf'] = init_mat['bsdf']
274
+ elif "bsdf" in cfgs:
275
+ mat['bsdf'] = cfgs["bsdf"]
276
+ else:
277
+ mat['bsdf'] = 'pbr'
278
+
279
+ if not cfgs.get("perturb_normal", False):
280
+ mat['no_perturbed_nrm'] = True
281
+
282
+ return mat
video3d/render/mesh.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ from difflib import unified_diff
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+
15
+ from . import obj
16
+ from . import util
17
+
18
+ #########################################################################################
19
+ # Base mesh class
20
+ #
21
+ # Minibatch in mesh is supported, as long as each mesh shares the same edge connectivity.
22
+ #########################################################################################
23
+ class Mesh:
24
+ def __init__(self,
25
+ v_pos=None,
26
+ t_pos_idx=None,
27
+ v_nrm=None,
28
+ t_nrm_idx=None,
29
+ v_tex=None,
30
+ t_tex_idx=None,
31
+ v_tng=None,
32
+ t_tng_idx=None,
33
+ material=None,
34
+ base=None):
35
+ self.v_pos = v_pos
36
+ self.v_nrm = v_nrm
37
+ self.v_tex = v_tex
38
+ self.v_tng = v_tng
39
+ self.t_pos_idx = t_pos_idx
40
+ self.t_nrm_idx = t_nrm_idx
41
+ self.t_tex_idx = t_tex_idx
42
+ self.t_tng_idx = t_tng_idx
43
+ self.material = material
44
+
45
+ if base is not None:
46
+ self.copy_none(base)
47
+
48
+ def __len__(self):
49
+ return len(self.v_pos)
50
+
51
+ def copy_none(self, other):
52
+ if self.v_pos is None:
53
+ self.v_pos = other.v_pos
54
+ if self.t_pos_idx is None:
55
+ self.t_pos_idx = other.t_pos_idx
56
+ if self.v_nrm is None:
57
+ self.v_nrm = other.v_nrm
58
+ if self.t_nrm_idx is None:
59
+ self.t_nrm_idx = other.t_nrm_idx
60
+ if self.v_tex is None:
61
+ self.v_tex = other.v_tex
62
+ if self.t_tex_idx is None:
63
+ self.t_tex_idx = other.t_tex_idx
64
+ if self.v_tng is None:
65
+ self.v_tng = other.v_tng
66
+ if self.t_tng_idx is None:
67
+ self.t_tng_idx = other.t_tng_idx
68
+ if self.material is None:
69
+ self.material = other.material
70
+
71
+ def clone(self):
72
+ out = Mesh(base=self)
73
+ if out.v_pos is not None:
74
+ out.v_pos = out.v_pos.clone().detach()
75
+ if out.t_pos_idx is not None:
76
+ out.t_pos_idx = out.t_pos_idx.clone().detach()
77
+ if out.v_nrm is not None:
78
+ out.v_nrm = out.v_nrm.clone().detach()
79
+ if out.t_nrm_idx is not None:
80
+ out.t_nrm_idx = out.t_nrm_idx.clone().detach()
81
+ if out.v_tex is not None:
82
+ out.v_tex = out.v_tex.clone().detach()
83
+ if out.t_tex_idx is not None:
84
+ out.t_tex_idx = out.t_tex_idx.clone().detach()
85
+ if out.v_tng is not None:
86
+ out.v_tng = out.v_tng.clone().detach()
87
+ if out.t_tng_idx is not None:
88
+ out.t_tng_idx = out.t_tng_idx.clone().detach()
89
+ return out
90
+
91
+ def detach(self):
92
+ return self.clone()
93
+
94
+ def extend(self, N: int):
95
+ """
96
+ Create new Mesh class which contains each input mesh N times.
97
+
98
+ Args:
99
+ N: number of new copies of each mesh.
100
+
101
+ Returns:
102
+ new Mesh object.
103
+ """
104
+ verts = self.v_pos.repeat(N, 1, 1)
105
+ faces = self.t_pos_idx
106
+ uvs = self.v_tex.repeat(N, 1, 1)
107
+ uv_idx = self.t_tex_idx
108
+ mat = self.material
109
+
110
+ return make_mesh(verts, faces, uvs, uv_idx, self.material)
111
+
112
+ def deform(self, deformation):
113
+ """
114
+ Create new Mesh class which is obtained by performing the deformation to the self.
115
+
116
+ Args:
117
+ deformation: tensor with shape (B, V, 3)
118
+
119
+ Returns:
120
+ new Mesh object after the deformation.
121
+ """
122
+ assert deformation.shape[1] == self.v_pos.shape[1] and deformation.shape[2] == 3
123
+ verts = self.v_pos + deformation
124
+ return make_mesh(verts, self.t_pos_idx, self.v_tex.repeat(len(verts), 1, 1), self.t_tex_idx, self.material)
125
+
126
+ def get_m_to_n(self, m: int, n: int):
127
+ """
128
+ Create new Mesh class with the n-th (included) mesh to the m-th (not included) mesh in the batch.
129
+
130
+ Args:
131
+ m: the index of the starting mesh to be contained.
132
+ n: the index of the first mesh not to be contained.
133
+ """
134
+ verts = self.v_pos[m:n, ...]
135
+ faces = self.t_pos_idx
136
+ uvs = self.v_tex[m:n, ...]
137
+ uv_idx = self.t_tex_idx
138
+ mat = self.material
139
+
140
+ return make_mesh(verts, faces, uvs, uv_idx, mat)
141
+
142
+ def first_n(self, n: int):
143
+ """
144
+ Create new Mesh class with only the first n meshes in the batch.
145
+
146
+ Args:
147
+ n: number of meshes to be contained.
148
+
149
+ Returns:
150
+ new Mesh object with the first n meshes.
151
+ """
152
+ return self.get_m_to_n(0, n)
153
+ verts = self.v_pos[:n, ...]
154
+ faces = self.t_pos_idx
155
+ uvs = self.v_tex[:n, ...]
156
+ uv_idx = self.t_tex_idx
157
+ mat = self.material
158
+
159
+ return make_mesh(verts, faces, uvs, uv_idx, mat)
160
+
161
+ def get_n(self, n: int):
162
+ """
163
+ Create new Mesh class with only the n-th meshes in the batch.
164
+
165
+ Args:
166
+ n: the index of the mesh to be contained.
167
+
168
+ Returns:
169
+ new Mesh object with the n-th mesh.
170
+ """
171
+ verts = self.v_pos[n:n+1, ...]
172
+ faces = self.t_pos_idx
173
+ uvs = self.v_tex[n:n+1, ...]
174
+ uv_idx = self.t_tex_idx
175
+ mat = self.material
176
+
177
+ return make_mesh(verts, faces, uvs, uv_idx, mat)
178
+
179
+
180
+ ######################################################################################
181
+ # Mesh loading helper
182
+ ######################################################################################
183
+ def load_mesh(filename, mtl_override=None):
184
+ name, ext = os.path.splitext(filename)
185
+ if ext == ".obj":
186
+ return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override)
187
+ assert False, "Invalid mesh file extension"
188
+
189
+ ######################################################################################
190
+ # Compute AABB
191
+ ######################################################################################
192
+ def aabb(mesh):
193
+ return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values
194
+
195
+ ######################################################################################
196
+ # Compute unique edge list from attribute/vertex index list
197
+ ######################################################################################
198
+ def compute_edges(attr_idx, return_inverse=False):
199
+ with torch.no_grad():
200
+ # Create all edges, packed by triangle
201
+ idx = attr_idx[0]
202
+ all_edges = torch.cat((
203
+ torch.stack((idx[:, 0], idx[:, 1]), dim=-1),
204
+ torch.stack((idx[:, 1], idx[:, 2]), dim=-1),
205
+ torch.stack((idx[:, 2], idx[:, 0]), dim=-1),
206
+ ), dim=-1).view(-1, 2)
207
+
208
+ # Swap edge order so min index is always first
209
+ order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
210
+ sorted_edges = torch.cat((
211
+ torch.gather(all_edges, 1, order),
212
+ torch.gather(all_edges, 1, 1 - order)
213
+ ), dim=-1)
214
+
215
+ # Eliminate duplicates and return inverse mapping
216
+ return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse)
217
+
218
+ ######################################################################################
219
+ # Compute unique edge to face mapping from attribute/vertex index list
220
+ ######################################################################################
221
+ def compute_edge_to_face_mapping(attr_idx, return_inverse=False):
222
+ with torch.no_grad():
223
+ # Get unique edges
224
+ # Create all edges, packed by triangle
225
+ idx = attr_idx[0]
226
+ all_edges = torch.cat((
227
+ torch.stack((idx[:, 0], idx[:, 1]), dim=-1),
228
+ torch.stack((idx[:, 1], idx[:, 2]), dim=-1),
229
+ torch.stack((idx[:, 2], idx[:, 0]), dim=-1),
230
+ ), dim=-1).view(-1, 2)
231
+
232
+ # Swap edge order so min index is always first
233
+ order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
234
+ sorted_edges = torch.cat((
235
+ torch.gather(all_edges, 1, order),
236
+ torch.gather(all_edges, 1, 1 - order)
237
+ ), dim=-1)
238
+
239
+ # Elliminate duplicates and return inverse mapping
240
+ unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)
241
+
242
+ tris = torch.arange(idx.shape[0]).repeat_interleave(3).cuda()
243
+
244
+ tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()
245
+
246
+ # Compute edge to face table
247
+ mask0 = order[:,0] == 0
248
+ mask1 = order[:,0] == 1
249
+ tris_per_edge[idx_map[mask0], 0] = tris[mask0]
250
+ tris_per_edge[idx_map[mask1], 1] = tris[mask1]
251
+
252
+ return tris_per_edge
253
+
254
+ ######################################################################################
255
+ # Align base mesh to reference mesh:move & rescale to match bounding boxes.
256
+ ######################################################################################
257
+ def unit_size(mesh):
258
+ with torch.no_grad():
259
+ vmin, vmax = aabb(mesh)
260
+ scale = 2 / torch.max(vmax - vmin).item()
261
+ v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin
262
+ v_pos = v_pos * scale # Rescale to unit size
263
+
264
+ return Mesh(v_pos, base=mesh)
265
+
266
+ ######################################################################################
267
+ # Center & scale mesh for rendering
268
+ ######################################################################################
269
+ def center_by_reference(base_mesh, ref_aabb, scale):
270
+ center = (ref_aabb[0] + ref_aabb[1]) * 0.5
271
+ scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item()
272
+ v_pos = (base_mesh.v_pos - center[None, ...]) * scale
273
+ return Mesh(v_pos, base=base_mesh)
274
+
275
+ ######################################################################################
276
+ # Simple smooth vertex normal computation
277
+ ######################################################################################
278
+ def auto_normals(imesh):
279
+ batch_size = imesh.v_pos.shape[0]
280
+
281
+ i0 = imesh.t_pos_idx[0, :, 0] # Shape: (F)
282
+ i1 = imesh.t_pos_idx[0, :, 1] # Shape: (F)
283
+ i2 = imesh.t_pos_idx[0, :, 2] # Shape: (F)
284
+
285
+ v0 = imesh.v_pos[:, i0, :] # Shape: (B, F, 3)
286
+ v1 = imesh.v_pos[:, i1, :] # Shape: (B, F, 3)
287
+ v2 = imesh.v_pos[:, i2, :] # Shape: (B, F, 3)
288
+
289
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # Shape: (B, F, 3)
290
+
291
+ # Splat face normals to vertices
292
+ v_nrm = torch.zeros_like(imesh.v_pos) # Shape: (B, V, 3)
293
+ v_nrm.scatter_add_(1, i0[None, :, None].repeat(batch_size, 1, 3), face_normals)
294
+ v_nrm.scatter_add_(1, i1[None, :, None].repeat(batch_size, 1, 3), face_normals)
295
+ v_nrm.scatter_add_(1, i2[None, :, None].repeat(batch_size, 1, 3), face_normals)
296
+
297
+ # Normalize, replace zero (degenerated) normals with some default value
298
+ v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20,
299
+ v_nrm, torch.tensor([0.0, 0.0, 1.0],
300
+ dtype=torch.float32, device='cuda'))
301
+ v_nrm = util.safe_normalize(v_nrm)
302
+
303
+ if torch.is_anomaly_enabled():
304
+ assert torch.all(torch.isfinite(v_nrm))
305
+
306
+ return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh)
307
+
308
+ ######################################################################################
309
+ # Compute tangent space from texture map coordinates
310
+ # Follows http://www.mikktspace.com/ conventions
311
+ ######################################################################################
312
+ def compute_tangents(imesh):
313
+ batch_size = imesh.v_pos.shape[0]
314
+
315
+ vn_idx = [None] * 3
316
+ pos = [None] * 3
317
+ tex = [None] * 3
318
+ for i in range(0,3):
319
+ pos[i] = imesh.v_pos[:, imesh.t_pos_idx[0, :, i]]
320
+ tex[i] = imesh.v_tex[:, imesh.t_tex_idx[0, :, i]]
321
+ vn_idx[i] = imesh.t_nrm_idx[..., i:i+1]
322
+
323
+ tangents = torch.zeros_like(imesh.v_nrm)
324
+ tansum = torch.zeros_like(imesh.v_nrm)
325
+
326
+ # Compute tangent space for each triangle
327
+ uve1 = tex[1] - tex[0] # Shape: (B, F, 2)
328
+ uve2 = tex[2] - tex[0] # Shape: (B, F, 2)
329
+ pe1 = pos[1] - pos[0] # Shape: (B, F, 3)
330
+ pe2 = pos[2] - pos[0] # Shape: (B, F, 3)
331
+
332
+ nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] # Shape: (B, F, 3)
333
+ denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] # Shape: (B, F, 1)
334
+
335
+ # Avoid division by zero for degenerated texture coordinates
336
+ tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) # Shape: (B, F, 3)
337
+
338
+ # Update all 3 vertices
339
+ for i in range(0,3):
340
+ idx = vn_idx[i].repeat(batch_size, 1, 3) # Shape: (B, F, 3)
341
+ tangents.scatter_add_(1, idx, tang) # tangents[n_i] = tangents[n_i] + tang
342
+ tansum.scatter_add_(1, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1
343
+ tangents = tangents / tansum
344
+
345
+ # Normalize and make sure tangent is perpendicular to normal
346
+ tangents = util.safe_normalize(tangents)
347
+ tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm)
348
+
349
+ if torch.is_anomaly_enabled():
350
+ assert torch.all(torch.isfinite(tangents))
351
+
352
+ return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh)
353
+
354
+ ######################################################################################
355
+ # Create new Mesh from verts, faces, uvs, and uv_idx. The rest is auto computed.
356
+ ######################################################################################
357
+ def make_mesh(verts, faces, uvs, uv_idx, material):
358
+ """
359
+ Create new Mesh class with given verts, faces, uvs, and uv_idx.
360
+
361
+ Args:
362
+ verts: tensor of shape (B, V, 3)
363
+ faces: tensor of shape (1, F, 3)
364
+ uvs: tensor of shape (B, V, 2)
365
+ uv_idx: tensor of shape (1, F, 3)
366
+ material: an Material instance, specifying the material of the mesh.
367
+
368
+ Returns:
369
+ new Mesh object.
370
+ """
371
+ assert len(verts.shape) == 3 and len(faces.shape) == 3 and len(uvs.shape) == 3 and len(uv_idx.shape) == 3, "All components must be batched."
372
+ assert faces.shape[0] == 1 and uv_idx.shape[0] == 1, "Every mesh must share the same edge connectivity."
373
+ assert verts.shape[0] == uvs.shape[0], "Batch size must be consistent."
374
+ ret = Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
375
+ ret = auto_normals(ret)
376
+ ret = compute_tangents(ret)
377
+ return ret
video3d/render/mlptexture.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import torch
11
+ import tinycudann as tcnn
12
+ import numpy as np
13
+
14
+ #######################################################################################################################################################
15
+ # Small MLP using PyTorch primitives, internal helper class
16
+ #######################################################################################################################################################
17
+
18
+ class _MLP(torch.nn.Module):
19
+ def __init__(self, cfg, loss_scale=1.0):
20
+ super(_MLP, self).__init__()
21
+ self.loss_scale = loss_scale
22
+ net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
23
+ for i in range(cfg['n_hidden_layers']-1):
24
+ net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
25
+ net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),)
26
+ self.net = torch.nn.Sequential(*net).cuda()
27
+
28
+ self.net.apply(self._init_weights)
29
+
30
+ if self.loss_scale != 1.0:
31
+ self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, ))
32
+
33
+ def forward(self, x):
34
+ return self.net(x.to(torch.float32))
35
+
36
+ @staticmethod
37
+ def _init_weights(m):
38
+ if type(m) == torch.nn.Linear:
39
+ torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
40
+ if hasattr(m.bias, 'data'):
41
+ m.bias.data.fill_(0.0)
42
+
43
+ #######################################################################################################################################################
44
+ # Outward visible MLP class
45
+ #######################################################################################################################################################
46
+
47
+ class MLPTexture3D(torch.nn.Module):
48
+ def __init__(self, AABB, channels=3, internal_dims=32, hidden=2, feat_dim=0, min_max=None, bsdf='diffuse', perturb_normal=False, symmetrize=False):
49
+ super(MLPTexture3D, self).__init__()
50
+
51
+ self.channels = channels
52
+ self.feat_dim = feat_dim
53
+ self.internal_dims = internal_dims
54
+ self.AABB = AABB
55
+ self.bsdf = bsdf
56
+ self.perturb_normal = perturb_normal
57
+ self.symmetrize = symmetrize
58
+ if min_max is not None:
59
+ self.register_buffer('min_max', min_max)
60
+ else:
61
+ self.min_max = None
62
+
63
+ # Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details.
64
+ desired_resolution = 4096
65
+ base_grid_resolution = 16
66
+ num_levels = 16
67
+ per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1))
68
+
69
+ enc_cfg = {
70
+ "otype": "HashGrid",
71
+ "n_levels": num_levels,
72
+ "n_features_per_level": 2,
73
+ "log2_hashmap_size": 19,
74
+ "base_resolution": base_grid_resolution,
75
+ "per_level_scale" : per_level_scale
76
+ }
77
+
78
+ # gradient_scaling = 128.0
79
+ gradient_scaling = 1.0
80
+ self.encoder = tcnn.Encoding(3, enc_cfg)
81
+ self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, ))
82
+
83
+ # Setup MLP
84
+ mlp_cfg = {
85
+ "n_input_dims" : internal_dims + feat_dim,
86
+ "n_output_dims" : self.channels,
87
+ "n_hidden_layers" : hidden,
88
+ "n_neurons" : self.internal_dims
89
+ }
90
+ self.linear = torch.nn.Linear(self.encoder.n_output_dims, internal_dims)
91
+ self.net = _MLP(mlp_cfg, gradient_scaling)
92
+ self.relu = torch.nn.ReLU(inplace=True)
93
+ print("Encoder output: %d dims" % (self.encoder.n_output_dims))
94
+
95
+ # Sample texture at a given location
96
+ def sample(self, texc, feat=None):
97
+ assert (feat is None and self.feat_dim == 0) or feat.shape[-1] == self.feat_dim
98
+
99
+ if self.symmetrize:
100
+ xs, ys, zs = texc.unbind(-1)
101
+ texc = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
102
+
103
+ _texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...])
104
+ _texc = torch.clamp(_texc, min=0, max=1)
105
+
106
+ _, image_h, image_w, _ = texc.shape
107
+ p_enc = self.encoder(_texc.contiguous())
108
+ x_in = self.linear(p_enc.type(texc.dtype))
109
+ if feat is not None:
110
+ feat_in = feat[:, None, None, :].repeat(1, image_h, image_w, 1).view(-1, self.feat_dim)
111
+ x_in = torch.concat([x_in, feat_in], dim=-1)
112
+ out = self.net(self.relu(x_in))
113
+
114
+ # Sigmoid limit and scale to the allowed range
115
+ out = torch.sigmoid(out)
116
+ if self.min_max is not None:
117
+ out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
118
+
119
+ return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c]
120
+
121
+ def cleanup(self):
122
+ tcnn.free_temporary_memory()
video3d/render/obj.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import os
11
+ import torch
12
+ import xatlas
13
+ import trimesh
14
+ import numpy as np
15
+ import cv2
16
+ import nvdiffrast.torch as dr
17
+ from video3d.render.render import render_uv
18
+ from video3d.render.mesh import Mesh
19
+ from . import texture
20
+ from . import mesh
21
+ from . import material
22
+
23
+ ######################################################################################
24
+ # Utility functions
25
+ ######################################################################################
26
+
27
+ def _find_mat(materials, name):
28
+ for mat in materials:
29
+ if mat['name'] == name:
30
+ return mat
31
+ return materials[0] # Materials 0 is the default
32
+
33
+ ######################################################################################
34
+ # Create mesh object from objfile
35
+ ######################################################################################
36
+
37
+ def load_obj(filename, clear_ks=True, mtl_override=None):
38
+ obj_path = os.path.dirname(filename)
39
+
40
+ # Read entire file
41
+ with open(filename, 'r') as f:
42
+ lines = f.readlines()
43
+
44
+ # Load materials
45
+ all_materials = [
46
+ {
47
+ 'name' : '_default_mat',
48
+ 'bsdf' : 'pbr',
49
+ 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')),
50
+ 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
51
+ }
52
+ ]
53
+ if mtl_override is None:
54
+ for line in lines:
55
+ if len(line.split()) == 0:
56
+ continue
57
+ if line.split()[0] == 'mtllib':
58
+ all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library
59
+ else:
60
+ all_materials += material.load_mtl(mtl_override)
61
+
62
+ # load vertices
63
+ vertices, texcoords, normals = [], [], []
64
+ for line in lines:
65
+ if len(line.split()) == 0:
66
+ continue
67
+
68
+ prefix = line.split()[0].lower()
69
+ if prefix == 'v':
70
+ vertices.append([float(v) for v in line.split()[1:]])
71
+ elif prefix == 'vt':
72
+ val = [float(v) for v in line.split()[1:]]
73
+ texcoords.append([val[0], 1.0 - val[1]])
74
+ elif prefix == 'vn':
75
+ normals.append([float(v) for v in line.split()[1:]])
76
+
77
+ # load faces
78
+ activeMatIdx = None
79
+ used_materials = []
80
+ faces, tfaces, nfaces, mfaces = [], [], [], []
81
+ for line in lines:
82
+ if len(line.split()) == 0:
83
+ continue
84
+
85
+ prefix = line.split()[0].lower()
86
+ if prefix == 'usemtl': # Track used materials
87
+ mat = _find_mat(all_materials, line.split()[1])
88
+ if not mat in used_materials:
89
+ used_materials.append(mat)
90
+ activeMatIdx = used_materials.index(mat)
91
+ elif prefix == 'f': # Parse face
92
+ vs = line.split()[1:]
93
+ nv = len(vs)
94
+ vv = vs[0].split('/')
95
+ v0 = int(vv[0]) - 1
96
+ t0 = int(vv[1]) - 1 if vv[1] != "" else -1
97
+ n0 = int(vv[2]) - 1 if vv[2] != "" else -1
98
+ for i in range(nv - 2): # Triangulate polygons
99
+ vv = vs[i + 1].split('/')
100
+ v1 = int(vv[0]) - 1
101
+ t1 = int(vv[1]) - 1 if vv[1] != "" else -1
102
+ n1 = int(vv[2]) - 1 if vv[2] != "" else -1
103
+ vv = vs[i + 2].split('/')
104
+ v2 = int(vv[0]) - 1
105
+ t2 = int(vv[1]) - 1 if vv[1] != "" else -1
106
+ n2 = int(vv[2]) - 1 if vv[2] != "" else -1
107
+ mfaces.append(activeMatIdx)
108
+ faces.append([v0, v1, v2])
109
+ tfaces.append([t0, t1, t2])
110
+ nfaces.append([n0, n1, n2])
111
+ assert len(tfaces) == len(faces) and len(nfaces) == len (faces)
112
+
113
+ # Create an "uber" material by combining all textures into a larger texture
114
+ if len(used_materials) > 1:
115
+ uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces)
116
+ else:
117
+ uber_material = used_materials[0]
118
+
119
+ vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda')
120
+ texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None
121
+ normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None
122
+
123
+ faces = torch.tensor(faces, dtype=torch.int64, device='cuda')
124
+ tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None
125
+ nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None
126
+
127
+ return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)
128
+
129
+ ######################################################################################
130
+ # Save mesh object to objfile
131
+ ######################################################################################
132
+
133
+ def write_obj(folder, fname, mesh, idx, save_material=True, feat=None, resolution=[256, 256]):
134
+ obj_file = os.path.join(folder, fname + '.obj')
135
+ print("Writing mesh: ", obj_file)
136
+ with open(obj_file, "w") as f:
137
+ f.write(f"mtllib {fname}.mtl\n")
138
+ f.write("g default\n")
139
+
140
+ v_pos = mesh.v_pos[idx].detach().cpu().numpy() if mesh.v_pos is not None else None
141
+ v_nrm = mesh.v_nrm[idx].detach().cpu().numpy() if mesh.v_nrm is not None else None
142
+ v_tex = mesh.v_tex[idx].detach().cpu().numpy() if mesh.v_tex is not None else None
143
+
144
+ t_pos_idx = mesh.t_pos_idx[0].detach().cpu().numpy() if mesh.t_pos_idx is not None else None
145
+ t_nrm_idx = mesh.t_nrm_idx[0].detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
146
+ t_tex_idx = mesh.t_tex_idx[0].detach().cpu().numpy() if mesh.t_tex_idx is not None else None
147
+
148
+ print(" writing %d vertices" % len(v_pos))
149
+ for v in v_pos:
150
+ f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
151
+
152
+ if v_tex is not None and save_material:
153
+ print(" writing %d texcoords" % len(v_tex))
154
+ assert(len(t_pos_idx) == len(t_tex_idx))
155
+ for v in v_tex:
156
+ f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
157
+
158
+ if v_nrm is not None:
159
+ print(" writing %d normals" % len(v_nrm))
160
+ assert(len(t_pos_idx) == len(t_nrm_idx))
161
+ for v in v_nrm:
162
+ f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
163
+
164
+ # faces
165
+ f.write("s 1 \n")
166
+ f.write("g pMesh1\n")
167
+ f.write("usemtl defaultMat\n")
168
+
169
+ # Write faces
170
+ print(" writing %d faces" % len(t_pos_idx))
171
+ for i in range(len(t_pos_idx)):
172
+ f.write("f ")
173
+ for j in range(3):
174
+ f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
175
+ f.write("\n")
176
+
177
+ if save_material and mesh.material is not None:
178
+ mtl_file = os.path.join(folder, fname + '.mtl')
179
+ print("Writing material: ", mtl_file)
180
+ material.save_mtl(mtl_file, mesh.material, mesh=mesh.get_n(idx), feat=feat, resolution=resolution)
181
+
182
+ print("Done exporting mesh")
183
+
184
+
185
+ def write_textured_obj(folder, fname, mesh, idx, save_material=True, feat=None, resolution=[256, 256], prior_shape=None):
186
+ mesh = mesh.get_n(idx)
187
+ obj_file = os.path.join(folder, fname + '.obj')
188
+ print("Writing mesh: ", obj_file)
189
+
190
+ # Create uvs with xatlas
191
+ v_pos = mesh.v_pos.detach().cpu().numpy()
192
+ t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy()
193
+
194
+ # v_color = torch.Tensor(v_pos)[None].to("cuda")
195
+ # v_color = mesh.material.sample(v_color, feat)
196
+ # v_color = v_color[0,0,:,:3].detach().cpu()
197
+ # v_color = torch.concat([v_color, torch.ones((v_color.shape[0], 1))], dim=-1)
198
+ # v_color = v_color.numpy() * 255
199
+ # v_color = v_color.astype(np.int32)
200
+ # tmp = trimesh.Trimesh(vertices=v_pos[0], faces=t_pos_idx[0], vertex_colors=v_color)
201
+ # _ = tmp.export("tmp.obj")
202
+ # from pdb import set_trace; set_trace()
203
+
204
+ atlas = xatlas.Atlas()
205
+ atlas.add_mesh(
206
+ v_pos[0],
207
+ t_pos_idx[0],
208
+ )
209
+ co = xatlas.ChartOptions()
210
+ po = xatlas.PackOptions()
211
+ # for k, v in xatlas_chart_options.items():
212
+ # setattr(co, k, v)
213
+ # for k, v in xatlas_pack_options.items():
214
+ # setattr(po, k, v)
215
+ atlas.generate(co, po)
216
+ vmapping, indices, uvs = atlas.get_mesh(0)
217
+ # vmapping, indices, uvs = xatlas.parametrize(v_pos[0], t_pos_idx[0])
218
+
219
+ # Convert to tensors
220
+ indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
221
+
222
+ uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
223
+ faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
224
+
225
+ # new_mesh = Mesh(v_tex=uvs, t_tex_idx=faces, base=mesh)
226
+ new_mesh = Mesh(v_tex=uvs[None], t_tex_idx=faces[None], base=mesh)
227
+
228
+ # glctx = dr.RasterizeGLContext()
229
+ # mask, kd, ks, normal = render_uv(glctx, new_mesh, resolution, mesh.material, feat=feat)
230
+
231
+ # kd_min, kd_max = torch.tensor([ 0.0, 0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), torch.tensor([ 1.0, 1.0, 1.0, 1.0], dtype=torch.float32, device='cuda')
232
+ # ks_min, ks_max = torch.tensor([ 0.0, 0.0, 0.0] , dtype=torch.float32, device='cuda'), torch.tensor([ 0.0, 0.0, 0.0] , dtype=torch.float32, device='cuda')
233
+ # nrm_min, nrm_max = torch.tensor([-1.0, -1.0, 0.0], dtype=torch.float32, device='cuda'), torch.tensor([ 1.0, 1.0, 1.0], dtype=torch.float32, device='cuda')
234
+
235
+ new_mesh.material = material.Material({
236
+ 'bsdf' : 'diffuse',
237
+ # 'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
238
+ # 'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
239
+ # 'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max]),
240
+ 'kd_ks_normal': mesh.material
241
+ })
242
+
243
+ with open(obj_file, "w") as f:
244
+ f.write(f"mtllib {fname}.mtl\n")
245
+ f.write("g default\n")
246
+
247
+ v_pos = new_mesh.v_pos[idx].detach().cpu().numpy() if new_mesh.v_pos is not None else None
248
+ v_nrm = new_mesh.v_nrm[idx].detach().cpu().numpy() if new_mesh.v_nrm is not None else None
249
+ v_tex = new_mesh.v_tex[idx].detach().cpu().numpy() if new_mesh.v_tex is not None else None
250
+
251
+ t_pos_idx = new_mesh.t_pos_idx[0].detach().cpu().numpy() if new_mesh.t_pos_idx is not None else None
252
+ t_nrm_idx = new_mesh.t_nrm_idx[0].detach().cpu().numpy() if new_mesh.t_nrm_idx is not None else None
253
+ t_tex_idx = new_mesh.t_tex_idx[0].detach().cpu().numpy() if new_mesh.t_tex_idx is not None else None
254
+
255
+ print(" writing %d vertices" % len(v_pos))
256
+ for v in v_pos:
257
+ f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
258
+
259
+ if v_tex is not None and save_material:
260
+ print(" writing %d texcoords" % len(v_tex))
261
+ assert(len(t_pos_idx) == len(t_tex_idx))
262
+ for v in v_tex:
263
+ f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
264
+
265
+ if v_nrm is not None:
266
+ print(" writing %d normals" % len(v_nrm))
267
+ assert(len(t_pos_idx) == len(t_nrm_idx))
268
+ for v in v_nrm:
269
+ f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
270
+
271
+ # faces
272
+ f.write("s 1 \n")
273
+ f.write("g pMesh1\n")
274
+ f.write("usemtl defaultMat\n")
275
+
276
+ # Write faces
277
+ print(" writing %d faces" % len(t_pos_idx))
278
+ for i in range(len(t_pos_idx)):
279
+ f.write("f ")
280
+ for j in range(3):
281
+ f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
282
+ f.write("\n")
283
+
284
+ mtl_file = os.path.join(folder, fname + '.mtl')
285
+ print("Writing material: ", mtl_file)
286
+ material.save_mtl(mtl_file, new_mesh.material, mesh=new_mesh, feat=feat, resolution=resolution, prior_shape=prior_shape)
287
+
288
+ print("Done exporting mesh")
video3d/render/regularizer.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import torch
11
+ import nvdiffrast.torch as dr
12
+
13
+ from . import util
14
+ from . import mesh
15
+
16
+ ######################################################################################
17
+ # Computes the image gradient, useful for kd/ks smoothness losses
18
+ ######################################################################################
19
+ def image_grad(buf, std=0.01):
20
+ t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"),
21
+ torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"),
22
+ indexing='ij')
23
+ tc = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...]
24
+ tap = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp')
25
+ return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:]
26
+
27
+ ######################################################################################
28
+ # Computes the avergage edge length of a mesh.
29
+ # Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients
30
+ ######################################################################################
31
+ def avg_edge_length(v_pos, t_pos_idx):
32
+ e_pos_idx = mesh.compute_edges(t_pos_idx)
33
+ edge_len = util.length(v_pos[:, e_pos_idx[:, 0]] - v_pos[:, e_pos_idx[:, 1]])
34
+ return torch.mean(edge_len)
35
+
36
+ ######################################################################################
37
+ # Laplacian regularization using umbrella operator (Fujiwara / Desbrun).
38
+ # https://mgarland.org/class/geom04/material/smoothing.pdf
39
+ ######################################################################################
40
+ def laplace_regularizer_const(v_pos, t_pos_idx):
41
+ batch_size = v_pos.shape[0]
42
+
43
+ term = torch.zeros_like(v_pos)
44
+ norm = torch.zeros_like(v_pos[..., 0:1])
45
+
46
+ v0 = v_pos[:, t_pos_idx[0, :, 0], :]
47
+ v1 = v_pos[:, t_pos_idx[0, :, 1], :]
48
+ v2 = v_pos[:, t_pos_idx[0, :, 2], :]
49
+
50
+ term.scatter_add_(1, t_pos_idx[..., 0:1].repeat(batch_size, 1, 3), (v1 - v0) + (v2 - v0))
51
+ term.scatter_add_(1, t_pos_idx[..., 1:2].repeat(batch_size, 1, 3), (v0 - v1) + (v2 - v1))
52
+ term.scatter_add_(1, t_pos_idx[..., 2:3].repeat(batch_size, 1, 3), (v0 - v2) + (v1 - v2))
53
+
54
+ two = torch.ones_like(v0) * 2.0
55
+ # norm.scatter_add_(1, t_pos_idx[..., 0:1].repeat(batch_size, 1, 3), two)
56
+ # norm.scatter_add_(1, t_pos_idx[..., 1:2].repeat(batch_size, 1, 3), two)
57
+ # norm.scatter_add_(1, t_pos_idx[..., 2:3].repeat(batch_size, 1, 3), two)
58
+ norm.scatter_add_(1, t_pos_idx[..., 0:1].repeat(batch_size, 1, 1), two)
59
+ norm.scatter_add_(1, t_pos_idx[..., 1:2].repeat(batch_size, 1, 1), two)
60
+ norm.scatter_add_(1, t_pos_idx[..., 2:3].repeat(batch_size, 1, 1), two)
61
+
62
+ term = term / torch.clamp(norm, min=1.0)
63
+
64
+ return torch.mean(term ** 2)
65
+
66
+ ######################################################################################
67
+ # Smooth vertex normals
68
+ ######################################################################################
69
+ def normal_consistency(v_pos, t_pos_idx):
70
+ # Compute face normals
71
+ v0 = v_pos[:, t_pos_idx[0, :, 0]]
72
+ v1 = v_pos[:, t_pos_idx[0, :, 1]]
73
+ v2 = v_pos[:, t_pos_idx[0, :, 2]]
74
+
75
+ face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1))
76
+
77
+ tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx)
78
+
79
+ # Fetch normals for both faces sharing an edge
80
+ n0 = face_normals[:, tris_per_edge[:, 0], :]
81
+ n1 = face_normals[:, tris_per_edge[:, 1], :]
82
+
83
+ # Compute error metric based on normal difference
84
+ term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0)
85
+ term = (1.0 - term) * 0.5
86
+
87
+ return torch.mean(torch.abs(term))
88
+
89
+
90
+ def get_edge_length(v_pos, t_pos_idx):
91
+ e_pos_idx = mesh.compute_edges(t_pos_idx)
92
+ edge_len = util.length(v_pos[:, e_pos_idx[:, 0]] - v_pos[:, e_pos_idx[:, 1]])
93
+ return edge_len
video3d/render/render.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import torch
11
+ import nvdiffrast.torch as dr
12
+
13
+ from . import util
14
+ from . import renderutils as ru
15
+ from . import light
16
+
17
+ # ==============================================================================================
18
+ # Helper functions
19
+ # ==============================================================================================
20
+ def interpolate(attr, rast, attr_idx, rast_db=None):
21
+ return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
22
+
23
+ # ==============================================================================================
24
+ # pixel shader
25
+ # ==============================================================================================
26
+ def shade(
27
+ gb_pos,
28
+ gb_geometric_normal,
29
+ gb_normal,
30
+ gb_tangent,
31
+ gb_tex_pos,
32
+ gb_texc,
33
+ gb_texc_deriv,
34
+ w2c,
35
+ view_pos,
36
+ lgt,
37
+ material,
38
+ bsdf,
39
+ feat,
40
+ two_sided_shading,
41
+ delta_xy_interp=None,
42
+ dino_pred=None,
43
+ class_vector=None,
44
+ im_features_map=None,
45
+ mvp=None
46
+ ):
47
+
48
+ ################################################################################
49
+ # Texture lookups
50
+ ################################################################################
51
+ perturbed_nrm = None
52
+ # Combined texture, used for MLPs because lookups are expensive
53
+ # all_tex_jitter = material.sample(gb_tex_pos + torch.normal(mean=0, std=0.01, size=gb_tex_pos.shape, device="cuda"), feat=feat)
54
+ if material is not None:
55
+ if im_features_map is None:
56
+ all_tex = material.sample(gb_tex_pos, feat=feat)
57
+ else:
58
+ all_tex = material.sample(gb_tex_pos, feat=feat, feat_map=im_features_map, mvp=mvp, w2c=w2c, deform_xyz=gb_pos)
59
+ else:
60
+ all_tex = torch.ones(*gb_pos.shape[:-1], 9, device=gb_pos.device)
61
+ kd, ks, perturbed_nrm = all_tex[..., :3], all_tex[..., 3:6], all_tex[..., 6:9]
62
+
63
+ # Compute albedo (kd) gradient, used for material regularizer
64
+ # kd_grad = torch.sum(torch.abs(all_tex_jitter[..., :-6] - all_tex[..., :-6]), dim=-1, keepdim=True) /
65
+
66
+ if dino_pred is not None and class_vector is None:
67
+ # DOR: predive the dino value using x,y,z, we would concatenate the label vector.
68
+ # trained together, generated image as the supervision for the one-hot-vector.
69
+ dino_feat_im_pred = dino_pred.sample(gb_tex_pos)
70
+ # dino_feat_im_pred = dino_pred.sample(gb_tex_pos.detach())
71
+ if dino_pred is not None and class_vector is not None:
72
+ dino_feat_im_pred = dino_pred.sample(gb_tex_pos, feat=class_vector)
73
+
74
+ # else:
75
+ # kd_jitter = material['kd'].sample(gb_texc + torch.normal(mean=0, std=0.005, size=gb_texc.shape, device="cuda"), gb_texc_deriv)
76
+ # kd = material['kd'].sample(gb_texc, gb_texc_deriv)
77
+ # ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha
78
+ # if 'normal' in material:
79
+ # perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv)
80
+ # kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3
81
+
82
+ # Separate kd into alpha and color, default alpha = 1
83
+ # alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
84
+ # kd = kd[..., 0:3]
85
+ alpha = torch.ones_like(kd[..., 0:1])
86
+
87
+ ################################################################################
88
+ # Normal perturbation & normal bend
89
+ ################################################################################
90
+ if material is None or not material.perturb_normal:
91
+ perturbed_nrm = None
92
+
93
+ gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=two_sided_shading, opengl=True, use_python=True)
94
+
95
+ # if two_sided_shading:
96
+ # view_vec = util.safe_normalize(view_pos - gb_pos, -1)
97
+ # gb_normal = torch.where(torch.sum(gb_geometric_normal * view_vec, -1, keepdim=True) > 0, gb_geometric_normal, -gb_geometric_normal)
98
+ # else:
99
+ # gb_normal = gb_geometric_normal
100
+
101
+ b, h, w, _ = gb_normal.shape
102
+ cam_normal = util.safe_normalize(torch.matmul(gb_normal.view(b, -1, 3), w2c[:,:3,:3].transpose(2,1))).view(b, h, w, 3)
103
+
104
+ ################################################################################
105
+ # Evaluate BSDF
106
+ ################################################################################
107
+
108
+ assert bsdf is not None or material.bsdf is not None, "Material must specify a BSDF type"
109
+ bsdf = bsdf if bsdf is not None else material.bsdf
110
+ shading = None
111
+ if bsdf == 'pbr':
112
+ if isinstance(lgt, light.EnvironmentLight):
113
+ shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
114
+ else:
115
+ assert False, "Invalid light type"
116
+ elif bsdf == 'diffuse':
117
+ if lgt is None:
118
+ shaded_col = kd
119
+ elif isinstance(lgt, light.EnvironmentLight):
120
+ shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False)
121
+ # elif isinstance(lgt, light.DirectionalLight):
122
+ # shaded_col, shading = lgt.shade(feat, kd, cam_normal)
123
+ # else:
124
+ # assert False, "Invalid light type"
125
+ else:
126
+ shaded_col, shading = lgt.shade(feat, kd, cam_normal)
127
+ elif bsdf == 'normal':
128
+ shaded_col = (gb_normal + 1.0) * 0.5
129
+ elif bsdf == 'geo_normal':
130
+ shaded_col = (gb_geometric_normal + 1.0) * 0.5
131
+ elif bsdf == 'tangent':
132
+ shaded_col = (gb_tangent + 1.0) * 0.5
133
+ elif bsdf == 'kd':
134
+ shaded_col = kd
135
+ elif bsdf == 'ks':
136
+ shaded_col = ks
137
+ else:
138
+ assert False, "Invalid BSDF '%s'" % bsdf
139
+
140
+ # Return multiple buffers
141
+ buffers = {
142
+ 'kd' : torch.cat((kd, alpha), dim=-1),
143
+ 'shaded' : torch.cat((shaded_col, alpha), dim=-1),
144
+ # 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1),
145
+ # 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1),
146
+ }
147
+
148
+ if dino_pred is not None:
149
+ buffers['dino_feat_im_pred'] = torch.cat((dino_feat_im_pred, alpha), dim=-1)
150
+
151
+ if delta_xy_interp is not None:
152
+ buffers['flow'] = torch.cat((delta_xy_interp, alpha), dim=-1)
153
+
154
+ if shading is not None:
155
+ buffers['shading'] = torch.cat((shading, alpha), dim=-1)
156
+
157
+ return buffers
158
+
159
+ # ==============================================================================================
160
+ # Render a depth slice of the mesh (scene), some limitations:
161
+ # - Single light
162
+ # - Single material
163
+ # ==============================================================================================
164
+ def render_layer(
165
+ rast,
166
+ rast_deriv,
167
+ mesh,
168
+ w2c,
169
+ view_pos,
170
+ material,
171
+ lgt,
172
+ resolution,
173
+ spp,
174
+ msaa,
175
+ bsdf,
176
+ feat,
177
+ prior_mesh,
178
+ two_sided_shading,
179
+ render_flow,
180
+ delta_xy=None,
181
+ dino_pred=None,
182
+ class_vector=None,
183
+ im_features_map=None,
184
+ mvp=None
185
+ ):
186
+
187
+ full_res = [resolution[0]*spp, resolution[1]*spp]
188
+
189
+ if prior_mesh is None:
190
+ prior_mesh = mesh
191
+
192
+ ################################################################################
193
+ # Rasterize
194
+ ################################################################################
195
+
196
+ # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
197
+ if spp > 1 and msaa:
198
+ rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest')
199
+ rast_out_deriv_s = util.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp
200
+ else:
201
+ rast_out_s = rast
202
+ rast_out_deriv_s = rast_deriv
203
+
204
+ if render_flow:
205
+ delta_xy_interp, _ = interpolate(delta_xy, rast_out_s, mesh.t_pos_idx[0].int())
206
+ else:
207
+ delta_xy_interp = None
208
+
209
+ ################################################################################
210
+ # Interpolate attributes
211
+ ################################################################################
212
+
213
+ # Interpolate world space position
214
+ gb_pos, _ = interpolate(mesh.v_pos, rast_out_s, mesh.t_pos_idx[0].int())
215
+
216
+ # Compute geometric normals. We need those because of bent normals trick (for bump mapping)
217
+ v0 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 0], :]
218
+ v1 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 1], :]
219
+ v2 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 2], :]
220
+ face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1))
221
+ num_faces = face_normals.shape[1]
222
+ face_normal_indices = (torch.arange(0, num_faces, dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
223
+ gb_geometric_normal, _ = interpolate(face_normals, rast_out_s, face_normal_indices.int())
224
+
225
+ # Compute tangent space
226
+ assert mesh.v_nrm is not None and mesh.v_tng is not None
227
+ gb_normal, _ = interpolate(mesh.v_nrm, rast_out_s, mesh.t_nrm_idx[0].int())
228
+ gb_tangent, _ = interpolate(mesh.v_tng, rast_out_s, mesh.t_tng_idx[0].int()) # Interpolate tangents
229
+
230
+ # Texture coordinate
231
+ assert mesh.v_tex is not None
232
+ gb_texc, gb_texc_deriv = interpolate(mesh.v_tex, rast_out_s, mesh.t_tex_idx[0].int(), rast_db=rast_out_deriv_s)
233
+
234
+ ################################################################################
235
+ # Shade
236
+ ################################################################################
237
+
238
+ gb_tex_pos, _ = interpolate(prior_mesh.v_pos, rast_out_s, mesh.t_pos_idx[0].int())
239
+ buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_tex_pos, gb_texc, gb_texc_deriv, w2c, view_pos, lgt, material, bsdf, feat=feat, two_sided_shading=two_sided_shading, delta_xy_interp=delta_xy_interp, dino_pred=dino_pred, class_vector=class_vector, im_features_map=im_features_map, mvp=mvp)
240
+
241
+ ################################################################################
242
+ # Prepare output
243
+ ################################################################################
244
+
245
+ # Scale back up to visibility resolution if using MSAA
246
+ if spp > 1 and msaa:
247
+ for key in buffers.keys():
248
+ buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest')
249
+
250
+ # Return buffers
251
+ return buffers
252
+
253
+ # ==============================================================================================
254
+ # Render a depth peeled mesh (scene), some limitations:
255
+ # - Single light
256
+ # - Single material
257
+ # ==============================================================================================
258
+ def render_mesh(
259
+ ctx,
260
+ mesh,
261
+ mtx_in,
262
+ w2c,
263
+ view_pos,
264
+ material,
265
+ lgt,
266
+ resolution,
267
+ spp = 1,
268
+ num_layers = 1,
269
+ msaa = False,
270
+ background = None,
271
+ bsdf = None,
272
+ feat = None,
273
+ prior_mesh = None,
274
+ two_sided_shading = True,
275
+ render_flow = False,
276
+ dino_pred = None,
277
+ class_vector = None,
278
+ num_frames = None,
279
+ im_features_map = None
280
+ ):
281
+
282
+ def prepare_input_vector(x):
283
+ x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x
284
+ return x[:, None, None, :] if len(x.shape) == 2 else x
285
+
286
+ def composite_buffer(key, layers, background, antialias):
287
+ accum = background
288
+ for buffers, rast in reversed(layers):
289
+ alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:]
290
+ accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha)
291
+ if antialias:
292
+ accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx[0].int())
293
+ return accum
294
+
295
+ assert mesh.t_pos_idx.shape[1] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)"
296
+ assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1])
297
+
298
+ full_res = [resolution[0] * spp, resolution[1] * spp]
299
+
300
+ # Convert numpy arrays to torch tensors
301
+ mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in
302
+ view_pos = prepare_input_vector(view_pos) # Shape: (B, 1, 1, 3)
303
+
304
+ # clip space transform
305
+ v_pos_clip = ru.xfm_points(mesh.v_pos, mtx_in, use_python=True)
306
+
307
+ # render flow
308
+ if render_flow:
309
+ v_pos_clip2 = v_pos_clip[..., :2] / v_pos_clip[..., -1:]
310
+ v_pos_clip2 = v_pos_clip2.view(-1, num_frames, *v_pos_clip2.shape[1:])
311
+ delta_xy = v_pos_clip2[:, 1:] - v_pos_clip2[:, :-1]
312
+ delta_xy = torch.cat([delta_xy, torch.zeros_like(delta_xy[:, :1])], dim=1)
313
+ delta_xy = delta_xy.view(-1, *delta_xy.shape[2:])
314
+ else:
315
+ delta_xy = None
316
+
317
+ # Render all layers front-to-back
318
+ layers = []
319
+ with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx[0].int(), full_res) as peeler:
320
+ for _ in range(num_layers):
321
+ rast, db = peeler.rasterize_next_layer()
322
+ rendered = render_layer(rast, db, mesh, w2c, view_pos, material, lgt, resolution, spp, msaa, bsdf, feat=feat, prior_mesh=prior_mesh, two_sided_shading=two_sided_shading, render_flow=render_flow, delta_xy=delta_xy, dino_pred=dino_pred, class_vector=class_vector, im_features_map=im_features_map, mvp=mtx_in)
323
+ layers += [(rendered, rast)]
324
+
325
+ # Setup background
326
+ if background is not None:
327
+ if spp > 1:
328
+ background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest')
329
+ background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1)
330
+ else:
331
+ background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda')
332
+
333
+ # Composite layers front-to-back
334
+ out_buffers = {}
335
+ for key in layers[0][0].keys():
336
+ antialias = key in ['shaded', 'dino_feat_im_pred', 'flow']
337
+ bg = background if key in ['shaded'] else torch.zeros_like(layers[0][0][key])
338
+ accum = composite_buffer(key, layers, bg, antialias)
339
+
340
+ # Downscale to framebuffer resolution. Use avg pooling
341
+ out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum
342
+
343
+ return out_buffers
344
+
345
+ # ==============================================================================================
346
+ # Render UVs
347
+ # ==============================================================================================
348
+ def render_uv(ctx, mesh, resolution, mlp_texture, feat=None, prior_shape=None):
349
+
350
+ # clip space transform
351
+ uv_clip = mesh.v_tex * 2.0 - 1.0
352
+
353
+ # pad to four component coordinate
354
+ uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1)
355
+
356
+ # rasterize
357
+ rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx[0].int(), resolution)
358
+
359
+ # Interpolate world space position
360
+ if prior_shape is not None:
361
+ gb_pos, _ = interpolate(prior_shape.v_pos, rast, mesh.t_pos_idx[0].int())
362
+ else:
363
+ gb_pos, _ = interpolate(mesh.v_pos, rast, mesh.t_pos_idx[0].int())
364
+
365
+ # Sample out textures from MLP
366
+ all_tex = mlp_texture.sample(gb_pos, feat=feat)
367
+ assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels"
368
+ perturbed_nrm = all_tex[..., -3:]
369
+ return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], util.safe_normalize(perturbed_nrm)
video3d/render/renderutils/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
11
+ __all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]
video3d/render/renderutils/bsdf.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import math
11
+ import torch
12
+
13
+ NORMAL_THRESHOLD = 0.1
14
+
15
+ ################################################################################
16
+ # Vector utility functions
17
+ ################################################################################
18
+
19
+ def _dot(x, y):
20
+ return torch.sum(x*y, -1, keepdim=True)
21
+
22
+ def _reflect(x, n):
23
+ return 2*_dot(x, n)*n - x
24
+
25
+ def _safe_normalize(x):
26
+ return torch.nn.functional.normalize(x, dim = -1)
27
+
28
+ def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
29
+ # Swap normal direction for backfacing surfaces
30
+ if two_sided_shading:
31
+ smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
32
+ geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
33
+
34
+ t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
35
+ return torch.lerp(geom_nrm, smooth_nrm, t)
36
+
37
+
38
+ def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
39
+ smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm, dim=-1))
40
+ if opengl:
41
+ shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
42
+ else:
43
+ shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
44
+ return _safe_normalize(shading_nrm)
45
+
46
+ def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
47
+ smooth_nrm = _safe_normalize(smooth_nrm)
48
+ smooth_tng = _safe_normalize(smooth_tng)
49
+ view_vec = _safe_normalize(view_pos - pos)
50
+ shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
51
+ return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
52
+
53
+ ################################################################################
54
+ # Simple lambertian diffuse BSDF
55
+ ################################################################################
56
+
57
+ def bsdf_lambert(nrm, wi):
58
+ return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
59
+
60
+ ################################################################################
61
+ # Frostbite diffuse
62
+ ################################################################################
63
+
64
+ def bsdf_frostbite(nrm, wi, wo, linearRoughness):
65
+ wiDotN = _dot(wi, nrm)
66
+ woDotN = _dot(wo, nrm)
67
+
68
+ h = _safe_normalize(wo + wi)
69
+ wiDotH = _dot(wi, h)
70
+
71
+ energyBias = 0.5 * linearRoughness
72
+ energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness
73
+ f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness
74
+ f0 = 1.0
75
+
76
+ wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)
77
+ woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)
78
+ res = wiScatter * woScatter * energyFactor
79
+ return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))
80
+
81
+ ################################################################################
82
+ # Phong specular, loosely based on mitsuba implementation
83
+ ################################################################################
84
+
85
+ def bsdf_phong(nrm, wo, wi, N):
86
+ dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
87
+ dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
88
+ return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
89
+
90
+ ################################################################################
91
+ # PBR's implementation of GGX specular
92
+ ################################################################################
93
+
94
+ specular_epsilon = 1e-4
95
+
96
+ def bsdf_fresnel_shlick(f0, f90, cosTheta):
97
+ _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
98
+ return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
99
+
100
+ def bsdf_ndf_ggx(alphaSqr, cosTheta):
101
+ _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
102
+ d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
103
+ return alphaSqr / (d * d * math.pi)
104
+
105
+ def bsdf_lambda_ggx(alphaSqr, cosTheta):
106
+ _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
107
+ cosThetaSqr = _cosTheta * _cosTheta
108
+ tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
109
+ res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
110
+ return res
111
+
112
+ def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
113
+ lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
114
+ lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
115
+ return 1 / (1 + lambdaI + lambdaO)
116
+
117
+ def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
118
+ _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
119
+ alphaSqr = _alpha * _alpha
120
+
121
+ h = _safe_normalize(wo + wi)
122
+ woDotN = _dot(wo, nrm)
123
+ wiDotN = _dot(wi, nrm)
124
+ woDotH = _dot(wo, h)
125
+ nDotH = _dot(nrm, h)
126
+
127
+ D = bsdf_ndf_ggx(alphaSqr, nDotH)
128
+ G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
129
+ F = bsdf_fresnel_shlick(col, 1, woDotH)
130
+
131
+ w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
132
+
133
+ frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
134
+ return torch.where(frontfacing, w, torch.zeros_like(w))
135
+
136
+ def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
137
+ wo = _safe_normalize(view_pos - pos)
138
+ wi = _safe_normalize(light_pos - pos)
139
+
140
+ spec_str = arm[..., 0:1] # x component
141
+ roughness = arm[..., 1:2] # y component
142
+ metallic = arm[..., 2:3] # z component
143
+ ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
144
+ kd = kd * (1.0 - metallic)
145
+
146
+ if BSDF == 0:
147
+ diffuse = kd * bsdf_lambert(nrm, wi)
148
+ else:
149
+ diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)
150
+ specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
151
+ return diffuse + specular
video3d/render/renderutils/c_src/bsdf.cu ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include "common.h"
13
+ #include "bsdf.h"
14
+
15
+ #define SPECULAR_EPSILON 1e-4f
16
+
17
+ //------------------------------------------------------------------------
18
+ // Lambert functions
19
+
20
+ __device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
21
+ {
22
+ return max(dot(nrm, wi) / M_PI, 0.0f);
23
+ }
24
+
25
+ __device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
26
+ {
27
+ if (dot(nrm, wi) > 0.0f)
28
+ bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
29
+ }
30
+
31
+ //------------------------------------------------------------------------
32
+ // Fresnel Schlick
33
+
34
+ __device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
35
+ {
36
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
37
+ float scale = powf(1.0f - _cosTheta, 5.0f);
38
+ return f0 * (1.0f - scale) + f90 * scale;
39
+ }
40
+
41
+ __device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
42
+ {
43
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
44
+ float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
45
+ d_f0 += d_out * (1.0 - scale);
46
+ d_f90 += d_out * scale;
47
+ if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
48
+ {
49
+ d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
50
+ }
51
+ }
52
+
53
+ __device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
54
+ {
55
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
56
+ float scale = powf(1.0f - _cosTheta, 5.0f);
57
+ return f0 * (1.0f - scale) + f90 * scale;
58
+ }
59
+
60
+ __device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
61
+ {
62
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
63
+ float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
64
+ d_f0 += d_out * (1.0 - scale);
65
+ d_f90 += d_out * scale;
66
+ if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
67
+ {
68
+ d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
69
+ }
70
+ }
71
+
72
+ //------------------------------------------------------------------------
73
+ // Frostbite diffuse
74
+
75
+ __device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
76
+ {
77
+ float wiDotN = dot(wi, nrm);
78
+ float woDotN = dot(wo, nrm);
79
+ if (wiDotN > 0.0f && woDotN > 0.0f)
80
+ {
81
+ vec3f h = safeNormalize(wo + wi);
82
+ float wiDotH = dot(wi, h);
83
+
84
+ float energyBias = 0.5f * linearRoughness;
85
+ float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
86
+ float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
87
+ float f0 = 1.f;
88
+
89
+ float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
90
+ float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
91
+
92
+ return wiScatter * woScatter * energyFactor;
93
+ }
94
+ else return 0.0f;
95
+ }
96
+
97
+ __device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
98
+ {
99
+ float wiDotN = dot(wi, nrm);
100
+ float woDotN = dot(wo, nrm);
101
+
102
+ if (wiDotN > 0.0f && woDotN > 0.0f)
103
+ {
104
+ vec3f h = safeNormalize(wo + wi);
105
+ float wiDotH = dot(wi, h);
106
+
107
+ float energyBias = 0.5f * linearRoughness;
108
+ float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
109
+ float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
110
+ float f0 = 1.f;
111
+
112
+ float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
113
+ float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
114
+
115
+ // -------------- BWD --------------
116
+ // Backprop: return wiScatter * woScatter * energyFactor;
117
+ float d_wiScatter = d_out * woScatter * energyFactor;
118
+ float d_woScatter = d_out * wiScatter * energyFactor;
119
+ float d_energyFactor = d_out * wiScatter * woScatter;
120
+
121
+ // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
122
+ float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
123
+ bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
124
+
125
+ // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
126
+ float d_wiDotN = 0.0f;
127
+ bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
128
+
129
+ // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
130
+ float d_energyBias = d_f90;
131
+ float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
132
+ d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
133
+
134
+ // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
135
+ d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
136
+
137
+ // Backprop: float energyBias = 0.5f * linearRoughness;
138
+ d_linearRoughness += 0.5 * d_energyBias;
139
+
140
+ // Backprop: float wiDotH = dot(wi, h);
141
+ vec3f d_h(0);
142
+ bwdDot(wi, h, d_wi, d_h, d_wiDotH);
143
+
144
+ // Backprop: vec3f h = safeNormalize(wo + wi);
145
+ vec3f d_wo_wi(0);
146
+ bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
147
+ d_wi += d_wo_wi; d_wo += d_wo_wi;
148
+
149
+ bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
150
+ bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
151
+ }
152
+ }
153
+
154
+ //------------------------------------------------------------------------
155
+ // Ndf GGX
156
+
157
+ __device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
158
+ {
159
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
160
+ float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
161
+ return alphaSqr / (d * d * M_PI);
162
+ }
163
+
164
+ __device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
165
+ {
166
+ // Torch only back propagates if clamp doesn't trigger
167
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
168
+ float cosThetaSqr = _cosTheta * _cosTheta;
169
+ d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
170
+ if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
171
+ {
172
+ d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
173
+ }
174
+ }
175
+
176
+ //------------------------------------------------------------------------
177
+ // Lambda GGX
178
+
179
+ __device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
180
+ {
181
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
182
+ float cosThetaSqr = _cosTheta * _cosTheta;
183
+ float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
184
+ float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
185
+ return res;
186
+ }
187
+
188
+ __device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
189
+ {
190
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
191
+ float cosThetaSqr = _cosTheta * _cosTheta;
192
+ float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
193
+ float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
194
+
195
+ d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
196
+ if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
197
+ d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
198
+ }
199
+
200
+ //------------------------------------------------------------------------
201
+ // Masking GGX
202
+
203
+ __device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
204
+ {
205
+ float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
206
+ float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
207
+ return 1.0f / (1.0f + lambdaI + lambdaO);
208
+ }
209
+
210
+ __device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
211
+ {
212
+ // FWD eval
213
+ float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
214
+ float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
215
+
216
+ // BWD eval
217
+ float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
218
+ bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
219
+ bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
220
+ }
221
+
222
+ //------------------------------------------------------------------------
223
+ // GGX specular
224
+
225
+ __device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
226
+ {
227
+ float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
228
+ float alphaSqr = _alpha * _alpha;
229
+
230
+ vec3f h = safeNormalize(wo + wi);
231
+ float woDotN = dot(wo, nrm);
232
+ float wiDotN = dot(wi, nrm);
233
+ float woDotH = dot(wo, h);
234
+ float nDotH = dot(nrm, h);
235
+
236
+ float D = fwdNdfGGX(alphaSqr, nDotH);
237
+ float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
238
+ vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
239
+ vec3f w = F * D * G * 0.25 / woDotN;
240
+
241
+ bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
242
+ return frontfacing ? w : 0.0f;
243
+ }
244
+
245
+ __device__ void bwdPbrSpecular(
246
+ const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
247
+ vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
248
+ {
249
+ ///////////////////////////////////////////////////////////////////////
250
+ // FWD eval
251
+
252
+ float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
253
+ float alphaSqr = _alpha * _alpha;
254
+
255
+ vec3f h = safeNormalize(wo + wi);
256
+ float woDotN = dot(wo, nrm);
257
+ float wiDotN = dot(wi, nrm);
258
+ float woDotH = dot(wo, h);
259
+ float nDotH = dot(nrm, h);
260
+
261
+ float D = fwdNdfGGX(alphaSqr, nDotH);
262
+ float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
263
+ vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
264
+ vec3f w = F * D * G * 0.25 / woDotN;
265
+ bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
266
+
267
+ if (frontfacing)
268
+ {
269
+ ///////////////////////////////////////////////////////////////////////
270
+ // BWD eval
271
+
272
+ vec3f d_F = d_out * D * G * 0.25f / woDotN;
273
+ float d_D = sum(d_out * F * G * 0.25f / woDotN);
274
+ float d_G = sum(d_out * F * D * 0.25f / woDotN);
275
+
276
+ float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
277
+
278
+ vec3f d_f90(0);
279
+ float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
280
+ bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
281
+ bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
282
+ bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
283
+
284
+ vec3f d_h(0);
285
+ bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
286
+ bwdDot(wo, h, d_wo, d_h, d_woDotH);
287
+ bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
288
+ bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
289
+
290
+ vec3f d_h_unnorm(0);
291
+ bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
292
+ d_wo += d_h_unnorm;
293
+ d_wi += d_h_unnorm;
294
+
295
+ if (alpha > min_roughness * min_roughness)
296
+ d_alpha += d_alphaSqr * 2 * alpha;
297
+ }
298
+ }
299
+
300
+ //------------------------------------------------------------------------
301
+ // Full PBR BSDF
302
+
303
+ __device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
304
+ {
305
+ vec3f wo = safeNormalize(view_pos - pos);
306
+ vec3f wi = safeNormalize(light_pos - pos);
307
+
308
+ float alpha = arm.y * arm.y;
309
+ vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
310
+ vec3f diff_col = kd * (1.0f - arm.z);
311
+
312
+ float diff = 0.0f;
313
+ if (BSDF == 0)
314
+ diff = fwdLambert(nrm, wi);
315
+ else
316
+ diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
317
+ vec3f diffuse = diff_col * diff;
318
+ vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
319
+
320
+ return diffuse + specular;
321
+ }
322
+
323
+ __device__ void bwdPbrBSDF(
324
+ const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
325
+ vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
326
+ {
327
+ ////////////////////////////////////////////////////////////////////////
328
+ // FWD
329
+ vec3f _wi = light_pos - pos;
330
+ vec3f _wo = view_pos - pos;
331
+ vec3f wi = safeNormalize(_wi);
332
+ vec3f wo = safeNormalize(_wo);
333
+
334
+ float alpha = arm.y * arm.y;
335
+ vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
336
+ vec3f diff_col = kd * (1.0f - arm.z);
337
+ float diff = 0.0f;
338
+ if (BSDF == 0)
339
+ diff = fwdLambert(nrm, wi);
340
+ else
341
+ diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
342
+
343
+ ////////////////////////////////////////////////////////////////////////
344
+ // BWD
345
+
346
+ float d_alpha(0);
347
+ vec3f d_spec_col(0), d_wi(0), d_wo(0);
348
+ bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
349
+
350
+ float d_diff = sum(diff_col * d_out);
351
+ if (BSDF == 0)
352
+ bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
353
+ else
354
+ bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
355
+
356
+ // Backprop: diff_col = kd * (1.0f - arm.z)
357
+ vec3f d_diff_col = d_out * diff;
358
+ d_kd += d_diff_col * (1.0f - arm.z);
359
+ d_arm.z -= sum(d_diff_col * kd);
360
+
361
+ // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
362
+ d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
363
+ d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
364
+ d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
365
+
366
+ // Backprop: alpha = arm.y * arm.y
367
+ d_arm.y += d_alpha * 2 * arm.y;
368
+
369
+ // Backprop: vec3f wi = safeNormalize(light_pos - pos);
370
+ vec3f d__wi(0);
371
+ bwdSafeNormalize(_wi, d__wi, d_wi);
372
+ d_light_pos += d__wi;
373
+ d_pos -= d__wi;
374
+
375
+ // Backprop: vec3f wo = safeNormalize(view_pos - pos);
376
+ vec3f d__wo(0);
377
+ bwdSafeNormalize(_wo, d__wo, d_wo);
378
+ d_view_pos += d__wo;
379
+ d_pos -= d__wo;
380
+ }
381
+
382
+ //------------------------------------------------------------------------
383
+ // Kernels
384
+
385
+ __global__ void LambertFwdKernel(LambertKernelParams p)
386
+ {
387
+ // Calculate pixel position.
388
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
389
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
390
+ unsigned int pz = blockIdx.z;
391
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
392
+ return;
393
+
394
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
395
+ vec3f wi = p.wi.fetch3(px, py, pz);
396
+
397
+ float res = fwdLambert(nrm, wi);
398
+
399
+ p.out.store(px, py, pz, res);
400
+ }
401
+
402
+ __global__ void LambertBwdKernel(LambertKernelParams p)
403
+ {
404
+ // Calculate pixel position.
405
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
406
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
407
+ unsigned int pz = blockIdx.z;
408
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
409
+ return;
410
+
411
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
412
+ vec3f wi = p.wi.fetch3(px, py, pz);
413
+ float d_out = p.out.fetch1(px, py, pz);
414
+
415
+ vec3f d_nrm(0), d_wi(0);
416
+ bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
417
+
418
+ p.nrm.store_grad(px, py, pz, d_nrm);
419
+ p.wi.store_grad(px, py, pz, d_wi);
420
+ }
421
+
422
+ __global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
423
+ {
424
+ // Calculate pixel position.
425
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
426
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
427
+ unsigned int pz = blockIdx.z;
428
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
429
+ return;
430
+
431
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
432
+ vec3f wi = p.wi.fetch3(px, py, pz);
433
+ vec3f wo = p.wo.fetch3(px, py, pz);
434
+ float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
435
+
436
+ float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
437
+
438
+ p.out.store(px, py, pz, res);
439
+ }
440
+
441
+ __global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
442
+ {
443
+ // Calculate pixel position.
444
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
445
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
446
+ unsigned int pz = blockIdx.z;
447
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
448
+ return;
449
+
450
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
451
+ vec3f wi = p.wi.fetch3(px, py, pz);
452
+ vec3f wo = p.wo.fetch3(px, py, pz);
453
+ float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
454
+ float d_out = p.out.fetch1(px, py, pz);
455
+
456
+ float d_linearRoughness = 0.0f;
457
+ vec3f d_nrm(0), d_wi(0), d_wo(0);
458
+ bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
459
+
460
+ p.nrm.store_grad(px, py, pz, d_nrm);
461
+ p.wi.store_grad(px, py, pz, d_wi);
462
+ p.wo.store_grad(px, py, pz, d_wo);
463
+ p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
464
+ }
465
+
466
+ __global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
467
+ {
468
+ // Calculate pixel position.
469
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
470
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
471
+ unsigned int pz = blockIdx.z;
472
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
473
+ return;
474
+
475
+ vec3f f0 = p.f0.fetch3(px, py, pz);
476
+ vec3f f90 = p.f90.fetch3(px, py, pz);
477
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
478
+
479
+ vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
480
+ p.out.store(px, py, pz, res);
481
+ }
482
+
483
+ __global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
484
+ {
485
+ // Calculate pixel position.
486
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
487
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
488
+ unsigned int pz = blockIdx.z;
489
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
490
+ return;
491
+
492
+ vec3f f0 = p.f0.fetch3(px, py, pz);
493
+ vec3f f90 = p.f90.fetch3(px, py, pz);
494
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
495
+ vec3f d_out = p.out.fetch3(px, py, pz);
496
+
497
+ vec3f d_f0(0), d_f90(0);
498
+ float d_cosTheta(0);
499
+ bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
500
+
501
+ p.f0.store_grad(px, py, pz, d_f0);
502
+ p.f90.store_grad(px, py, pz, d_f90);
503
+ p.cosTheta.store_grad(px, py, pz, d_cosTheta);
504
+ }
505
+
506
+ __global__ void ndfGGXFwdKernel(NdfGGXParams p)
507
+ {
508
+ // Calculate pixel position.
509
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
510
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
511
+ unsigned int pz = blockIdx.z;
512
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
513
+ return;
514
+
515
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
516
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
517
+ float res = fwdNdfGGX(alphaSqr, cosTheta);
518
+
519
+ p.out.store(px, py, pz, res);
520
+ }
521
+
522
+ __global__ void ndfGGXBwdKernel(NdfGGXParams p)
523
+ {
524
+ // Calculate pixel position.
525
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
526
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
527
+ unsigned int pz = blockIdx.z;
528
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
529
+ return;
530
+
531
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
532
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
533
+ float d_out = p.out.fetch1(px, py, pz);
534
+
535
+ float d_alphaSqr(0), d_cosTheta(0);
536
+ bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
537
+
538
+ p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
539
+ p.cosTheta.store_grad(px, py, pz, d_cosTheta);
540
+ }
541
+
542
+ __global__ void lambdaGGXFwdKernel(NdfGGXParams p)
543
+ {
544
+ // Calculate pixel position.
545
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
546
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
547
+ unsigned int pz = blockIdx.z;
548
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
549
+ return;
550
+
551
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
552
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
553
+ float res = fwdLambdaGGX(alphaSqr, cosTheta);
554
+
555
+ p.out.store(px, py, pz, res);
556
+ }
557
+
558
+ __global__ void lambdaGGXBwdKernel(NdfGGXParams p)
559
+ {
560
+ // Calculate pixel position.
561
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
562
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
563
+ unsigned int pz = blockIdx.z;
564
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
565
+ return;
566
+
567
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
568
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
569
+ float d_out = p.out.fetch1(px, py, pz);
570
+
571
+ float d_alphaSqr(0), d_cosTheta(0);
572
+ bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
573
+
574
+ p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
575
+ p.cosTheta.store_grad(px, py, pz, d_cosTheta);
576
+ }
577
+
578
+ __global__ void maskingSmithFwdKernel(MaskingSmithParams p)
579
+ {
580
+ // Calculate pixel position.
581
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
582
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
583
+ unsigned int pz = blockIdx.z;
584
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
585
+ return;
586
+
587
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
588
+ float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
589
+ float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
590
+ float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
591
+
592
+ p.out.store(px, py, pz, res);
593
+ }
594
+
595
+ __global__ void maskingSmithBwdKernel(MaskingSmithParams p)
596
+ {
597
+ // Calculate pixel position.
598
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
599
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
600
+ unsigned int pz = blockIdx.z;
601
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
602
+ return;
603
+
604
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
605
+ float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
606
+ float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
607
+ float d_out = p.out.fetch1(px, py, pz);
608
+
609
+ float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
610
+ bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
611
+
612
+ p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
613
+ p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
614
+ p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
615
+ }
616
+
617
+ __global__ void pbrSpecularFwdKernel(PbrSpecular p)
618
+ {
619
+ // Calculate pixel position.
620
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
621
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
622
+ unsigned int pz = blockIdx.z;
623
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
624
+ return;
625
+
626
+ vec3f col = p.col.fetch3(px, py, pz);
627
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
628
+ vec3f wo = p.wo.fetch3(px, py, pz);
629
+ vec3f wi = p.wi.fetch3(px, py, pz);
630
+ float alpha = p.alpha.fetch1(px, py, pz);
631
+
632
+ vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
633
+
634
+ p.out.store(px, py, pz, res);
635
+ }
636
+
637
+ __global__ void pbrSpecularBwdKernel(PbrSpecular p)
638
+ {
639
+ // Calculate pixel position.
640
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
641
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
642
+ unsigned int pz = blockIdx.z;
643
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
644
+ return;
645
+
646
+ vec3f col = p.col.fetch3(px, py, pz);
647
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
648
+ vec3f wo = p.wo.fetch3(px, py, pz);
649
+ vec3f wi = p.wi.fetch3(px, py, pz);
650
+ float alpha = p.alpha.fetch1(px, py, pz);
651
+ vec3f d_out = p.out.fetch3(px, py, pz);
652
+
653
+ float d_alpha(0);
654
+ vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
655
+ bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
656
+
657
+ p.col.store_grad(px, py, pz, d_col);
658
+ p.nrm.store_grad(px, py, pz, d_nrm);
659
+ p.wo.store_grad(px, py, pz, d_wo);
660
+ p.wi.store_grad(px, py, pz, d_wi);
661
+ p.alpha.store_grad(px, py, pz, d_alpha);
662
+ }
663
+
664
+ __global__ void pbrBSDFFwdKernel(PbrBSDF p)
665
+ {
666
+ // Calculate pixel position.
667
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
668
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
669
+ unsigned int pz = blockIdx.z;
670
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
671
+ return;
672
+
673
+ vec3f kd = p.kd.fetch3(px, py, pz);
674
+ vec3f arm = p.arm.fetch3(px, py, pz);
675
+ vec3f pos = p.pos.fetch3(px, py, pz);
676
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
677
+ vec3f view_pos = p.view_pos.fetch3(px, py, pz);
678
+ vec3f light_pos = p.light_pos.fetch3(px, py, pz);
679
+
680
+ vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
681
+
682
+ p.out.store(px, py, pz, res);
683
+ }
684
+ __global__ void pbrBSDFBwdKernel(PbrBSDF p)
685
+ {
686
+ // Calculate pixel position.
687
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
688
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
689
+ unsigned int pz = blockIdx.z;
690
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
691
+ return;
692
+
693
+ vec3f kd = p.kd.fetch3(px, py, pz);
694
+ vec3f arm = p.arm.fetch3(px, py, pz);
695
+ vec3f pos = p.pos.fetch3(px, py, pz);
696
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
697
+ vec3f view_pos = p.view_pos.fetch3(px, py, pz);
698
+ vec3f light_pos = p.light_pos.fetch3(px, py, pz);
699
+ vec3f d_out = p.out.fetch3(px, py, pz);
700
+
701
+ vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
702
+ bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
703
+
704
+ p.kd.store_grad(px, py, pz, d_kd);
705
+ p.arm.store_grad(px, py, pz, d_arm);
706
+ p.pos.store_grad(px, py, pz, d_pos);
707
+ p.nrm.store_grad(px, py, pz, d_nrm);
708
+ p.view_pos.store_grad(px, py, pz, d_view_pos);
709
+ p.light_pos.store_grad(px, py, pz, d_light_pos);
710
+ }
video3d/render/renderutils/c_src/bsdf.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include "common.h"
15
+
16
+ struct LambertKernelParams
17
+ {
18
+ Tensor nrm;
19
+ Tensor wi;
20
+ Tensor out;
21
+ dim3 gridSize;
22
+ };
23
+
24
+ struct FrostbiteDiffuseKernelParams
25
+ {
26
+ Tensor nrm;
27
+ Tensor wi;
28
+ Tensor wo;
29
+ Tensor linearRoughness;
30
+ Tensor out;
31
+ dim3 gridSize;
32
+ };
33
+
34
+ struct FresnelShlickKernelParams
35
+ {
36
+ Tensor f0;
37
+ Tensor f90;
38
+ Tensor cosTheta;
39
+ Tensor out;
40
+ dim3 gridSize;
41
+ };
42
+
43
+ struct NdfGGXParams
44
+ {
45
+ Tensor alphaSqr;
46
+ Tensor cosTheta;
47
+ Tensor out;
48
+ dim3 gridSize;
49
+ };
50
+
51
+ struct MaskingSmithParams
52
+ {
53
+ Tensor alphaSqr;
54
+ Tensor cosThetaI;
55
+ Tensor cosThetaO;
56
+ Tensor out;
57
+ dim3 gridSize;
58
+ };
59
+
60
+ struct PbrSpecular
61
+ {
62
+ Tensor col;
63
+ Tensor nrm;
64
+ Tensor wo;
65
+ Tensor wi;
66
+ Tensor alpha;
67
+ Tensor out;
68
+ dim3 gridSize;
69
+ float min_roughness;
70
+ };
71
+
72
+ struct PbrBSDF
73
+ {
74
+ Tensor kd;
75
+ Tensor arm;
76
+ Tensor pos;
77
+ Tensor nrm;
78
+ Tensor view_pos;
79
+ Tensor light_pos;
80
+ Tensor out;
81
+ dim3 gridSize;
82
+ float min_roughness;
83
+ int BSDF;
84
+ };
video3d/render/renderutils/c_src/common.cpp ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include <cuda_runtime.h>
13
+ #include <algorithm>
14
+
15
+ //------------------------------------------------------------------------
16
+ // Block and grid size calculators for kernel launches.
17
+
18
+ dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
19
+ {
20
+ int maxThreads = maxWidth * maxHeight;
21
+ if (maxThreads <= 1 || (dims.x * dims.y) <= 1)
22
+ return dim3(1, 1, 1); // Degenerate.
23
+
24
+ // Start from max size.
25
+ int bw = maxWidth;
26
+ int bh = maxHeight;
27
+
28
+ // Optimizations for weirdly sized buffers.
29
+ if (dims.x < bw)
30
+ {
31
+ // Decrease block width to smallest power of two that covers the buffer width.
32
+ while ((bw >> 1) >= dims.x)
33
+ bw >>= 1;
34
+
35
+ // Maximize height.
36
+ bh = maxThreads / bw;
37
+ if (bh > dims.y)
38
+ bh = dims.y;
39
+ }
40
+ else if (dims.y < bh)
41
+ {
42
+ // Halve height and double width until fits completely inside buffer vertically.
43
+ while (bh > dims.y)
44
+ {
45
+ bh >>= 1;
46
+ if (bw < dims.x)
47
+ bw <<= 1;
48
+ }
49
+ }
50
+
51
+ // Done.
52
+ return dim3(bw, bh, 1);
53
+ }
54
+
55
+ // returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync)
56
+ dim3 getWarpSize(dim3 blockSize)
57
+ {
58
+ return dim3(
59
+ std::min(blockSize.x, 32u),
60
+ std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)),
61
+ std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z))
62
+ );
63
+ }
64
+
65
+ dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)
66
+ {
67
+ dim3 gridSize;
68
+ gridSize.x = (dims.x - 1) / blockSize.x + 1;
69
+ gridSize.y = (dims.y - 1) / blockSize.y + 1;
70
+ gridSize.z = (dims.z - 1) / blockSize.z + 1;
71
+ return gridSize;
72
+ }
73
+
74
+ //------------------------------------------------------------------------
video3d/render/renderutils/c_src/common.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+ #include <cuda.h>
14
+ #include <stdint.h>
15
+
16
+ #include "vec3f.h"
17
+ #include "vec4f.h"
18
+ #include "tensor.h"
19
+
20
+ dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims);
21
+ dim3 getLaunchGridSize(dim3 blockSize, dim3 dims);
22
+
23
+ #ifdef __CUDACC__
24
+
25
+ #ifdef _MSC_VER
26
+ #define M_PI 3.14159265358979323846f
27
+ #endif
28
+
29
+ __host__ __device__ static inline dim3 getWarpSize(dim3 blockSize)
30
+ {
31
+ return dim3(
32
+ min(blockSize.x, 32u),
33
+ min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)),
34
+ min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z))
35
+ );
36
+ }
37
+
38
+ __device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); }
39
+ #else
40
+ dim3 getWarpSize(dim3 blockSize);
41
+ #endif
video3d/render/renderutils/c_src/cubemap.cu ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include "common.h"
13
+ #include "cubemap.h"
14
+ #include <float.h>
15
+
16
+ // https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf
17
+ __device__ float pixel_area(int x, int y, int N)
18
+ {
19
+ if (N > 1)
20
+ {
21
+ int H = N / 2;
22
+ x = abs(x - H);
23
+ y = abs(y - H);
24
+ float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H);
25
+ float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H);
26
+ return dx * dy;
27
+ }
28
+ else
29
+ return 1;
30
+ }
31
+
32
+ __device__ vec3f cube_to_dir(int x, int y, int side, int N)
33
+ {
34
+ float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f;
35
+ float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f;
36
+ switch (side)
37
+ {
38
+ case 0: return safeNormalize(vec3f(1, -fy, -fx));
39
+ case 1: return safeNormalize(vec3f(-1, -fy, fx));
40
+ case 2: return safeNormalize(vec3f(fx, 1, fy));
41
+ case 3: return safeNormalize(vec3f(fx, -1, -fy));
42
+ case 4: return safeNormalize(vec3f(fx, -fy, 1));
43
+ case 5: return safeNormalize(vec3f(-fx, -fy, -1));
44
+ }
45
+ return vec3f(0,0,0); // Unreachable
46
+ }
47
+
48
+ __device__ vec3f dir_to_side(int side, vec3f v)
49
+ {
50
+ switch (side)
51
+ {
52
+ case 0: return vec3f(-v.z, -v.y, v.x);
53
+ case 1: return vec3f( v.z, -v.y, -v.x);
54
+ case 2: return vec3f( v.x, v.z, v.y);
55
+ case 3: return vec3f( v.x, -v.z, -v.y);
56
+ case 4: return vec3f( v.x, -v.y, v.z);
57
+ case 5: return vec3f(-v.x, -v.y, -v.z);
58
+ }
59
+ return vec3f(0,0,0); // Unreachable
60
+ }
61
+
62
+ __device__ void extents_1d(float x, float z, float theta, float& _min, float& _max)
63
+ {
64
+ float l = sqrtf(x * x + z * z);
65
+ float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l;
66
+ float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l;
67
+ if (pzl <= 0.00001f)
68
+ _min = pxl > 0.0f ? FLT_MAX : -FLT_MAX;
69
+ else
70
+ _min = pxl / pzl;
71
+ if (pzr <= 0.00001f)
72
+ _max = pxr > 0.0f ? FLT_MAX : -FLT_MAX;
73
+ else
74
+ _max = pxr / pzr;
75
+ }
76
+
77
+ __device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax)
78
+ {
79
+ vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1
80
+
81
+ if (theta < 0.785398f) // PI/4
82
+ {
83
+ float xmin, xmax, ymin, ymax;
84
+ extents_1d(c.x, c.z, theta, xmin, xmax);
85
+ extents_1d(c.y, c.z, theta, ymin, ymax);
86
+
87
+ if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f)
88
+ {
89
+ _xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb
90
+ }
91
+ else
92
+ {
93
+ _xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
94
+ _xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
95
+ _ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
96
+ _ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
97
+ }
98
+ }
99
+ else
100
+ {
101
+ _xmin = 0.0f;
102
+ _xmax = (float)(N-1);
103
+ _ymin = 0.0f;
104
+ _ymax = (float)(N-1);
105
+ }
106
+ }
107
+
108
+ ///////////////////////////////////////////////////////////////////////////////////////////////////////////
109
+ // Diffuse kernel
110
+ __global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p)
111
+ {
112
+ // Calculate pixel position.
113
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
114
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
115
+ int pz = blockIdx.z;
116
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
117
+ return;
118
+
119
+ int Npx = p.cubemap.dims[1];
120
+ vec3f N = cube_to_dir(px, py, pz, Npx);
121
+
122
+ vec3f col(0);
123
+
124
+ for (int s = 0; s < p.cubemap.dims[0]; ++s)
125
+ {
126
+ for (int y = 0; y < Npx; ++y)
127
+ {
128
+ for (int x = 0; x < Npx; ++x)
129
+ {
130
+ vec3f L = cube_to_dir(x, y, s, Npx);
131
+ float costheta = min(max(dot(N, L), 0.0f), 0.999f);
132
+ float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
133
+ col += p.cubemap.fetch3(x, y, s) * w;
134
+ }
135
+ }
136
+ }
137
+
138
+ p.out.store(px, py, pz, col);
139
+ }
140
+
141
+ __global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p)
142
+ {
143
+ // Calculate pixel position.
144
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
145
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
146
+ int pz = blockIdx.z;
147
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
148
+ return;
149
+
150
+ int Npx = p.cubemap.dims[1];
151
+ vec3f N = cube_to_dir(px, py, pz, Npx);
152
+ vec3f grad = p.out.fetch3(px, py, pz);
153
+
154
+ for (int s = 0; s < p.cubemap.dims[0]; ++s)
155
+ {
156
+ for (int y = 0; y < Npx; ++y)
157
+ {
158
+ for (int x = 0; x < Npx; ++x)
159
+ {
160
+ vec3f L = cube_to_dir(x, y, s, Npx);
161
+ float costheta = min(max(dot(N, L), 0.0f), 0.999f);
162
+ float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
163
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
164
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
165
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
166
+ }
167
+ }
168
+ }
169
+ }
170
+
171
+ ///////////////////////////////////////////////////////////////////////////////////////////////////////////
172
+ // GGX splitsum kernel
173
+
174
+ __device__ inline float ndfGGX(const float alphaSqr, const float cosTheta)
175
+ {
176
+ float _cosTheta = clamp(cosTheta, 0.0, 1.0f);
177
+ float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
178
+ return alphaSqr / (d * d * M_PI);
179
+ }
180
+
181
+ __global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p)
182
+ {
183
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
184
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
185
+ int pz = blockIdx.z;
186
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
187
+ return;
188
+
189
+ int Npx = p.gridSize.x;
190
+ vec3f VNR = cube_to_dir(px, py, pz, Npx);
191
+
192
+ const int TILE_SIZE = 16;
193
+
194
+ // Brute force entire cubemap and compute bounds for the cone
195
+ for (int s = 0; s < p.gridSize.z; ++s)
196
+ {
197
+ // Assume empty BBox
198
+ int _min_x = p.gridSize.x - 1, _max_x = 0;
199
+ int _min_y = p.gridSize.y - 1, _max_y = 0;
200
+
201
+ // For each (8x8) tile
202
+ for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++)
203
+ {
204
+ for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++)
205
+ {
206
+ // Compute tile extents
207
+ int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE;
208
+ int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y);
209
+
210
+ // Use some blunt interval arithmetics to cull tiles
211
+ vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx);
212
+ vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx);
213
+
214
+ float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x));
215
+ float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y));
216
+ float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z));
217
+
218
+ float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z);
219
+ if (maxdp >= p.costheta_cutoff)
220
+ {
221
+ // Test all pixels in tile.
222
+ for (int y = tsy; y < tey; ++y)
223
+ {
224
+ for (int x = tsx; x < tex; ++x)
225
+ {
226
+ vec3f L = cube_to_dir(x, y, s, Npx);
227
+ if (dot(L, VNR) >= p.costheta_cutoff)
228
+ {
229
+ _min_x = min(_min_x, x);
230
+ _max_x = max(_max_x, x);
231
+ _min_y = min(_min_y, y);
232
+ _max_y = max(_max_y, y);
233
+ }
234
+ }
235
+ }
236
+ }
237
+ }
238
+ }
239
+ p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x);
240
+ p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x);
241
+ p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y);
242
+ p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y);
243
+ }
244
+ }
245
+
246
+ __global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p)
247
+ {
248
+ // Calculate pixel position.
249
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
250
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
251
+ int pz = blockIdx.z;
252
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
253
+ return;
254
+
255
+ int Npx = p.cubemap.dims[1];
256
+ vec3f VNR = cube_to_dir(px, py, pz, Npx);
257
+
258
+ float alpha = p.roughness * p.roughness;
259
+ float alphaSqr = alpha * alpha;
260
+
261
+ float wsum = 0.0f;
262
+ vec3f col(0);
263
+ for (int s = 0; s < p.cubemap.dims[0]; ++s)
264
+ {
265
+ int xmin, xmax, ymin, ymax;
266
+ xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
267
+ xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
268
+ ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
269
+ ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
270
+
271
+ if (xmin <= xmax)
272
+ {
273
+ for (int y = ymin; y <= ymax; ++y)
274
+ {
275
+ for (int x = xmin; x <= xmax; ++x)
276
+ {
277
+ vec3f L = cube_to_dir(x, y, s, Npx);
278
+ if (dot(L, VNR) >= p.costheta_cutoff)
279
+ {
280
+ vec3f H = safeNormalize(L + VNR);
281
+
282
+ float wiDotN = max(dot(L, VNR), 0.0f);
283
+ float VNRDotH = max(dot(VNR, H), 0.0f);
284
+
285
+ float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
286
+ col += p.cubemap.fetch3(x, y, s) * w;
287
+ wsum += w;
288
+ }
289
+ }
290
+ }
291
+ }
292
+ }
293
+
294
+ p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x);
295
+ p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y);
296
+ p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z);
297
+ p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum);
298
+ }
299
+
300
+ __global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p)
301
+ {
302
+ // Calculate pixel position.
303
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
304
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
305
+ int pz = blockIdx.z;
306
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
307
+ return;
308
+
309
+ int Npx = p.cubemap.dims[1];
310
+ vec3f VNR = cube_to_dir(px, py, pz, Npx);
311
+
312
+ vec3f grad = p.out.fetch3(px, py, pz);
313
+
314
+ float alpha = p.roughness * p.roughness;
315
+ float alphaSqr = alpha * alpha;
316
+
317
+ vec3f col(0);
318
+ for (int s = 0; s < p.cubemap.dims[0]; ++s)
319
+ {
320
+ int xmin, xmax, ymin, ymax;
321
+ xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
322
+ xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
323
+ ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
324
+ ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
325
+
326
+ if (xmin <= xmax)
327
+ {
328
+ for (int y = ymin; y <= ymax; ++y)
329
+ {
330
+ for (int x = xmin; x <= xmax; ++x)
331
+ {
332
+ vec3f L = cube_to_dir(x, y, s, Npx);
333
+ if (dot(L, VNR) >= p.costheta_cutoff)
334
+ {
335
+ vec3f H = safeNormalize(L + VNR);
336
+
337
+ float wiDotN = max(dot(L, VNR), 0.0f);
338
+ float VNRDotH = max(dot(VNR, H), 0.0f);
339
+
340
+ float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
341
+
342
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
343
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
344
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
345
+ }
346
+ }
347
+ }
348
+ }
349
+ }
350
+ }
video3d/render/renderutils/c_src/cubemap.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include "common.h"
15
+
16
+ struct DiffuseCubemapKernelParams
17
+ {
18
+ Tensor cubemap;
19
+ Tensor out;
20
+ dim3 gridSize;
21
+ };
22
+
23
+ struct SpecularCubemapKernelParams
24
+ {
25
+ Tensor cubemap;
26
+ Tensor bounds;
27
+ Tensor out;
28
+ dim3 gridSize;
29
+ float costheta_cutoff;
30
+ float roughness;
31
+ };
32
+
33
+ struct SpecularBoundsKernelParams
34
+ {
35
+ float costheta_cutoff;
36
+ Tensor out;
37
+ dim3 gridSize;
38
+ };
video3d/render/renderutils/c_src/loss.cu ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include <cuda.h>
13
+
14
+ #include "common.h"
15
+ #include "loss.h"
16
+
17
+ //------------------------------------------------------------------------
18
+ // Utils
19
+
20
+ __device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; }
21
+
22
+ __device__ float warpSum(float val) {
23
+ for (int i = 1; i < 32; i *= 2)
24
+ val += __shfl_xor_sync(0xFFFFFFFF, val, i);
25
+ return val;
26
+ }
27
+
28
+ //------------------------------------------------------------------------
29
+ // Tonemapping
30
+
31
+ __device__ inline float fwdSRGB(float x)
32
+ {
33
+ return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f);
34
+ }
35
+
36
+ __device__ inline void bwdSRGB(float x, float &d_x, float d_out)
37
+ {
38
+ if (x > 0.0031308f)
39
+ d_x += d_out * 0.439583f / powf(x, 0.583333f);
40
+ else if (x > 0.0f)
41
+ d_x += d_out * 12.92f;
42
+ }
43
+
44
+ __device__ inline vec3f fwdTonemapLogSRGB(vec3f x)
45
+ {
46
+ return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f)));
47
+ }
48
+
49
+ __device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out)
50
+ {
51
+ if (x.x > 0.0f && x.x < 65535.0f)
52
+ {
53
+ bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x);
54
+ d_x.x *= 1 / (x.x + 1.0f);
55
+ }
56
+ if (x.y > 0.0f && x.y < 65535.0f)
57
+ {
58
+ bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y);
59
+ d_x.y *= 1 / (x.y + 1.0f);
60
+ }
61
+ if (x.z > 0.0f && x.z < 65535.0f)
62
+ {
63
+ bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z);
64
+ d_x.z *= 1 / (x.z + 1.0f);
65
+ }
66
+ }
67
+
68
+ __device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f)
69
+ {
70
+ return (img - target) * (img - target) / (img * img + target * target + eps);
71
+ }
72
+
73
+ __device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f)
74
+ {
75
+ float denom = (target * target + img * img + eps);
76
+ d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom);
77
+ d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom);
78
+ }
79
+
80
+ __device__ inline float fwdSMAPE(float img, float target, float eps=0.01f)
81
+ {
82
+ return abs(img - target) / (img + target + eps);
83
+ }
84
+
85
+ __device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f)
86
+ {
87
+ float denom = (target + img + eps);
88
+ d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom);
89
+ d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom);
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+ // Kernels
94
+
95
+ __global__ void imgLossFwdKernel(LossKernelParams p)
96
+ {
97
+ // Calculate pixel position.
98
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
99
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
100
+ unsigned int pz = blockIdx.z;
101
+
102
+ float floss = 0.0f;
103
+ if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z)
104
+ {
105
+ vec3f img = p.img.fetch3(px, py, pz);
106
+ vec3f target = p.target.fetch3(px, py, pz);
107
+
108
+ img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f));
109
+ target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f));
110
+
111
+ if (p.tonemapper == TONEMAPPER_LOG_SRGB)
112
+ {
113
+ img = fwdTonemapLogSRGB(img);
114
+ target = fwdTonemapLogSRGB(target);
115
+ }
116
+
117
+ vec3f vloss(0);
118
+ if (p.loss == LOSS_MSE)
119
+ vloss = (img - target) * (img - target);
120
+ else if (p.loss == LOSS_RELMSE)
121
+ vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z));
122
+ else if (p.loss == LOSS_SMAPE)
123
+ vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z));
124
+ else
125
+ vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z));
126
+
127
+ floss = sum(vloss) / 3.0f;
128
+ }
129
+
130
+ floss = warpSum(floss);
131
+
132
+ dim3 warpSize = getWarpSize(blockDim);
133
+ if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0)
134
+ p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss);
135
+ }
136
+
137
+ __global__ void imgLossBwdKernel(LossKernelParams p)
138
+ {
139
+ // Calculate pixel position.
140
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
141
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
142
+ unsigned int pz = blockIdx.z;
143
+
144
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
145
+ return;
146
+
147
+ dim3 warpSize = getWarpSize(blockDim);
148
+
149
+ vec3f _img = p.img.fetch3(px, py, pz);
150
+ vec3f _target = p.target.fetch3(px, py, pz);
151
+ float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z);
152
+
153
+ /////////////////////////////////////////////////////////////////////
154
+ // FWD
155
+
156
+ vec3f img = _img, target = _target;
157
+ if (p.tonemapper == TONEMAPPER_LOG_SRGB)
158
+ {
159
+ img = fwdTonemapLogSRGB(img);
160
+ target = fwdTonemapLogSRGB(target);
161
+ }
162
+
163
+ /////////////////////////////////////////////////////////////////////
164
+ // BWD
165
+
166
+ vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f;
167
+
168
+ vec3f d_img(0), d_target(0);
169
+ if (p.loss == LOSS_MSE)
170
+ {
171
+ d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z));
172
+ d_target = -d_img;
173
+ }
174
+ else if (p.loss == LOSS_RELMSE)
175
+ {
176
+ bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
177
+ bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
178
+ bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
179
+ }
180
+ else if (p.loss == LOSS_SMAPE)
181
+ {
182
+ bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
183
+ bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
184
+ bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
185
+ }
186
+ else
187
+ {
188
+ d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z));
189
+ d_target = -d_img;
190
+ }
191
+
192
+
193
+ if (p.tonemapper == TONEMAPPER_LOG_SRGB)
194
+ {
195
+ vec3f d__img(0), d__target(0);
196
+ bwdTonemapLogSRGB(_img, d__img, d_img);
197
+ bwdTonemapLogSRGB(_target, d__target, d_target);
198
+ d_img = d__img; d_target = d__target;
199
+ }
200
+
201
+ if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0;
202
+ if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0;
203
+ if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0;
204
+ if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0;
205
+ if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0;
206
+ if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0;
207
+
208
+ p.img.store_grad(px, py, pz, d_img);
209
+ p.target.store_grad(px, py, pz, d_target);
210
+ }
video3d/render/renderutils/c_src/loss.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include "common.h"
15
+
16
+ enum TonemapperType
17
+ {
18
+ TONEMAPPER_NONE = 0,
19
+ TONEMAPPER_LOG_SRGB = 1
20
+ };
21
+
22
+ enum LossType
23
+ {
24
+ LOSS_L1 = 0,
25
+ LOSS_MSE = 1,
26
+ LOSS_RELMSE = 2,
27
+ LOSS_SMAPE = 3
28
+ };
29
+
30
+ struct LossKernelParams
31
+ {
32
+ Tensor img;
33
+ Tensor target;
34
+ Tensor out;
35
+ dim3 gridSize;
36
+ TonemapperType tonemapper;
37
+ LossType loss;
38
+ };
video3d/render/renderutils/c_src/mesh.cu ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include <cuda.h>
13
+ #include <stdio.h>
14
+
15
+ #include "common.h"
16
+ #include "mesh.h"
17
+
18
+
19
+ //------------------------------------------------------------------------
20
+ // Kernels
21
+
22
+ __global__ void xfmPointsFwdKernel(XfmKernelParams p)
23
+ {
24
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
25
+ unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
26
+
27
+ __shared__ float mtx[4][4];
28
+ if (threadIdx.x < 16)
29
+ mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
30
+ __syncthreads();
31
+
32
+ if (px >= p.gridSize.x)
33
+ return;
34
+
35
+ vec3f pos(
36
+ p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
37
+ p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
38
+ p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
39
+ );
40
+
41
+ if (p.isPoints)
42
+ {
43
+ p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]);
44
+ p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]);
45
+ p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]);
46
+ p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]);
47
+ }
48
+ else
49
+ {
50
+ p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]);
51
+ p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]);
52
+ p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]);
53
+ }
54
+ }
55
+
56
+ __global__ void xfmPointsBwdKernel(XfmKernelParams p)
57
+ {
58
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
59
+ unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
60
+
61
+ __shared__ float mtx[4][4];
62
+ if (threadIdx.x < 16)
63
+ mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
64
+ __syncthreads();
65
+
66
+ if (px >= p.gridSize.x)
67
+ return;
68
+
69
+ vec3f pos(
70
+ p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
71
+ p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
72
+ p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
73
+ );
74
+
75
+ vec4f d_out(
76
+ p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)),
77
+ p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)),
78
+ p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)),
79
+ p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0))
80
+ );
81
+
82
+ if (p.isPoints)
83
+ {
84
+ p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]);
85
+ p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]);
86
+ p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]);
87
+ }
88
+ else
89
+ {
90
+ p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]);
91
+ p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]);
92
+ p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]);
93
+ }
94
+ }
video3d/render/renderutils/c_src/mesh.h ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include "common.h"
15
+
16
+ struct XfmKernelParams
17
+ {
18
+ bool isPoints;
19
+ Tensor points;
20
+ Tensor matrix;
21
+ Tensor out;
22
+ dim3 gridSize;
23
+ };
video3d/render/renderutils/c_src/normal.cu ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include "common.h"
13
+ #include "normal.h"
14
+
15
+ #define NORMAL_THRESHOLD 0.1f
16
+
17
+ //------------------------------------------------------------------------
18
+ // Perturb shading normal by tangent frame
19
+
20
+ __device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl)
21
+ {
22
+ vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
23
+ vec3f smooth_bitng = safeNormalize(_smooth_bitng);
24
+ vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
25
+ return safeNormalize(_shading_nrm);
26
+ }
27
+
28
+ __device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl)
29
+ {
30
+ ////////////////////////////////////////////////////////////////////////
31
+ // FWD
32
+ vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
33
+ vec3f smooth_bitng = safeNormalize(_smooth_bitng);
34
+ vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
35
+
36
+ ////////////////////////////////////////////////////////////////////////
37
+ // BWD
38
+ vec3f d_shading_nrm(0);
39
+ bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out);
40
+
41
+ vec3f d_smooth_bitng(0);
42
+
43
+ if (perturbed_nrm.z > 0.0f)
44
+ {
45
+ d_smooth_nrm += d_shading_nrm * perturbed_nrm.z;
46
+ d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm);
47
+ }
48
+
49
+ d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y;
50
+ d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng);
51
+
52
+ d_smooth_tng += d_shading_nrm * perturbed_nrm.x;
53
+ d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng);
54
+
55
+ vec3f d__smooth_bitng(0);
56
+ bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng);
57
+
58
+ bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng);
59
+ }
60
+
61
+ //------------------------------------------------------------------------
62
+ #define bent_nrm_eps 0.001f
63
+
64
+ __device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm)
65
+ {
66
+ float dp = dot(view_vec, smooth_nrm);
67
+ float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
68
+ return geom_nrm * (1.0f - t) + smooth_nrm * t;
69
+ }
70
+
71
+ __device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out)
72
+ {
73
+ ////////////////////////////////////////////////////////////////////////
74
+ // FWD
75
+ float dp = dot(view_vec, smooth_nrm);
76
+ float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
77
+
78
+ ////////////////////////////////////////////////////////////////////////
79
+ // BWD
80
+ if (dp > NORMAL_THRESHOLD)
81
+ d_smooth_nrm += d_out;
82
+ else
83
+ {
84
+ // geom_nrm * (1.0f - t) + smooth_nrm * t;
85
+ d_geom_nrm += d_out * (1.0f - t);
86
+ d_smooth_nrm += d_out * t;
87
+ float d_t = sum(d_out * (smooth_nrm - geom_nrm));
88
+
89
+ float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD;
90
+
91
+ bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp);
92
+ }
93
+ }
94
+
95
+ //------------------------------------------------------------------------
96
+ // Kernels
97
+
98
+ __global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p)
99
+ {
100
+ // Calculate pixel position.
101
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
102
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
103
+ unsigned int pz = blockIdx.z;
104
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
105
+ return;
106
+
107
+ vec3f pos = p.pos.fetch3(px, py, pz);
108
+ vec3f view_pos = p.view_pos.fetch3(px, py, pz);
109
+ vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
110
+ vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
111
+ vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
112
+ vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
113
+
114
+ vec3f smooth_nrm = safeNormalize(_smooth_nrm);
115
+ vec3f smooth_tng = safeNormalize(_smooth_tng);
116
+ vec3f view_vec = safeNormalize(view_pos - pos);
117
+ vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
118
+
119
+ vec3f res;
120
+ if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
121
+ res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm);
122
+ else
123
+ res = fwdBendNormal(view_vec, shading_nrm, geom_nrm);
124
+
125
+ p.out.store(px, py, pz, res);
126
+ }
127
+
128
+ __global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p)
129
+ {
130
+ // Calculate pixel position.
131
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
132
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
133
+ unsigned int pz = blockIdx.z;
134
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
135
+ return;
136
+
137
+ vec3f pos = p.pos.fetch3(px, py, pz);
138
+ vec3f view_pos = p.view_pos.fetch3(px, py, pz);
139
+ vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
140
+ vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
141
+ vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
142
+ vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
143
+ vec3f d_out = p.out.fetch3(px, py, pz);
144
+
145
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
146
+ // FWD
147
+
148
+ vec3f smooth_nrm = safeNormalize(_smooth_nrm);
149
+ vec3f smooth_tng = safeNormalize(_smooth_tng);
150
+ vec3f _view_vec = view_pos - pos;
151
+ vec3f view_vec = safeNormalize(view_pos - pos);
152
+
153
+ vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
154
+
155
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
156
+ // BWD
157
+
158
+ vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0);
159
+ if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
160
+ {
161
+ bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
162
+ d_shading_nrm = -d_shading_nrm;
163
+ d_geom_nrm = -d_geom_nrm;
164
+ }
165
+ else
166
+ bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
167
+
168
+ vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0);
169
+ bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl);
170
+
171
+ vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0);
172
+ bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec);
173
+ bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm);
174
+ bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng);
175
+
176
+ p.pos.store_grad(px, py, pz, -d__view_vec);
177
+ p.view_pos.store_grad(px, py, pz, d__view_vec);
178
+ p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm);
179
+ p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm);
180
+ p.smooth_tng.store_grad(px, py, pz, d__smooth_tng);
181
+ p.geom_nrm.store_grad(px, py, pz, d_geom_nrm);
182
+ }
video3d/render/renderutils/c_src/normal.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include "common.h"
15
+
16
+ struct PrepareShadingNormalKernelParams
17
+ {
18
+ Tensor pos;
19
+ Tensor view_pos;
20
+ Tensor perturbed_nrm;
21
+ Tensor smooth_nrm;
22
+ Tensor smooth_tng;
23
+ Tensor geom_nrm;
24
+ Tensor out;
25
+ dim3 gridSize;
26
+ bool two_sided_shading, opengl;
27
+ };
video3d/render/renderutils/c_src/tensor.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+ #if defined(__CUDACC__) && defined(BFLOAT16)
14
+ #include <cuda_bf16.h> // bfloat16 is float32 compatible with less mantissa bits
15
+ #endif
16
+
17
+ //---------------------------------------------------------------------------------
18
+ // CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16
19
+
20
+ struct Tensor
21
+ {
22
+ void* val;
23
+ void* d_val;
24
+ int dims[4], _dims[4];
25
+ int strides[4];
26
+ bool fp16;
27
+
28
+ #if defined(__CUDA__) && !defined(__CUDA_ARCH__)
29
+ Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {}
30
+ #endif
31
+
32
+ #ifdef __CUDACC__
33
+ // Helpers to index and read/write a single element
34
+ __device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; }
35
+ __device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); }
36
+ __device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; }
37
+ #ifdef BFLOAT16
38
+ __device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; }
39
+ __device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; }
40
+ __device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; }
41
+ #else
42
+ __device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; }
43
+ __device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; }
44
+ __device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; }
45
+ #endif
46
+
47
+ //////////////////////////////////////////////////////////////////////////////////////////
48
+ // Fetch, use broadcasting for tensor dimensions of size 1
49
+ __device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const
50
+ {
51
+ return fetch(nhwcIndex(z, y, x, 0));
52
+ }
53
+
54
+ __device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const
55
+ {
56
+ return vec3f(
57
+ fetch(nhwcIndex(z, y, x, 0)),
58
+ fetch(nhwcIndex(z, y, x, 1)),
59
+ fetch(nhwcIndex(z, y, x, 2))
60
+ );
61
+ }
62
+
63
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
64
+ // Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
65
+ __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val)
66
+ {
67
+ store(_nhwcIndex(z, y, x, 0), _val);
68
+ }
69
+
70
+ __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
71
+ {
72
+ store(_nhwcIndex(z, y, x, 0), _val.x);
73
+ store(_nhwcIndex(z, y, x, 1), _val.y);
74
+ store(_nhwcIndex(z, y, x, 2), _val.z);
75
+ }
76
+
77
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
78
+ // Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
79
+ __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val)
80
+ {
81
+ store_grad(nhwcIndexContinuous(z, y, x, 0), _val);
82
+ }
83
+
84
+ __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
85
+ {
86
+ store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x);
87
+ store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y);
88
+ store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z);
89
+ }
90
+ #endif
91
+
92
+ };
video3d/render/renderutils/c_src/torch_bindings.cpp ADDED
@@ -0,0 +1,1062 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #ifdef _MSC_VER
13
+ #pragma warning(push, 0)
14
+ #include <torch/extension.h>
15
+ #pragma warning(pop)
16
+ #else
17
+ #include <torch/extension.h>
18
+ #endif
19
+
20
+ #include <ATen/cuda/CUDAContext.h>
21
+ #include <ATen/cuda/CUDAUtils.h>
22
+ #include <algorithm>
23
+ #include <string>
24
+
25
+ #define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) { cudaError_t err = CUDA_CALL; AT_CUDA_CHECK(cudaGetLastError()); }
26
+ #define NVDR_CHECK_GL_ERROR(GL_CALL) { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); }
27
+ #define CHECK_TENSOR(X, DIMS, CHANNELS) \
28
+ TORCH_CHECK(X.is_cuda(), #X " must be a cuda tensor") \
29
+ TORCH_CHECK(X.scalar_type() == torch::kFloat || X.scalar_type() == torch::kBFloat16, #X " must be fp32 or bf16") \
30
+ TORCH_CHECK(X.dim() == DIMS, #X " must have " #DIMS " dimensions") \
31
+ TORCH_CHECK(X.size(DIMS - 1) == CHANNELS, #X " must have " #CHANNELS " channels")
32
+
33
+ #include "common.h"
34
+ #include "loss.h"
35
+ #include "normal.h"
36
+ #include "cubemap.h"
37
+ #include "bsdf.h"
38
+ #include "mesh.h"
39
+
40
+ #define BLOCK_X 8
41
+ #define BLOCK_Y 8
42
+
43
+ //------------------------------------------------------------------------
44
+ // mesh.cu
45
+
46
+ void xfmPointsFwdKernel(XfmKernelParams p);
47
+ void xfmPointsBwdKernel(XfmKernelParams p);
48
+
49
+ //------------------------------------------------------------------------
50
+ // loss.cu
51
+
52
+ void imgLossFwdKernel(LossKernelParams p);
53
+ void imgLossBwdKernel(LossKernelParams p);
54
+
55
+ //------------------------------------------------------------------------
56
+ // normal.cu
57
+
58
+ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p);
59
+ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p);
60
+
61
+ //------------------------------------------------------------------------
62
+ // cubemap.cu
63
+
64
+ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p);
65
+ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p);
66
+ void SpecularBoundsKernel(SpecularBoundsKernelParams p);
67
+ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p);
68
+ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p);
69
+
70
+ //------------------------------------------------------------------------
71
+ // bsdf.cu
72
+
73
+ void LambertFwdKernel(LambertKernelParams p);
74
+ void LambertBwdKernel(LambertKernelParams p);
75
+
76
+ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p);
77
+ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p);
78
+
79
+ void FresnelShlickFwdKernel(FresnelShlickKernelParams p);
80
+ void FresnelShlickBwdKernel(FresnelShlickKernelParams p);
81
+
82
+ void ndfGGXFwdKernel(NdfGGXParams p);
83
+ void ndfGGXBwdKernel(NdfGGXParams p);
84
+
85
+ void lambdaGGXFwdKernel(NdfGGXParams p);
86
+ void lambdaGGXBwdKernel(NdfGGXParams p);
87
+
88
+ void maskingSmithFwdKernel(MaskingSmithParams p);
89
+ void maskingSmithBwdKernel(MaskingSmithParams p);
90
+
91
+ void pbrSpecularFwdKernel(PbrSpecular p);
92
+ void pbrSpecularBwdKernel(PbrSpecular p);
93
+
94
+ void pbrBSDFFwdKernel(PbrBSDF p);
95
+ void pbrBSDFBwdKernel(PbrBSDF p);
96
+
97
+ //------------------------------------------------------------------------
98
+ // Tensor helpers
99
+
100
+ void update_grid(dim3 &gridSize, torch::Tensor x)
101
+ {
102
+ gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2));
103
+ gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1));
104
+ gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0));
105
+ }
106
+
107
+ template<typename... Ts>
108
+ void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs)
109
+ {
110
+ gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2));
111
+ gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1));
112
+ gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0));
113
+ update_grid(gridSize, std::forward<Ts>(vs)...);
114
+ }
115
+
116
+ Tensor make_cuda_tensor(torch::Tensor val)
117
+ {
118
+ Tensor res;
119
+ for (int i = 0; i < val.dim(); ++i)
120
+ {
121
+ res.dims[i] = val.size(i);
122
+ res.strides[i] = val.stride(i);
123
+ }
124
+ res.fp16 = val.scalar_type() == torch::kBFloat16;
125
+ res.val = res.fp16 ? (void*)val.data_ptr<torch::BFloat16>() : (void*)val.data_ptr<float>();
126
+ res.d_val = nullptr;
127
+ return res;
128
+ }
129
+
130
+ Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* grad = nullptr)
131
+ {
132
+ Tensor res;
133
+ for (int i = 0; i < val.dim(); ++i)
134
+ {
135
+ res.dims[i] = val.size(i);
136
+ res.strides[i] = val.stride(i);
137
+ }
138
+ if (val.dim() == 4)
139
+ res._dims[0] = outDims.z, res._dims[1] = outDims.y, res._dims[2] = outDims.x, res._dims[3] = val.size(3);
140
+ else
141
+ res._dims[0] = outDims.z, res._dims[1] = outDims.x, res._dims[2] = val.size(2), res._dims[3] = 1; // Add a trailing one for indexing math to work out
142
+
143
+ res.fp16 = val.scalar_type() == torch::kBFloat16;
144
+ res.val = res.fp16 ? (void*)val.data_ptr<torch::BFloat16>() : (void*)val.data_ptr<float>();
145
+ res.d_val = nullptr;
146
+ if (grad != nullptr)
147
+ {
148
+ if (val.dim() == 4)
149
+ *grad = torch::empty({ outDims.z, outDims.y, outDims.x, val.size(3) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA));
150
+ else // 3
151
+ *grad = torch::empty({ outDims.z, outDims.x, val.size(2) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA));
152
+
153
+ res.d_val = res.fp16 ? (void*)grad->data_ptr<torch::BFloat16>() : (void*)grad->data_ptr<float>();
154
+ }
155
+ return res;
156
+ }
157
+
158
+ //------------------------------------------------------------------------
159
+ // prepare_shading_normal
160
+
161
+ torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, bool two_sided_shading, bool opengl, bool fp16)
162
+ {
163
+ CHECK_TENSOR(pos, 4, 3);
164
+ CHECK_TENSOR(view_pos, 4, 3);
165
+ CHECK_TENSOR(perturbed_nrm, 4, 3);
166
+ CHECK_TENSOR(smooth_nrm, 4, 3);
167
+ CHECK_TENSOR(smooth_tng, 4, 3);
168
+ CHECK_TENSOR(geom_nrm, 4, 3);
169
+
170
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
171
+
172
+ // Extract input parameters.
173
+ PrepareShadingNormalKernelParams p;
174
+ p.two_sided_shading = two_sided_shading;
175
+ p.opengl = opengl;
176
+ p.out.fp16 = fp16;
177
+ update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm);
178
+
179
+ // Allocate output tensors.
180
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
181
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
182
+
183
+ // Choose launch parameters.
184
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
185
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
186
+
187
+ // Setup tensors
188
+ p.pos = make_cuda_tensor(pos, p.gridSize);
189
+ p.view_pos = make_cuda_tensor(view_pos, p.gridSize);
190
+ p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize);
191
+ p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize);
192
+ p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize);
193
+ p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize);
194
+ p.out = make_cuda_tensor(out, p.gridSize);
195
+
196
+ // Launch CUDA kernel.
197
+ void* args[] = { &p };
198
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalFwdKernel, gridSize, blockSize, args, 0, stream));
199
+
200
+ return out;
201
+ }
202
+
203
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> prepare_shading_normal_bwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, torch::Tensor grad, bool two_sided_shading, bool opengl)
204
+ {
205
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
206
+
207
+ // Extract input parameters.
208
+ PrepareShadingNormalKernelParams p;
209
+ p.two_sided_shading = two_sided_shading;
210
+ p.opengl = opengl;
211
+ update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm);
212
+
213
+ // Choose launch parameters.
214
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
215
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
216
+
217
+ // Setup tensors
218
+ torch::Tensor pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad;
219
+ p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad);
220
+ p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad);
221
+ p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize, &perturbed_nrm_grad);
222
+ p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize, &smooth_nrm_grad);
223
+ p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize, &smooth_tng_grad);
224
+ p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize, &geom_nrm_grad);
225
+ p.out = make_cuda_tensor(grad, p.gridSize);
226
+
227
+ // Launch CUDA kernel.
228
+ void* args[] = { &p };
229
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalBwdKernel, gridSize, blockSize, args, 0, stream));
230
+
231
+ return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad);
232
+ }
233
+
234
+ //------------------------------------------------------------------------
235
+ // lambert
236
+
237
+ torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16)
238
+ {
239
+ CHECK_TENSOR(nrm, 4, 3);
240
+ CHECK_TENSOR(wi, 4, 3);
241
+
242
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
243
+
244
+ // Extract input parameters.
245
+ LambertKernelParams p;
246
+ p.out.fp16 = fp16;
247
+ update_grid(p.gridSize, nrm, wi);
248
+
249
+ // Allocate output tensors.
250
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
251
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
252
+
253
+ // Choose launch parameters.
254
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
255
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
256
+
257
+ p.nrm = make_cuda_tensor(nrm, p.gridSize);
258
+ p.wi = make_cuda_tensor(wi, p.gridSize);
259
+ p.out = make_cuda_tensor(out, p.gridSize);
260
+
261
+ // Launch CUDA kernel.
262
+ void* args[] = { &p };
263
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertFwdKernel, gridSize, blockSize, args, 0, stream));
264
+
265
+ return out;
266
+ }
267
+
268
+ std::tuple<torch::Tensor, torch::Tensor> lambert_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor grad)
269
+ {
270
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
271
+
272
+ // Extract input parameters.
273
+ LambertKernelParams p;
274
+ update_grid(p.gridSize, nrm, wi);
275
+
276
+ // Choose launch parameters.
277
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
278
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
279
+
280
+ torch::Tensor nrm_grad, wi_grad;
281
+ p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
282
+ p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);
283
+ p.out = make_cuda_tensor(grad, p.gridSize);
284
+
285
+ // Launch CUDA kernel.
286
+ void* args[] = { &p };
287
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertBwdKernel, gridSize, blockSize, args, 0, stream));
288
+
289
+ return std::tuple<torch::Tensor, torch::Tensor>(nrm_grad, wi_grad);
290
+ }
291
+
292
+ //------------------------------------------------------------------------
293
+ // frostbite diffuse
294
+
295
+ torch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, bool fp16)
296
+ {
297
+ CHECK_TENSOR(nrm, 4, 3);
298
+ CHECK_TENSOR(wi, 4, 3);
299
+ CHECK_TENSOR(wo, 4, 3);
300
+ CHECK_TENSOR(linearRoughness, 4, 1);
301
+
302
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
303
+
304
+ // Extract input parameters.
305
+ FrostbiteDiffuseKernelParams p;
306
+ p.out.fp16 = fp16;
307
+ update_grid(p.gridSize, nrm, wi, wo, linearRoughness);
308
+
309
+ // Allocate output tensors.
310
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
311
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
312
+
313
+ // Choose launch parameters.
314
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
315
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
316
+
317
+ p.nrm = make_cuda_tensor(nrm, p.gridSize);
318
+ p.wi = make_cuda_tensor(wi, p.gridSize);
319
+ p.wo = make_cuda_tensor(wo, p.gridSize);
320
+ p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize);
321
+ p.out = make_cuda_tensor(out, p.gridSize);
322
+
323
+ // Launch CUDA kernel.
324
+ void* args[] = { &p };
325
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseFwdKernel, gridSize, blockSize, args, 0, stream));
326
+
327
+ return out;
328
+ }
329
+
330
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> frostbite_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, torch::Tensor grad)
331
+ {
332
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
333
+
334
+ // Extract input parameters.
335
+ FrostbiteDiffuseKernelParams p;
336
+ update_grid(p.gridSize, nrm, wi, wo, linearRoughness);
337
+
338
+ // Choose launch parameters.
339
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
340
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
341
+
342
+ torch::Tensor nrm_grad, wi_grad, wo_grad, linearRoughness_grad;
343
+ p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
344
+ p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);
345
+ p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad);
346
+ p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize, &linearRoughness_grad);
347
+ p.out = make_cuda_tensor(grad, p.gridSize);
348
+
349
+ // Launch CUDA kernel.
350
+ void* args[] = { &p };
351
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseBwdKernel, gridSize, blockSize, args, 0, stream));
352
+
353
+ return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(nrm_grad, wi_grad, wo_grad, linearRoughness_grad);
354
+ }
355
+
356
+ //------------------------------------------------------------------------
357
+ // fresnel_shlick
358
+
359
+ torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, bool fp16)
360
+ {
361
+ CHECK_TENSOR(f0, 4, 3);
362
+ CHECK_TENSOR(f90, 4, 3);
363
+ CHECK_TENSOR(cosTheta, 4, 1);
364
+
365
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
366
+
367
+ // Extract input parameters.
368
+ FresnelShlickKernelParams p;
369
+ p.out.fp16 = fp16;
370
+ update_grid(p.gridSize, f0, f90, cosTheta);
371
+
372
+ // Allocate output tensors.
373
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
374
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
375
+
376
+ // Choose launch parameters.
377
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
378
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
379
+
380
+ p.f0 = make_cuda_tensor(f0, p.gridSize);
381
+ p.f90 = make_cuda_tensor(f90, p.gridSize);
382
+ p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);
383
+ p.out = make_cuda_tensor(out, p.gridSize);
384
+
385
+ // Launch CUDA kernel.
386
+ void* args[] = { &p };
387
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickFwdKernel, gridSize, blockSize, args, 0, stream));
388
+
389
+ return out;
390
+ }
391
+
392
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fresnel_shlick_bwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, torch::Tensor grad)
393
+ {
394
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
395
+
396
+ // Extract input parameters.
397
+ FresnelShlickKernelParams p;
398
+ update_grid(p.gridSize, f0, f90, cosTheta);
399
+
400
+ // Choose launch parameters.
401
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
402
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
403
+
404
+ torch::Tensor f0_grad, f90_grad, cosT_grad;
405
+ p.f0 = make_cuda_tensor(f0, p.gridSize, &f0_grad);
406
+ p.f90 = make_cuda_tensor(f90, p.gridSize, &f90_grad);
407
+ p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosT_grad);
408
+ p.out = make_cuda_tensor(grad, p.gridSize);
409
+
410
+ // Launch CUDA kernel.
411
+ void* args[] = { &p };
412
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickBwdKernel, gridSize, blockSize, args, 0, stream));
413
+
414
+ return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>(f0_grad, f90_grad, cosT_grad);
415
+ }
416
+
417
+ //------------------------------------------------------------------------
418
+ // ndf_ggd
419
+
420
+ torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16)
421
+ {
422
+ CHECK_TENSOR(alphaSqr, 4, 1);
423
+ CHECK_TENSOR(cosTheta, 4, 1);
424
+
425
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
426
+
427
+ // Extract input parameters.
428
+ NdfGGXParams p;
429
+ p.out.fp16 = fp16;
430
+ update_grid(p.gridSize, alphaSqr, cosTheta);
431
+
432
+ // Allocate output tensors.
433
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
434
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
435
+
436
+ // Choose launch parameters.
437
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
438
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
439
+
440
+ p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);
441
+ p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);
442
+ p.out = make_cuda_tensor(out, p.gridSize);
443
+
444
+ // Launch CUDA kernel.
445
+ void* args[] = { &p };
446
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXFwdKernel, gridSize, blockSize, args, 0, stream));
447
+
448
+ return out;
449
+ }
450
+
451
+ std::tuple<torch::Tensor, torch::Tensor> ndf_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad)
452
+ {
453
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
454
+
455
+ // Extract input parameters.
456
+ NdfGGXParams p;
457
+ update_grid(p.gridSize, alphaSqr, cosTheta);
458
+
459
+ // Choose launch parameters.
460
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
461
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
462
+
463
+ torch::Tensor alphaSqr_grad, cosTheta_grad;
464
+ p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);
465
+ p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad);
466
+ p.out = make_cuda_tensor(grad, p.gridSize);
467
+
468
+ // Launch CUDA kernel.
469
+ void* args[] = { &p };
470
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXBwdKernel, gridSize, blockSize, args, 0, stream));
471
+
472
+ return std::tuple<torch::Tensor, torch::Tensor>(alphaSqr_grad, cosTheta_grad);
473
+ }
474
+
475
+ //------------------------------------------------------------------------
476
+ // lambda_ggx
477
+
478
+ torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16)
479
+ {
480
+ CHECK_TENSOR(alphaSqr, 4, 1);
481
+ CHECK_TENSOR(cosTheta, 4, 1);
482
+
483
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
484
+
485
+ // Extract input parameters.
486
+ NdfGGXParams p;
487
+ p.out.fp16 = fp16;
488
+ update_grid(p.gridSize, alphaSqr, cosTheta);
489
+
490
+ // Allocate output tensors.
491
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
492
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
493
+
494
+ // Choose launch parameters.
495
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
496
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
497
+
498
+ p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);
499
+ p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);
500
+ p.out = make_cuda_tensor(out, p.gridSize);
501
+
502
+ // Launch CUDA kernel.
503
+ void* args[] = { &p };
504
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXFwdKernel, gridSize, blockSize, args, 0, stream));
505
+
506
+ return out;
507
+ }
508
+
509
+ std::tuple<torch::Tensor, torch::Tensor> lambda_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad)
510
+ {
511
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
512
+
513
+ // Extract input parameters.
514
+ NdfGGXParams p;
515
+ update_grid(p.gridSize, alphaSqr, cosTheta);
516
+
517
+ // Choose launch parameters.
518
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
519
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
520
+
521
+ torch::Tensor alphaSqr_grad, cosTheta_grad;
522
+ p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);
523
+ p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad);
524
+ p.out = make_cuda_tensor(grad, p.gridSize);
525
+
526
+ // Launch CUDA kernel.
527
+ void* args[] = { &p };
528
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXBwdKernel, gridSize, blockSize, args, 0, stream));
529
+
530
+ return std::tuple<torch::Tensor, torch::Tensor>(alphaSqr_grad, cosTheta_grad);
531
+ }
532
+
533
+ //------------------------------------------------------------------------
534
+ // masking_smith
535
+
536
+ torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, bool fp16)
537
+ {
538
+ CHECK_TENSOR(alphaSqr, 4, 1);
539
+ CHECK_TENSOR(cosThetaI, 4, 1);
540
+ CHECK_TENSOR(cosThetaO, 4, 1);
541
+
542
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
543
+
544
+ // Extract input parameters.
545
+ MaskingSmithParams p;
546
+ p.out.fp16 = fp16;
547
+ update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO);
548
+
549
+ // Allocate output tensors.
550
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
551
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
552
+
553
+ // Choose launch parameters.
554
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
555
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
556
+
557
+ p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);
558
+ p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize);
559
+ p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize);
560
+ p.out = make_cuda_tensor(out, p.gridSize);
561
+
562
+ // Launch CUDA kernel.
563
+ void* args[] = { &p };
564
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithFwdKernel, gridSize, blockSize, args, 0, stream));
565
+
566
+ return out;
567
+ }
568
+
569
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> masking_smith_bwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, torch::Tensor grad)
570
+ {
571
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
572
+
573
+ // Extract input parameters.
574
+ MaskingSmithParams p;
575
+ update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO);
576
+
577
+ // Choose launch parameters.
578
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
579
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
580
+
581
+ torch::Tensor alphaSqr_grad, cosThetaI_grad, cosThetaO_grad;
582
+ p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);
583
+ p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize, &cosThetaI_grad);
584
+ p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize, &cosThetaO_grad);
585
+ p.out = make_cuda_tensor(grad, p.gridSize);
586
+
587
+ // Launch CUDA kernel.
588
+ void* args[] = { &p };
589
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithBwdKernel, gridSize, blockSize, args, 0, stream));
590
+
591
+ return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>(alphaSqr_grad, cosThetaI_grad, cosThetaO_grad);
592
+ }
593
+
594
+ //------------------------------------------------------------------------
595
+ // pbr_specular
596
+
597
+ torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, bool fp16)
598
+ {
599
+ CHECK_TENSOR(col, 4, 3);
600
+ CHECK_TENSOR(nrm, 4, 3);
601
+ CHECK_TENSOR(wo, 4, 3);
602
+ CHECK_TENSOR(wi, 4, 3);
603
+ CHECK_TENSOR(alpha, 4, 1);
604
+
605
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
606
+
607
+ // Extract input parameters.
608
+ PbrSpecular p;
609
+ p.out.fp16 = fp16;
610
+ p.min_roughness = min_roughness;
611
+ update_grid(p.gridSize, col, nrm, wo, wi, alpha);
612
+
613
+ // Allocate output tensors.
614
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
615
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
616
+
617
+ // Choose launch parameters.
618
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
619
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
620
+
621
+ p.col = make_cuda_tensor(col, p.gridSize);
622
+ p.nrm = make_cuda_tensor(nrm, p.gridSize);
623
+ p.wo = make_cuda_tensor(wo, p.gridSize);
624
+ p.wi = make_cuda_tensor(wi, p.gridSize);
625
+ p.alpha = make_cuda_tensor(alpha, p.gridSize);
626
+ p.out = make_cuda_tensor(out, p.gridSize);
627
+
628
+ // Launch CUDA kernel.
629
+ void* args[] = { &p };
630
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularFwdKernel, gridSize, blockSize, args, 0, stream));
631
+
632
+ return out;
633
+ }
634
+
635
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> pbr_specular_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, torch::Tensor grad)
636
+ {
637
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
638
+
639
+ // Extract input parameters.
640
+ PbrSpecular p;
641
+ update_grid(p.gridSize, col, nrm, wo, wi, alpha);
642
+ p.min_roughness = min_roughness;
643
+
644
+ // Choose launch parameters.
645
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
646
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
647
+
648
+ torch::Tensor col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad;
649
+ p.col = make_cuda_tensor(col, p.gridSize, &col_grad);
650
+ p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
651
+ p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad);
652
+ p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);
653
+ p.alpha = make_cuda_tensor(alpha, p.gridSize, &alpha_grad);
654
+ p.out = make_cuda_tensor(grad, p.gridSize);
655
+
656
+ // Launch CUDA kernel.
657
+ void* args[] = { &p };
658
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularBwdKernel, gridSize, blockSize, args, 0, stream));
659
+
660
+ return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad);
661
+ }
662
+
663
+ //------------------------------------------------------------------------
664
+ // pbr_bsdf
665
+
666
+ torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, bool fp16)
667
+ {
668
+ CHECK_TENSOR(kd, 4, 3);
669
+ CHECK_TENSOR(arm, 4, 3);
670
+ CHECK_TENSOR(pos, 4, 3);
671
+ CHECK_TENSOR(nrm, 4, 3);
672
+ CHECK_TENSOR(view_pos, 4, 3);
673
+ CHECK_TENSOR(light_pos, 4, 3);
674
+
675
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
676
+
677
+ // Extract input parameters.
678
+ PbrBSDF p;
679
+ p.out.fp16 = fp16;
680
+ p.min_roughness = min_roughness;
681
+ p.BSDF = BSDF;
682
+ update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos);
683
+
684
+ // Allocate output tensors.
685
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
686
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
687
+
688
+ // Choose launch parameters.
689
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
690
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
691
+
692
+ p.kd = make_cuda_tensor(kd, p.gridSize);
693
+ p.arm = make_cuda_tensor(arm, p.gridSize);
694
+ p.pos = make_cuda_tensor(pos, p.gridSize);
695
+ p.nrm = make_cuda_tensor(nrm, p.gridSize);
696
+ p.view_pos = make_cuda_tensor(view_pos, p.gridSize);
697
+ p.light_pos = make_cuda_tensor(light_pos, p.gridSize);
698
+ p.out = make_cuda_tensor(out, p.gridSize);
699
+
700
+ // Launch CUDA kernel.
701
+ void* args[] = { &p };
702
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFFwdKernel, gridSize, blockSize, args, 0, stream));
703
+
704
+ return out;
705
+ }
706
+
707
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> pbr_bsdf_bwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, torch::Tensor grad)
708
+ {
709
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
710
+
711
+ // Extract input parameters.
712
+ PbrBSDF p;
713
+ update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos);
714
+ p.min_roughness = min_roughness;
715
+ p.BSDF = BSDF;
716
+
717
+ // Choose launch parameters.
718
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
719
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
720
+
721
+ torch::Tensor kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad;
722
+ p.kd = make_cuda_tensor(kd, p.gridSize, &kd_grad);
723
+ p.arm = make_cuda_tensor(arm, p.gridSize, &arm_grad);
724
+ p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad);
725
+ p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
726
+ p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad);
727
+ p.light_pos = make_cuda_tensor(light_pos, p.gridSize, &light_pos_grad);
728
+ p.out = make_cuda_tensor(grad, p.gridSize);
729
+
730
+ // Launch CUDA kernel.
731
+ void* args[] = { &p };
732
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFBwdKernel, gridSize, blockSize, args, 0, stream));
733
+
734
+ return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad);
735
+ }
736
+
737
+ //------------------------------------------------------------------------
738
+ // filter_cubemap
739
+
740
+ torch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap)
741
+ {
742
+ CHECK_TENSOR(cubemap, 4, 3);
743
+
744
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
745
+
746
+ // Extract input parameters.
747
+ DiffuseCubemapKernelParams p;
748
+ update_grid(p.gridSize, cubemap);
749
+
750
+ // Allocate output tensors.
751
+ torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
752
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
753
+
754
+ // Choose launch parameters.
755
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
756
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
757
+
758
+ // Setup tensors
759
+ p.cubemap = make_cuda_tensor(cubemap, p.gridSize);
760
+ p.out = make_cuda_tensor(out, p.gridSize);
761
+
762
+ // Launch CUDA kernel.
763
+ void* args[] = { &p };
764
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapFwdKernel, gridSize, blockSize, args, 0, stream));
765
+
766
+ return out;
767
+ }
768
+
769
+ torch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor grad)
770
+ {
771
+ CHECK_TENSOR(cubemap, 4, 3);
772
+ CHECK_TENSOR(grad, 4, 3);
773
+
774
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
775
+
776
+ // Extract input parameters.
777
+ DiffuseCubemapKernelParams p;
778
+ update_grid(p.gridSize, cubemap);
779
+
780
+ // Choose launch parameters.
781
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
782
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
783
+
784
+ // Setup tensors
785
+ torch::Tensor cubemap_grad;
786
+ p.cubemap = make_cuda_tensor(cubemap, p.gridSize);
787
+ p.out = make_cuda_tensor(grad, p.gridSize);
788
+
789
+ cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
790
+ p.cubemap.d_val = (void*)cubemap_grad.data_ptr<float>();
791
+
792
+ // Launch CUDA kernel.
793
+ void* args[] = { &p };
794
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapBwdKernel, gridSize, blockSize, args, 0, stream));
795
+
796
+ return cubemap_grad;
797
+ }
798
+
799
+ torch::Tensor specular_bounds(int resolution, float costheta_cutoff)
800
+ {
801
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
802
+
803
+ // Extract input parameters.
804
+ SpecularBoundsKernelParams p;
805
+ p.costheta_cutoff = costheta_cutoff;
806
+ p.gridSize = dim3(resolution, resolution, 6);
807
+
808
+ // Allocate output tensors.
809
+ torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
810
+ torch::Tensor out = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 6*4 }, opts);
811
+
812
+ // Choose launch parameters.
813
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
814
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
815
+
816
+ // Setup tensors
817
+ p.out = make_cuda_tensor(out, p.gridSize);
818
+
819
+ // Launch CUDA kernel.
820
+ void* args[] = { &p };
821
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularBoundsKernel, gridSize, blockSize, args, 0, stream));
822
+
823
+ return out;
824
+ }
825
+
826
+ torch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor bounds, float roughness, float costheta_cutoff)
827
+ {
828
+ CHECK_TENSOR(cubemap, 4, 3);
829
+ CHECK_TENSOR(bounds, 4, 6*4);
830
+
831
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
832
+
833
+ // Extract input parameters.
834
+ SpecularCubemapKernelParams p;
835
+ p.roughness = roughness;
836
+ p.costheta_cutoff = costheta_cutoff;
837
+ update_grid(p.gridSize, cubemap);
838
+
839
+ // Allocate output tensors.
840
+ torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
841
+ torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 4 }, opts);
842
+
843
+ // Choose launch parameters.
844
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
845
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
846
+
847
+ // Setup tensors
848
+ p.cubemap = make_cuda_tensor(cubemap, p.gridSize);
849
+ p.bounds = make_cuda_tensor(bounds, p.gridSize);
850
+ p.out = make_cuda_tensor(out, p.gridSize);
851
+
852
+ // Launch CUDA kernel.
853
+ void* args[] = { &p };
854
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapFwdKernel, gridSize, blockSize, args, 0, stream));
855
+
856
+ return out;
857
+ }
858
+
859
+ torch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor bounds, torch::Tensor grad, float roughness, float costheta_cutoff)
860
+ {
861
+ CHECK_TENSOR(cubemap, 4, 3);
862
+ CHECK_TENSOR(bounds, 4, 6*4);
863
+
864
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
865
+
866
+ // Extract input parameters.
867
+ SpecularCubemapKernelParams p;
868
+ p.roughness = roughness;
869
+ p.costheta_cutoff = costheta_cutoff;
870
+ update_grid(p.gridSize, cubemap);
871
+
872
+ // Choose launch parameters.
873
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
874
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
875
+
876
+ // Setup tensors
877
+ torch::Tensor cubemap_grad;
878
+ p.cubemap = make_cuda_tensor(cubemap, p.gridSize);
879
+ p.bounds = make_cuda_tensor(bounds, p.gridSize);
880
+ p.out = make_cuda_tensor(grad, p.gridSize);
881
+
882
+ cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
883
+ p.cubemap.d_val = (void*)cubemap_grad.data_ptr<float>();
884
+
885
+ // Launch CUDA kernel.
886
+ void* args[] = { &p };
887
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapBwdKernel, gridSize, blockSize, args, 0, stream));
888
+
889
+ return cubemap_grad;
890
+ }
891
+
892
+ //------------------------------------------------------------------------
893
+ // loss function
894
+
895
+ LossType strToLoss(std::string str)
896
+ {
897
+ if (str == "mse")
898
+ return LOSS_MSE;
899
+ else if (str == "relmse")
900
+ return LOSS_RELMSE;
901
+ else if (str == "smape")
902
+ return LOSS_SMAPE;
903
+ else
904
+ return LOSS_L1;
905
+ }
906
+
907
+ torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, std::string loss, std::string tonemapper, bool fp16)
908
+ {
909
+ CHECK_TENSOR(img, 4, 3);
910
+ CHECK_TENSOR(target, 4, 3);
911
+
912
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
913
+
914
+ // Extract input parameters.
915
+ LossKernelParams p;
916
+ p.out.fp16 = fp16;
917
+ p.loss = strToLoss(loss);
918
+ p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE;
919
+ update_grid(p.gridSize, img, target);
920
+
921
+ // Choose launch parameters.
922
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
923
+ dim3 warpSize = getWarpSize(blockSize);
924
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
925
+
926
+ // Allocate output tensors.
927
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
928
+ torch::Tensor out = torch::empty({ (p.gridSize.z - 1)/ warpSize.z + 1, (p.gridSize.y - 1) / warpSize.y + 1, (p.gridSize.x - 1) / warpSize.x + 1, 1 }, opts);
929
+
930
+ p.img = make_cuda_tensor(img, p.gridSize);
931
+ p.target = make_cuda_tensor(target, p.gridSize);
932
+ p.out = make_cuda_tensor(out, p.gridSize);
933
+
934
+ // Launch CUDA kernel.
935
+ void* args[] = { &p };
936
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossFwdKernel, gridSize, blockSize, args, 0, stream));
937
+
938
+ return out;
939
+ }
940
+
941
+ std::tuple<torch::Tensor, torch::Tensor> image_loss_bwd(torch::Tensor img, torch::Tensor target, torch::Tensor grad, std::string loss, std::string tonemapper)
942
+ {
943
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
944
+
945
+ // Extract input parameters.
946
+ LossKernelParams p;
947
+ p.loss = strToLoss(loss);
948
+ p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE;
949
+ update_grid(p.gridSize, img, target);
950
+
951
+ // Choose launch parameters.
952
+ dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
953
+ dim3 warpSize = getWarpSize(blockSize);
954
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
955
+
956
+ torch::Tensor img_grad, target_grad;
957
+ p.img = make_cuda_tensor(img, p.gridSize, &img_grad);
958
+ p.target = make_cuda_tensor(target, p.gridSize, &target_grad);
959
+ p.out = make_cuda_tensor(grad, p.gridSize);
960
+
961
+ // Launch CUDA kernel.
962
+ void* args[] = { &p };
963
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossBwdKernel, gridSize, blockSize, args, 0, stream));
964
+
965
+ return std::tuple<torch::Tensor, torch::Tensor>(img_grad, target_grad);
966
+ }
967
+
968
+ //------------------------------------------------------------------------
969
+ // transform function
970
+
971
+ torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool isPoints, bool fp16)
972
+ {
973
+ CHECK_TENSOR(points, 3, 3);
974
+ CHECK_TENSOR(matrix, 3, 4);
975
+
976
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
977
+
978
+ // Extract input parameters.
979
+ XfmKernelParams p;
980
+ p.out.fp16 = fp16;
981
+ p.isPoints = isPoints;
982
+ p.gridSize.x = points.size(1);
983
+ p.gridSize.y = 1;
984
+ p.gridSize.z = std::max(matrix.size(0), points.size(0));
985
+
986
+ // Choose launch parameters.
987
+ dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1);
988
+ dim3 warpSize = getWarpSize(blockSize);
989
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
990
+
991
+ // Allocate output tensors.
992
+ torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
993
+ torch::Tensor out = isPoints ? torch::empty({ matrix.size(0), points.size(1), 4 }, opts) : torch::empty({ matrix.size(0), points.size(1), 3 }, opts);
994
+
995
+ p.points = make_cuda_tensor(points, p.gridSize);
996
+ p.matrix = make_cuda_tensor(matrix, p.gridSize);
997
+ p.out = make_cuda_tensor(out, p.gridSize);
998
+
999
+ // Launch CUDA kernel.
1000
+ void* args[] = { &p };
1001
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsFwdKernel, gridSize, blockSize, args, 0, stream));
1002
+
1003
+ return out;
1004
+ }
1005
+
1006
+ torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch::Tensor grad, bool isPoints)
1007
+ {
1008
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1009
+
1010
+ // Extract input parameters.
1011
+ XfmKernelParams p;
1012
+ p.isPoints = isPoints;
1013
+ p.gridSize.x = points.size(1);
1014
+ p.gridSize.y = 1;
1015
+ p.gridSize.z = std::max(matrix.size(0), points.size(0));
1016
+
1017
+ // Choose launch parameters.
1018
+ dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1);
1019
+ dim3 warpSize = getWarpSize(blockSize);
1020
+ dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
1021
+
1022
+ torch::Tensor points_grad;
1023
+ p.points = make_cuda_tensor(points, p.gridSize, &points_grad);
1024
+ p.matrix = make_cuda_tensor(matrix, p.gridSize);
1025
+ p.out = make_cuda_tensor(grad, p.gridSize);
1026
+
1027
+ // Launch CUDA kernel.
1028
+ void* args[] = { &p };
1029
+ NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsBwdKernel, gridSize, blockSize, args, 0, stream));
1030
+
1031
+ return points_grad;
1032
+ }
1033
+
1034
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1035
+ m.def("prepare_shading_normal_fwd", &prepare_shading_normal_fwd, "prepare_shading_normal_fwd");
1036
+ m.def("prepare_shading_normal_bwd", &prepare_shading_normal_bwd, "prepare_shading_normal_bwd");
1037
+ m.def("lambert_fwd", &lambert_fwd, "lambert_fwd");
1038
+ m.def("lambert_bwd", &lambert_bwd, "lambert_bwd");
1039
+ m.def("frostbite_fwd", &frostbite_fwd, "frostbite_fwd");
1040
+ m.def("frostbite_bwd", &frostbite_bwd, "frostbite_bwd");
1041
+ m.def("fresnel_shlick_fwd", &fresnel_shlick_fwd, "fresnel_shlick_fwd");
1042
+ m.def("fresnel_shlick_bwd", &fresnel_shlick_bwd, "fresnel_shlick_bwd");
1043
+ m.def("ndf_ggx_fwd", &ndf_ggx_fwd, "ndf_ggx_fwd");
1044
+ m.def("ndf_ggx_bwd", &ndf_ggx_bwd, "ndf_ggx_bwd");
1045
+ m.def("lambda_ggx_fwd", &lambda_ggx_fwd, "lambda_ggx_fwd");
1046
+ m.def("lambda_ggx_bwd", &lambda_ggx_bwd, "lambda_ggx_bwd");
1047
+ m.def("masking_smith_fwd", &masking_smith_fwd, "masking_smith_fwd");
1048
+ m.def("masking_smith_bwd", &masking_smith_bwd, "masking_smith_bwd");
1049
+ m.def("pbr_specular_fwd", &pbr_specular_fwd, "pbr_specular_fwd");
1050
+ m.def("pbr_specular_bwd", &pbr_specular_bwd, "pbr_specular_bwd");
1051
+ m.def("pbr_bsdf_fwd", &pbr_bsdf_fwd, "pbr_bsdf_fwd");
1052
+ m.def("pbr_bsdf_bwd", &pbr_bsdf_bwd, "pbr_bsdf_bwd");
1053
+ m.def("diffuse_cubemap_fwd", &diffuse_cubemap_fwd, "diffuse_cubemap_fwd");
1054
+ m.def("diffuse_cubemap_bwd", &diffuse_cubemap_bwd, "diffuse_cubemap_bwd");
1055
+ m.def("specular_bounds", &specular_bounds, "specular_bounds");
1056
+ m.def("specular_cubemap_fwd", &specular_cubemap_fwd, "specular_cubemap_fwd");
1057
+ m.def("specular_cubemap_bwd", &specular_cubemap_bwd, "specular_cubemap_bwd");
1058
+ m.def("image_loss_fwd", &image_loss_fwd, "image_loss_fwd");
1059
+ m.def("image_loss_bwd", &image_loss_bwd, "image_loss_bwd");
1060
+ m.def("xfm_fwd", &xfm_fwd, "xfm_fwd");
1061
+ m.def("xfm_bwd", &xfm_bwd, "xfm_bwd");
1062
+ }
video3d/render/renderutils/c_src/vec3f.h ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+
14
+ struct vec3f
15
+ {
16
+ float x, y, z;
17
+
18
+ #ifdef __CUDACC__
19
+ __device__ vec3f() { }
20
+ __device__ vec3f(float v) { x = v; y = v; z = v; }
21
+ __device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; }
22
+ __device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; }
23
+
24
+ __device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; }
25
+ __device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; }
26
+ __device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; }
27
+ __device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; }
28
+ #endif
29
+ };
30
+
31
+ #ifdef __CUDACC__
32
+ __device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); }
33
+ __device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); }
34
+ __device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); }
35
+ __device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); }
36
+ __device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); }
37
+
38
+ __device__ static inline float sum(vec3f a)
39
+ {
40
+ return a.x + a.y + a.z;
41
+ }
42
+
43
+ __device__ static inline vec3f cross(vec3f a, vec3f b)
44
+ {
45
+ vec3f out;
46
+ out.x = a.y * b.z - a.z * b.y;
47
+ out.y = a.z * b.x - a.x * b.z;
48
+ out.z = a.x * b.y - a.y * b.x;
49
+ return out;
50
+ }
51
+
52
+ __device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out)
53
+ {
54
+ d_a.x += d_out.z * b.y - d_out.y * b.z;
55
+ d_a.y += d_out.x * b.z - d_out.z * b.x;
56
+ d_a.z += d_out.y * b.x - d_out.x * b.y;
57
+
58
+ d_b.x += d_out.y * a.z - d_out.z * a.y;
59
+ d_b.y += d_out.z * a.x - d_out.x * a.z;
60
+ d_b.z += d_out.x * a.y - d_out.y * a.x;
61
+ }
62
+
63
+ __device__ static inline float dot(vec3f a, vec3f b)
64
+ {
65
+ return a.x * b.x + a.y * b.y + a.z * b.z;
66
+ }
67
+
68
+ __device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out)
69
+ {
70
+ d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z;
71
+ d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z;
72
+ }
73
+
74
+ __device__ static inline vec3f reflect(vec3f x, vec3f n)
75
+ {
76
+ return n * 2.0f * dot(n, x) - x;
77
+ }
78
+
79
+ __device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out)
80
+ {
81
+ d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z);
82
+ d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z);
83
+ d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1);
84
+
85
+ d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x);
86
+ d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y);
87
+ d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z));
88
+ }
89
+
90
+ __device__ static inline vec3f safeNormalize(vec3f v)
91
+ {
92
+ float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
93
+ return l > 0.0f ? (v / l) : vec3f(0.0f);
94
+ }
95
+
96
+ __device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out)
97
+ {
98
+
99
+ float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
100
+ if (l > 0.0f)
101
+ {
102
+ float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f);
103
+ d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac;
104
+ d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac;
105
+ d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac;
106
+ }
107
+ }
108
+
109
+ #endif
video3d/render/renderutils/c_src/vec4f.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+
14
+ struct vec4f
15
+ {
16
+ float x, y, z, w;
17
+
18
+ #ifdef __CUDACC__
19
+ __device__ vec4f() { }
20
+ __device__ vec4f(float v) { x = v; y = v; z = v; w = v; }
21
+ __device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; }
22
+ __device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; }
23
+ #endif
24
+ };
25
+
video3d/render/renderutils/loss.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import torch
11
+
12
+ #----------------------------------------------------------------------------
13
+ # HDR image losses
14
+ #----------------------------------------------------------------------------
15
+
16
+ def _tonemap_srgb(f):
17
+ return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
18
+
19
+ def _SMAPE(img, target, eps=0.01):
20
+ nom = torch.abs(img - target)
21
+ denom = torch.abs(img) + torch.abs(target) + 0.01
22
+ return torch.mean(nom / denom)
23
+
24
+ def _RELMSE(img, target, eps=0.1):
25
+ nom = (img - target) * (img - target)
26
+ denom = img * img + target * target + 0.1
27
+ return torch.mean(nom / denom)
28
+
29
+ def image_loss_fn(img, target, loss, tonemapper):
30
+ if tonemapper == 'log_srgb':
31
+ img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1))
32
+ target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1))
33
+
34
+ if loss == 'mse':
35
+ return torch.nn.functional.mse_loss(img, target)
36
+ elif loss == 'smape':
37
+ return _SMAPE(img, target)
38
+ elif loss == 'relmse':
39
+ return _RELMSE(img, target)
40
+ else:
41
+ return torch.nn.functional.l1_loss(img, target)
video3d/render/renderutils/ops.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import numpy as np
11
+ import os
12
+ import sys
13
+ import torch
14
+ import torch.utils.cpp_extension
15
+
16
+ from .bsdf import *
17
+ from .loss import *
18
+
19
+ #----------------------------------------------------------------------------
20
+ # C++/Cuda plugin compiler/loader.
21
+
22
+ _cached_plugin = None
23
+ def _get_plugin():
24
+ # Return cached plugin if already loaded.
25
+ global _cached_plugin
26
+ if _cached_plugin is not None:
27
+ return _cached_plugin
28
+
29
+ # Make sure we can find the necessary compiler and libary binaries.
30
+ if os.name == 'nt':
31
+ def find_cl_path():
32
+ import glob
33
+ for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']:
34
+ paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True)
35
+ if paths:
36
+ return paths[0]
37
+
38
+ # If cl.exe is not on path, try to find it.
39
+ if os.system("where cl.exe >nul 2>nul") != 0:
40
+ cl_path = find_cl_path()
41
+ if cl_path is None:
42
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
43
+ os.environ['PATH'] += ';' + cl_path
44
+
45
+ # Compiler options.
46
+ opts = ['-DNVDR_TORCH']
47
+
48
+ # Linker options.
49
+ if os.name == 'posix':
50
+ ldflags = ['-lcuda', '-lnvrtc']
51
+ elif os.name == 'nt':
52
+ ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib']
53
+
54
+ # List of sources.
55
+ source_files = [
56
+ 'c_src/mesh.cu',
57
+ 'c_src/loss.cu',
58
+ 'c_src/bsdf.cu',
59
+ 'c_src/normal.cu',
60
+ 'c_src/cubemap.cu',
61
+ 'c_src/common.cpp',
62
+ 'c_src/torch_bindings.cpp'
63
+ ]
64
+
65
+ # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.
66
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
67
+
68
+ # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment.
69
+ try:
70
+ lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock')
71
+ if os.path.exists(lock_fn):
72
+ print("Warning: Lock file exists in build directory: '%s'" % lock_fn)
73
+ except:
74
+ pass
75
+
76
+ # Compile and load.
77
+ source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]
78
+ torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts,
79
+ extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True)
80
+
81
+ # Import, cache, and return the compiled module.
82
+ import renderutils_plugin
83
+ _cached_plugin = renderutils_plugin
84
+ return _cached_plugin
85
+
86
+ #----------------------------------------------------------------------------
87
+ # Internal kernels, just used for testing functionality
88
+
89
+ class _fresnel_shlick_func(torch.autograd.Function):
90
+ @staticmethod
91
+ def forward(ctx, f0, f90, cosTheta):
92
+ out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False)
93
+ ctx.save_for_backward(f0, f90, cosTheta)
94
+ return out
95
+
96
+ @staticmethod
97
+ def backward(ctx, dout):
98
+ f0, f90, cosTheta = ctx.saved_variables
99
+ return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,)
100
+
101
+ def _fresnel_shlick(f0, f90, cosTheta, use_python=False):
102
+ if use_python:
103
+ out = bsdf_fresnel_shlick(f0, f90, cosTheta)
104
+ else:
105
+ out = _fresnel_shlick_func.apply(f0, f90, cosTheta)
106
+
107
+ if torch.is_anomaly_enabled():
108
+ assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN"
109
+ return out
110
+
111
+
112
+ class _ndf_ggx_func(torch.autograd.Function):
113
+ @staticmethod
114
+ def forward(ctx, alphaSqr, cosTheta):
115
+ out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False)
116
+ ctx.save_for_backward(alphaSqr, cosTheta)
117
+ return out
118
+
119
+ @staticmethod
120
+ def backward(ctx, dout):
121
+ alphaSqr, cosTheta = ctx.saved_variables
122
+ return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
123
+
124
+ def _ndf_ggx(alphaSqr, cosTheta, use_python=False):
125
+ if use_python:
126
+ out = bsdf_ndf_ggx(alphaSqr, cosTheta)
127
+ else:
128
+ out = _ndf_ggx_func.apply(alphaSqr, cosTheta)
129
+
130
+ if torch.is_anomaly_enabled():
131
+ assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN"
132
+ return out
133
+
134
+ class _lambda_ggx_func(torch.autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, alphaSqr, cosTheta):
137
+ out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False)
138
+ ctx.save_for_backward(alphaSqr, cosTheta)
139
+ return out
140
+
141
+ @staticmethod
142
+ def backward(ctx, dout):
143
+ alphaSqr, cosTheta = ctx.saved_variables
144
+ return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
145
+
146
+ def _lambda_ggx(alphaSqr, cosTheta, use_python=False):
147
+ if use_python:
148
+ out = bsdf_lambda_ggx(alphaSqr, cosTheta)
149
+ else:
150
+ out = _lambda_ggx_func.apply(alphaSqr, cosTheta)
151
+
152
+ if torch.is_anomaly_enabled():
153
+ assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN"
154
+ return out
155
+
156
+ class _masking_smith_func(torch.autograd.Function):
157
+ @staticmethod
158
+ def forward(ctx, alphaSqr, cosThetaI, cosThetaO):
159
+ ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO)
160
+ out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False)
161
+ return out
162
+
163
+ @staticmethod
164
+ def backward(ctx, dout):
165
+ alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables
166
+ return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,)
167
+
168
+ def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False):
169
+ if use_python:
170
+ out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO)
171
+ else:
172
+ out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO)
173
+
174
+ if torch.is_anomaly_enabled():
175
+ assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN"
176
+ return out
177
+
178
+ #----------------------------------------------------------------------------
179
+ # Shading normal setup (bump mapping + bent normals)
180
+
181
+ class _prepare_shading_normal_func(torch.autograd.Function):
182
+ @staticmethod
183
+ def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
184
+ ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl
185
+ out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False)
186
+ ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm)
187
+ return out
188
+
189
+ @staticmethod
190
+ def backward(ctx, dout):
191
+ pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables
192
+ return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None)
193
+
194
+ def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False):
195
+ '''Takes care of all corner cases and produces a final normal used for shading:
196
+ - Constructs tangent space
197
+ - Flips normal direction based on geometric normal for two sided Shading
198
+ - Perturbs shading normal by normal map
199
+ - Bends backfacing normals towards the camera to avoid shading artifacts
200
+
201
+ All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
202
+
203
+ Args:
204
+ pos: World space g-buffer position.
205
+ view_pos: Camera position in world space (typically using broadcasting).
206
+ perturbed_nrm: Trangent-space normal perturbation from normal map lookup.
207
+ smooth_nrm: Interpolated vertex normals.
208
+ smooth_tng: Interpolated vertex tangents.
209
+ geom_nrm: Geometric (face) normals.
210
+ two_sided_shading: Use one/two sided shading
211
+ opengl: Use OpenGL/DirectX normal map conventions
212
+ use_python: Use PyTorch implementation (for validation)
213
+ Returns:
214
+ Final shading normal
215
+ '''
216
+
217
+ if perturbed_nrm is None:
218
+ perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...]
219
+
220
+ if use_python:
221
+ out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
222
+ else:
223
+ out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
224
+
225
+ if torch.is_anomaly_enabled():
226
+ assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN"
227
+ return out
228
+
229
+ #----------------------------------------------------------------------------
230
+ # BSDF functions
231
+
232
+ class _lambert_func(torch.autograd.Function):
233
+ @staticmethod
234
+ def forward(ctx, nrm, wi):
235
+ out = _get_plugin().lambert_fwd(nrm, wi, False)
236
+ ctx.save_for_backward(nrm, wi)
237
+ return out
238
+
239
+ @staticmethod
240
+ def backward(ctx, dout):
241
+ nrm, wi = ctx.saved_variables
242
+ return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,)
243
+
244
+ def lambert(nrm, wi, use_python=False):
245
+ '''Lambertian bsdf.
246
+ All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
247
+
248
+ Args:
249
+ nrm: World space shading normal.
250
+ wi: World space light vector.
251
+ use_python: Use PyTorch implementation (for validation)
252
+
253
+ Returns:
254
+ Shaded diffuse value with shape [minibatch_size, height, width, 1]
255
+ '''
256
+
257
+ if use_python:
258
+ out = bsdf_lambert(nrm, wi)
259
+ else:
260
+ out = _lambert_func.apply(nrm, wi)
261
+
262
+ if torch.is_anomaly_enabled():
263
+ assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN"
264
+ return out
265
+
266
+ class _frostbite_diffuse_func(torch.autograd.Function):
267
+ @staticmethod
268
+ def forward(ctx, nrm, wi, wo, linearRoughness):
269
+ out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False)
270
+ ctx.save_for_backward(nrm, wi, wo, linearRoughness)
271
+ return out
272
+
273
+ @staticmethod
274
+ def backward(ctx, dout):
275
+ nrm, wi, wo, linearRoughness = ctx.saved_variables
276
+ return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,)
277
+
278
+ def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False):
279
+ '''Frostbite, normalized Disney Diffuse bsdf.
280
+ All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
281
+
282
+ Args:
283
+ nrm: World space shading normal.
284
+ wi: World space light vector.
285
+ wo: World space camera vector.
286
+ linearRoughness: Material roughness
287
+ use_python: Use PyTorch implementation (for validation)
288
+
289
+ Returns:
290
+ Shaded diffuse value with shape [minibatch_size, height, width, 1]
291
+ '''
292
+
293
+ if use_python:
294
+ out = bsdf_frostbite(nrm, wi, wo, linearRoughness)
295
+ else:
296
+ out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness)
297
+
298
+ if torch.is_anomaly_enabled():
299
+ assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN"
300
+ return out
301
+
302
+ class _pbr_specular_func(torch.autograd.Function):
303
+ @staticmethod
304
+ def forward(ctx, col, nrm, wo, wi, alpha, min_roughness):
305
+ ctx.save_for_backward(col, nrm, wo, wi, alpha)
306
+ ctx.min_roughness = min_roughness
307
+ out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False)
308
+ return out
309
+
310
+ @staticmethod
311
+ def backward(ctx, dout):
312
+ col, nrm, wo, wi, alpha = ctx.saved_variables
313
+ return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None)
314
+
315
+ def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False):
316
+ '''Physically-based specular bsdf.
317
+ All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
318
+
319
+ Args:
320
+ col: Specular lobe color
321
+ nrm: World space shading normal.
322
+ wo: World space camera vector.
323
+ wi: World space light vector
324
+ alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1]
325
+ min_roughness: Scalar roughness clamping threshold
326
+
327
+ use_python: Use PyTorch implementation (for validation)
328
+ Returns:
329
+ Shaded specular color
330
+ '''
331
+
332
+ if use_python:
333
+ out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness)
334
+ else:
335
+ out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness)
336
+
337
+ if torch.is_anomaly_enabled():
338
+ assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN"
339
+ return out
340
+
341
+ class _pbr_bsdf_func(torch.autograd.Function):
342
+ @staticmethod
343
+ def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
344
+ ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos)
345
+ ctx.min_roughness = min_roughness
346
+ ctx.BSDF = BSDF
347
+ out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False)
348
+ return out
349
+
350
+ @staticmethod
351
+ def backward(ctx, dout):
352
+ kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables
353
+ return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None)
354
+
355
+ def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf="lambert", use_python=False):
356
+ '''Physically-based bsdf, both diffuse & specular lobes
357
+ All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
358
+
359
+ Args:
360
+ kd: Diffuse albedo.
361
+ arm: Specular parameters (attenuation, linear roughness, metalness).
362
+ pos: World space position.
363
+ nrm: World space shading normal.
364
+ view_pos: Camera position in world space, typically using broadcasting.
365
+ light_pos: Light position in world space, typically using broadcasting.
366
+ min_roughness: Scalar roughness clamping threshold
367
+ bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite'
368
+
369
+ use_python: Use PyTorch implementation (for validation)
370
+
371
+ Returns:
372
+ Shaded color.
373
+ '''
374
+
375
+ BSDF = 0
376
+ if bsdf == 'frostbite':
377
+ BSDF = 1
378
+
379
+ if use_python:
380
+ out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)
381
+ else:
382
+ out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)
383
+
384
+ if torch.is_anomaly_enabled():
385
+ assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN"
386
+ return out
387
+
388
+ #----------------------------------------------------------------------------
389
+ # cubemap filter with filtering across edges
390
+
391
+ class _diffuse_cubemap_func(torch.autograd.Function):
392
+ @staticmethod
393
+ def forward(ctx, cubemap):
394
+ out = _get_plugin().diffuse_cubemap_fwd(cubemap)
395
+ ctx.save_for_backward(cubemap)
396
+ return out
397
+
398
+ @staticmethod
399
+ def backward(ctx, dout):
400
+ cubemap, = ctx.saved_variables
401
+ cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout)
402
+ return cubemap_grad, None
403
+
404
+ def diffuse_cubemap(cubemap, use_python=False):
405
+ if use_python:
406
+ assert False
407
+ else:
408
+ out = _diffuse_cubemap_func.apply(cubemap)
409
+ if torch.is_anomaly_enabled():
410
+ assert torch.all(torch.isfinite(out)), "Output of diffuse_cubemap contains inf or NaN"
411
+ return out
412
+
413
+ class _specular_cubemap(torch.autograd.Function):
414
+ @staticmethod
415
+ def forward(ctx, cubemap, roughness, costheta_cutoff, bounds):
416
+ out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff)
417
+ ctx.save_for_backward(cubemap, bounds)
418
+ ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff
419
+ return out
420
+
421
+ @staticmethod
422
+ def backward(ctx, dout):
423
+ cubemap, bounds = ctx.saved_variables
424
+ cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff)
425
+ return cubemap_grad, None, None, None
426
+
427
+ # Compute the bounds of the GGX NDF lobe to retain "cutoff" percent of the energy
428
+ def __ndfBounds(res, roughness, cutoff):
429
+ def ndfGGX(alphaSqr, costheta):
430
+ costheta = np.clip(costheta, 0.0, 1.0)
431
+ d = (costheta * alphaSqr - costheta) * costheta + 1.0
432
+ return alphaSqr / (d * d * np.pi)
433
+
434
+ # Sample out cutoff angle
435
+ nSamples = 1000000
436
+ costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples))
437
+ D = np.cumsum(ndfGGX(roughness**4, costheta))
438
+ idx = np.argmax(D >= D[..., -1] * cutoff)
439
+
440
+ # Brute force compute lookup table with bounds
441
+ bounds = _get_plugin().specular_bounds(res, costheta[idx])
442
+
443
+ return costheta[idx], bounds
444
+ __ndfBoundsDict = {}
445
+
446
+ def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False):
447
+ assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], "Bad shape for cubemap tensor: %s" % str(cubemap.shape)
448
+
449
+ if use_python:
450
+ assert False
451
+ else:
452
+ key = (cubemap.shape[1], roughness, cutoff)
453
+ if key not in __ndfBoundsDict:
454
+ __ndfBoundsDict[key] = __ndfBounds(*key)
455
+ out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key])
456
+ if torch.is_anomaly_enabled():
457
+ assert torch.all(torch.isfinite(out)), "Output of specular_cubemap contains inf or NaN"
458
+ return out[..., 0:3] / out[..., 3:]
459
+
460
+ #----------------------------------------------------------------------------
461
+ # Fast image loss function
462
+
463
+ class _image_loss_func(torch.autograd.Function):
464
+ @staticmethod
465
+ def forward(ctx, img, target, loss, tonemapper):
466
+ ctx.loss, ctx.tonemapper = loss, tonemapper
467
+ ctx.save_for_backward(img, target)
468
+ out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False)
469
+ return out
470
+
471
+ @staticmethod
472
+ def backward(ctx, dout):
473
+ img, target = ctx.saved_variables
474
+ return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None)
475
+
476
+ def image_loss(img, target, loss='l1', tonemapper='none', use_python=False):
477
+ '''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf.
478
+ All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
479
+
480
+ Args:
481
+ img: Input image.
482
+ target: Target (reference) image.
483
+ loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse']
484
+ tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb']
485
+ use_python: Use PyTorch implementation (for validation)
486
+
487
+ Returns:
488
+ Image space loss (scalar value).
489
+ '''
490
+ if use_python:
491
+ out = image_loss_fn(img, target, loss, tonemapper)
492
+ else:
493
+ out = _image_loss_func.apply(img, target, loss, tonemapper)
494
+ out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2])
495
+
496
+ if torch.is_anomaly_enabled():
497
+ assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN"
498
+ return out
499
+
500
+ #----------------------------------------------------------------------------
501
+ # Transform points function
502
+
503
+ class _xfm_func(torch.autograd.Function):
504
+ @staticmethod
505
+ def forward(ctx, points, matrix, isPoints):
506
+ ctx.save_for_backward(points, matrix)
507
+ ctx.isPoints = isPoints
508
+ return _get_plugin().xfm_fwd(points, matrix, isPoints, False)
509
+
510
+ @staticmethod
511
+ def backward(ctx, dout):
512
+ points, matrix = ctx.saved_variables
513
+ return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None)
514
+
515
+ def xfm_points(points, matrix, use_python=False):
516
+ '''Transform points.
517
+ Args:
518
+ points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
519
+ matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
520
+ use_python: Use PyTorch's torch.matmul (for validation)
521
+ Returns:
522
+ Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
523
+ '''
524
+ if use_python:
525
+ out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
526
+ else:
527
+ out = _xfm_func.apply(points, matrix, True)
528
+
529
+ if torch.is_anomaly_enabled():
530
+ assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
531
+ return out
532
+
533
+ def xfm_vectors(vectors, matrix, use_python=False):
534
+ '''Transform vectors.
535
+ Args:
536
+ vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
537
+ matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
538
+ use_python: Use PyTorch's torch.matmul (for validation)
539
+
540
+ Returns:
541
+ Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
542
+ '''
543
+
544
+ if use_python:
545
+ out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous()
546
+ else:
547
+ out = _xfm_func.apply(vectors, matrix, False)
548
+
549
+ if torch.is_anomaly_enabled():
550
+ assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN"
551
+ return out
552
+
553
+
554
+
video3d/render/renderutils/tests/test_bsdf.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import torch
11
+
12
+ import os
13
+ import sys
14
+ sys.path.insert(0, os.path.join(sys.path[0], '../..'))
15
+ import renderutils as ru
16
+
17
+ RES = 4
18
+ DTYPE = torch.float32
19
+
20
+ def relative_loss(name, ref, cuda):
21
+ ref = ref.float()
22
+ cuda = cuda.float()
23
+ print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
24
+
25
+ def test_normal():
26
+ pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
27
+ pos_ref = pos_cuda.clone().detach().requires_grad_(True)
28
+ view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
29
+ view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True)
30
+ perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
31
+ perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True)
32
+ smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
33
+ smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True)
34
+ smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
35
+ smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True)
36
+ geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
37
+ geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True)
38
+ target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
39
+
40
+ ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True)
41
+ ref_loss = torch.nn.MSELoss()(ref, target)
42
+ ref_loss.backward()
43
+
44
+ cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True)
45
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
46
+ cuda_loss.backward()
47
+
48
+ print("-------------------------------------------------------------")
49
+ print(" bent normal")
50
+ print("-------------------------------------------------------------")
51
+ relative_loss("res:", ref, cuda)
52
+ relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
53
+ relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad)
54
+ relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad)
55
+ relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad)
56
+ relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad)
57
+ relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad)
58
+
59
+ def test_schlick():
60
+ f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
61
+ f0_ref = f0_cuda.clone().detach().requires_grad_(True)
62
+ f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
63
+ f90_ref = f90_cuda.clone().detach().requires_grad_(True)
64
+ cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0
65
+ cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
66
+ cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
67
+ target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
68
+
69
+ ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True)
70
+ ref_loss = torch.nn.MSELoss()(ref, target)
71
+ ref_loss.backward()
72
+
73
+ cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda)
74
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
75
+ cuda_loss.backward()
76
+
77
+ print("-------------------------------------------------------------")
78
+ print(" Fresnel shlick")
79
+ print("-------------------------------------------------------------")
80
+ relative_loss("res:", ref, cuda)
81
+ relative_loss("f0:", f0_ref.grad, f0_cuda.grad)
82
+ relative_loss("f90:", f90_ref.grad, f90_cuda.grad)
83
+ relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
84
+
85
+ def test_ndf_ggx():
86
+ alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
87
+ alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True)
88
+ alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
89
+ cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
90
+ cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
91
+ cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
92
+ target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
93
+
94
+ ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True)
95
+ ref_loss = torch.nn.MSELoss()(ref, target)
96
+ ref_loss.backward()
97
+
98
+ cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda)
99
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
100
+ cuda_loss.backward()
101
+
102
+ print("-------------------------------------------------------------")
103
+ print(" Ndf GGX")
104
+ print("-------------------------------------------------------------")
105
+ relative_loss("res:", ref, cuda)
106
+ relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
107
+ relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
108
+
109
+ def test_lambda_ggx():
110
+ alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
111
+ alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
112
+ cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
113
+ cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
114
+ cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
115
+ target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
116
+
117
+ ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True)
118
+ ref_loss = torch.nn.MSELoss()(ref, target)
119
+ ref_loss.backward()
120
+
121
+ cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda)
122
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
123
+ cuda_loss.backward()
124
+
125
+ print("-------------------------------------------------------------")
126
+ print(" Lambda GGX")
127
+ print("-------------------------------------------------------------")
128
+ relative_loss("res:", ref, cuda)
129
+ relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
130
+ relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
131
+
132
+ def test_masking_smith():
133
+ alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
134
+ alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
135
+ cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
136
+ cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True)
137
+ cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
138
+ cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True)
139
+ target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
140
+
141
+ ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True)
142
+ ref_loss = torch.nn.MSELoss()(ref, target)
143
+ ref_loss.backward()
144
+
145
+ cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda)
146
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
147
+ cuda_loss.backward()
148
+
149
+ print("-------------------------------------------------------------")
150
+ print(" Smith masking term")
151
+ print("-------------------------------------------------------------")
152
+ relative_loss("res:", ref, cuda)
153
+ relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
154
+ relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad)
155
+ relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad)
156
+
157
+ def test_lambert():
158
+ normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
159
+ normals_ref = normals_cuda.clone().detach().requires_grad_(True)
160
+ wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
161
+ wi_ref = wi_cuda.clone().detach().requires_grad_(True)
162
+ target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
163
+
164
+ ref = ru.lambert(normals_ref, wi_ref, use_python=True)
165
+ ref_loss = torch.nn.MSELoss()(ref, target)
166
+ ref_loss.backward()
167
+
168
+ cuda = ru.lambert(normals_cuda, wi_cuda)
169
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
170
+ cuda_loss.backward()
171
+
172
+ print("-------------------------------------------------------------")
173
+ print(" Lambert")
174
+ print("-------------------------------------------------------------")
175
+ relative_loss("res:", ref, cuda)
176
+ relative_loss("nrm:", normals_ref.grad, normals_cuda.grad)
177
+ relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
178
+
179
+ def test_frostbite():
180
+ normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
181
+ normals_ref = normals_cuda.clone().detach().requires_grad_(True)
182
+ wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
183
+ wi_ref = wi_cuda.clone().detach().requires_grad_(True)
184
+ wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
185
+ wo_ref = wo_cuda.clone().detach().requires_grad_(True)
186
+ rough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
187
+ rough_ref = rough_cuda.clone().detach().requires_grad_(True)
188
+ target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
189
+
190
+ ref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True)
191
+ ref_loss = torch.nn.MSELoss()(ref, target)
192
+ ref_loss.backward()
193
+
194
+ cuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda)
195
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
196
+ cuda_loss.backward()
197
+
198
+ print("-------------------------------------------------------------")
199
+ print(" Frostbite")
200
+ print("-------------------------------------------------------------")
201
+ relative_loss("res:", ref, cuda)
202
+ relative_loss("nrm:", normals_ref.grad, normals_cuda.grad)
203
+ relative_loss("wo:", wo_ref.grad, wo_cuda.grad)
204
+ relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
205
+ relative_loss("rough:", rough_ref.grad, rough_cuda.grad)
206
+
207
+ def test_pbr_specular():
208
+ col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
209
+ col_ref = col_cuda.clone().detach().requires_grad_(True)
210
+ nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
211
+ nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
212
+ wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
213
+ wi_ref = wi_cuda.clone().detach().requires_grad_(True)
214
+ wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
215
+ wo_ref = wo_cuda.clone().detach().requires_grad_(True)
216
+ alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
217
+ alpha_ref = alpha_cuda.clone().detach().requires_grad_(True)
218
+ target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
219
+
220
+ ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True)
221
+ ref_loss = torch.nn.MSELoss()(ref, target)
222
+ ref_loss.backward()
223
+
224
+ cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda)
225
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
226
+ cuda_loss.backward()
227
+
228
+ print("-------------------------------------------------------------")
229
+ print(" Pbr specular")
230
+ print("-------------------------------------------------------------")
231
+
232
+ relative_loss("res:", ref, cuda)
233
+ if col_ref.grad is not None:
234
+ relative_loss("col:", col_ref.grad, col_cuda.grad)
235
+ if nrm_ref.grad is not None:
236
+ relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
237
+ if wi_ref.grad is not None:
238
+ relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
239
+ if wo_ref.grad is not None:
240
+ relative_loss("wo:", wo_ref.grad, wo_cuda.grad)
241
+ if alpha_ref.grad is not None:
242
+ relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad)
243
+
244
+ def test_pbr_bsdf(bsdf):
245
+ kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
246
+ kd_ref = kd_cuda.clone().detach().requires_grad_(True)
247
+ arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
248
+ arm_ref = arm_cuda.clone().detach().requires_grad_(True)
249
+ pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
250
+ pos_ref = pos_cuda.clone().detach().requires_grad_(True)
251
+ nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
252
+ nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
253
+ view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
254
+ view_ref = view_cuda.clone().detach().requires_grad_(True)
255
+ light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
256
+ light_ref = light_cuda.clone().detach().requires_grad_(True)
257
+ target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
258
+
259
+ ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf)
260
+ ref_loss = torch.nn.MSELoss()(ref, target)
261
+ ref_loss.backward()
262
+
263
+ cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf)
264
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
265
+ cuda_loss.backward()
266
+
267
+ print("-------------------------------------------------------------")
268
+ print(" Pbr BSDF")
269
+ print("-------------------------------------------------------------")
270
+
271
+ relative_loss("res:", ref, cuda)
272
+ if kd_ref.grad is not None:
273
+ relative_loss("kd:", kd_ref.grad, kd_cuda.grad)
274
+ if arm_ref.grad is not None:
275
+ relative_loss("arm:", arm_ref.grad, arm_cuda.grad)
276
+ if pos_ref.grad is not None:
277
+ relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
278
+ if nrm_ref.grad is not None:
279
+ relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
280
+ if view_ref.grad is not None:
281
+ relative_loss("view:", view_ref.grad, view_cuda.grad)
282
+ if light_ref.grad is not None:
283
+ relative_loss("light:", light_ref.grad, light_cuda.grad)
284
+
285
+ test_normal()
286
+
287
+ test_schlick()
288
+ test_ndf_ggx()
289
+ test_lambda_ggx()
290
+ test_masking_smith()
291
+
292
+ test_lambert()
293
+ test_frostbite()
294
+ test_pbr_specular()
295
+ test_pbr_bsdf('lambert')
296
+ test_pbr_bsdf('frostbite')
video3d/render/renderutils/tests/test_cubemap.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import torch
11
+
12
+ import os
13
+ import sys
14
+ sys.path.insert(0, os.path.join(sys.path[0], '../..'))
15
+ import renderutils as ru
16
+
17
+ RES = 4
18
+ DTYPE = torch.float32
19
+
20
+ def relative_loss(name, ref, cuda):
21
+ ref = ref.float()
22
+ cuda = cuda.float()
23
+ print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
24
+
25
+ def test_cubemap():
26
+ cubemap_cuda = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
27
+ cubemap_ref = cubemap_cuda.clone().detach().requires_grad_(True)
28
+ weights = torch.rand(3, 3, 1, dtype=DTYPE, device='cuda')
29
+ target = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda')
30
+
31
+ ref = ru.filter_cubemap(cubemap_ref, weights, use_python=True)
32
+ ref_loss = torch.nn.MSELoss()(ref, target)
33
+ ref_loss.backward()
34
+
35
+ cuda = ru.filter_cubemap(cubemap_cuda, weights, use_python=False)
36
+ cuda_loss = torch.nn.MSELoss()(cuda, target)
37
+ cuda_loss.backward()
38
+
39
+ print("-------------------------------------------------------------")
40
+ print(" Cubemap:")
41
+ print("-------------------------------------------------------------")
42
+
43
+ relative_loss("flt:", ref, cuda)
44
+ relative_loss("cubemap:", cubemap_ref.grad, cubemap_cuda.grad)
45
+
46
+
47
+ test_cubemap()