HaolinLiu commited on
Commit
cc9780d
·
1 Parent(s): 78c29b6

first commit of codes and update readme.md

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +72 -1
  2. configs/config_utils.py +70 -0
  3. configs/finetune_triplane_diffusion.yaml +67 -0
  4. configs/train_triplane_diffusion.yaml +64 -0
  5. configs/train_triplane_vae.yaml +30 -0
  6. data/download_preprocess_data_here +1 -0
  7. datasets/SingleView_dataset.py +453 -0
  8. datasets/__init__.py +91 -0
  9. datasets/taxonomy.py +111 -0
  10. datasets/transforms.py +180 -0
  11. engine/engine_triplane_dm.py +136 -0
  12. engine/engine_triplane_vae.py +185 -0
  13. evaluation/dist_eval.sh +16 -0
  14. evaluation/evaluate_object_reconstruction.py +239 -0
  15. evaluation/pyTorchChamferDistance/.gitignore +3 -0
  16. evaluation/pyTorchChamferDistance/LICENSE.md +21 -0
  17. evaluation/pyTorchChamferDistance/README.md +23 -0
  18. evaluation/pyTorchChamferDistance/__init__.py +0 -0
  19. evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py +1 -0
  20. evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp +185 -0
  21. evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu +209 -0
  22. evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py +58 -0
  23. finetune_diffusion.sh +18 -0
  24. models/TriplaneVAE.py +94 -0
  25. models/Triplane_Diffusion.py +190 -0
  26. models/__init__.py +20 -0
  27. models/modules/PointEMB.py +34 -0
  28. models/modules/Positional_Embedding.py +15 -0
  29. models/modules/__init__.py +5 -0
  30. models/modules/decoder.py +121 -0
  31. models/modules/diffusion_sampler.py +89 -0
  32. models/modules/encoder.py +235 -0
  33. models/modules/image_sampler.py +1046 -0
  34. models/modules/parpoints_encoder.py +168 -0
  35. models/modules/point_transformer.py +442 -0
  36. models/modules/pointnet2_backbone.py +188 -0
  37. models/modules/resnet_block.py +47 -0
  38. models/modules/resunet.py +440 -0
  39. models/modules/unet.py +304 -0
  40. models/modules/utils.py +25 -0
  41. output/put_checkpoints_here +1 -0
  42. process_scripts/augment_arkit_partial_point.py +64 -0
  43. process_scripts/augment_synthetic_partial_points.py +64 -0
  44. process_scripts/dist_export_triplane_features.sh +8 -0
  45. process_scripts/dist_extract_vit.sh +6 -0
  46. process_scripts/export_triplane_features.py +122 -0
  47. process_scripts/extract_img_vit_features.py +73 -0
  48. process_scripts/generate_split_for_arkit.py +102 -0
  49. process_scripts/generate_split_for_synthetic_data.py +78 -0
  50. process_scripts/unzip_all_data.py +38 -0
README.md CHANGED
@@ -8,4 +8,75 @@ Repository of LASA: Instance Reconstruction from Real Scans using A Large-scale
8
  ![292080628-a4b020dc-2673-4b1b-bfa6-ec9422625624](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/7a0dfc11-5454-428f-bfba-e8cd0d0af96e)
9
  ![292080638-324bbef9-c93b-4d96-b814-120204374383](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/ee07691a-8767-4701-9a32-19a70e0e240a)
10
 
11
- #### Codes and dataset will be released soon!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  ![292080628-a4b020dc-2673-4b1b-bfa6-ec9422625624](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/7a0dfc11-5454-428f-bfba-e8cd0d0af96e)
9
  ![292080638-324bbef9-c93b-4d96-b814-120204374383](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/ee07691a-8767-4701-9a32-19a70e0e240a)
10
 
11
+ ## Dataset
12
+ Complete raw data will be released soon.
13
+
14
+ ## Download preprocessed data and processing
15
+ Download the preprocessed data from <a href="https://pan.baidu.com/s/1tCEGYBH0DEh8NcAURTnMbw?pwd=62ux">
16
+ BaiduYun (code: 62ux)<a/>. (These data will be updated as cleaning process continues.) Put all the downloaded data under LASA, unzip the align_mat_all.zip mannually.
17
+ You can choose to the the script ./process_scripts/unzip_all_data to unzip all the data in occ_data and other_data by following commands:
18
+ ```angular2html
19
+ cd process_scripts
20
+ python unzip_all_data.py --unzip_occ --unzip_other
21
+ ```
22
+ Run the following commands to generate augmented partial point cloud for synthetic dataset and LASA dataset
23
+ ```angular2html
24
+ cd process_scripts
25
+ python augment_arkit_partial_point.py --cat arkit_chair arkit_stool ...
26
+ python augment_synthetic_partial_point.py --cat 03001627 future_chair ABO_chair ...
27
+ ```
28
+ Run the following command to extract image features
29
+ ```angular2html
30
+ cd process_scripts
31
+ bash dist_extract_vit.sh
32
+ ```
33
+ Finally, run the following command to generate train/val splits:
34
+ ```angular2html
35
+ cd process_scripts
36
+ python generate_split_for_arkit --cat arkit_chair arkit_stool ...
37
+ python generate_split_for_synthetic_data.py --cat 03001627 future_chair ABO_chair ...
38
+ ```
39
+
40
+ ## Evaluation
41
+ Download the pretrained weight for chair from <a href="https://pan.baidu.com/s/10liUOaC4CXGn7bN6SQkZsw?pwd=hlf9"> chair_checkpoint.<a/> (code:hlf9).
42
+ Put these folder under LASA/output.<br> The ae folder stores the VAE weight, dm folder stores the diffusion model trained on synthetic data.
43
+ finetune_dm folder stores the diffusion model finetuned on LASA dataset.
44
+ Run the following commands to evaluate and extract the mesh:
45
+ ```angular2html
46
+ cd evaluation
47
+ bash dist_eval.sh
48
+ ```
49
+ The category entries are the sub-category from arkit scenes, please see ./datasets/taxonomy.py about how they are defined.
50
+ For example, if you want to evaluate on LASA's chair, category should contain both arkit_chair and arkit_stool.
51
+ make sure the --ae-pth and --dm-pth entry points to the correct checkpoint path. If you are evaluating on LASA,
52
+ make sure the --dm-pth points to the finetuned weight in the ./output/finetune_dm folder. The result will be saved
53
+ under ./output_result.
54
+
55
+ ## Training
56
+ Run the <strong>train_VAE.sh</strong> to train the VAE model. If you aims to train on one category, just specify one category from <strong> chair,
57
+ cabinet, table, sofa, bed, shelf</strong>. Inputting <strong>all</strong> will train on all categories. Makes sure to download and preprocess all
58
+ the required sub-category data. The sub-category arrangement can be found in ./datasets/taxonomy.py <br>
59
+ After finish training the VAE model, run the following commands to pre-extract the VAE features for every object:
60
+ ```angular2html
61
+ cd process_scripts
62
+ bash dist_export_triplane_features.sh
63
+ ```
64
+ Then, we can start training the diffusion model on the synthetic dataset by running the <strong>train_diffusion.sh</strong>.<br>
65
+ Finally, finetune the diffusion model on LASA dataset by running <strong> finetune_diffusion.sh</strong>. <br><br>
66
+
67
+ Early stopping is used by mannualy stopping the training by 150 epochs and 500 epochs for training VAE model and diffusion model respetively.
68
+ All experiments in the paper are conducted on 8 A100 GPUs with batch size = 22.
69
+ ## TODO
70
+
71
+ - [ ] Object Detection Code
72
+ - [ ] Code for Demo on both arkitscene and in the wild data
73
+
74
+ ## Citation
75
+ ```
76
+ @article{liu2023lasa,
77
+ title={LASA: Instance Reconstruction from Real Scans using A Large-scale Aligned Shape Annotation Dataset},
78
+ author={Liu, Haolin and Ye, Chongjie and Nie, Yinyu and He, Yingfan and Han, Xiaoguang},
79
+ journal={arXiv preprint arXiv:2312.12418},
80
+ year={2023}
81
+ }
82
+ ```
configs/config_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import logging
4
+ from datetime import datetime
5
+
6
+ def update_recursive(dict1, dict2):
7
+ ''' Update two config dictionaries recursively.
8
+
9
+ Args:
10
+ dict1 (dict): first dictionary to be updated
11
+ dict2 (dict): second dictionary which entries should be used
12
+
13
+ '''
14
+ for k, v in dict2.items():
15
+ if k not in dict1:
16
+ dict1[k] = dict()
17
+ if isinstance(v, dict):
18
+ update_recursive(dict1[k], v)
19
+ else:
20
+ dict1[k] = v
21
+
22
+ class CONFIG(object):
23
+ '''
24
+ Stores all configures
25
+ '''
26
+ def __init__(self, input=None):
27
+ '''
28
+ Loads config file
29
+ :param path (str): path to config file
30
+ :return:
31
+ '''
32
+ self.config = self.read_to_dict(input)
33
+
34
+ def read_to_dict(self, input):
35
+ if not input:
36
+ return dict()
37
+ if isinstance(input, str) and os.path.isfile(input):
38
+ if input.endswith('yaml'):
39
+ with open(input, 'r') as f:
40
+ config = yaml.load(f, Loader=yaml.FullLoader)
41
+ else:
42
+ ValueError('Config file should be with the format of *.yaml')
43
+ elif isinstance(input, dict):
44
+ config = input
45
+ else:
46
+ raise ValueError('Unrecognized input type (i.e. not *.yaml file nor dict).')
47
+
48
+ return config
49
+
50
+ def update_config(self, *args, **kwargs):
51
+ '''
52
+ update config and corresponding logger setting
53
+ :param input: dict settings add to config file
54
+ :return:
55
+ '''
56
+ cfg1 = dict()
57
+ for item in args:
58
+ cfg1.update(self.read_to_dict(item))
59
+
60
+ cfg2 = self.read_to_dict(kwargs)
61
+
62
+ new_cfg = {**cfg1, **cfg2}
63
+
64
+ update_recursive(self.config, new_cfg)
65
+ # when update config file, the corresponding logger should also be updated.
66
+ self.__update_logger()
67
+
68
+ def write_config(self,save_path):
69
+ with open(save_path, 'w') as file:
70
+ yaml.dump(self.config, file, default_flow_style = False)
configs/finetune_triplane_diffusion.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ ae: #ae model is loaded to
3
+ type: TriVAE
4
+ point_emb_dim: 48
5
+ padding: 0.1
6
+ encoder:
7
+ plane_reso: 128
8
+ plane_latent_dim: 32
9
+ latent_dim: 32
10
+ unet:
11
+ depth: 4
12
+ merge_mode: concat
13
+ start_filts: 32
14
+ output_dim: 64
15
+ decoder:
16
+ plane_reso: 128
17
+ latent_dim: 32
18
+ n_blocks: 5
19
+ query_emb_dim: 48
20
+ hidden_dim: 128
21
+ unet:
22
+ depth: 4
23
+ merge_mode: concat
24
+ start_filts: 64
25
+ output_dim: 32
26
+ dm:
27
+ type: triplane_diff_multiimg_cond
28
+ backbone: resunet_multiimg_direct_atten
29
+ diff_reso: 64
30
+ input_channel: 32
31
+ output_channel: 32
32
+ triplane_padding: 0.1 #should be consistent with padding in ae
33
+
34
+ use_par: True
35
+ par_channel: 32
36
+ par_emb_dim: 48
37
+ norm: "batch"
38
+ img_in_channels: 1280
39
+ vit_reso: 16
40
+ use_cat_embedding: ???
41
+ block_type: multiview_local
42
+ par_point_encoder:
43
+ plane_reso: 64
44
+ plane_latent_dim: 32
45
+ n_blocks: 5
46
+ unet:
47
+ depth: 3
48
+ merge_mode: concat
49
+ start_filts: 32
50
+ output_dim: 32
51
+ criterion:
52
+ type: EDMLoss_MultiImgCond
53
+ use_par: True
54
+ dataset:
55
+ type: Occ_Par_MultiImg_Finetune
56
+ data_path: ???
57
+ surface_size: 20000
58
+ par_pc_size: 2048
59
+ load_proj_mat: True
60
+ load_image: True
61
+ par_point_aug: 0.5
62
+ par_prefix: "aug7_"
63
+ keyword: lowres #use lowres arkitscene or highres to train, lowres scene is more user accessible
64
+ jitter_partial_pretrain: 0.02
65
+ jitter_partial_finetune: 0.02
66
+ jitter_partial_val: 0.0
67
+ use_pretrain_data: False
configs/train_triplane_diffusion.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ ae: #ae model is loaded to
3
+ type: TriVAE
4
+ point_emb_dim: 48
5
+ padding: 0.1
6
+ encoder:
7
+ plane_reso: 128
8
+ plane_latent_dim: 32
9
+ latent_dim: 32
10
+ unet:
11
+ depth: 4
12
+ merge_mode: concat
13
+ start_filts: 32
14
+ output_dim: 64
15
+ decoder:
16
+ plane_reso: 128
17
+ latent_dim: 32
18
+ n_blocks: 5
19
+ query_emb_dim: 48
20
+ hidden_dim: 128
21
+ unet:
22
+ depth: 4
23
+ merge_mode: concat
24
+ start_filts: 64
25
+ output_dim: 32
26
+ dm:
27
+ type: triplane_diff_multiimg_cond
28
+ backbone: resunet_multiimg_direct_atten
29
+ diff_reso: 64
30
+ input_channel: 32
31
+ output_channel: 32
32
+ triplane_padding: 0.1 #should be consistent with padding in ae
33
+
34
+ use_par: True
35
+ par_channel: 32
36
+ par_emb_dim: 48
37
+ norm: "batch"
38
+ img_in_channels: 1280
39
+ vit_reso: 16
40
+ use_cat_embedding: ???
41
+ block_type: multiview_local
42
+ par_point_encoder:
43
+ plane_reso: 64
44
+ plane_latent_dim: 32
45
+ n_blocks: 5
46
+ unet:
47
+ depth: 3
48
+ merge_mode: concat
49
+ start_filts: 32
50
+ output_dim: 32
51
+ criterion:
52
+ type: EDMLoss_MultiImgCond
53
+ use_par: True
54
+ dataset:
55
+ type: Occ_Par_MultiImg
56
+ data_path: ???
57
+ surface_size: 20000
58
+ par_pc_size: 2048
59
+ load_proj_mat: True
60
+ load_image: True
61
+ par_point_aug: 0.5
62
+ par_prefix: "aug7_" # prefix of the filenames of the partial point cloud
63
+ jitter_partial_train: 0.02
64
+ jitter_partial_val: 0.0
configs/train_triplane_vae.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ type: TriVAE
3
+ point_emb_dim: 48
4
+ padding: 0.1
5
+ encoder:
6
+ plane_reso: 128
7
+ plane_latent_dim: 32
8
+ latent_dim: 32
9
+ unet:
10
+ depth: 4
11
+ merge_mode: concat
12
+ start_filts: 32
13
+ output_dim: 64
14
+ decoder:
15
+ plane_reso: 128
16
+ latent_dim: 32
17
+ n_blocks: 5
18
+ query_emb_dim: 48
19
+ hidden_dim: 128
20
+ unet:
21
+ depth: 4
22
+ merge_mode: concat
23
+ start_filts: 64
24
+ output_dim: 32
25
+ dataset:
26
+ type: Occ
27
+ category: chair
28
+ data_path: ???
29
+ surface_size: 20000
30
+ num_samples: 2048
data/download_preprocess_data_here ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
datasets/SingleView_dataset.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+
5
+ import yaml
6
+
7
+ import torch
8
+ from torch.utils import data
9
+
10
+ import numpy as np
11
+ import json
12
+
13
+ from PIL import Image
14
+
15
+ import h5py
16
+ import torch.distributed as dist
17
+ import open3d as o3d
18
+ o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)
19
+ import pickle as p
20
+ import time
21
+ import cv2
22
+ from torchvision import transforms
23
+ import copy
24
+ from datasets.taxonomy import category_map_from_synthetic as category_ids
25
+ class Object_Occ(data.Dataset):
26
+ def __init__(self, dataset_folder, split, categories=['03001627', "future_chair", 'ABO_chair'], transform=None,
27
+ sampling=True,
28
+ num_samples=4096, return_surface=True, surface_sampling=True, surface_size=2048, replica=16):
29
+
30
+ self.pc_size = surface_size
31
+
32
+ self.transform = transform
33
+ self.num_samples = num_samples
34
+ self.sampling = sampling
35
+ self.split = split
36
+
37
+ self.dataset_folder = dataset_folder
38
+ self.return_surface = return_surface
39
+ self.surface_sampling = surface_sampling
40
+
41
+ self.dataset_folder = dataset_folder
42
+ self.point_folder = os.path.join(self.dataset_folder, 'occ_data')
43
+ self.mesh_folder = os.path.join(self.dataset_folder, 'other_data')
44
+
45
+ if categories is None:
46
+ categories = os.listdir(self.point_folder)
47
+ categories = [c for c in categories if
48
+ os.path.isdir(os.path.join(self.point_folder, c)) and c.startswith('0')]
49
+ categories.sort()
50
+
51
+ print(categories)
52
+
53
+ self.models = []
54
+ for c_idx, c in enumerate(categories):
55
+ subpath = os.path.join(self.point_folder, c)
56
+ print(subpath)
57
+ assert os.path.isdir(subpath)
58
+
59
+ split_file = os.path.join(subpath, split + '.lst')
60
+ with open(split_file, 'r') as f:
61
+ models_c = f.readlines()
62
+ models_c = [item.rstrip('\n') for item in models_c]
63
+
64
+ for m in models_c[:]:
65
+ if len(m)<=1:
66
+ continue
67
+ if m.endswith('.npz'):
68
+ model_id = m[:-4]
69
+ else:
70
+ model_id = m
71
+ self.models.append({
72
+ 'category': c, 'model': model_id
73
+ })
74
+ self.replica = replica
75
+
76
+ def __getitem__(self, idx):
77
+ if self.replica >= 1:
78
+ idx = idx % len(self.models)
79
+ else:
80
+ random_segment = random.randint(0, int(1 / self.replica) - 1)
81
+ idx = int(random_segment * self.replica * len(self.models) + idx)
82
+
83
+ category = self.models[idx]['category']
84
+ model = self.models[idx]['model']
85
+
86
+ point_path = os.path.join(self.point_folder, category, model + '.npz')
87
+ # print(point_path)
88
+ try:
89
+ start_t = time.time()
90
+ with np.load(point_path) as data:
91
+ vol_points = data['vol_points']
92
+ vol_label = data['vol_label']
93
+ near_points = data['near_points']
94
+ near_label = data['near_label']
95
+ end_t = time.time()
96
+ # print("loading time %f"%(end_t-start_t))
97
+ except Exception as e:
98
+ print(e)
99
+ print(point_path)
100
+
101
+ with open(point_path.replace('.npz', '.npy'), 'rb') as f:
102
+ scale = np.load(f).item()
103
+ # scale=1.0
104
+
105
+ if self.return_surface:
106
+ pc_path = os.path.join(self.mesh_folder, category, '4_pointcloud', model + '.npz')
107
+ with np.load(pc_path) as data:
108
+ try:
109
+ surface = data['points'].astype(np.float32)
110
+ except:
111
+ print(pc_path,"has problems")
112
+ raise AttributeError
113
+ surface = surface * scale
114
+ if self.surface_sampling:
115
+ ind = np.random.default_rng().choice(surface.shape[0], self.pc_size, replace=False)
116
+ surface = surface[ind]
117
+ surface = torch.from_numpy(surface)
118
+
119
+ if self.sampling:
120
+ '''need to conduct label balancing'''
121
+ vol_ind=np.random.default_rng().choice(vol_points.shape[0], self.num_samples,
122
+ replace=(vol_points.shape[0]<self.num_samples))
123
+ near_ind=np.random.default_rng().choice(near_points.shape[0], self.num_samples,
124
+ replace=(near_points.shape[0]<self.num_samples))
125
+ vol_points=vol_points[vol_ind]
126
+ vol_label=vol_label[vol_ind]
127
+ near_points=near_points[near_ind]
128
+ near_label=near_label[near_ind]
129
+
130
+ vol_points = torch.from_numpy(vol_points)
131
+ vol_label = torch.from_numpy(vol_label).float()
132
+
133
+ if self.split == 'train':
134
+ near_points = torch.from_numpy(near_points)
135
+ near_label = torch.from_numpy(near_label).float()
136
+
137
+ points = torch.cat([vol_points, near_points], dim=0)
138
+ labels = torch.cat([vol_label, near_label], dim=0)
139
+ else:
140
+ points = vol_points
141
+ labels = vol_label
142
+
143
+ tran_mat=np.eye(4)
144
+ if self.transform:
145
+ surface, points, _,_, tran_mat = self.transform(surface, points)
146
+
147
+ data_dict = {
148
+ "points": points,
149
+ "labels": labels,
150
+ "category_ids": category_ids[category],
151
+ "model_id": model,
152
+ "tran_mat":tran_mat,
153
+ "category":category,
154
+ }
155
+ if self.return_surface:
156
+ data_dict["surface"] = surface
157
+
158
+ return data_dict
159
+
160
+ def __len__(self):
161
+ if self.split != 'train':
162
+ return len(self.models)
163
+ else:
164
+ return int(len(self.models) * self.replica)
165
+
166
+ class Object_PartialPoints_MultiImg(data.Dataset):
167
+ def __init__(self, dataset_folder, split, split_filename, categories=['03001627', 'future_chair', 'ABO_chair'],
168
+ transform=None, sampling=True, num_samples=4096,
169
+ return_surface=True, ret_sample=True,surface_sampling=True,
170
+ surface_size=20000,par_pc_size=2048, par_point_aug=None,par_prefix="aug7_",
171
+ load_proj_mat=False,load_image=False,load_org_img=False,max_img_length=5,load_triplane=True,replica=2,
172
+ eval_multiview=False,scene_id=None,num_objects=-1):
173
+
174
+ self.surface_size = surface_size
175
+ self.par_pc_size=par_pc_size
176
+ self.transform = transform
177
+ self.num_samples = num_samples
178
+ self.sampling = sampling
179
+ self.split = split
180
+ self.par_point_aug=par_point_aug
181
+ self.par_prefix=par_prefix
182
+
183
+ self.dataset_folder = dataset_folder
184
+ self.return_surface = return_surface
185
+ self.ret_sample=ret_sample
186
+ self.surface_sampling = surface_sampling
187
+ self.load_proj_mat=load_proj_mat
188
+ self.load_img=load_image
189
+ self.load_org_img=load_org_img
190
+ self.load_triplane=load_triplane
191
+ self.max_img_length=max_img_length
192
+ self.eval_multiview=eval_multiview
193
+
194
+ self.dataset_folder = dataset_folder
195
+ self.point_folder = os.path.join(self.dataset_folder, 'occ_data')
196
+ self.mesh_folder = os.path.join(self.dataset_folder, 'other_data')
197
+
198
+ if scene_id is not None:
199
+ scene_model_map_path=os.path.join(self.dataset_folder,"modelid_in_sceneid.json")
200
+ with open(scene_model_map_path,'r') as f:
201
+ scene_model_map=json.load(f)
202
+ valid_modelid=scene_model_map[scene_id]
203
+
204
+ if categories is None:
205
+ categories = os.listdir(self.point_folder)
206
+ categories = [c for c in categories if
207
+ os.path.isdir(os.path.join(self.point_folder, c)) and c.startswith('0')]
208
+ categories.sort()
209
+
210
+ print(categories)
211
+ self.models = []
212
+ self.model_images_names = {}
213
+ for c_idx, c in enumerate(categories):
214
+ cat_count=0
215
+ subpath = os.path.join(self.point_folder, c)
216
+ print(subpath)
217
+ assert os.path.isdir(subpath)
218
+
219
+ split_file = os.path.join(subpath, split_filename)
220
+ with open(split_file, 'r') as f:
221
+ splits = json.load(f)
222
+ for item in splits:
223
+ # print(item)
224
+ model_id = item['model_id']
225
+ if scene_id is not None and model_id not in valid_modelid:
226
+ continue
227
+ image_filenames = item['image_filenames']
228
+ partial_filenames = item['partial_filenames']
229
+ if len(image_filenames)==0 or len(partial_filenames)==0:
230
+ continue
231
+ self.model_images_names[model_id] = image_filenames
232
+ if split=="train":
233
+ self.models += [
234
+ {'category': c, 'model': model_id, "partial_filenames": partial_filenames,
235
+ "image_filenames": image_filenames}
236
+ ]
237
+ else:
238
+ if self.eval_multiview:
239
+ for length in range(0,len(image_filenames)):
240
+ self.models+=[
241
+ {'category': c, 'model': model_id, "partial_filenames": partial_filenames[0:1],
242
+ "image_filenames": image_filenames[0:length+1]}
243
+ ]
244
+ self.models += [
245
+ {'category': c, 'model': model_id, "partial_filenames": partial_filenames[0:1],
246
+ "image_filenames": image_filenames}
247
+ ]
248
+ if num_objects!=-1:
249
+ indexes=np.linspace(0,len(self.models)-1,num=num_objects).astype(np.int32)
250
+ self.models = [self.models[i] for i in indexes]
251
+
252
+ self.replica = replica
253
+
254
+ def load_samples(self,point_path):
255
+ try:
256
+ start_t = time.time()
257
+ with np.load(point_path) as data:
258
+ vol_points = data['vol_points']
259
+ vol_label = data['vol_label']
260
+ near_points = data['near_points']
261
+ near_label = data['near_label']
262
+ end_t = time.time()
263
+ # print("reading time %f"%(end_t-start_t))
264
+ except Exception as e:
265
+ print(e)
266
+ print(point_path)
267
+ return vol_points,vol_label,near_points,near_label
268
+
269
+ def load_surface(self,surface_path,scale):
270
+ with np.load(surface_path) as data:
271
+ surface = data['points'].astype(np.float32)
272
+ surface = surface * scale
273
+ if self.surface_sampling:
274
+ ind = np.random.default_rng().choice(surface.shape[0], self.surface_size, replace=False)
275
+ surface = surface[ind]
276
+ surface = torch.from_numpy(surface).float()
277
+ return surface
278
+
279
+ def load_par_points(self,partial_path,scale):
280
+ # print(partial_path)
281
+ par_point_o3d = o3d.io.read_point_cloud(partial_path)
282
+ par_points = np.asarray(par_point_o3d.points)
283
+ par_points = par_points * scale
284
+ replace = par_points.shape[0] < self.par_pc_size
285
+ ind = np.random.default_rng().choice(par_points.shape[0], self.par_pc_size, replace=replace)
286
+ par_points = par_points[ind]
287
+ par_points = torch.from_numpy(par_points).float()
288
+ return par_points
289
+
290
+ def process_samples(self,vol_points,vol_label,near_points,near_label):
291
+ if self.sampling:
292
+ ind = np.random.default_rng().choice(vol_points.shape[0], self.num_samples, replace=False)
293
+ vol_points = vol_points[ind]
294
+ vol_label = vol_label[ind]
295
+
296
+ ind = np.random.default_rng().choice(near_points.shape[0], self.num_samples, replace=False)
297
+ near_points = near_points[ind]
298
+ near_label = near_label[ind]
299
+ vol_points = torch.from_numpy(vol_points)
300
+ vol_label = torch.from_numpy(vol_label).float()
301
+ if self.split == 'train':
302
+ near_points = torch.from_numpy(near_points)
303
+ near_label = torch.from_numpy(near_label).float()
304
+
305
+ points = torch.cat([vol_points, near_points], dim=0)
306
+ labels = torch.cat([vol_label, near_label], dim=0)
307
+ else:
308
+ ind = np.random.default_rng().choice(vol_points.shape[0], 100000, replace=False)
309
+ points = vol_points[ind]
310
+ labels = vol_label[ind]
311
+ return points,labels
312
+
313
+ def __getitem__(self, idx):
314
+ if self.replica >= 1:
315
+ idx = idx % len(self.models)
316
+ else:
317
+ random_segment = random.randint(0, int(1 / self.replica) - 1)
318
+ idx = int(random_segment * self.replica * len(self.models) + idx)
319
+ category = self.models[idx]['category']
320
+ model = self.models[idx]['model']
321
+ #image_filenames = self.model_images_names[model]
322
+ image_filenames = self.models[idx]["image_filenames"]
323
+ if self.split=="train":
324
+ n_frames = np.random.randint(min(2,len(image_filenames)), min(len(image_filenames) + 1, self.max_img_length + 1))
325
+ img_indexes = np.random.choice(len(image_filenames), n_frames,
326
+ replace=(n_frames > len(image_filenames))).tolist()
327
+ else:
328
+ if self.eval_multiview:
329
+ '''use all images'''
330
+ n_frames=len(image_filenames)
331
+ img_indexes=[i for i in range(n_frames)]
332
+ else:
333
+ n_frames = min(len(image_filenames),self.max_img_length)
334
+ img_indexes=np.linspace(start=0,stop=len(image_filenames)-1,num=n_frames).astype(np.int32)
335
+
336
+ partial_filenames = self.models[idx]['partial_filenames']
337
+ par_index = np.random.choice(len(partial_filenames), 1)[0]
338
+ partial_name = partial_filenames[par_index]
339
+
340
+ vol_points,vol_label,near_points,near_label=None,None,None,None
341
+ points,labels=None,None
342
+ point_path = os.path.join(self.point_folder, category, model + '.npz')
343
+ if self.ret_sample:
344
+ vol_points,vol_label,near_points,near_label=self.load_samples(point_path)
345
+ points,labels = self.process_samples(vol_points, vol_label, near_points,near_label)
346
+
347
+ with open(point_path.replace('.npz', '.npy'), 'rb') as f:
348
+ scale = np.load(f).item()
349
+
350
+ surface=None
351
+ pc_path = os.path.join(self.mesh_folder, category, '4_pointcloud', model + '.npz')
352
+ if self.return_surface:
353
+ surface=self.load_surface(pc_path,scale)
354
+
355
+ partial_path = os.path.join(self.mesh_folder, category, "5_partial_points", model, partial_name)
356
+ if self.par_point_aug is not None and random.random()<self.par_point_aug: #add augmentation
357
+ par_aug_path=os.path.join(self.mesh_folder, category, "5_partial_points", model, self.par_prefix+partial_name)
358
+ #print(par_aug_path,os.path.exists(par_aug_path))
359
+ if os.path.exists(par_aug_path):
360
+ partial_path=par_aug_path
361
+ else:
362
+ raise FileNotFoundError
363
+ par_points=self.load_par_points(partial_path,scale)
364
+
365
+ image_list=[]
366
+ valid_frames=[]
367
+ image_namelist=[]
368
+ if self.load_img:
369
+ for img_index in img_indexes:
370
+ image_name=image_filenames[img_index]
371
+ image_feat_path=os.path.join(self.mesh_folder,category,"7_img_features",model,image_name[:-4]+'.npz')
372
+ image=np.load(image_feat_path)["img_features"]
373
+ image_list.append(torch.from_numpy(image).float())
374
+ valid_frames.append(True)
375
+ image_namelist.append(image_name)
376
+ while len(image_list)<self.max_img_length:
377
+ image_list.append(torch.from_numpy(np.zeros(image_list[0].shape).astype(np.float32)).float())
378
+ valid_frames.append(False)
379
+ org_img_list=[]
380
+ if self.load_org_img:
381
+ for img_index in img_indexes:
382
+ image_name = image_filenames[img_index]
383
+ image_path = os.path.join(self.mesh_folder, category, "6_images", model,
384
+ image_name)
385
+ org_image = cv2.imread(image_path)
386
+ org_image = cv2.resize(org_image,dsize=(224,224),interpolation=cv2.INTER_LINEAR)
387
+ org_img_list.append(org_image)
388
+
389
+ proj_mat=None
390
+ proj_mat_list=[]
391
+ if self.load_proj_mat:
392
+ for img_index in img_indexes:
393
+ image_name = image_filenames[img_index]
394
+ proj_mat_path = os.path.join(self.mesh_folder, category, "8_proj_matrix", model, image_name[:-4]+".npy")
395
+ proj_mat=np.load(proj_mat_path)
396
+ proj_mat_list.append(proj_mat)
397
+ while len(proj_mat_list)<self.max_img_length:
398
+ proj_mat_list.append(np.eye(4))
399
+ tran_mat=None
400
+ if self.load_triplane:
401
+ triplane_folder=os.path.join(self.mesh_folder,category,'9_triplane_kl25_64',model)
402
+ triplane_list=os.listdir(triplane_folder)
403
+ select_index=np.random.randint(0,len(triplane_list))
404
+ triplane_path=os.path.join(triplane_folder,triplane_list[select_index])
405
+ #triplane_path=os.path.join(triplane_folder,"triplane_feat_0.npz")
406
+ triplane_content=np.load(triplane_path)
407
+ triplane_mean,triplane_logvar,tran_mat=triplane_content['mean'],triplane_content['logvar'],triplane_content['tran_mat']
408
+ tran_mat=torch.from_numpy(tran_mat).float()
409
+
410
+ if self.transform:
411
+ if not self.load_triplane:
412
+ surface, points, par_points,proj_mat,tran_mat = self.transform(surface, points, par_points,proj_mat_list)
413
+ tran_mat=torch.from_numpy(tran_mat).float()
414
+ else:
415
+ surface, points, par_points, proj_mat = self.transform(surface, points, par_points, proj_mat_list,tran_mat)
416
+
417
+ category_id=category_ids[category]
418
+ one_hot=torch.zeros((6)).float()
419
+ one_hot[category_id]=1.0
420
+ ret_dict = {
421
+ "category_ids": category_ids[category],
422
+ "category":category,
423
+ "category_code":one_hot,
424
+ "model_id": model,
425
+ "partial_name": partial_name[:-4],
426
+ "class_name": category,
427
+ }
428
+ if tran_mat is not None:
429
+ ret_dict["tran_mat"]=tran_mat
430
+ if self.ret_sample:
431
+ ret_dict["points"]=points
432
+ ret_dict["labels"]=labels
433
+ if self.return_surface:
434
+ ret_dict["surface"] = surface
435
+ ret_dict["par_points"] = par_points
436
+ if self.load_img:
437
+ ret_dict["image"] = torch.stack(image_list,dim=0)
438
+ ret_dict["valid_frames"]= torch.tensor(valid_frames).bool()
439
+ if self.load_org_img:
440
+ ret_dict["org_image"]=org_img_list
441
+ ret_dict["image_namelist"]=image_namelist
442
+ if self.load_proj_mat:
443
+ ret_dict["proj_mat"]=torch.stack([torch.from_numpy(mat) for mat in proj_mat_list],dim=0)
444
+ if self.load_triplane:
445
+ ret_dict['triplane_mean']=torch.from_numpy(triplane_mean).float()
446
+ ret_dict['triplane_logvar'] = torch.from_numpy(triplane_logvar).float()
447
+ return ret_dict
448
+
449
+ def __len__(self):
450
+ if self.split != 'train':
451
+ return len(self.models)
452
+ else:
453
+ return int(len(self.models) * self.replica)
datasets/__init__.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data
2
+
3
+ from .SingleView_dataset import Object_Occ,Object_PartialPoints_MultiImg
4
+ from .transforms import Scale_Shift_Rotate,Aug_with_Tran, Augment_Points
5
+ from .taxonomy import synthetic_category_combined,synthetic_arkit_category_combined,arkit_category
6
+
7
+ def build_object_occ_dataset(split,args):
8
+ transform = Scale_Shift_Rotate(rot_shift_surface=True,use_scale=True,use_whole_scale=True)
9
+ category=args['category']
10
+ #category_list=synthetic_category_combined[category]
11
+ category_list=synthetic_arkit_category_combined[category]
12
+ replica=args['replica']
13
+ if split == "train":
14
+ return Object_Occ(args['data_path'], split=split, categories=category_list,
15
+ transform=transform, sampling=True,
16
+ num_samples=args['num_samples'], return_surface=True,
17
+ surface_sampling=True, surface_size=args['surface_size'],replica=replica)
18
+ elif split == "val":
19
+ return Object_Occ(args['data_path'], split=split,categories=category_list,
20
+ transform=transform, sampling=False,
21
+ num_samples=args['num_samples'], return_surface=True,
22
+ surface_sampling=True,surface_size=args['surface_size'], replica=1)
23
+
24
+ def build_par_multiimg_dataset(split,args):
25
+ #transform=Scale_Shift_Rotate(rot_shift_surface=False,use_scale=False,use_shift=False,use_rot=False) #fix the encoder into cannonical space
26
+ #transform=Scale_Shift_Rotate(rot_shift_surface=True)
27
+ transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_train'])
28
+ val_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_val'])
29
+ category=args['category']
30
+ category_list=synthetic_category_combined[category]
31
+ if split == "train":
32
+ return Object_PartialPoints_MultiImg(args['data_path'], split_filename="train_par_img.json",split=split,
33
+ categories=category_list,
34
+ transform=transform, sampling=True,
35
+ num_samples=1024, return_surface=False,ret_sample=False,
36
+ surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
37
+ load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,
38
+ par_prefix=args['par_prefix'],par_point_aug=args['par_point_aug'],replica=args['replica'],
39
+ num_objects=args['num_objects'])
40
+ elif split =="val":
41
+ return Object_PartialPoints_MultiImg(args['data_path'], split_filename="val_par_img.json",split=split,
42
+ categories=category_list,
43
+ transform=val_transform, sampling=False,
44
+ num_samples=1024, return_surface=False,ret_sample=True,
45
+ surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
46
+ load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,
47
+ par_prefix=args['par_prefix'],par_point_aug=None,replica=1)
48
+
49
+ def build_finetune_par_multiimg_dataset(split,args):
50
+ #transform=Scale_Shift_Rotate(rot_shift_surface=False,use_scale=False,use_shift=False,use_rot=False) #fix the encoder into cannonical space
51
+ #transform=Scale_Shift_Rotate(rot_shift_surface=True)
52
+ keyword=args['keyword']
53
+ pretrain_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_pretrain']) #add more noise to partial points
54
+ finetune_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_finetune'])
55
+ val_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_val'])
56
+
57
+ pretrain_cat=synthetic_category_combined[args['category']]
58
+ arkit_cat=arkit_category[args['category']]
59
+ use_pretrain_data=args["use_pretrain_data"]
60
+ #print(arkit_cat,pretrain_cat)
61
+ if split == "train":
62
+ if use_pretrain_data:
63
+ pretrain_dataset=Object_PartialPoints_MultiImg(args['data_path'], split_filename="train_par_img.json",categories=pretrain_cat,
64
+ split=split,transform=pretrain_transform, sampling=True,num_samples=1024, return_surface=False,ret_sample=False,
65
+ surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
66
+ load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,par_point_aug=args['par_point_aug'],
67
+ par_prefix=args['par_prefix'],replica=1)
68
+ finetune_dataset=Object_PartialPoints_MultiImg(args['data_path'], split_filename=keyword+"_train_par_img.json",categories=arkit_cat,
69
+ split=split,transform=finetune_transform, sampling=True,num_samples=1024, return_surface=False,ret_sample=False,
70
+ surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
71
+ load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,par_point_aug=None,replica=args['replica'])
72
+ if use_pretrain_data:
73
+ return torch.utils.data.ConcatDataset([pretrain_dataset,finetune_dataset])
74
+ else:
75
+ return finetune_dataset
76
+ elif split =="val":
77
+ return Object_PartialPoints_MultiImg(args['data_path'], split_filename=keyword+"_val_par_img.json",categories=arkit_cat,split=split,
78
+ transform=val_transform, sampling=False,
79
+ num_samples=1024, return_surface=False,ret_sample=True,
80
+ surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
81
+ load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,par_point_aug=None,replica=1)
82
+
83
+ def build_dataset(split,args):
84
+ if args['type']=="Occ":
85
+ return build_object_occ_dataset(split,args)
86
+ elif args['type']=="Occ_Par_MultiImg":
87
+ return build_par_multiimg_dataset(split,args)
88
+ elif args['type']=="Occ_Par_MultiImg_Finetune":
89
+ return build_finetune_par_multiimg_dataset(split,args)
90
+ else:
91
+ raise NotImplementedError
datasets/taxonomy.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ category_map={
2
+ "bathtub":0,
3
+ "bed":1,
4
+ "cabinet":2,
5
+ "chair":3,
6
+ "dishwasher":4,
7
+ "fireplace":5,
8
+ "oven":6,
9
+ "refrigerator":7,
10
+ "shelf":8,
11
+ "sink":9,
12
+ "sofa":10,
13
+ "stool":11,
14
+ "stove":12,
15
+ "table":13,
16
+ "toilet":14,
17
+ "washer":15
18
+ }
19
+
20
+ category_map_from_synthetic={
21
+ "03001627":0,
22
+ "future_chair":0,
23
+ "ABO_chair":0,
24
+ "arkit_chair":0,
25
+ "future_stool":0,
26
+ "arkit_stool":0,
27
+
28
+ "04256520":1,
29
+ "future_sofa":1,
30
+ "ABO_sofa":1,
31
+ "arkit_sofa":1,
32
+
33
+ "04379243":2,
34
+ "ABO_table":2,
35
+ "future_table":2,
36
+ "arkit_table":2,
37
+
38
+ "02933112":3,
39
+ "future_cabinet":3,
40
+ "ABO_cabinet":3,
41
+ "arkit_cabinet":3,
42
+ "arkit_oven":3,
43
+ "arkit_refrigerator":3,
44
+ "arkit_dishwasher":3,
45
+ "03207941":3,
46
+
47
+ "02818832":4,
48
+ "future_bed":4,
49
+ "ABO_bed":4,
50
+ "arkit_bed":4,
51
+
52
+ "02871439":5,
53
+ "future_shelf":5,
54
+ "ABO_shelf":5,
55
+ "arkit_shelf":5,
56
+
57
+ }
58
+
59
+ synthetic_category_combined={
60
+ "sofa":["future_sofa","ABO_sofa","04256520"],
61
+ "chair":["03001627","future_chair","ABO_chair",
62
+ "future_stool"],
63
+ "table":[
64
+ "04379243",
65
+ "future_table",
66
+ "ABO_table",
67
+ ],
68
+ "cabinet":["02933112","03207941","future_cabinet","ABO_cabinet"],
69
+ "bed":["02818832","future_bed","ABO_bed"],
70
+ "shelf":["02871439","future_shelf","ABO_shelf"],
71
+ "all":["future_sofa","ABO_sofa","04256520",
72
+ "03001627", "future_chair", "ABO_chair",
73
+ "future_stool","04379243","future_table",
74
+ "ABO_table","02933112","03207941","future_cabinet","ABO_cabinet",
75
+ "02818832","future_bed","ABO_bed",
76
+ "02871439","future_shelf","ABO_shelf"
77
+ ]
78
+ }
79
+
80
+ synthetic_arkit_category_combined={
81
+ "sofa":["future_sofa","ABO_sofa","04256520","arkit_sofa"],
82
+ "chair":["03001627","future_chair","ABO_chair",
83
+ "future_stool","arkit_chair","arkit_stool"],
84
+ "table":["04379243","ABO_table","future_table","arkit_table"],
85
+ "cabinet":["02933112","03207941","future_cabinet","ABO_cabinet","arkit_cabinet","arkit_stove","arkit_washer","arkit_dishwasher","arkit_refrigerator","arkit_oven"],
86
+ "bed":["02818832","future_bed","ABO_bed","arkit_bed"],
87
+ "shelf":["02871439","future_shelf","ABO_shelf","arkit_shelf"],
88
+ "all":[
89
+ "future_sofa","ABO_sofa","04256520","arkit_sofa",
90
+ "03001627","future_chair","ABO_chair",
91
+ "future_stool","arkit_chair","arkit_stool",
92
+ "04379243","ABO_table","future_table","arkit_table",
93
+ "02933112","03207941","future_cabinet","ABO_cabinet","arkit_cabinet","arkit_dishwasher","arkit_refrigerator","arkit_oven",
94
+ "02818832","future_bed","ABO_bed","arkit_bed",
95
+ "02871439","future_shelf","ABO_shelf","arkit_shelf"
96
+ ]
97
+ }
98
+
99
+ arkit_category={
100
+ "chair":["arkit_chair","arkit_stool"],
101
+ "sofa":["arkit_sofa"],
102
+ "table":["arkit_table"],
103
+ "cabinet":["arkit_cabinet","arkit_dishwasher","arkit_refrigerator","arkit_oven"],
104
+ "bed":["arkit_bed"],
105
+ "shelf":["arkit_shelf"],
106
+ "all":["arkit_chair","arkit_stool",
107
+ "arkit_sofa","arkit_table",
108
+ "arkit_cabinet","arkit_dishwasher","arkit_refrigerator","arkit_oven",
109
+ "arkit_bed",
110
+ "arkit_shelf"],
111
+ }
datasets/transforms.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def get_rot_from_yaw(angle):
5
+ cy=torch.cos(angle)
6
+ sy=torch.sin(angle)
7
+ R=torch.tensor([[cy,0,-sy],
8
+ [0,1,0],
9
+ [sy,0,cy]]).float()
10
+ return R
11
+
12
+ class Aug_with_Tran(object):
13
+ def __init__(self,jitter_surface=True,jitter_partial=True,par_jitter_sigma=0.02):
14
+ self.jitter_surface=jitter_surface
15
+ self.jitter_partial=jitter_partial
16
+ self.par_jitter_sigma=par_jitter_sigma
17
+
18
+ def __call__(self,surface,point,par_points,proj_mat,tran_mat):
19
+ if surface is not None:surface=torch.mm(surface,tran_mat[0:3,0:3].transpose(0,1))+tran_mat[0:3,3]
20
+ if point is not None:point=torch.mm(point,tran_mat[0:3,0:3].transpose(0,1))+tran_mat[0:3,3]
21
+ if par_points is not None:par_points=torch.mm(par_points,tran_mat[0:3,0:3].transpose(0,1))+tran_mat[0:3,3]
22
+ if proj_mat is not None:
23
+ '''need to put the augmentation back'''
24
+ inv_tran_mat = np.linalg.inv(tran_mat)
25
+ if isinstance(proj_mat, list):
26
+ for idx, mat in enumerate(proj_mat):
27
+ mat = np.dot(mat, inv_tran_mat)
28
+ proj_mat[idx] = mat
29
+ else:
30
+ proj_mat = np.dot(proj_mat, inv_tran_mat)
31
+
32
+ if self.jitter_surface and surface is not None:
33
+ surface += 0.005 * torch.randn_like(surface)
34
+ surface.clamp_(min=-1, max=1)
35
+ if self.jitter_partial and par_points is not None:
36
+ par_points+=self.par_jitter_sigma * torch.randn_like(par_points)
37
+
38
+
39
+ return surface,point,par_points,proj_mat
40
+
41
+
42
+ #add small augmentation
43
+ class Scale_Shift_Rotate(object):
44
+ def __init__(self, interval=(0.75, 1.25), angle=(-5,5), shift=(-0.1,0.1), use_scale=True,use_whole_scale=False,use_rot=True,
45
+ use_shift=True,jitter=True,jitter_partial=True,par_jitter_sigma=0.02,rot_shift_surface=True):
46
+ assert isinstance(interval, tuple)
47
+ self.interval = interval
48
+ self.angle=angle
49
+ self.shift=shift
50
+ self.jitter = jitter
51
+ self.jitter_partial=jitter_partial
52
+ self.rot_shift_surface=rot_shift_surface
53
+ self.use_scale=use_scale
54
+ self.use_rot=use_rot
55
+ self.use_shift=use_shift
56
+ self.par_jitter_sigma=par_jitter_sigma
57
+ self.use_whole_scale=use_whole_scale
58
+
59
+ def __call__(self, surface, point, par_points=None,proj_mat=None):
60
+ if self.use_scale:
61
+ scaling = torch.rand(1, 3) * 0.5 + 0.75
62
+ else:
63
+ scaling = torch.ones((1,3)).float()
64
+ if self.use_shift:
65
+ shifting = torch.rand(1,3) *(self.shift[1]-self.shift[0])+self.shift[0]
66
+ else:
67
+ shifting=np.zeros((1,3))
68
+ if self.use_rot:
69
+ angle=torch.rand(1)*(self.angle[1]-self.angle[0])+self.angle[0]
70
+ else:
71
+ angle=torch.tensor((0))
72
+ #print(angle)
73
+ angle=angle/180*np.pi
74
+ rot_mat=get_rot_from_yaw(angle)
75
+
76
+ surface = surface * scaling
77
+ point = point * scaling
78
+
79
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
80
+ if self.use_whole_scale:
81
+ scale = scale*(np.random.random()*0.3+0.7)
82
+ surface *= scale
83
+ point *= scale
84
+
85
+ #scale = 1
86
+
87
+ if self.rot_shift_surface:
88
+ surface=torch.mm(surface,rot_mat.transpose(0,1))
89
+ surface = surface + shifting
90
+ point=torch.mm(point,rot_mat.transpose(0,1))
91
+ point=point+shifting
92
+
93
+ if par_points is not None:
94
+ par_points = par_points * scaling
95
+ par_points=torch.mm(par_points,rot_mat.transpose(0,1))
96
+ par_points+=shifting
97
+ par_points *= scale
98
+
99
+ post_scale_tran=np.eye(4)
100
+ post_scale_tran[0,0],post_scale_tran[1,1],post_scale_tran[2,2]=scale,scale,scale
101
+ shift_tran = np.eye(4)
102
+ shift_tran[0:3, 3] = shifting
103
+ rot_tran = np.eye(4)
104
+ rot_tran[0:3, 0:3] = rot_mat
105
+ scale_tran = np.eye(4)
106
+ scale_tran[0, 0], scale_tran[1, 1], scale_tran[2, 2] = scaling[0, 0], scaling[
107
+ 0, 1], scaling[0, 2]
108
+
109
+ #print(post_scale_tran,np.dot(np.dot(shift_tran,np.dot(rot_tran,scale_tran))))
110
+ tran_mat=np.dot(post_scale_tran,np.dot(shift_tran,np.dot(rot_tran,scale_tran)))
111
+ #tran_mat=np.dot(post_scale_tran,tran_mat)
112
+ #print(np.linalg.norm(surface - (np.dot(org_surface,tran_mat[0:3,0:3].T)+tran_mat[0:3,3])))
113
+ if proj_mat is not None:
114
+ '''need to put the augmentation back'''
115
+ inv_tran_mat=np.linalg.inv(tran_mat)
116
+ if isinstance(proj_mat,list):
117
+ for idx,mat in enumerate(proj_mat):
118
+ mat=np.dot(mat,inv_tran_mat)
119
+ proj_mat[idx]=mat
120
+ else:
121
+ proj_mat=np.dot(proj_mat,inv_tran_mat)
122
+
123
+
124
+ if self.jitter:
125
+ surface += 0.005 * torch.randn_like(surface)
126
+ surface.clamp_(min=-1, max=1)
127
+ if self.jitter_partial and par_points is not None:
128
+ par_points+=self.par_jitter_sigma * torch.randn_like(par_points)
129
+
130
+ return surface, point, par_points, proj_mat, tran_mat
131
+
132
+
133
+ class Augment_Points(object):
134
+ def __init__(self, interval=(0.75, 1.25), angle=(-5,5), shift=(-0.1,0.1), use_scale=True,use_rot=True,
135
+ use_shift=True,jitter=True,jitter_sigma=0.02):
136
+ assert isinstance(interval, tuple)
137
+ self.interval = interval
138
+ self.angle=angle
139
+ self.shift=shift
140
+ self.jitter = jitter
141
+ self.use_scale=use_scale
142
+ self.use_rot=use_rot
143
+ self.use_shift=use_shift
144
+ self.jitter_sigma=jitter_sigma
145
+
146
+ def __call__(self, points1,points2):
147
+ if self.use_scale:
148
+ scaling = torch.rand(1, 3) * 0.5 + 0.75
149
+ else:
150
+ scaling = torch.ones((1,3)).float()
151
+ if self.use_shift:
152
+ shifting = torch.rand(1,3) *(self.shift[1]-self.shift[0])+self.shift[0]
153
+ else:
154
+ shifting=np.zeros((1,3))
155
+ if self.use_rot:
156
+ angle=torch.rand(1)*(self.angle[1]-self.angle[0])+self.angle[0]
157
+ else:
158
+ angle=torch.tensor((0))
159
+ #print(angle)
160
+ angle=angle/180*np.pi
161
+ rot_mat=get_rot_from_yaw(angle)
162
+
163
+ points1 = points1 * scaling
164
+ points2 = points2 * scaling
165
+
166
+ #scale = 1
167
+ scale = min((1 / torch.abs(points1).max().item()) * 0.999999,(1 / torch.abs(points2).max().item()) * 0.999999)
168
+ points1 *= scale
169
+ points2 *= scale
170
+
171
+ points1=torch.mm(points1,rot_mat.transpose(0,1))
172
+ points1 = points1 + shifting
173
+ points2=torch.mm(points2,rot_mat.transpose(0,1))
174
+ points2=points2+shifting
175
+
176
+ if self.jitter:
177
+ points1 += self.jitter_sigma * torch.randn_like(points1)
178
+ points2 += self.jitter_sigma * torch.randn_like(points2)
179
+
180
+ return points1,points2
engine/engine_triplane_dm.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # References:
3
+ # MAE: https://github.com/facebookresearch/mae
4
+ # DeiT: https://github.com/facebookresearch/deit
5
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import sys
10
+ from typing import Iterable
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ import util.misc as misc
16
+ import util.lr_sched as lr_sched
17
+ import numpy as np
18
+ import os
19
+ import pickle as p
20
+ import torch.distributed as dist
21
+ import time
22
+ from models.modules.encoder import DiagonalGaussianDistribution
23
+
24
+
25
+ def train_one_epoch(model: torch.nn.Module, ae: torch.nn.Module, criterion: torch.nn.Module,
26
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
27
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
28
+ log_writer=None,log_dir=None, args=None):
29
+ model.train(True)
30
+ metric_logger = misc.MetricLogger(delimiter=" ")
31
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
32
+ header = 'Epoch: [{}]'.format(epoch)
33
+ print_freq = 20
34
+
35
+ accum_iter = args.accum_iter
36
+ use_cls_free= args.use_cls_free
37
+
38
+ optimizer.zero_grad()
39
+
40
+ if log_writer is not None:
41
+ print('log_dir: {}'.format(log_writer.log_dir))
42
+
43
+ for data_iter_step, data_batch in enumerate(
44
+ metric_logger.log_every(data_loader, print_freq, header)):
45
+
46
+ # we use a per iteration (instead of per epoch) lr scheduler
47
+ if not args.constant_lr:
48
+ if data_iter_step % accum_iter == 0:
49
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
50
+
51
+ input_dict=model.module.prepare_data(data_batch)
52
+ with torch.cuda.amp.autocast(enabled=False):
53
+ loss_all = criterion(model,input_dict,classifier_free=use_cls_free)
54
+ loss=loss_all.mean()
55
+
56
+ loss_value = loss.item()
57
+ if not math.isfinite(loss_value):
58
+ print("Loss is {}, stopping training".format(loss_value))
59
+ sys.exit(1)
60
+
61
+ loss /= accum_iter
62
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
63
+ parameters=model.parameters(), create_graph=False,
64
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
65
+ if (data_iter_step + 1) % accum_iter == 0:
66
+ optimizer.zero_grad()
67
+
68
+ torch.cuda.synchronize()
69
+
70
+ metric_logger.update(loss=loss_value)
71
+
72
+ min_lr = 10.
73
+ max_lr = 0.
74
+ for group in optimizer.param_groups:
75
+ min_lr = min(min_lr, group["lr"])
76
+ max_lr = max(max_lr, group["lr"])
77
+
78
+ metric_logger.update(lr=max_lr)
79
+
80
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
81
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
82
+ """ We use epoch_1000x as the x-axis in tensorboard.
83
+ This calibrates different curves when batch size changes.
84
+ """
85
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
86
+ log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
87
+ log_writer.add_scalar('lr', max_lr, epoch_1000x)
88
+
89
+ # gather the stats from all processes
90
+ metric_logger.synchronize_between_processes()
91
+ print("Averaged stats:", metric_logger)
92
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
93
+
94
+ @torch.no_grad()
95
+ def evaluate_reconstruction(data_loader, model, ae, criterion, device):
96
+ metric_logger = misc.MetricLogger(delimiter=" ")
97
+ header = 'Test:'
98
+
99
+ # switch to evaluation mode
100
+ model.eval()
101
+ for data_batch in metric_logger.log_every(data_loader, 50, header):
102
+ with torch.no_grad():
103
+ input_dict=model.module.prepare_data(data_batch)
104
+ loss_all = criterion(model, input_dict,classifier_free=False)
105
+ loss = loss_all.mean()
106
+ sample_input=model.module.prepare_sample_data(data_batch)
107
+ sampled_array = model.module.sample(sample_input).float()
108
+ sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear")
109
+ eval_input=model.module.prepare_eval_data(data_batch)
110
+ samples=eval_input["samples"]
111
+ labels=eval_input["labels"]
112
+ for j in range(sampled_array.shape[0]):
113
+ output = ae.decode(sampled_array[j:j + 1], samples[j:j+1]).squeeze(-1)
114
+ pred = torch.zeros_like(output)
115
+ pred[output >= 0.0] = 1
116
+ label=labels[j:j+1]
117
+
118
+ accuracy = (pred == label).float().sum(dim=1) / label.shape[1]
119
+ accuracy = accuracy.mean()
120
+ intersection = (pred * label).sum(dim=1)
121
+ union = (pred + label).gt(0).sum(dim=1)
122
+ iou = intersection * 1.0 / union + 1e-5
123
+ iou = iou.mean()
124
+
125
+ metric_logger.update(iou=iou.item())
126
+ metric_logger.update(accuracy=accuracy.item())
127
+ metric_logger.update(loss=loss.item())
128
+ metric_logger.synchronize_between_processes()
129
+ print('* iou {ious.global_avg:.3f}'
130
+ .format(ious=metric_logger.iou))
131
+ print('* accuracy {accuracies.global_avg:.3f}'
132
+ .format(accuracies=metric_logger.accuracy))
133
+ print('* loss {losses.global_avg:.3f}'
134
+ .format(losses=metric_logger.loss))
135
+
136
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
engine/engine_triplane_vae.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # References:
3
+ # MAE: https://github.com/facebookresearch/mae
4
+ # DeiT: https://github.com/facebookresearch/deit
5
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import sys
10
+ sys.path.append("..")
11
+ from typing import Iterable
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ import util.misc as misc
17
+ import util.lr_sched as lr_sched
18
+
19
+
20
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
21
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
22
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
23
+ log_writer=None, args=None):
24
+ model.train(True)
25
+ metric_logger = misc.MetricLogger(delimiter=" ")
26
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
27
+ header = 'Epoch: [{}]'.format(epoch)
28
+ print_freq = 20
29
+
30
+ accum_iter = args.accum_iter
31
+
32
+ optimizer.zero_grad()
33
+
34
+ kl_weight = 25e-3 #TODO: try to modify this, it is 1e-3 originally, large kl ease the training of diffusion, but decrease in VAE results
35
+
36
+ if log_writer is not None:
37
+ print('log_dir: {}'.format(log_writer.log_dir))
38
+
39
+ for data_iter_step, data_batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
40
+
41
+ # we use a per iteration (instead of per epoch) lr scheduler
42
+ if data_iter_step % accum_iter == 0:
43
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
44
+
45
+ points = data_batch['points'].to(device, non_blocking=True)
46
+ labels = data_batch['labels'].to(device, non_blocking=True)
47
+ surface = data_batch['surface'].to(device, non_blocking=True)
48
+ # print(points.shape)
49
+ with torch.cuda.amp.autocast(enabled=False):
50
+ outputs = model(surface, points)
51
+ if 'kl' in outputs:
52
+ loss_kl = outputs['kl']
53
+ #print(loss_kl.shape)
54
+ loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
55
+ else:
56
+ loss_kl = None
57
+
58
+ outputs = outputs['logits']
59
+
60
+ num_samples=outputs.shape[1]//2
61
+ #print(num_samples)
62
+ loss_vol = criterion(outputs[:, :num_samples], labels[:, :num_samples])
63
+ loss_near = criterion(outputs[:, num_samples:], labels[:, num_samples:])
64
+
65
+ if loss_kl is not None:
66
+ loss = loss_vol + 0.1 * loss_near + kl_weight * loss_kl
67
+ else:
68
+ loss = loss_vol + 0.1 * loss_near
69
+
70
+ loss_value = loss.item()
71
+
72
+ threshold = 0
73
+
74
+ pred = torch.zeros_like(outputs[:, :num_samples])
75
+ pred[outputs[:, :num_samples] >= threshold] = 1
76
+
77
+ accuracy = (pred == labels[:, :num_samples]).float().sum(dim=1) / labels[:, :num_samples].shape[1]
78
+ accuracy = accuracy.mean()
79
+ intersection = (pred * labels[:, :num_samples]).sum(dim=1)
80
+ union = (pred + labels[:, :num_samples]).gt(0).sum(dim=1) + 1e-5
81
+ iou = intersection * 1.0 / union
82
+ iou = iou.mean()
83
+
84
+ if not math.isfinite(loss_value):
85
+ print("Loss is {}, stopping training".format(loss_value))
86
+ sys.exit(1)
87
+
88
+ loss /= accum_iter
89
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
90
+ parameters=model.parameters(), create_graph=False,
91
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
92
+ if (data_iter_step + 1) % accum_iter == 0:
93
+ optimizer.zero_grad()
94
+
95
+ torch.cuda.synchronize()
96
+
97
+ metric_logger.update(loss=loss_value)
98
+
99
+ metric_logger.update(loss_vol=loss_vol.item())
100
+ metric_logger.update(loss_near=loss_near.item())
101
+
102
+ if loss_kl is not None:
103
+ metric_logger.update(loss_kl=loss_kl.item())
104
+
105
+ metric_logger.update(iou=iou.item())
106
+
107
+ min_lr = 10.
108
+ max_lr = 0.
109
+ for group in optimizer.param_groups:
110
+ min_lr = min(min_lr, group["lr"])
111
+ max_lr = max(max_lr, group["lr"])
112
+
113
+ metric_logger.update(lr=max_lr)
114
+
115
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
116
+ iou_reduce=misc.all_reduce_mean(iou)
117
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
118
+ """ We use epoch_1000x as the x-axis in tensorboard.
119
+ This calibrates different curves when batch size changes.
120
+ """
121
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
122
+ log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
123
+ log_writer.add_scalar('iou', iou_reduce, epoch_1000x)
124
+ log_writer.add_scalar('lr', max_lr, epoch_1000x)
125
+
126
+ # gather the stats from all processes
127
+ metric_logger.synchronize_between_processes()
128
+ print("Averaged stats:", metric_logger)
129
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
130
+
131
+
132
+ @torch.no_grad()
133
+ def evaluate(data_loader, model, device):
134
+ criterion = torch.nn.BCEWithLogitsLoss()
135
+
136
+ metric_logger = misc.MetricLogger(delimiter=" ")
137
+ header = 'Test:'
138
+
139
+ # switch to evaluation mode
140
+ model.eval()
141
+
142
+ for data_batch in metric_logger.log_every(data_loader, 50, header):
143
+
144
+ points = data_batch['points'].to(device, non_blocking=True)
145
+ labels = data_batch['labels'].to(device, non_blocking=True)
146
+ surface = data_batch['surface'].to(device, non_blocking=True)
147
+ # compute output
148
+ with torch.cuda.amp.autocast(enabled=False):
149
+
150
+ outputs = model(surface, points)
151
+ if 'kl' in outputs:
152
+ loss_kl = outputs['kl']
153
+ loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
154
+ else:
155
+ loss_kl = None
156
+
157
+ outputs = outputs['logits']
158
+
159
+ loss = criterion(outputs, labels)
160
+
161
+ threshold = 0
162
+
163
+ pred = torch.zeros_like(outputs)
164
+ pred[outputs >= threshold] = 1
165
+
166
+ accuracy = (pred == labels).float().sum(dim=1) / labels.shape[1]
167
+ accuracy = accuracy.mean()
168
+ intersection = (pred * labels).sum(dim=1)
169
+ union = (pred + labels).gt(0).sum(dim=1)
170
+ iou = intersection * 1.0 / union + 1e-5
171
+ iou = iou.mean()
172
+
173
+ batch_size = points.shape[0]
174
+ metric_logger.update(loss=loss.item())
175
+ metric_logger.meters['iou'].update(iou.item(), n=batch_size)
176
+
177
+ if loss_kl is not None:
178
+ metric_logger.update(loss_kl=loss_kl.item())
179
+
180
+ # gather the stats from all processes
181
+ metric_logger.synchronize_between_processes()
182
+ print('* iou {iou.global_avg:.3f} loss {losses.global_avg:.3f}'
183
+ .format(iou=metric_logger.iou, losses=metric_logger.loss))
184
+
185
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
evaluation/dist_eval.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15002 --nproc_per_node=2 \
2
+ evaluate_object_reconstruction.py \
3
+ --configs ../configs/finetune_triplane_diffusion.yaml \
4
+ --category arkit_chair arkit_stool \
5
+ --ae-pth ../output/ae/chair/best-checkpoint.pth \
6
+ --dm-pth ../output/finetune_dm/lowres_chair/best-checkpoint.pth \
7
+ --output_folder ../output_result/chair_result \
8
+ --data-pth ../data \
9
+ --eval_cd \
10
+ --reso 256 \
11
+ --save_mesh \
12
+ --save_par_points \
13
+ --save_image \
14
+ --save_surface
15
+
16
+ #check ./datasets/taxonomy to see how sub categories are defined
evaluation/evaluate_object_reconstruction.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ sys.path.append("..")
4
+ sys.path.append(".")
5
+ import numpy as np
6
+
7
+ import mcubes
8
+ import os
9
+ import torch
10
+
11
+ import trimesh
12
+
13
+ from datasets.SingleView_dataset import Object_PartialPoints_MultiImg
14
+ from datasets.transforms import Scale_Shift_Rotate
15
+ from models import get_model
16
+ from pathlib import Path
17
+ import open3d as o3d
18
+ from configs.config_utils import CONFIG
19
+ import cv2
20
+ from util.misc import MetricLogger
21
+ import scipy
22
+ from pyTorchChamferDistance.chamfer_distance import ChamferDistance
23
+ from util.projection_utils import draw_proj_image
24
+ from util import misc
25
+ import time
26
+ dist_chamfer=ChamferDistance()
27
+
28
+
29
+ def pc_metrics(p1, p2, space_ext=2, fscore_param=0.01, scale=.5):
30
+ """ p2: reference ponits
31
+ (B, N, 3)
32
+ """
33
+ p1, p2, space_ext = p1 * scale, p2 * scale, space_ext * scale
34
+ f_thresh = space_ext * fscore_param
35
+
36
+ #print(p1.shape,p2.shape)
37
+ d1, d2, _, _ = dist_chamfer(p1, p2)
38
+ #print(d1.shape,d2.shape)
39
+ d1sqrt, d2sqrt = (d1 ** .5), (d2 ** .5)
40
+ chamfer_L1 = d1sqrt.mean(axis=-1) + d2sqrt.mean(axis=-1)
41
+ chamfer_L2 = d1.mean(axis=-1) + d2.mean(axis=-1)
42
+ precision = (d1sqrt < f_thresh).sum(axis=-1).float() / p1.shape[1]
43
+ recall = (d2sqrt < f_thresh).sum(axis=-1).float() / p2.shape[1]
44
+ #print(precision,recall)
45
+ fscore = 2 * torch.div(recall * precision, recall + precision)
46
+ fscore[fscore == float("inf")] = 0
47
+ return chamfer_L1,chamfer_L2,fscore
48
+
49
+ if __name__ == "__main__":
50
+
51
+ parser = argparse.ArgumentParser('this script can be used to compute iou fscore chamfer distance before icp align', add_help=False)
52
+ parser.add_argument('--configs',type=str,required=True)
53
+ parser.add_argument('--output_folder', type=str, default="../output_result/Triplane_diff_parcond_0926")
54
+ parser.add_argument('--dm-pth',type=str)
55
+ parser.add_argument('--ae-pth',type=str)
56
+ parser.add_argument('--data-pth', type=str,default="../")
57
+ parser.add_argument('--save_mesh',action="store_true",default=False)
58
+ parser.add_argument('--save_image',action="store_true",default=False)
59
+ parser.add_argument('--save_par_points', action="store_true", default=False)
60
+ parser.add_argument('--save_proj_img',action="store_true",default=False)
61
+ parser.add_argument('--save_surface',action="store_true",default=False)
62
+ parser.add_argument('--reso',default=128,type=int)
63
+ parser.add_argument('--category',nargs="+",type=str)
64
+ parser.add_argument('--eval_cd',action="store_true",default=False)
65
+ parser.add_argument('--use_augmentation',action="store_true",default=False)
66
+
67
+ parser.add_argument('--world_size', default=1, type=int,
68
+ help='number of distributed processes')
69
+ parser.add_argument('--local_rank', default=-1, type=int)
70
+ parser.add_argument('--dist_on_itp', action='store_true')
71
+ parser.add_argument('--dist_url', default='env://',
72
+ help='url used to set up distributed training')
73
+ parser.add_argument('--device', default='cuda',
74
+ help='device to use for training / testing')
75
+ args = parser.parse_args()
76
+ misc.init_distributed_mode(args)
77
+ config_path=args.configs
78
+ config=CONFIG(config_path)
79
+ dataset_config=config.config['dataset']
80
+ dataset_config['data_path']=args.data_pth
81
+ if "arkit" in args.category[0]:
82
+ split_filename=dataset_config['keyword']+'_val_par_img.json'
83
+ else:
84
+ split_filename='val_par_img.json'
85
+
86
+ transform = None
87
+ if args.use_augmentation:
88
+ transform=Scale_Shift_Rotate(jitter_partial=False,jitter=False,use_scale=False,angle=(-10,10),shift=(-0.1,0.1))
89
+ dataset_val = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename=split_filename,categories=args.category,split="val",
90
+ transform=transform, sampling=False,
91
+ num_samples=1024, return_surface=True,ret_sample=True,
92
+ surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],surface_size=100000,
93
+ load_proj_mat=True,load_image=True,load_org_img=True,load_triplane=None,par_point_aug=None,replica=1)
94
+ batch_size=1
95
+
96
+ num_tasks = misc.get_world_size()
97
+ global_rank = misc.get_rank()
98
+ val_sampler = torch.utils.data.DistributedSampler(
99
+ dataset_val, num_replicas=num_tasks, rank=global_rank,
100
+ shuffle=False) # shu
101
+ dataloader_val=torch.utils.data.DataLoader(
102
+ dataset_val,
103
+ sampler=val_sampler,
104
+ batch_size=batch_size,
105
+ num_workers=10,
106
+ shuffle=False,
107
+ )
108
+ output_folder=args.output_folder
109
+
110
+ device = torch.device('cuda')
111
+
112
+ ae_config=config.config['model']['ae']
113
+ dm_config=config.config['model']['dm']
114
+ ae_model=get_model(ae_config).to(device)
115
+ if args.category[0] == "all":
116
+ dm_config["use_cat_embedding"]=True
117
+ else:
118
+ dm_config["use_cat_embedding"] = False
119
+ dm_model=get_model(dm_config).to(device)
120
+ ae_model.eval()
121
+ dm_model.eval()
122
+ ae_model.load_state_dict(torch.load(args.ae_pth)['model'])
123
+ dm_model.load_state_dict(torch.load(args.dm_pth)['model'])
124
+
125
+ density = args.reso
126
+ gap = 2.2 / density
127
+ x = np.linspace(-1.1, 1.1, int(density + 1))
128
+ y = np.linspace(-1.1, 1.1, int(density + 1))
129
+ z = np.linspace(-1.1, 1.1, int(density + 1))
130
+ xv, yv, zv = np.meshgrid(x, y, z,indexing='ij')
131
+ grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,non_blocking=True)
132
+
133
+ metric_logger=MetricLogger(delimiter=" ")
134
+ header = 'Test:'
135
+
136
+ with torch.no_grad():
137
+ for data_batch in metric_logger.log_every(dataloader_val,10, header):
138
+ # if data_iter_step==100:
139
+ # break
140
+ partial_name = data_batch['partial_name']
141
+ class_name = data_batch['class_name']
142
+ model_ids=data_batch['model_id']
143
+ surface=data_batch['surface']
144
+ proj_matrices=data_batch['proj_mat']
145
+ sample_points=data_batch["points"].cuda().float()
146
+ labels=data_batch["labels"].cuda().float()
147
+ sample_input=dm_model.prepare_sample_data(data_batch)
148
+ #t1 = time.time()
149
+ sampled_array = dm_model.sample(sample_input,num_steps=36).float()
150
+ #t2 = time.time()
151
+ #sample_time = t2 - t1
152
+ #print("sampling time %f" % (sample_time))
153
+ sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear")
154
+ for j in range(sampled_array.shape[0]):
155
+ if args.save_mesh | args.save_par_points | args.save_image:
156
+ object_folder = os.path.join(output_folder, class_name[j], model_ids[j])
157
+ Path(object_folder).mkdir(parents=True, exist_ok=True)
158
+ '''calculate iou'''
159
+ sample_point=sample_points[j:j+1]
160
+ sample_output=ae_model.decode(sampled_array[j:j + 1],sample_point)
161
+ sample_pred=torch.zeros_like(sample_output)
162
+ sample_pred[sample_output>=0.0]=1
163
+ label=labels[j:j+1]
164
+ intersection = (sample_pred * label).sum(dim=1)
165
+ union = (sample_pred + label).gt(0).sum(dim=1)
166
+ iou = intersection * 1.0 / union + 1e-5
167
+ iou = iou.mean()
168
+ metric_logger.update(iou=iou.item())
169
+
170
+ if args.use_augmentation:
171
+ tran_mat=data_batch["tran_mat"][j].numpy()
172
+ mat_save_path='{}/tran_mat.npy'.format(object_folder)
173
+ np.save(mat_save_path,tran_mat)
174
+
175
+ if args.eval_cd:
176
+ grid_list=torch.split(grid,128**3,dim=1)
177
+ output_list=[]
178
+ #t3=time.time()
179
+ for sub_grid in grid_list:
180
+ output_list.append(ae_model.decode(sampled_array[j:j + 1],sub_grid))
181
+ output=torch.cat(output_list,dim=1)
182
+ #t4=time.time()
183
+ #decoding_time=t4-t3
184
+ #print("decoding time:",decoding_time)
185
+ logits = output[j].detach()
186
+
187
+ volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy()
188
+ verts, faces = mcubes.marching_cubes(volume, 0)
189
+
190
+ verts *= gap
191
+ verts -= 1.1
192
+ #print("vertice max min",np.amin(verts,axis=0),np.amax(verts,axis=0))
193
+
194
+
195
+ m = trimesh.Trimesh(verts, faces)
196
+ '''calculate fscore and chamfer distance'''
197
+ result_surface,_=trimesh.sample.sample_surface(m,100000)
198
+ gt_surface=surface[j]
199
+ assert gt_surface.shape[0]==result_surface.shape[0]
200
+
201
+ result_surface_gpu = torch.from_numpy(result_surface).float().cuda().unsqueeze(0)
202
+ gt_surface_gpu = gt_surface.float().cuda().unsqueeze(0)
203
+ _,chamfer_L2,fscore=pc_metrics(result_surface_gpu,gt_surface_gpu)
204
+ metric_logger.update(chamferl2=chamfer_L2*1000.0)
205
+ metric_logger.update(fscore=fscore)
206
+
207
+ if args.save_mesh:
208
+ m.export('{}/{}_mesh.ply'.format(object_folder, partial_name[j]))
209
+
210
+ if args.save_par_points:
211
+ par_point_input = data_batch['par_points'][j].numpy()
212
+ #print("input max min", np.amin(par_point_input, axis=0), np.amax(par_point_input, axis=0))
213
+ par_point_o3d = o3d.geometry.PointCloud()
214
+ par_point_o3d.points = o3d.utility.Vector3dVector(par_point_input[:, 0:3])
215
+ o3d.io.write_point_cloud('{}/{}.ply'.format(object_folder, partial_name[j]), par_point_o3d)
216
+ if args.save_image:
217
+ image_list=data_batch["org_image"]
218
+ for idx,image in enumerate(image_list):
219
+ image=image[0].numpy().astype(np.uint8)
220
+ if args.save_proj_img:
221
+ proj_mat=proj_matrices[j,idx].numpy()
222
+ proj_image=draw_proj_image(image,proj_mat,result_surface)
223
+ proj_save_path = '{}/proj_{}.jpg'.format(object_folder, idx)
224
+ cv2.imwrite(proj_save_path,proj_image)
225
+ save_path='{}/{}.jpg'.format(object_folder, idx)
226
+ cv2.imwrite(save_path,image)
227
+ if args.save_surface:
228
+ surface=gt_surface.numpy().astype(np.float32)
229
+ surface_o3d = o3d.geometry.PointCloud()
230
+ surface_o3d.points = o3d.utility.Vector3dVector(surface[:, 0:3])
231
+ o3d.io.write_point_cloud('{}/surface.ply'.format(object_folder), surface_o3d)
232
+ metric_logger.synchronize_between_processes()
233
+ print('* iou {ious.global_avg:.3f}'
234
+ .format(ious=metric_logger.iou))
235
+ if args.eval_cd:
236
+ print('* chamferl2 {chamferl2s.global_avg:.3f}'
237
+ .format(chamferl2s=metric_logger.chamferl2))
238
+ print('* fscore {fscores.global_avg:.3f}'
239
+ .format(fscores=metric_logger.fscore))
evaluation/pyTorchChamferDistance/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .DS_Store
2
+ ._*
3
+
evaluation/pyTorchChamferDistance/LICENSE.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) [year] [fullname]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
evaluation/pyTorchChamferDistance/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chamfer Distance for pyTorch
2
+
3
+ This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension.
4
+
5
+ As it is using pyTorch's [JIT compilation](https://pytorch.org/tutorials/advanced/cpp_extension.html), there are no additional prerequisite steps that have to be taken. Simply import the module as shown below; CUDA and C++ code will be compiled on the first run.
6
+
7
+ ### Usage
8
+ ```python
9
+ from chamfer_distance import ChamferDistance
10
+ chamfer_dist = ChamferDistance()
11
+
12
+ #...
13
+ # points and points_reconstructed are n_points x 3 matrices
14
+
15
+ dist1, dist2 = chamfer_dist(points, points_reconstructed)
16
+ loss = (torch.mean(dist1)) + (torch.mean(dist2))
17
+
18
+
19
+ #...
20
+ ```
21
+
22
+ ### Integration
23
+ This code has been integrated into the [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) library for 3D Deep Learning by NVIDIAGameWorks. You should probably take a look at it if you are working on anything 3D :)
evaluation/pyTorchChamferDistance/__init__.py ADDED
File without changes
evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .chamfer_distance import ChamferDistance
evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+
3
+ // CUDA forward declarations
4
+ int ChamferDistanceKernelLauncher(
5
+ const int b, const int n,
6
+ const float* xyz,
7
+ const int m,
8
+ const float* xyz2,
9
+ float* result,
10
+ int* result_i,
11
+ float* result2,
12
+ int* result2_i);
13
+
14
+ int ChamferDistanceGradKernelLauncher(
15
+ const int b, const int n,
16
+ const float* xyz1,
17
+ const int m,
18
+ const float* xyz2,
19
+ const float* grad_dist1,
20
+ const int* idx1,
21
+ const float* grad_dist2,
22
+ const int* idx2,
23
+ float* grad_xyz1,
24
+ float* grad_xyz2);
25
+
26
+
27
+ void chamfer_distance_forward_cuda(
28
+ const at::Tensor xyz1,
29
+ const at::Tensor xyz2,
30
+ const at::Tensor dist1,
31
+ const at::Tensor dist2,
32
+ const at::Tensor idx1,
33
+ const at::Tensor idx2)
34
+ {
35
+ ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
36
+ xyz2.size(1), xyz2.data<float>(),
37
+ dist1.data<float>(), idx1.data<int>(),
38
+ dist2.data<float>(), idx2.data<int>());
39
+ }
40
+
41
+ void chamfer_distance_backward_cuda(
42
+ const at::Tensor xyz1,
43
+ const at::Tensor xyz2,
44
+ at::Tensor gradxyz1,
45
+ at::Tensor gradxyz2,
46
+ at::Tensor graddist1,
47
+ at::Tensor graddist2,
48
+ at::Tensor idx1,
49
+ at::Tensor idx2)
50
+ {
51
+ ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
52
+ xyz2.size(1), xyz2.data<float>(),
53
+ graddist1.data<float>(), idx1.data<int>(),
54
+ graddist2.data<float>(), idx2.data<int>(),
55
+ gradxyz1.data<float>(), gradxyz2.data<float>());
56
+ }
57
+
58
+
59
+ void nnsearch(
60
+ const int b, const int n, const int m,
61
+ const float* xyz1,
62
+ const float* xyz2,
63
+ float* dist,
64
+ int* idx)
65
+ {
66
+ for (int i = 0; i < b; i++) {
67
+ for (int j = 0; j < n; j++) {
68
+ const float x1 = xyz1[(i*n+j)*3+0];
69
+ const float y1 = xyz1[(i*n+j)*3+1];
70
+ const float z1 = xyz1[(i*n+j)*3+2];
71
+ double best = 0;
72
+ int besti = 0;
73
+ for (int k = 0; k < m; k++) {
74
+ const float x2 = xyz2[(i*m+k)*3+0] - x1;
75
+ const float y2 = xyz2[(i*m+k)*3+1] - y1;
76
+ const float z2 = xyz2[(i*m+k)*3+2] - z1;
77
+ const double d=x2*x2+y2*y2+z2*z2;
78
+ if (k==0 || d < best){
79
+ best = d;
80
+ besti = k;
81
+ }
82
+ }
83
+ dist[i*n+j] = best;
84
+ idx[i*n+j] = besti;
85
+ }
86
+ }
87
+ }
88
+
89
+
90
+ void chamfer_distance_forward(
91
+ const at::Tensor xyz1,
92
+ const at::Tensor xyz2,
93
+ const at::Tensor dist1,
94
+ const at::Tensor dist2,
95
+ const at::Tensor idx1,
96
+ const at::Tensor idx2)
97
+ {
98
+ const int batchsize = xyz1.size(0);
99
+ const int n = xyz1.size(1);
100
+ const int m = xyz2.size(1);
101
+
102
+ const float* xyz1_data = xyz1.data<float>();
103
+ const float* xyz2_data = xyz2.data<float>();
104
+ float* dist1_data = dist1.data<float>();
105
+ float* dist2_data = dist2.data<float>();
106
+ int* idx1_data = idx1.data<int>();
107
+ int* idx2_data = idx2.data<int>();
108
+
109
+ nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
110
+ nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
111
+ }
112
+
113
+
114
+ void chamfer_distance_backward(
115
+ const at::Tensor xyz1,
116
+ const at::Tensor xyz2,
117
+ at::Tensor gradxyz1,
118
+ at::Tensor gradxyz2,
119
+ at::Tensor graddist1,
120
+ at::Tensor graddist2,
121
+ at::Tensor idx1,
122
+ at::Tensor idx2)
123
+ {
124
+ const int b = xyz1.size(0);
125
+ const int n = xyz1.size(1);
126
+ const int m = xyz2.size(1);
127
+
128
+ const float* xyz1_data = xyz1.data<float>();
129
+ const float* xyz2_data = xyz2.data<float>();
130
+ float* gradxyz1_data = gradxyz1.data<float>();
131
+ float* gradxyz2_data = gradxyz2.data<float>();
132
+ float* graddist1_data = graddist1.data<float>();
133
+ float* graddist2_data = graddist2.data<float>();
134
+ const int* idx1_data = idx1.data<int>();
135
+ const int* idx2_data = idx2.data<int>();
136
+
137
+ for (int i = 0; i < b*n*3; i++)
138
+ gradxyz1_data[i] = 0;
139
+ for (int i = 0; i < b*m*3; i++)
140
+ gradxyz2_data[i] = 0;
141
+ for (int i = 0;i < b; i++) {
142
+ for (int j = 0; j < n; j++) {
143
+ const float x1 = xyz1_data[(i*n+j)*3+0];
144
+ const float y1 = xyz1_data[(i*n+j)*3+1];
145
+ const float z1 = xyz1_data[(i*n+j)*3+2];
146
+ const int j2 = idx1_data[i*n+j];
147
+
148
+ const float x2 = xyz2_data[(i*m+j2)*3+0];
149
+ const float y2 = xyz2_data[(i*m+j2)*3+1];
150
+ const float z2 = xyz2_data[(i*m+j2)*3+2];
151
+ const float g = graddist1_data[i*n+j]*2;
152
+
153
+ gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2);
154
+ gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2);
155
+ gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2);
156
+ gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2));
157
+ gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2));
158
+ gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2));
159
+ }
160
+ for (int j = 0; j < m; j++) {
161
+ const float x1 = xyz2_data[(i*m+j)*3+0];
162
+ const float y1 = xyz2_data[(i*m+j)*3+1];
163
+ const float z1 = xyz2_data[(i*m+j)*3+2];
164
+ const int j2 = idx2_data[i*m+j];
165
+ const float x2 = xyz1_data[(i*n+j2)*3+0];
166
+ const float y2 = xyz1_data[(i*n+j2)*3+1];
167
+ const float z2 = xyz1_data[(i*n+j2)*3+2];
168
+ const float g = graddist2_data[i*m+j]*2;
169
+ gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2);
170
+ gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2);
171
+ gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2);
172
+ gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2));
173
+ gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2));
174
+ gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2));
175
+ }
176
+ }
177
+ }
178
+
179
+
180
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
181
+ m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
182
+ m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
183
+ m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
184
+ m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
185
+ }
evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_runtime.h>
5
+
6
+ __global__
7
+ void ChamferDistanceKernel(
8
+ int b,
9
+ int n,
10
+ const float* xyz,
11
+ int m,
12
+ const float* xyz2,
13
+ float* result,
14
+ int* result_i)
15
+ {
16
+ const int batch=512;
17
+ __shared__ float buf[batch*3];
18
+ for (int i=blockIdx.x;i<b;i+=gridDim.x){
19
+ for (int k2=0;k2<m;k2+=batch){
20
+ int end_k=min(m,k2+batch)-k2;
21
+ for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
22
+ buf[j]=xyz2[(i*m+k2)*3+j];
23
+ }
24
+ __syncthreads();
25
+ for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
26
+ float x1=xyz[(i*n+j)*3+0];
27
+ float y1=xyz[(i*n+j)*3+1];
28
+ float z1=xyz[(i*n+j)*3+2];
29
+ int best_i=0;
30
+ float best=0;
31
+ int end_ka=end_k-(end_k&3);
32
+ if (end_ka==batch){
33
+ for (int k=0;k<batch;k+=4){
34
+ {
35
+ float x2=buf[k*3+0]-x1;
36
+ float y2=buf[k*3+1]-y1;
37
+ float z2=buf[k*3+2]-z1;
38
+ float d=x2*x2+y2*y2+z2*z2;
39
+ if (k==0 || d<best){
40
+ best=d;
41
+ best_i=k+k2;
42
+ }
43
+ }
44
+ {
45
+ float x2=buf[k*3+3]-x1;
46
+ float y2=buf[k*3+4]-y1;
47
+ float z2=buf[k*3+5]-z1;
48
+ float d=x2*x2+y2*y2+z2*z2;
49
+ if (d<best){
50
+ best=d;
51
+ best_i=k+k2+1;
52
+ }
53
+ }
54
+ {
55
+ float x2=buf[k*3+6]-x1;
56
+ float y2=buf[k*3+7]-y1;
57
+ float z2=buf[k*3+8]-z1;
58
+ float d=x2*x2+y2*y2+z2*z2;
59
+ if (d<best){
60
+ best=d;
61
+ best_i=k+k2+2;
62
+ }
63
+ }
64
+ {
65
+ float x2=buf[k*3+9]-x1;
66
+ float y2=buf[k*3+10]-y1;
67
+ float z2=buf[k*3+11]-z1;
68
+ float d=x2*x2+y2*y2+z2*z2;
69
+ if (d<best){
70
+ best=d;
71
+ best_i=k+k2+3;
72
+ }
73
+ }
74
+ }
75
+ }else{
76
+ for (int k=0;k<end_ka;k+=4){
77
+ {
78
+ float x2=buf[k*3+0]-x1;
79
+ float y2=buf[k*3+1]-y1;
80
+ float z2=buf[k*3+2]-z1;
81
+ float d=x2*x2+y2*y2+z2*z2;
82
+ if (k==0 || d<best){
83
+ best=d;
84
+ best_i=k+k2;
85
+ }
86
+ }
87
+ {
88
+ float x2=buf[k*3+3]-x1;
89
+ float y2=buf[k*3+4]-y1;
90
+ float z2=buf[k*3+5]-z1;
91
+ float d=x2*x2+y2*y2+z2*z2;
92
+ if (d<best){
93
+ best=d;
94
+ best_i=k+k2+1;
95
+ }
96
+ }
97
+ {
98
+ float x2=buf[k*3+6]-x1;
99
+ float y2=buf[k*3+7]-y1;
100
+ float z2=buf[k*3+8]-z1;
101
+ float d=x2*x2+y2*y2+z2*z2;
102
+ if (d<best){
103
+ best=d;
104
+ best_i=k+k2+2;
105
+ }
106
+ }
107
+ {
108
+ float x2=buf[k*3+9]-x1;
109
+ float y2=buf[k*3+10]-y1;
110
+ float z2=buf[k*3+11]-z1;
111
+ float d=x2*x2+y2*y2+z2*z2;
112
+ if (d<best){
113
+ best=d;
114
+ best_i=k+k2+3;
115
+ }
116
+ }
117
+ }
118
+ }
119
+ for (int k=end_ka;k<end_k;k++){
120
+ float x2=buf[k*3+0]-x1;
121
+ float y2=buf[k*3+1]-y1;
122
+ float z2=buf[k*3+2]-z1;
123
+ float d=x2*x2+y2*y2+z2*z2;
124
+ if (k==0 || d<best){
125
+ best=d;
126
+ best_i=k+k2;
127
+ }
128
+ }
129
+ if (k2==0 || result[(i*n+j)]>best){
130
+ result[(i*n+j)]=best;
131
+ result_i[(i*n+j)]=best_i;
132
+ }
133
+ }
134
+ __syncthreads();
135
+ }
136
+ }
137
+ }
138
+
139
+ void ChamferDistanceKernelLauncher(
140
+ const int b, const int n,
141
+ const float* xyz,
142
+ const int m,
143
+ const float* xyz2,
144
+ float* result,
145
+ int* result_i,
146
+ float* result2,
147
+ int* result2_i)
148
+ {
149
+ ChamferDistanceKernel<<<dim3(32,16,1),512>>>(b, n, xyz, m, xyz2, result, result_i);
150
+ ChamferDistanceKernel<<<dim3(32,16,1),512>>>(b, m, xyz2, n, xyz, result2, result2_i);
151
+
152
+ cudaError_t err = cudaGetLastError();
153
+ if (err != cudaSuccess)
154
+ printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err));
155
+ }
156
+
157
+
158
+ __global__
159
+ void ChamferDistanceGradKernel(
160
+ int b, int n,
161
+ const float* xyz1,
162
+ int m,
163
+ const float* xyz2,
164
+ const float* grad_dist1,
165
+ const int* idx1,
166
+ float* grad_xyz1,
167
+ float* grad_xyz2)
168
+ {
169
+ for (int i = blockIdx.x; i<b; i += gridDim.x) {
170
+ for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x*gridDim.y) {
171
+ float x1=xyz1[(i*n+j)*3+0];
172
+ float y1=xyz1[(i*n+j)*3+1];
173
+ float z1=xyz1[(i*n+j)*3+2];
174
+ int j2=idx1[i*n+j];
175
+ float x2=xyz2[(i*m+j2)*3+0];
176
+ float y2=xyz2[(i*m+j2)*3+1];
177
+ float z2=xyz2[(i*m+j2)*3+2];
178
+ float g=grad_dist1[i*n+j]*2;
179
+ atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
180
+ atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
181
+ atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
182
+ atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
183
+ atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
184
+ atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
185
+ }
186
+ }
187
+ }
188
+
189
+ void ChamferDistanceGradKernelLauncher(
190
+ const int b, const int n,
191
+ const float* xyz1,
192
+ const int m,
193
+ const float* xyz2,
194
+ const float* grad_dist1,
195
+ const int* idx1,
196
+ const float* grad_dist2,
197
+ const int* idx2,
198
+ float* grad_xyz1,
199
+ float* grad_xyz2)
200
+ {
201
+ cudaMemset(grad_xyz1, 0, b*n*3*4);
202
+ cudaMemset(grad_xyz2, 0, b*m*3*4);
203
+ ChamferDistanceGradKernel<<<dim3(1,16,1), 256>>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2);
204
+ ChamferDistanceGradKernel<<<dim3(1,16,1), 256>>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1);
205
+
206
+ cudaError_t err = cudaGetLastError();
207
+ if (err != cudaSuccess)
208
+ printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err));
209
+ }
evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ from torch.utils.cpp_extension import load
5
+ cd = load(name="build",
6
+ sources=["pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp",
7
+ "pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu"],
8
+ build_directory="pyTorchChamferDistance/build")
9
+
10
+ class ChamferDistanceFunction(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(ctx, xyz1, xyz2):
13
+ batchsize, n, _ = xyz1.size()
14
+ _, m, _ = xyz2.size()
15
+ xyz1 = xyz1.contiguous()
16
+ xyz2 = xyz2.contiguous()
17
+ dist1 = torch.zeros(batchsize, n)
18
+ dist2 = torch.zeros(batchsize, m)
19
+
20
+ idx1 = torch.zeros(batchsize, n, dtype=torch.int)
21
+ idx2 = torch.zeros(batchsize, m, dtype=torch.int)
22
+
23
+ if not xyz1.is_cuda:
24
+ cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
25
+ else:
26
+ dist1 = dist1.cuda()
27
+ dist2 = dist2.cuda()
28
+ idx1 = idx1.cuda()
29
+ idx2 = idx2.cuda()
30
+ cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2)
31
+
32
+ ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
33
+
34
+ return dist1, dist2, idx1, idx2
35
+
36
+ @staticmethod
37
+ def backward(ctx, graddist1, graddist2, *args):
38
+ xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
39
+
40
+ graddist1 = graddist1.contiguous()
41
+ graddist2 = graddist2.contiguous()
42
+
43
+ gradxyz1 = torch.zeros(xyz1.size())
44
+ gradxyz2 = torch.zeros(xyz2.size())
45
+
46
+ if not graddist1.is_cuda:
47
+ cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
48
+ else:
49
+ gradxyz1 = gradxyz1.cuda()
50
+ gradxyz2 = gradxyz2.cuda()
51
+ cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
52
+
53
+ return gradxyz1, gradxyz2
54
+
55
+
56
+ class ChamferDistance(torch.nn.Module):
57
+ def forward(self, xyz1, xyz2):
58
+ return ChamferDistanceFunction.apply(xyz1, xyz2)
finetune_diffusion.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd scripts
2
+ CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' torchrun --master_port 15003 --nproc_per_node=8 \
3
+ train_triplane_diffusion.py \
4
+ --configs ../configs/finetune_triplane_diffusion.yaml \
5
+ --accum_iter 2 \
6
+ --output_dir ../output/finetune_dm/lowres_chair \
7
+ --log_dir ../output/finetune_dm/lowres_chair --num_workers 8 \
8
+ --batch_size 22 \
9
+ --blr 1e-4 \
10
+ --epochs 500 \
11
+ --dist_eval \
12
+ --warmup_epochs 20 \
13
+ --ae-pth ../output/ae/chair/best-checkpoint.pth \
14
+ --category chair \
15
+ --finetune \
16
+ --finetune-pth ../output/dm/chair/best-checkpoint.pth \
17
+ --data-pth ../data \
18
+ --replica 5
models/TriplaneVAE.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import sys,os
3
+ sys.path.append("..")
4
+ import torch
5
+ from datasets import build_dataset
6
+ from configs.config_utils import CONFIG
7
+ from torch.utils.data import DataLoader
8
+ from models.modules import PointEmbed
9
+ from models.modules import ConvPointnet_Encoder,ConvPointnet_Decoder
10
+ import numpy as np
11
+
12
+ class TriplaneVAE(nn.Module):
13
+ def __init__(self,opt):
14
+ super().__init__()
15
+ self.point_embedder=PointEmbed(hidden_dim=opt['point_emb_dim'])
16
+
17
+ encoder_args=opt['encoder']
18
+ decoder_args=opt['decoder']
19
+ self.encoder=ConvPointnet_Encoder(c_dim=encoder_args['plane_latent_dim'],dim=opt['point_emb_dim'],latent_dim=encoder_args['latent_dim'],
20
+ plane_resolution=encoder_args['plane_reso'],unet_kwargs=encoder_args['unet'],unet=True,padding=opt['padding'])
21
+ self.decoder=ConvPointnet_Decoder(latent_dim=decoder_args['latent_dim'],query_emb_dim=decoder_args['query_emb_dim'],
22
+ hidden_dim=decoder_args['hidden_dim'],unet_kwargs=decoder_args['unet'],n_blocks=decoder_args['n_blocks'],
23
+ plane_resolution=decoder_args['plane_reso'],padding=opt['padding'])
24
+
25
+ def forward(self,p,query):
26
+ '''
27
+ :param p: surface points cloud of shape B,N,3
28
+ :param query: sample points of shape B,N,3
29
+ :return:
30
+ '''
31
+ point_emb=self.point_embedder(p)
32
+ query_emb=self.point_embedder(query)
33
+ kl,plane_feat,means,logvars=self.encoder(p,point_emb)
34
+ if self.training:
35
+ if np.random.random()<0.5:
36
+ '''randomly sacle the triplane, and conduct triplane diffusion on 64x64x64 plane, promote robustness'''
37
+ plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode="bilinear")
38
+ plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode="bilinear")
39
+ # if self.training:
40
+ # if np.random.random()<0.5:
41
+ # means = torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
42
+ # vars=torch.exp(logvars)
43
+ # vars = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")
44
+ # new_logvars=torch.log(vars)
45
+ # posterior = DiagonalGaussianDistribution(means, new_logvars)
46
+ # plane_feat=posterior.sample()
47
+ # plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode='bilinear')
48
+
49
+ # mean_scale=torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
50
+ # vars = torch.exp(logvars)
51
+ # vars_scale = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")/4
52
+ # logvars_scale=torch.log(vars_scale)
53
+ # scale_noise=torch.randn(mean_scale.shape).to(mean_scale.device)
54
+ # plane_feat_scale2=mean_scale+torch.exp(0.5*logvars_scale)*scale_noise
55
+ # plane_feat=torch.nn.functional.interpolate(plane_feat_scale2,scale_factor=2,mode='bilinear')
56
+ o=self.decoder(plane_feat,query,query_emb)
57
+
58
+ return {'logits':o,'kl':kl}
59
+
60
+
61
+ def decode(self,plane_feature,query):
62
+ query_embedding=self.point_embedder(query)
63
+ o=self.decoder(plane_feature,query,query_embedding)
64
+
65
+ return o
66
+
67
+ def encode(self,p):
68
+ point_emb = self.point_embedder(p)
69
+ kl, plane_feat,mean,logvar = self.encoder(p, point_emb)
70
+ '''p is point cloud of B,N,3'''
71
+ return plane_feat,kl,mean,logvar
72
+
73
+ if __name__=="__main__":
74
+ configs=CONFIG("../configs/train_triplane_vae_64.yaml")
75
+ config=configs.config
76
+ dataset_config=config['datasets']
77
+ model_config=config["model"]
78
+ dataset=build_dataset("train",dataset_config)
79
+ dataset.__getitem__(0)
80
+ dataloader=DataLoader(
81
+ dataset=dataset,
82
+ batch_size=10,
83
+ shuffle=True,
84
+ num_workers=2,
85
+ )
86
+ net=TriplaneVAE(model_config).float().cuda()
87
+ for idx,data_batch in enumerate(dataloader):
88
+ if idx==1:
89
+ break
90
+ surface=data_batch['surface'].float().cuda()
91
+ query=data_batch['points'].float().cuda()
92
+ net(surface,query)
93
+
94
+
models/Triplane_Diffusion.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from models.modules.resunet import ResUnet_DirectAttenMultiImg_Cond
4
+ from models.modules.parpoints_encoder import ParPoint_Encoder
5
+ from models.modules.PointEMB import PointEmbed
6
+ from models.modules.utils import StackedRandomGenerator
7
+ from models.modules.diffusion_sampler import edm_sampler
8
+ from models.modules.encoder import DiagonalGaussianDistribution
9
+ import numpy as np
10
+ class EDMLoss_MultiImgCond:
11
+ def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5,use_par=False):
12
+ self.P_mean = P_mean
13
+ self.P_std = P_std
14
+ self.sigma_data = sigma_data
15
+ self.use_par=use_par
16
+
17
+ def __call__(self, net, data_batch, classifier_free=False):
18
+ inputs = data_batch['input']
19
+ image=data_batch['image']
20
+ proj_mat=data_batch['proj_mat']
21
+ valid_frames=data_batch['valid_frames']
22
+ par_points=data_batch["par_points"]
23
+ category_code=data_batch["category_code"]
24
+ rnd_normal = torch.randn([inputs.shape[0], 1, 1, 1], device=inputs.device)
25
+
26
+ sigma = (rnd_normal * self.P_std + self.P_mean).exp() #B,1,1,1
27
+ weight = (sigma ** 2 + self.sigma_data ** 2) / (self.sigma_data * sigma) ** 2
28
+ y=inputs
29
+
30
+ n = torch.randn_like(y) * sigma
31
+
32
+ # if classifier_free and np.random.random()<0.5:
33
+ # net.par_feat=torch.zeros((inputs.shape[0],32,inputs.shape[2],inputs.shape[3])).float().to(inputs.device)
34
+ if classifier_free and np.random.random()<0.5:
35
+ image=torch.zeros_like(image).float().cuda()
36
+ net.module.extract_img_feat(image)
37
+ net.module.set_proj_matrix(proj_mat)
38
+ net.module.set_valid_frames(valid_frames)
39
+ net.module.set_category_code(category_code)
40
+ if self.use_par:
41
+ net.module.extract_point_feat(par_points)
42
+
43
+ D_yn = net(y + n,sigma)
44
+ loss = weight * ((D_yn - y) ** 2)
45
+ return loss
46
+
47
+ class Triplane_Diff_MultiImgCond_EDM(nn.Module):
48
+ def __init__(self,opt):
49
+ super().__init__()
50
+ self.diff_reso=opt['diff_reso']
51
+ self.diff_dim=opt['output_channel']
52
+ self.use_cat_embedding=opt['use_cat_embedding']
53
+ self.use_fp16=False
54
+ self.sigma_data=0.5
55
+ self.sigma_max=float("inf")
56
+ self.sigma_min=0
57
+ self.use_par=opt['use_par']
58
+ self.triplane_padding=opt['triplane_padding']
59
+ self.block_type=opt['block_type']
60
+ #self.use_bn=opt['use_bn']
61
+ if opt['backbone']=="resunet_multiimg_direct_atten":
62
+ self.denoise_model=ResUnet_DirectAttenMultiImg_Cond(channel=opt['input_channel'],
63
+ output_channel=opt['output_channel'],use_par=opt['use_par'],par_channel=opt['par_channel'],
64
+ img_in_channels=opt['img_in_channels'],vit_reso=opt['vit_reso'],triplane_padding=self.triplane_padding,
65
+ norm=opt['norm'],use_cat_embedding=self.use_cat_embedding,block_type=self.block_type)
66
+ else:
67
+ raise NotImplementedError
68
+ if opt['use_par']: #use partial point cloud as inputs
69
+ par_emb_dim = opt['par_emb_dim']
70
+ par_args = opt['par_point_encoder']
71
+ self.point_embedder = PointEmbed(hidden_dim=par_emb_dim)
72
+ self.par_points_encoder = ParPoint_Encoder(c_dim=par_args['plane_latent_dim'], dim=par_emb_dim,
73
+ plane_resolution=par_args['plane_reso'],
74
+ unet_kwargs=par_args['unet'])
75
+ self.unflatten = torch.nn.Unflatten(1, (16, 16))
76
+ def prepare_data(self,data_batch):
77
+ #par_points = data_batch['par_points'].to(device, non_blocking=True)
78
+ device=torch.device("cuda")
79
+ means, logvars = data_batch['triplane_mean'].to(device, non_blocking=True), data_batch['triplane_logvar'].to(
80
+ device, non_blocking=True)
81
+ distribution = DiagonalGaussianDistribution(means, logvars)
82
+ plane_feat = distribution.sample()
83
+
84
+ image=data_batch["image"].to(device)
85
+ proj_mat = data_batch['proj_mat'].to(device, non_blocking=True)
86
+ valid_frames=data_batch["valid_frames"].to(device,non_blocking=True)
87
+ par_points=data_batch["par_points"].to(device,non_blocking=True)
88
+ category_code=data_batch["category_code"].to(device,non_blocking=True)
89
+ input_dict = {"input": plane_feat.float(),
90
+ "image": image.float(),
91
+ "par_points":par_points.float(),
92
+ "proj_mat":proj_mat.float(),
93
+ "category_code":category_code.float(),
94
+ "valid_frames":valid_frames.float()} # TODO: add image and proj matrix
95
+
96
+ return input_dict
97
+
98
+ def prepare_sample_data(self,data_batch):
99
+ device=torch.device("cuda")
100
+ image=data_batch['image'].to(device, non_blocking=True)
101
+ proj_mat = data_batch['proj_mat'].to(device, non_blocking=True)
102
+ valid_frames = data_batch["valid_frames"].to(device, non_blocking=True)
103
+ par_points = data_batch["par_points"].to(device, non_blocking=True)
104
+ category_code=data_batch["category_code"].to(device,non_blocking=True)
105
+ sample_dict={
106
+ "image":image.float(),
107
+ "proj_mat":proj_mat.float(),
108
+ "valid_frames":valid_frames.float(),
109
+ "category_code":category_code.float(),
110
+ "par_points":par_points.float(),
111
+ }
112
+ return sample_dict
113
+
114
+ def prepare_eval_data(self,data_batch):
115
+ device=torch.device("cuda")
116
+ samples=data_batch["points"].to(device, non_blocking=True)
117
+ labels=data_batch['labels'].to(device,non_blocking=True)
118
+
119
+ eval_dict={
120
+ "samples":samples,
121
+ "labels":labels,
122
+ }
123
+ return eval_dict
124
+
125
+ def extract_point_feat(self,par_points):
126
+ par_emb=self.point_embedder(par_points)
127
+ self.par_feat=self.par_points_encoder(par_points,par_emb)
128
+
129
+ def extract_img_feat(self,image):
130
+ self.image_emb=image
131
+
132
+ def set_proj_matrix(self,proj_matrix):
133
+ self.proj_matrix=proj_matrix
134
+
135
+ def set_valid_frames(self,valid_frames):
136
+ self.valid_frames=valid_frames
137
+
138
+ def set_category_code(self,category_code):
139
+ self.category_code=category_code
140
+
141
+ def forward(self, x, sigma,force_fp32=False):
142
+ x = x.to(torch.float32)
143
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) #B,1,1,1
144
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
145
+
146
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
147
+ c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
148
+ c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
149
+ c_noise = sigma.log() / 4 #B,1,1,1, need to check how to add embedding into unet
150
+
151
+ if self.use_par:
152
+ F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(), self.image_emb, self.proj_matrix,
153
+ self.valid_frames,self.category_code,self.par_feat)
154
+ else:
155
+ F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(),self.image_emb,self.proj_matrix,
156
+ self.valid_frames,self.category_code)
157
+ assert F_x.dtype == dtype
158
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
159
+ return D_x
160
+
161
+ def round_sigma(self, sigma):
162
+ return torch.as_tensor(sigma)
163
+
164
+ @torch.no_grad()
165
+ def sample(self, input_batch, batch_seeds=None,ret_all=False,num_steps=18):
166
+ img_cond=input_batch['image']
167
+ proj_mat=input_batch['proj_mat']
168
+ valid_frames=input_batch["valid_frames"]
169
+ category_code=input_batch["category_code"]
170
+ if img_cond is not None:
171
+ batch_size, device = img_cond.shape[0], img_cond.device
172
+ if batch_seeds is None:
173
+ batch_seeds = torch.arange(batch_size)
174
+ else:
175
+ device = batch_seeds.device
176
+ batch_size = batch_seeds.shape[0]
177
+
178
+ self.extract_img_feat(img_cond)
179
+ self.set_proj_matrix(proj_mat)
180
+ self.set_valid_frames(valid_frames)
181
+ self.set_category_code(category_code)
182
+ if self.use_par:
183
+ par_points=input_batch["par_points"]
184
+ self.extract_point_feat(par_points)
185
+ rnd = StackedRandomGenerator(device, batch_seeds)
186
+ latents = rnd.randn([batch_size, self.diff_dim, self.diff_reso*3,self.diff_reso], device=device)
187
+
188
+ return edm_sampler(self, latents, randn_like=rnd.randn_like,ret_all=ret_all,sigma_min=0.002, sigma_max=80,num_steps=num_steps)
189
+
190
+
models/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .TriplaneVAE import TriplaneVAE
2
+ from .Triplane_Diffusion import Triplane_Diff_MultiImgCond_EDM
3
+ from .Triplane_Diffusion import EDMLoss_MultiImgCond
4
+ #from .Point_Diffusion_EDM import PointEDM,EDMLoss_PointAug
5
+
6
+ def get_model(model_args):
7
+ if model_args['type']=="TriVAE":
8
+ model=TriplaneVAE(model_args)
9
+ elif model_args['type']=="triplane_diff_multiimg_cond":
10
+ model=Triplane_Diff_MultiImgCond_EDM(model_args)
11
+ else:
12
+ raise NotImplementedError
13
+ return model
14
+
15
+ def get_criterion(cri_args):
16
+ if cri_args['type']=="EDMLoss_MultiImgCond":
17
+ criterion=EDMLoss_MultiImgCond(use_par=cri_args['use_par'])
18
+ else:
19
+ raise NotImplementedError
20
+ return criterion
models/modules/PointEMB.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+
5
+ class PointEmbed(nn.Module):
6
+ def __init__(self, hidden_dim=48):
7
+ super().__init__()
8
+
9
+ assert hidden_dim % 6 == 0
10
+
11
+ self.embedding_dim = hidden_dim
12
+ e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
13
+ e = torch.stack([
14
+ torch.cat([e, torch.zeros(self.embedding_dim // 6),
15
+ torch.zeros(self.embedding_dim // 6)]),
16
+ torch.cat([torch.zeros(self.embedding_dim // 6), e,
17
+ torch.zeros(self.embedding_dim // 6)]),
18
+ torch.cat([torch.zeros(self.embedding_dim // 6),
19
+ torch.zeros(self.embedding_dim // 6), e]),
20
+ ])
21
+ self.register_buffer('basis', e) # 3 x 24
22
+
23
+
24
+ @staticmethod
25
+ def embed(input, basis):
26
+ projections = torch.einsum(
27
+ 'bnd,de->bne', input, basis) # N,24
28
+ embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
29
+ return embeddings
30
+
31
+ def forward(self, input):
32
+ # input: B x N x 3
33
+ embed = self.embed(input, self.basis)
34
+ return embed
models/modules/Positional_Embedding.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ class PositionalEmbedding(torch.nn.Module):
3
+ def __init__(self, num_channels, max_positions=10000, endpoint=False):
4
+ super().__init__()
5
+ self.num_channels = num_channels
6
+ self.max_positions = max_positions
7
+ self.endpoint = endpoint
8
+
9
+ def forward(self, x):
10
+ freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
11
+ freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
12
+ freqs = (1 / self.max_positions) ** freqs
13
+ x = x.ger(freqs.to(x.dtype))
14
+ x = torch.cat([x.cos(), x.sin()], dim=1)
15
+ return x
models/modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .encoder import ConvPointnet_Encoder
2
+ from .resnet_block import ResnetBlockFC
3
+ from .unet import UNet,RollOut_Conv
4
+ from .PointEMB import PointEmbed
5
+ from .decoder import ConvPointnet_Decoder
models/modules/decoder.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch_scatter import scatter_mean, scatter_max
5
+ from .unet import UNet
6
+ from .resnet_block import ResnetBlockFC
7
+ import numpy as np
8
+
9
+ class ConvPointnet_Decoder(nn.Module):
10
+ ''' PointNet-based encoder network with ResNet blocks for each point.
11
+ Number of input points are fixed.
12
+
13
+ Args:
14
+ c_dim (int): dimension of latent code c
15
+ dim (int): input points dimension
16
+ hidden_dim (int): hidden dimension of the network
17
+ scatter_type (str): feature aggregation when doing local pooling
18
+ unet (bool): weather to use U-Net
19
+ unet_kwargs (str): U-Net parameters
20
+ plane_resolution (int): defined resolution for plane feature
21
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
22
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
23
+ n_blocks (int): number of blocks ResNetBlockFC layers
24
+ '''
25
+
26
+ def __init__(self, latent_dim=32,query_emb_dim=51,hidden_dim=128, unet_kwargs=None,
27
+ plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
28
+ super().__init__()
29
+
30
+ self.latent_dim=32
31
+ self.actvn = nn.ReLU()
32
+
33
+ self.unet = UNet(unet_kwargs['output_dim'], in_channels=latent_dim, **unet_kwargs)
34
+
35
+ self.fc_c=nn.ModuleList
36
+ self.reso_plane = plane_resolution
37
+ self.plane_type = plane_type
38
+ self.padding = padding
39
+ self.n_blocks=n_blocks
40
+
41
+ self.fc_c = nn.ModuleList([
42
+ nn.Linear(latent_dim*3, hidden_dim) for i in range(n_blocks)
43
+ ])
44
+ self.fc_p=nn.Linear(query_emb_dim,hidden_dim)
45
+ self.fc_out=nn.Linear(hidden_dim,1)
46
+
47
+ self.blocks = nn.ModuleList([
48
+ ResnetBlockFC(hidden_dim) for i in range(n_blocks)
49
+ ])
50
+
51
+ def forward(self, plane_features,query,query_emb): # , query2):
52
+ plane_feature=self.unet(plane_features)
53
+ H,W=plane_feature.shape[2:4]
54
+ xz_feat,xy_feat,yz_feat=torch.split(plane_feature,dim=2,split_size_or_sections=H//3)
55
+ xz_sample_feat=self.sample_plane_feature(query,xz_feat,'xz')
56
+ xy_sample_feat=self.sample_plane_feature(query,xy_feat,'xy')
57
+ yz_sample_feat=self.sample_plane_feature(query,yz_feat,'yz')
58
+
59
+ sample_feat=torch.cat([xz_sample_feat,xy_sample_feat,yz_sample_feat],dim=1)
60
+ sample_feat=sample_feat.transpose(1,2)
61
+
62
+ net=self.fc_p(query_emb)
63
+ for i in range(self.n_blocks):
64
+ net=net+self.fc_c[i](sample_feat)
65
+ net=self.blocks[i](net)
66
+ out=self.fc_out(self.actvn(net)).squeeze(-1)
67
+ return out
68
+
69
+
70
+ def normalize_coordinate(self, p, padding=0.1, plane='xz'):
71
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
72
+
73
+ Args:
74
+ p (tensor): point
75
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
76
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
77
+ '''
78
+ if plane == 'xz':
79
+ xy = p[:, :, [0, 2]]
80
+ elif plane == 'xy':
81
+ xy = p[:, :, [0, 1]]
82
+ else:
83
+ xy = p[:, :, [1, 2]]
84
+ #print("origin",torch.amin(xy), torch.amax(xy))
85
+ xy=xy/2 #xy is originally -1 ~ 1
86
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
87
+ xy_new = xy_new + 0.5 # range (0, 1)
88
+ #print("scale",torch.amin(xy_new),torch.amax(xy_new))
89
+
90
+ # f there are outliers out of the range
91
+ if xy_new.max() >= 1:
92
+ xy_new[xy_new >= 1] = 1 - 10e-6
93
+ if xy_new.min() < 0:
94
+ xy_new[xy_new < 0] = 0.0
95
+ return xy_new
96
+
97
+ def coordinate2index(self, x, reso):
98
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
99
+ Corresponds to our 3D model
100
+
101
+ Args:
102
+ x (tensor): coordinate
103
+ reso (int): defined resolution
104
+ coord_type (str): coordinate type
105
+ '''
106
+ x = (x * reso).long()
107
+ index = x[:, :, 0] + reso * x[:, :, 1]
108
+ index = index[:, None, :]
109
+ return index
110
+
111
+ # uses values from plane_feature and pixel locations from vgrid to interpolate feature
112
+ def sample_plane_feature(self, query, plane_feature, plane):
113
+ xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
114
+ xy = xy[:, :, None].float()
115
+ vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
116
+ sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True,
117
+ mode='bilinear').squeeze(-1)
118
+ return sampled_feat
119
+
120
+
121
+
models/modules/diffusion_sampler.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def edm_sampler(
5
+ net, latents, randn_like=torch.randn_like,
6
+ num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
7
+ # S_churn=40, S_min=0.05, S_max=50, S_noise=1.003,
8
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False
9
+ ):
10
+ # Adjust noise levels based on what's supported by the network.
11
+ sigma_min = max(sigma_min, net.sigma_min)
12
+ sigma_max = min(sigma_max, net.sigma_max)
13
+
14
+ # Time step discretization.
15
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
16
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
17
+ t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
18
+
19
+ # Main sampling loop.
20
+ x_next = latents.to(torch.float64) * t_steps[0]
21
+ all_x=[]
22
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
23
+ x_cur = x_next
24
+
25
+ # Increase noise temporarily.
26
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
27
+ t_hat = net.round_sigma(t_cur + gamma * t_cur)
28
+ x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
29
+
30
+ # Euler step.
31
+ denoised = net(x_hat, t_hat).to(torch.float64)
32
+ d_cur = (x_hat - denoised) / t_hat
33
+ x_next = x_hat + (t_next - t_hat) * d_cur
34
+
35
+ # Apply 2nd order correction.
36
+ if i < num_steps - 1:
37
+ denoised = net(x_next, t_next).to(torch.float64)
38
+ d_prime = (x_next - denoised) / t_next
39
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
40
+ all_x.append(x_next.clone()/(t_next**2+1).sqrt())
41
+
42
+ if ret_all:
43
+ return x_next,all_x
44
+
45
+ return x_next
46
+
47
+ def edm_sampler_cond(
48
+ net, latents,cond_points, randn_like=torch.randn_like,
49
+ num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
50
+ # S_churn=40, S_min=0.05, S_max=50, S_noise=1.003,
51
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False
52
+ ):
53
+ # Adjust noise levels based on what's supported by the network.
54
+ sigma_min = max(sigma_min, net.sigma_min)
55
+ sigma_max = min(sigma_max, net.sigma_max)
56
+
57
+ # Time step discretization.
58
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
59
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
60
+ t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
61
+
62
+ # Main sampling loop.
63
+ x_next = latents.to(torch.float64) * t_steps[0]
64
+ all_x=[]
65
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
66
+ x_cur = x_next
67
+
68
+ # Increase noise temporarily.
69
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
70
+ t_hat = net.round_sigma(t_cur + gamma * t_cur)
71
+ x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
72
+
73
+ # Euler step.
74
+ denoised = net(x_hat, t_hat,cond_points).to(torch.float64)
75
+ d_cur = (x_hat - denoised) / t_hat
76
+ x_next = x_hat + (t_next - t_hat) * d_cur
77
+
78
+ # Apply 2nd order correction.
79
+ if i < num_steps - 1:
80
+ denoised = net(x_next, t_next,cond_points).to(torch.float64)
81
+ d_prime = (x_next - denoised) / t_next
82
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
83
+ all_x.append(x_next.clone()/(t_next**2+1).sqrt())
84
+
85
+ if ret_all:
86
+ return x_next,all_x
87
+
88
+ return x_next
89
+
models/modules/encoder.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch_scatter import scatter_mean, scatter_max
5
+ from .unet import UNet
6
+ from .resnet_block import ResnetBlockFC
7
+ import numpy as np
8
+
9
+ class DiagonalGaussianDistribution(object):
10
+ def __init__(self, mean, logvar, deterministic=False):
11
+ self.mean = mean
12
+ self.logvar = logvar
13
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
14
+ self.deterministic = deterministic
15
+ self.std = torch.exp(0.5 * self.logvar)
16
+ self.var = torch.exp(self.logvar)
17
+ if self.deterministic:
18
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device)
19
+
20
+ def sample(self):
21
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.mean.device)
22
+ return x
23
+
24
+ def kl(self, other=None):
25
+ if self.deterministic:
26
+ return torch.Tensor([0.])
27
+ else:
28
+ if other is None:
29
+ return 0.5 * torch.mean(torch.pow(self.mean, 2)
30
+ + self.var - 1.0 - self.logvar,
31
+ dim=[1, 2,3])
32
+ else:
33
+ return 0.5 * torch.mean(
34
+ torch.pow(self.mean - other.mean, 2) / other.var
35
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
36
+ dim=[1, 2, 3])
37
+
38
+ def nll(self, sample, dims=[1,2,3]):
39
+ if self.deterministic:
40
+ return torch.Tensor([0.])
41
+ logtwopi = np.log(2.0 * np.pi)
42
+ return 0.5 * torch.sum(
43
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
44
+ dim=dims)
45
+
46
+ def mode(self):
47
+ return self.mean
48
+
49
+ class ConvPointnet_Encoder(nn.Module):
50
+ ''' PointNet-based encoder network with ResNet blocks for each point.
51
+ Number of input points are fixed.
52
+
53
+ Args:
54
+ c_dim (int): dimension of latent code c
55
+ dim (int): input points dimension
56
+ hidden_dim (int): hidden dimension of the network
57
+ scatter_type (str): feature aggregation when doing local pooling
58
+ unet (bool): weather to use U-Net
59
+ unet_kwargs (str): U-Net parameters
60
+ plane_resolution (int): defined resolution for plane feature
61
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
62
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
63
+ n_blocks (int): number of blocks ResNetBlockFC layers
64
+ '''
65
+
66
+ def __init__(self, c_dim=128, dim=3, hidden_dim=128,latent_dim=32, scatter_type='max',
67
+ unet=False, unet_kwargs=None,
68
+ plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
69
+ super().__init__()
70
+ self.c_dim = c_dim
71
+
72
+ self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
73
+ self.blocks = nn.ModuleList([
74
+ ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)
75
+ ])
76
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
77
+
78
+ self.actvn = nn.ReLU()
79
+ self.hidden_dim = hidden_dim
80
+
81
+ if unet:
82
+ self.unet = UNet(unet_kwargs['output_dim'], in_channels=c_dim, **unet_kwargs)
83
+ else:
84
+ self.unet = None
85
+
86
+ self.reso_plane = plane_resolution
87
+ self.plane_type = plane_type
88
+ self.padding = padding
89
+
90
+ if scatter_type == 'max':
91
+ self.scatter = scatter_max
92
+ elif scatter_type == 'mean':
93
+ self.scatter = scatter_mean
94
+
95
+ self.mean_fc = nn.Conv2d(unet_kwargs['output_dim'], latent_dim,kernel_size=1)
96
+ self.logvar_fc = nn.Conv2d(unet_kwargs['output_dim'], latent_dim,kernel_size=1)
97
+
98
+ # takes in "p": point cloud and "query": sdf_xyz
99
+ # sample plane features for unlabeled_query as well
100
+ def forward(self, p,point_emb): # , query2):
101
+ batch_size, T, D = p.size()
102
+ #print('origin',torch.amin(p[0],dim=0),torch.amax(p[0],dim=0))
103
+ # acquire the index for each point
104
+ coord = {}
105
+ index = {}
106
+ if 'xz' in self.plane_type:
107
+ coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
108
+ index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
109
+ if 'xy' in self.plane_type:
110
+ coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
111
+ index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
112
+ if 'yz' in self.plane_type:
113
+ coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
114
+ index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
115
+ net = self.fc_pos(point_emb)
116
+
117
+ net = self.blocks[0](net)
118
+ for block in self.blocks[1:]:
119
+ pooled = self.pool_local(coord, index, net)
120
+ net = torch.cat([net, pooled], dim=2)
121
+ net = block(net)
122
+
123
+ c = self.fc_c(net)
124
+ #print(c.shape)
125
+
126
+ fea = {}
127
+ plane_feat_sum = 0
128
+ # second_sum = 0
129
+ if 'xz' in self.plane_type:
130
+ fea['xz'] = self.generate_plane_features(p, c,
131
+ plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
132
+ if 'xy' in self.plane_type:
133
+ fea['xy'] = self.generate_plane_features(p, c, plane='xy')
134
+ if 'yz' in self.plane_type:
135
+ fea['yz'] = self.generate_plane_features(p, c, plane='yz')
136
+ cat_feature = torch.cat([fea['xz'], fea['xy'], fea['yz']],
137
+ dim=2) # concat at row dimension
138
+ #print(cat_feature.shape)
139
+ plane_feat=self.unet(cat_feature)
140
+
141
+ mean=self.mean_fc(plane_feat)
142
+ logvar=self.logvar_fc(plane_feat)
143
+
144
+ posterior = DiagonalGaussianDistribution(mean, logvar)
145
+ x = posterior.sample()
146
+ kl = posterior.kl()
147
+
148
+ return kl, x, mean, logvar
149
+
150
+
151
+ def normalize_coordinate(self, p, padding=0.1, plane='xz'):
152
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
153
+
154
+ Args:
155
+ p (tensor): point
156
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
157
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
158
+ '''
159
+ if plane == 'xz':
160
+ xy = p[:, :, [0, 2]]
161
+ elif plane == 'xy':
162
+ xy = p[:, :, [0, 1]]
163
+ else:
164
+ xy = p[:, :, [1, 2]]
165
+ #print("origin",torch.amin(xy), torch.amax(xy))
166
+ xy=xy/2 #xy is originally -1 ~ 1
167
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
168
+ xy_new = xy_new + 0.5 # range (0, 1)
169
+ #print("scale",torch.amin(xy_new),torch.amax(xy_new))
170
+
171
+ # f there are outliers out of the range
172
+ if xy_new.max() >= 1:
173
+ xy_new[xy_new >= 1] = 1 - 10e-6
174
+ if xy_new.min() < 0:
175
+ xy_new[xy_new < 0] = 0.0
176
+ return xy_new
177
+
178
+ def coordinate2index(self, x, reso):
179
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
180
+ Corresponds to our 3D model
181
+
182
+ Args:
183
+ x (tensor): coordinate
184
+ reso (int): defined resolution
185
+ coord_type (str): coordinate type
186
+ '''
187
+ x = (x * reso).long()
188
+ index = x[:, :, 0] + reso * x[:, :, 1]
189
+ index = index[:, None, :]
190
+ return index
191
+
192
+ # xy is the normalized coordinates of the point cloud of each plane
193
+ # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
194
+ def pool_local(self, xy, index, c):
195
+ bs, fea_dim = c.size(0), c.size(2)
196
+ keys = xy.keys()
197
+
198
+ c_out = 0
199
+ for key in keys:
200
+ # scatter plane features from points
201
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2)
202
+ if self.scatter == scatter_max:
203
+ fea = fea[0]
204
+ # gather feature back to points
205
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
206
+ c_out += fea
207
+ return c_out.permute(0, 2, 1)
208
+
209
+ def generate_plane_features(self, p, c, plane='xz'):
210
+ # acquire indices of features in plane
211
+ xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
212
+ index = self.coordinate2index(xy, self.reso_plane)
213
+
214
+ # scatter plane features from points
215
+ fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2)
216
+ c = c.permute(0, 2, 1) # B x 512 x T
217
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
218
+ fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane,
219
+ self.reso_plane) # sparce matrix (B x 512 x reso x reso)
220
+ #print(fea_plane.shape)
221
+
222
+ return fea_plane
223
+
224
+ # sample_plane_feature function copied from /src/conv_onet/models/decoder.py
225
+ # uses values from plane_feature and pixel locations from vgrid to interpolate feature
226
+ def sample_plane_feature(self, query, plane_feature, plane):
227
+ xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
228
+ xy = xy[:, :, None].float()
229
+ vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
230
+ sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True,
231
+ mode='bilinear').squeeze(-1)
232
+ return sampled_feat
233
+
234
+
235
+
models/modules/image_sampler.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('../..')
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+ from models.modules.unet import RollOut_Conv
7
+ from einops import rearrange, reduce
8
+ MB =1024.0*1024.0
9
+ def mask_kernel(x, sigma=1):
10
+ return torch.abs(x) < sigma #if the distance is smaller than the kernel size, return True
11
+
12
+ def mask_kernel_close_false(x, sigma=1):
13
+ return torch.abs(x) > sigma #if the distance is smaller than the kernel size, return False
14
+
15
+ class Image_Local_Sampler(nn.Module):
16
+ def __init__(self,reso,padding=0.1,in_channels=1280,out_channels=512):
17
+ super().__init__()
18
+ self.triplane_reso=reso
19
+ self.padding=padding
20
+ self.get_triplane_coord()
21
+ self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1)
22
+ def get_triplane_coord(self):
23
+ '''xz plane firstly, z is at the '''
24
+ x=torch.arange(self.triplane_reso)
25
+ z=torch.arange(self.triplane_reso)
26
+ X,Z=torch.meshgrid(x,z,indexing='xy')
27
+ xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order
28
+
29
+ '''xy plane'''
30
+ x = torch.arange(self.triplane_reso)
31
+ y = torch.arange(self.triplane_reso)
32
+ X, Y = torch.meshgrid(x, y, indexing='xy')
33
+ xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order
34
+
35
+ '''yz plane'''
36
+ y = torch.arange(self.triplane_reso)
37
+ z = torch.arange(self.triplane_reso)
38
+ Y,Z = torch.meshgrid(y,z,indexing='xy')
39
+ yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1)
40
+
41
+ triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0)
42
+ triplane_coords=triplane_coords/(self.triplane_reso-1)
43
+ triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6)
44
+ self.triplane_coords=triplane_coords.float().cuda()
45
+
46
+ def forward(self,image_feat,proj_mat):
47
+ image_feat=self.img_proj(image_feat)
48
+ batch_size=image_feat.shape[0]
49
+ triplane_coords=self.triplane_coords.unsqueeze(0).expand(batch_size,-1,-1,-1) #B,192,64,3
50
+ #print(torch.amin(triplane_coords),torch.amax(triplane_coords))
51
+ coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,triplane_coords.shape[1],triplane_coords.shape[2],1)).float().cuda()],dim=-1)
52
+ coord_inimg=torch.einsum('bhwc,bck->bhwk',coord_homo,proj_mat.transpose(1,2))
53
+ x=coord_inimg[:,:,:,0]/coord_inimg[:,:,:,2]
54
+ y=coord_inimg[:,:,:,1]/coord_inimg[:,:,:,2]
55
+ x=(x/(224.0-1.0)-0.5)*2 #-1~1
56
+ y=(y/(224.0-1.0)-0.5)*2 #-1~1
57
+ dist=coord_inimg[:,:,:,2]
58
+
59
+ xy=torch.cat([x[:,:,:,None],y[:,:,:,None]],dim=-1)
60
+ #print(image_feat.shape,xy.shape)
61
+ sample_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear')
62
+ return sample_feat
63
+
64
+ def position_encoding(d_model, length):
65
+ if d_model % 2 != 0:
66
+ raise ValueError("Cannot use sin/cos positional encoding with "
67
+ "odd dim (got dim={:d})".format(d_model))
68
+ pe = torch.zeros(length, d_model)
69
+ position = torch.arange(0, length).unsqueeze(1) #length,1
70
+ div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
71
+ -(math.log(10000.0) / d_model))) #d_model//2, this is the frequency
72
+ pe[:, 0::2] = torch.sin(position.float() * div_term) #length*(d_model//2)
73
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
74
+
75
+ return pe
76
+
77
+ class Image_Vox_Local_Sampler(nn.Module):
78
+ def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
79
+ super().__init__()
80
+ self.triplane_reso=reso
81
+ self.padding=padding
82
+ self.get_vox_coord()
83
+ self.out_channels=out_channels
84
+ self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1)
85
+
86
+ self.vox_process=nn.Sequential(
87
+ nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1,),
88
+ )
89
+ self.k=nn.Linear(in_features=inner_channel,out_features=inner_channel)
90
+ self.q=nn.Linear(in_features=inner_channel,out_features=inner_channel)
91
+ self.v=nn.Linear(in_features=inner_channel,out_features=inner_channel)
92
+ self.attn = torch.nn.MultiheadAttention(
93
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
94
+
95
+ self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
96
+ self.condition_pe = position_encoding(inner_channel, self.triplane_reso).unsqueeze(0)
97
+ def get_vox_coord(self):
98
+ x = torch.arange(self.triplane_reso)
99
+ y = torch.arange(self.triplane_reso)
100
+ z = torch.arange(self.triplane_reso)
101
+
102
+ X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
103
+ vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
104
+ vox_coor=vox_coor/(self.triplane_reso-1)
105
+ vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6)
106
+ self.vox_coor=vox_coor.view(-1,3).float().cuda()
107
+
108
+
109
+ def forward(self,triplane_feat,image_feat,proj_mat):
110
+ xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64
111
+ image_feat=self.img_proj(image_feat)
112
+ batch_size=image_feat.shape[0]
113
+ vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3
114
+ vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1)
115
+ coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2))
116
+ x=coord_inimg[:,:,0]/coord_inimg[:,:,2]
117
+ y=coord_inimg[:,:,1]/coord_inimg[:,:,2]
118
+ x=(x/(224.0-1.0)-0.5)*2 #-1~1
119
+ y=(y/(224.0-1.0)-0.5)*2 #-1~1
120
+
121
+ xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2
122
+ #print(image_feat.shape,xy.shape)
123
+ grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\
124
+ view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3
125
+
126
+ grid_feat=self.vox_process(grid_feat)
127
+ xzy_grid=grid_feat.permute(0,4,2,3,1)
128
+ xz_as_query=xz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
129
+ xz_as_key=xzy_grid.reshape(batch_size*self.triplane_reso**2,self.triplane_reso,-1)
130
+
131
+ xyz_grid=grid_feat.permute(0,3,2,4,1)
132
+ xy_as_query=xy_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
133
+ xy_as_key = xyz_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1)
134
+
135
+ yzx_grid = grid_feat.permute(0, 4, 3, 2, 1)
136
+ yz_as_query = yz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
137
+ yz_as_key = yzx_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1)
138
+
139
+ query=self.q(torch.cat([xz_as_query,xy_as_query,yz_as_query],dim=0))
140
+ key=self.k(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device)
141
+ value=self.v(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device)
142
+
143
+ attn,_=self.attn(query,key,value)
144
+ xz_plane,xy_plane,yz_plane=torch.split(attn,dim=0,split_size_or_sections=batch_size*self.triplane_reso**2)
145
+ xz_plane=xz_plane.reshape(batch_size,self.triplane_reso,self.triplane_reso,-1).permute(0,3,1,2)
146
+ xy_plane = xy_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2)
147
+ yz_plane = yz_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2)
148
+
149
+ triplane_wImg=torch.cat([xz_plane,xy_plane,yz_plane],dim=2)
150
+ triplane_wImg=self.proj_out(triplane_wImg)
151
+ #print(triplane_wImg.shape)
152
+
153
+ return triplane_wImg
154
+
155
+ class Image_Direct_AttenwMask_Sampler(nn.Module):
156
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
157
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
158
+ super().__init__()
159
+ self.triplane_reso=reso
160
+ self.vit_reso=vit_reso
161
+ self.padding=padding
162
+ self.n_heads=n_heads
163
+ self.get_plane_expand_coord()
164
+ self.get_vit_coords()
165
+ self.out_channels=out_channels
166
+ self.kernel_func=mask_kernel
167
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
168
+ self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
169
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
170
+ self.attn = torch.nn.MultiheadAttention(
171
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
172
+
173
+ self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
174
+ self.image_pe = position_encoding(inner_channel, self.vit_reso**2+1).unsqueeze(0).cuda().float() #1,n_img*reso*reso,channel
175
+ self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float()
176
+ def get_plane_expand_coord(self):
177
+ x = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
178
+ y = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
179
+ z = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
180
+
181
+ first,second,third=torch.meshgrid(x,y,z,indexing='xy')
182
+ xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3
183
+ xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz
184
+ xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy
185
+ yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx
186
+
187
+ # print(xyz_coords[0,0,0],xyz_coords[0,0,1],xyz_coords[1,0,0],xyz_coords[0,1,0])
188
+ # print(xzy_coords[0, 0, 0], xzy_coords[0, 0, 1], xzy_coords[1, 0, 0], xzy_coords[0, 1, 0])
189
+ # print(yzx_coords[0, 0, 0], yzx_coords[0, 0, 1], yzx_coords[1, 0, 0], yzx_coords[0, 1, 0])
190
+
191
+ xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1)
192
+ xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1)
193
+ yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1)
194
+
195
+ coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0)
196
+ self.plane_coords=coords.cuda().float()
197
+ # self.xzy_coords=xzy_coords.cuda().float() #reso**3,3
198
+ # self.xyz_coords=xyz_coords.cuda().float() #reso**3,3
199
+ # self.yzx_coords=yzx_coords.cuda().float() #reso**3,3
200
+
201
+ def get_vit_coords(self):
202
+ x=torch.arange(self.vit_reso)
203
+ y=torch.arange(self.vit_reso)
204
+
205
+ X,Y=torch.meshgrid(x,y,indexing='xy')
206
+ vit_coords=torch.stack([X,Y],dim=-1)
207
+ self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
208
+
209
+ def get_attn_mask(self,coords_proj,vit_coords,kernel_size=1.0):
210
+ '''
211
+ :param coords_proj: B,reso**3,2, in range of 0~1
212
+ :param vit_coords: B,vit_reso**2,2, in range of 0~vit_reso
213
+ :param kernel_size: 0.5, so that only one pixel will be select
214
+ :return:
215
+ '''
216
+ bs=coords_proj.shape[0]
217
+ coords_proj=coords_proj*(self.vit_reso-1)
218
+ #print(torch.amin(coords_proj[0,0:self.triplane_reso**3]),torch.amax(coords_proj[0,0:self.triplane_reso**3]))
219
+ dist=torch.cdist(coords_proj.float(),vit_coords.float())
220
+ mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B,3*reso**3,vit_reso**2
221
+ mask=mask.reshape(bs,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2)
222
+ mask=torch.sum(mask,dim=2)
223
+ attn_mask=(mask==0)
224
+ return attn_mask
225
+
226
+ def forward(self,triplane_feat,image_feat,proj_mat):
227
+ #xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64
228
+ batch_size=image_feat.shape[0]
229
+ #print(self.plane_coords.shape)
230
+ coords=self.plane_coords.unsqueeze(0).expand(batch_size,-1,-1)
231
+
232
+ coords_homo=torch.cat([coords,torch.ones(batch_size,self.triplane_reso**3*3,1).float().cuda()],dim=-1)
233
+ coords_inimg=torch.einsum('bhc,bck->bhk',coords_homo,proj_mat.transpose(1,2))
234
+ coords_x=coords_inimg[:,:,0]/coords_inimg[:,:,2]/(224.0-1) #0~1
235
+ coords_y=coords_inimg[:,:,1]/coords_inimg[:,:,2]/(224.0-1) #0~1
236
+ coords_x=torch.clamp(coords_x,min=0.0,max=1.0)
237
+ coords_y=torch.clamp(coords_y,min=0.0,max=1.0)
238
+ #print(torch.amin(coords_x),torch.amax(coords_x))
239
+ coords_proj=torch.stack([coords_x,coords_y],dim=-1)
240
+ vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1)
241
+ attn_mask=torch.repeat_interleave(
242
+ self.get_attn_mask(coords_proj,vit_coords,kernel_size=1.0),self.n_heads, 0
243
+ )
244
+ attn_mask = torch.cat([torch.zeros([attn_mask.shape[0], attn_mask.shape[1], 1]).cuda().bool(), attn_mask],
245
+ dim=-1) # add global token
246
+ #print(attn_mask.shape,torch.sum(attn_mask.float()))
247
+ triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1)
248
+ #print(triplane_feat.shape,self.triplane_pe.shape)
249
+ query=self.q(triplane_feat)+self.triplane_pe
250
+ key=self.k(image_feat)+self.image_pe
251
+ value=self.v(image_feat)+self.image_pe
252
+ #print(query.shape,key.shape,value.shape)
253
+ attn,_=self.attn(query,key,value,attn_mask=attn_mask)
254
+ #print(attn.shape)
255
+ output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso)
256
+
257
+ return output
258
+
259
+ class MultiImage_Direct_AttenwMask_Sampler(nn.Module):
260
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
261
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
262
+ super().__init__()
263
+ self.triplane_reso=reso
264
+ self.vit_reso=vit_reso
265
+ self.padding=padding
266
+ self.n_heads=n_heads
267
+ self.get_plane_expand_coord()
268
+ self.get_vit_coords()
269
+ self.out_channels=out_channels
270
+ self.kernel_func=mask_kernel
271
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
272
+ self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
273
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
274
+ self.attn = torch.nn.MultiheadAttention(
275
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
276
+
277
+ self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
278
+ self.image_pe = position_encoding(inner_channel, max_nimg*(self.vit_reso**2+1)).unsqueeze(0).cuda().float()
279
+ self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float()
280
+ def get_plane_expand_coord(self):
281
+ x = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
282
+ y = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
283
+ z = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
284
+
285
+ first,second,third=torch.meshgrid(x,y,z,indexing='xy')
286
+ xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3
287
+ xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz
288
+ xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy
289
+ yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx
290
+
291
+ xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1)
292
+ xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1)
293
+ yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1)
294
+
295
+ coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0)
296
+ self.plane_coords=coords.cuda().float()
297
+ # self.xzy_coords=xzy_coords.cuda().float() #reso**3,3
298
+ # self.xyz_coords=xyz_coords.cuda().float() #reso**3,3
299
+ # self.yzx_coords=yzx_coords.cuda().float() #reso**3,3
300
+
301
+ def get_vit_coords(self):
302
+ x=torch.arange(self.vit_reso)
303
+ y=torch.arange(self.vit_reso)
304
+
305
+ X,Y=torch.meshgrid(x,y,indexing='xy')
306
+ vit_coords=torch.stack([X,Y],dim=-1)
307
+ self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
308
+
309
+ def get_attn_mask(self,coords_proj,vit_coords,valid_frames,kernel_size=1.0):
310
+ '''
311
+ :param coords_proj: B,n_img,3*reso**3,2, in range of 0~vit_reso
312
+ :param vit_coords: B,n_img,vit_reso**2,2, in range of 0~vit_reso
313
+ :param kernel_size: 0.5, so that only one pixel will be select
314
+ :return:
315
+ '''
316
+ bs,n_img=coords_proj.shape[0],coords_proj.shape[1]
317
+ coords_proj_flat=coords_proj.reshape(bs*n_img,3*self.triplane_reso**3,2)
318
+ vit_coords_flat=vit_coords.reshape(bs*n_img,self.vit_reso**2,2)
319
+ dist=torch.cdist(coords_proj_flat.float(),vit_coords_flat.float())
320
+ mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B*n_img,3*reso**3,vit_reso**2
321
+ mask=mask.reshape(bs,n_img,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2)
322
+ mask=torch.sum(mask,dim=3) #B,n_img,3*reso**2,vit_reso**2
323
+ mask=torch.cat([torch.ones(size=mask.shape[0:3]).unsqueeze(3).float().cuda(),mask],dim=-1) #B,n_img,3*reso**2,vit_reso**2+1, add global mask
324
+ mask[valid_frames == 0, :, :] = False
325
+ mask=mask.permute(0,2,1,3).reshape(bs,3*self.triplane_reso**2,-1) #B,3*reso**2,n_img*(vit_resso**2+1)
326
+ attn_mask=(mask==0) #invert the mask, False indicates valid, True indicates invalid
327
+ return attn_mask
328
+
329
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
330
+ '''image feat is bs,n_img,length,channel'''
331
+ batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
332
+ img_length=image_feat.shape[2]
333
+ image_feat_flat=image_feat.view(batch_size,n_img*img_length,-1)
334
+ coords=self.plane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
335
+
336
+ coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3*3,1).float().cuda()],dim=-1)
337
+ #print(coord_homo.shape,proj_mat.shape)
338
+ coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
339
+ x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
340
+ y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
341
+ x = x/(224.0-1)
342
+ y = y/(224.0-1)
343
+ coords_x=torch.clamp(x,min=0.0,max=1.0)*(self.vit_reso-1)
344
+ coords_y=torch.clamp(y,min=0.0,max=1.0)*(self.vit_reso-1)
345
+ coords_proj=torch.stack([coords_x,coords_y],dim=-1)
346
+ vit_coords=self.vit_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
347
+ attn_mask=torch.repeat_interleave(
348
+ self.get_attn_mask(coords_proj,vit_coords,valid_frames,kernel_size=1.0),self.n_heads, 0
349
+ )
350
+ triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1)
351
+ query=self.q(triplane_feat)+self.triplane_pe
352
+ key=self.k(image_feat_flat)+self.image_pe
353
+ value=self.v(image_feat_flat)+self.image_pe
354
+ attn,_=self.attn(query,key,value,attn_mask=attn_mask)
355
+ output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso)
356
+
357
+ return output
358
+
359
+ class MultiImage_Fuse_Sampler(nn.Module):
360
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
361
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
362
+ super().__init__()
363
+ self.triplane_reso=reso
364
+ self.vit_reso=vit_reso
365
+ self.inner_channel=inner_channel
366
+ self.padding=padding
367
+ self.n_heads=n_heads
368
+ self.get_vox_coord()
369
+ self.get_vit_coords()
370
+ self.out_channels=out_channels
371
+ self.kernel_func=mask_kernel
372
+ self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso))
373
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
374
+ self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel)
375
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
376
+
377
+ #self.cross_attn=CrossAttention(query_dim=inner_channel,heads=8,dim_head=inner_channel//8)
378
+ self.cross_attn = torch.nn.MultiheadAttention(
379
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
380
+ self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
381
+ self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].cuda().float() #1,1,length,channel
382
+ #self.image_pe = self.image_pe.reshape(1,max_nimg,self.vit_reso,self.vit_reso,inner_channel)
383
+ self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float()
384
+
385
+ def get_vit_coords(self):
386
+ x = torch.arange(self.vit_reso)
387
+ y = torch.arange(self.vit_reso)
388
+
389
+ X, Y = torch.meshgrid(x, y, indexing='xy')
390
+ vit_coords = torch.stack([X, Y], dim=-1)
391
+ self.vit_coords = vit_coords.cuda().float() #reso,reso,2
392
+
393
+ def get_vox_coord(self):
394
+ x = torch.arange(self.triplane_reso)
395
+ y = torch.arange(self.triplane_reso)
396
+ z = torch.arange(self.triplane_reso)
397
+
398
+ X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
399
+ vox_coor = torch.cat([X[:, :, :, None], Y[:, :, :, None], Z[:, :, :, None]], dim=-1)
400
+ self.vox_index = vox_coor.view(-1, 3).long().cuda()
401
+
402
+ vox_coor = self.vox_index.float() / (self.triplane_reso - 1)
403
+ vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6)
404
+ self.vox_coor = vox_coor.view(-1, 3).float().cuda()
405
+
406
+ def get_attn_mask(self,valid_frames):
407
+ '''
408
+ :param valid_frames: of shape B,n_img
409
+ '''
410
+ #print(valid_frames)
411
+ #bs,n_img=valid_frames.shape[0:2]
412
+ attn_mask=(valid_frames.float()==0)
413
+ #attn_mask=attn_mask.unsqueeze(1).unsqueeze(2).expand(-1,self.triplane_reso**3,-1,-1) #B,1,n_img
414
+ #attn_mask=attn_mask.reshape(bs*self.triplane_reso**3,-1,n_img).bool()
415
+ attn_mask=torch.repeat_interleave(attn_mask.unsqueeze(1),self.triplane_reso**3,0)
416
+ # print(attn_mask[self.triplane_reso**3*1+10])
417
+ # print(attn_mask[self.triplane_reso ** 3 * 2+10])
418
+ # print(attn_mask[self.triplane_reso ** 3 * 3+10])
419
+ return attn_mask
420
+
421
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
422
+ '''image feat is bs,n_img,length,channel'''
423
+ batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
424
+ image_feat=image_feat[:,:,1:,:] #discard global feature
425
+
426
+ #image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c
427
+ image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c
428
+ image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c
429
+ image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c
430
+ unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso
431
+ #unflat_k_v=image_k_v.permute(0,4,1,2,3)
432
+ #vit_coords=self.vit_coords[None,None].expand(batch_size,n_img,-1,-1,-1) #Bs,n_img,reso,reso,2
433
+
434
+ coords=self.vox_coor.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
435
+ coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3,1).float().cuda()],dim=-1)
436
+ coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
437
+ x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
438
+ y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
439
+ x = x/(224.0-1) #0~1
440
+ y = y/(224.0-1)
441
+ coords_proj=torch.stack([x,y],dim=-1)
442
+ coords_proj=(coords_proj-0.5)*2
443
+ img_index=((torch.arange(n_img)[None,:,None,None].expand(
444
+ batch_size,-1,self.triplane_reso**3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1
445
+
446
+ # img_index_feat=torch.arange(n_img)[None,:,None,None,None].expand(
447
+ # batch_size,-1,self.vit_reso,self.vit_reso,-1).float().cuda() #Bs,n_img,reso,reso,1
448
+ #coords_feat=torch.cat([vit_coords,img_index_feat],dim=-1).permute(0,4,1,2,3)#Bs,n_img,reso,reso,3
449
+ grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index
450
+ grid=torch.clamp(grid,min=-1.0,max=1.0)
451
+ sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3
452
+ xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3,
453
+ dim=2) # B,C,64,64
454
+ xz_vox_feat=xz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,4,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zxy
455
+ xz_vox_feat=rearrange(xz_vox_feat, 'b c z x y -> b (x y z) c')
456
+ xy_vox_feat=xy_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,2,4).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #yxz
457
+ xy_vox_feat=rearrange(xy_vox_feat, 'b c y x z -> b (x y z) c')
458
+ yz_vox_feat=yz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,4,3,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zyx
459
+ yz_vox_feat=rearrange(yz_vox_feat, 'b c z y x -> b (x y z) c')
460
+ #xz_vox_feat = xz_feat[:, :, vox_index[:, 2], vox_index[:, 0]].transpose(1, 2) # B,C,64*64*64
461
+ #xy_vox_feat = xy_feat[:, :, vox_index[:, 1], vox_index[:, 0]].transpose(1, 2)
462
+ #yz_vox_feat = yz_feat[:, :, vox_index[:, 2], vox_index[:, 1]].transpose(1, 2)
463
+
464
+ triplane_expand_feat = torch.cat([xz_vox_feat, xy_vox_feat, yz_vox_feat], dim=-1) # B,64*64*64,3*C
465
+ triplane_query = self.q(triplane_expand_feat) + self.triplane_pe
466
+ k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c')
467
+ #k_v=sample_k_v.permute(0,3,2,1).reshape(batch_size*self.triplane_reso**3,n_img,-1) #B*64**3,n_img,C
468
+ k=k_v[:,:,0:self.inner_channel]
469
+ v=k_v[:,:,self.inner_channel:]
470
+ q=rearrange(triplane_query,'b k c -> (b k) 1 c')
471
+ #q=triplane_query.view(batch_size*self.triplane_reso**3,1,-1)
472
+ #k,v is of shape, B*reso**3,k,channel, q is of shape B*reso**3,1,channel
473
+ #attn mask should be B*reso**3*n_heads,1,k
474
+ #attn_mask=torch.repeat_interleave(self.get_attn_mask(valid_frames),self.n_heads,0)
475
+ #print(q.shape,k.shape,v.shape)
476
+ attn_out,_=self.cross_attn(q,k,v)#attn_mask=attn_mask) #fuse multi-view feature
477
+ #volume=attn_out.view(batch_size,self.triplane_reso,self.triplane_reso,self.triplane_reso,-1) #B,reso,reso,reso,channel
478
+ #print(attn_out.shape)
479
+ volume=rearrange(attn_out,'(b x y z) 1 c -> b x y z c',x=self.triplane_reso,y=self.triplane_reso,z=self.triplane_reso)
480
+ #xz_feat = torch.mean(volume, dim=2).transpose(1,2) #B,reso,reso,C
481
+ xz_feat = reduce(volume, "b x y z c -> b z x c", 'mean')
482
+ #xy_feat = torch.mean(volume, dim=3).transpose(1,2) #B,reso,reso,C
483
+ xy_feat= reduce(volume, 'b x y z c -> b y x c', 'mean')
484
+ #yz_feat = torch.mean(volume, dim=1).transpose(1,2) #B,reso,reso,C
485
+ yz_feat=reduce(volume, 'b x y z c -> b z y c', 'mean')
486
+ triplane_out = torch.cat([xz_feat, xy_feat, yz_feat], dim=1) #B,reso*3,reso,C
487
+ #print(triplane_out.shape)
488
+ triplane_out = self.proj_out(triplane_out)
489
+ triplane_out = triplane_out.permute(0,3,1,2)
490
+ #print(triplane_out.shape)
491
+ return triplane_out
492
+
493
+ class MultiImage_TriFuse_Sampler(nn.Module):
494
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
495
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
496
+ super().__init__()
497
+ self.triplane_reso=reso
498
+ self.vit_reso=vit_reso
499
+ self.inner_channel=inner_channel
500
+ self.padding=padding
501
+ self.n_heads=n_heads
502
+ self.get_triplane_coord()
503
+ self.get_vit_coords()
504
+ self.out_channels=out_channels
505
+ self.kernel_func=mask_kernel
506
+ self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso))
507
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
508
+ self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
509
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
510
+
511
+ self.cross_attn = torch.nn.MultiheadAttention(
512
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
513
+ self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
514
+ self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel
515
+ self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 2*3).unsqueeze(0).cuda().float()
516
+
517
+ def get_vit_coords(self):
518
+ x = torch.arange(self.vit_reso)
519
+ y = torch.arange(self.vit_reso)
520
+
521
+ X, Y = torch.meshgrid(x, y, indexing='xy')
522
+ vit_coords = torch.stack([X, Y], dim=-1)
523
+ self.vit_coords = vit_coords.cuda().float() #reso,reso,2
524
+
525
+ def get_triplane_coord(self):
526
+ '''xz plane firstly, z is at the '''
527
+ x = torch.arange(self.triplane_reso)
528
+ z = torch.arange(self.triplane_reso)
529
+ X, Z = torch.meshgrid(x, z, indexing='xy')
530
+ xz_coords = torch.cat(
531
+ [X[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2, Z[:, :, None]],
532
+ dim=-1) # in xyz order
533
+
534
+ '''xy plane'''
535
+ x = torch.arange(self.triplane_reso)
536
+ y = torch.arange(self.triplane_reso)
537
+ X, Y = torch.meshgrid(x, y, indexing='xy')
538
+ xy_coords = torch.cat(
539
+ [X[:, :, None], Y[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2],
540
+ dim=-1) # in xyz order
541
+
542
+ '''yz plane'''
543
+ y = torch.arange(self.triplane_reso)
544
+ z = torch.arange(self.triplane_reso)
545
+ Y, Z = torch.meshgrid(y, z, indexing='xy')
546
+ yz_coords = torch.cat(
547
+ [torch.ones_like(Y[:, :, None]) * (self.triplane_reso - 1) / 2, Y[:, :, None], Z[:, :, None]], dim=-1)
548
+
549
+ triplane_coords = torch.cat([xz_coords, xy_coords, yz_coords], dim=0)
550
+ triplane_coords = triplane_coords / (self.triplane_reso - 1)
551
+ triplane_coords = (triplane_coords - 0.5) * 2 * (1 + self.padding + 10e-6)
552
+ self.triplane_coords = triplane_coords.view(-1,3).float().cuda()
553
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
554
+ '''image feat is bs,n_img,length,channel'''
555
+ batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
556
+ image_feat=image_feat[:,:,1:,:] #discard global feature
557
+ #print(image_feat.shape)
558
+
559
+ #image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c
560
+ image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c
561
+ image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c
562
+ image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c
563
+ unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso
564
+
565
+ coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
566
+ coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**2*3,1).float().cuda()],dim=-1)
567
+ coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
568
+ x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
569
+ y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
570
+ x = x/(224.0-1) #0~1
571
+ y = y/(224.0-1)
572
+ coords_proj=torch.stack([x,y],dim=-1)
573
+ coords_proj=(coords_proj-0.5)*2
574
+ img_index=((torch.arange(n_img)[None,:,None,None].expand(
575
+ batch_size,-1,self.triplane_reso**2*3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1
576
+
577
+ grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index
578
+ grid=torch.clamp(grid,min=-1.0,max=1.0)
579
+ sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3
580
+
581
+ triplane_flat_feat=rearrange(triplane_feat,'b c h w -> b (h w) c')
582
+ triplane_query = self.q(triplane_flat_feat) + self.triplane_pe
583
+
584
+ k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c')
585
+ k=k_v[:,:,0:self.inner_channel]
586
+ v=k_v[:,:,self.inner_channel:]
587
+ q=rearrange(triplane_query,'b k c -> (b k) 1 c')
588
+ attn_out,_=self.cross_attn(q,k,v)
589
+ triplane_out=rearrange(attn_out,'(b h w) 1 c -> b c h w',b=batch_size,h=self.triplane_reso*3,w=self.triplane_reso)
590
+ triplane_out = self.proj_out(triplane_out)
591
+ return triplane_out
592
+
593
+
594
+ class MultiImage_Global_Sampler(nn.Module):
595
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
596
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
597
+ super().__init__()
598
+ self.triplane_reso=reso
599
+ self.vit_reso=vit_reso
600
+ self.inner_channel=inner_channel
601
+ self.padding=padding
602
+ self.n_heads=n_heads
603
+ self.out_channels=out_channels
604
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
605
+ self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
606
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
607
+
608
+ self.cross_attn = torch.nn.MultiheadAttention(
609
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
610
+ self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
611
+ self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel
612
+ self.triplane_pe = position_encoding(inner_channel, self.triplane_reso**2*3).unsqueeze(0).cuda().float()
613
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
614
+ '''image feat is bs,n_img,length,channel
615
+ triplane feat is bs,C,H*3,W
616
+ '''
617
+ batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
618
+ L=image_feat.shape[2]-1
619
+ image_feat=image_feat[:,:,1:,:] #discard global feature
620
+
621
+ image_k=self.k(image_feat)+self.image_pe #B,n_img,h*w,c
622
+ image_v=self.v(image_feat)+self.image_pe #B,n_img,h*w,c
623
+ image_k=image_k.view(batch_size,n_img*L,-1)
624
+ image_v=image_v.view(batch_size,n_img*L,-1)
625
+
626
+ triplane_flat_feat=rearrange(triplane_feat,"b c h w -> b (h w) c")
627
+ triplane_query = self.q(triplane_flat_feat) + self.triplane_pe
628
+ #print(triplane_query.shape,image_k.shape,image_v.shape)
629
+ attn_out,_=self.cross_attn(triplane_query,image_k,image_v)
630
+ triplane_flat_out = self.proj_out(attn_out)
631
+ triplane_out=rearrange(triplane_flat_out,"b (h w) c -> b c h w",h=self.triplane_reso*3,w=self.triplane_reso)
632
+
633
+ return triplane_out
634
+
635
+ class CrossAttention(nn.Module):
636
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
637
+ super().__init__()
638
+ inner_dim = dim_head * heads
639
+
640
+ if context_dim is None:
641
+ context_dim = query_dim
642
+
643
+ self.scale = dim_head ** -0.5
644
+ self.heads = heads
645
+
646
+ self.to_out = nn.Sequential(
647
+ nn.Linear(inner_dim, query_dim),
648
+ nn.Dropout(dropout)
649
+ )
650
+
651
+ def forward(self, q,k,v):
652
+ h = self.heads
653
+
654
+ q, k, v = map(lambda t: rearrange(
655
+ t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
656
+
657
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
658
+
659
+ # attention, what we cannot get enough of
660
+ attn = sim.softmax(dim=-1)
661
+
662
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
663
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
664
+ return self.to_out(out)
665
+
666
+ class Image_Vox_Local_Sampler_Pooling(nn.Module):
667
+ def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,stride=4):
668
+ super().__init__()
669
+ self.triplane_reso=reso
670
+ self.padding=padding
671
+ self.get_vox_coord()
672
+ self.out_channels=out_channels
673
+ self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1)
674
+
675
+ self.vox_process=nn.Sequential(
676
+ nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1)
677
+ )
678
+ self.xz_conv=nn.Sequential(
679
+ nn.BatchNorm3d(inner_channel),
680
+ nn.ReLU(),
681
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
682
+ nn.AvgPool3d((1,stride,1),stride=(1,stride,1)), #8
683
+ nn.BatchNorm3d(inner_channel),
684
+ nn.ReLU(),
685
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
686
+ nn.AvgPool3d((1,stride,1), stride=(1,stride,1)), #2
687
+ nn.BatchNorm3d(inner_channel),
688
+ nn.ReLU(),
689
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
690
+ )
691
+ self.xy_conv = nn.Sequential(
692
+ nn.BatchNorm3d(inner_channel),
693
+ nn.ReLU(),
694
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
695
+ nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 8
696
+ nn.BatchNorm3d(inner_channel),
697
+ nn.ReLU(),
698
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
699
+ nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 2
700
+ nn.BatchNorm3d(inner_channel),
701
+ nn.ReLU(),
702
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
703
+ )
704
+ self.yz_conv = nn.Sequential(
705
+ nn.BatchNorm3d(inner_channel),
706
+ nn.ReLU(),
707
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
708
+ nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 8
709
+ nn.BatchNorm3d(inner_channel),
710
+ nn.ReLU(),
711
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
712
+ nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 2
713
+ nn.BatchNorm3d(inner_channel),
714
+ nn.ReLU(),
715
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
716
+ )
717
+ self.roll_out_conv=RollOut_Conv(in_channels=inner_channel,out_channels=out_channels)
718
+ #self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
719
+ def get_vox_coord(self):
720
+ x = torch.arange(self.triplane_reso)
721
+ y = torch.arange(self.triplane_reso)
722
+ z = torch.arange(self.triplane_reso)
723
+
724
+ X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
725
+ vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
726
+ vox_coor=vox_coor/(self.triplane_reso-1)
727
+ vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6)
728
+ self.vox_coor=vox_coor.view(-1,3).float().cuda()
729
+
730
+
731
+ def forward(self,image_feat,proj_mat):
732
+ image_feat=self.img_proj(image_feat)
733
+ batch_size=image_feat.shape[0]
734
+ vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3
735
+ vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1)
736
+ coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2))
737
+ x=coord_inimg[:,:,0]/coord_inimg[:,:,2]
738
+ y=coord_inimg[:,:,1]/coord_inimg[:,:,2]
739
+ x=(x/(224.0-1.0)-0.5)*2 #-1~1
740
+ y=(y/(224.0-1.0)-0.5)*2 #-1~1
741
+
742
+ xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2
743
+ #print(image_feat.shape,xy.shape)
744
+ grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\
745
+ view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3
746
+
747
+ grid_feat=self.vox_process(grid_feat)
748
+ xz_feat=torch.mean(self.xz_conv(grid_feat),dim=3).transpose(2,3)
749
+ xy_feat=torch.mean(self.xy_conv(grid_feat),dim=4).transpose(2,3)
750
+ yz_feat=torch.mean(self.yz_conv(grid_feat),dim=2).transpose(2,3)
751
+ triplane_wImg=torch.cat([xz_feat,xy_feat,yz_feat],dim=2)
752
+ #print(triplane_wImg.shape)
753
+
754
+ return self.roll_out_conv(triplane_wImg)
755
+
756
+ class Image_ExpandVox_attn_Sampler(nn.Module):
757
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
758
+ super().__init__()
759
+ self.triplane_reso=reso
760
+ self.padding=padding
761
+ self.vit_reso=vit_reso
762
+ self.get_vox_coord()
763
+ self.get_vit_coords()
764
+ self.out_channels=out_channels
765
+ self.n_heads=n_heads
766
+
767
+ self.kernel_func = mask_kernel_close_false
768
+ self.k = nn.Linear(in_features=img_in_channels, out_features=inner_channel)
769
+ # self.q_xz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
770
+ # self.q_xy = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
771
+ # self.q_yz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
772
+ self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel)
773
+
774
+ self.v = nn.Linear(in_features=img_in_channels, out_features=inner_channel)
775
+ self.attn = torch.nn.MultiheadAttention(
776
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
777
+ self.out_proj=nn.Linear(in_features=inner_channel,out_features=out_channels)
778
+
779
+ self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float()
780
+ self.image_pe = position_encoding(inner_channel, self.vit_reso ** 2+1).unsqueeze(0).cuda().float()
781
+ def get_vox_coord(self):
782
+ x = torch.arange(self.triplane_reso)
783
+ y = torch.arange(self.triplane_reso)
784
+ z = torch.arange(self.triplane_reso)
785
+
786
+ X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
787
+ vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
788
+ self.vox_index=vox_coor.view(-1,3).long().cuda()
789
+
790
+
791
+ vox_coor = self.vox_index.float() / (self.triplane_reso - 1)
792
+ vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6)
793
+ self.vox_coor = vox_coor.view(-1, 3).float().cuda()
794
+ # print(self.vox_coor[0])
795
+ # print(self.vox_coor[self.triplane_reso**2])#x should increase
796
+ # print(self.vox_coor[self.triplane_reso]) #y should increase
797
+ # print(self.vox_coor[1])#z should increase
798
+
799
+ def get_vit_coords(self):
800
+ x=torch.arange(self.vit_reso)
801
+ y=torch.arange(self.vit_reso)
802
+
803
+ X,Y=torch.meshgrid(x,y,indexing='xy')
804
+ vit_coords=torch.stack([X,Y],dim=-1)
805
+ self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
806
+
807
+ def compute_attn_mask(self,proj_coords,vit_coords,kernel_size=1.0):
808
+ dist = torch.cdist(proj_coords.float(), vit_coords.float())
809
+ mask = self.kernel_func(dist, sigma=kernel_size) # True if valid, B,reso**3,vit_reso**2
810
+ return mask
811
+
812
+
813
+ def forward(self,triplane_feat,image_feat,proj_mat):
814
+ xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3, dim=2) # B,C,64,64
815
+ #xz_feat=self.q_xz(xz_feat)
816
+ #xy_feat=self.q_xy(xy_feat)
817
+ #yz_feat=self.q_yz(yz_feat)
818
+ batch_size=image_feat.shape[0]
819
+ vox_index=self.vox_index #64*64*64,3
820
+ xz_vox_feat=xz_feat[:,:,vox_index[:,2],vox_index[:,0]].transpose(1,2) #B,C,64*64*64
821
+ xy_vox_feat=xy_feat[:,:,vox_index[:,1],vox_index[:,0]].transpose(1,2)
822
+ yz_vox_feat=yz_feat[:,:,vox_index[:,2],vox_index[:,1]].transpose(1,2)
823
+ triplane_expand_feat=torch.cat([xz_vox_feat,xy_vox_feat,yz_vox_feat],dim=-1)#B,C,64*64*64,3
824
+ triplane_query=self.q(triplane_expand_feat)+self.triplane_pe
825
+
826
+ '''compute projection'''
827
+ vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #
828
+ vox_homo = torch.cat([vox_coords, torch.ones((batch_size, self.triplane_reso ** 3, 1)).float().cuda()], dim=-1)
829
+ coord_inimg = torch.einsum('bhc,bck->bhk', vox_homo, proj_mat.transpose(1, 2))
830
+ x = coord_inimg[:, :, 0] / coord_inimg[:, :, 2]
831
+ y = coord_inimg[:, :, 1] / coord_inimg[:, :, 2]
832
+ #
833
+ x = x / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1
834
+ y = y / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1 #B,N
835
+ xy=torch.stack([x,y],dim=-1) #B,64*64*64,2
836
+ xy=torch.clamp(xy,min=0,max=self.vit_reso-1)
837
+ vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1) #B, 16*16,2
838
+ attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,vit_coords,kernel_size=0.5),
839
+ self.n_heads,0) #B*n_heads, reso**3, vit_reso**2
840
+
841
+ k=self.k(image_feat)+self.image_pe
842
+ v=self.v(image_feat)+self.image_pe
843
+ attn_mask=torch.cat([torch.zeros([attn_mask.shape[0],attn_mask.shape[1],1]).cuda().bool(),attn_mask],dim=-1) #add empty token to each key and value
844
+ vox_feat,_=self.attn(triplane_query,k,v,attn_mask=attn_mask) #B,reso**3,C
845
+ feat_volume=self.out_proj(vox_feat).transpose(1,2).reshape(batch_size,-1,self.triplane_reso,
846
+ self.triplane_reso,self.triplane_reso)
847
+ xz_feat=torch.mean(feat_volume,dim=3).transpose(2,3)
848
+ xy_feat=torch.mean(feat_volume,dim=4).transpose(2,3)
849
+ yz_feat=torch.mean(feat_volume,dim=2).transpose(2,3)
850
+ triplane_out=torch.cat([xz_feat,xy_feat,yz_feat],dim=2)
851
+ return triplane_out
852
+
853
+ class Multi_Image_Fusion(nn.Module):
854
+ def __init__(self,reso,image_reso=16,padding=0.1,img_channels=1280,triplane_channel=64,inner_channels=128,output_channel=64,n_heads=8):
855
+ super().__init__()
856
+ self.triplane_reso=reso
857
+ self.image_reso=image_reso
858
+ self.padding=padding
859
+ self.get_triplane_coord()
860
+ self.get_vit_coords()
861
+ self.img_proj=nn.Conv3d(in_channels=img_channels,out_channels=512,kernel_size=1)
862
+ self.kernel_func=mask_kernel
863
+
864
+ self.q = nn.Linear(in_features=triplane_channel, out_features=inner_channels, bias=False)
865
+ self.k = nn.Linear(in_features=512, out_features=inner_channels)
866
+ self.v = nn.Linear(in_features=512, out_features=inner_channels)
867
+
868
+ self.attn = torch.nn.MultiheadAttention(
869
+ embed_dim=inner_channels, num_heads=n_heads, batch_first=True)
870
+ self.out_proj=nn.Linear(in_features=inner_channels,out_features=output_channel)
871
+ self.n_heads=n_heads
872
+
873
+ def get_triplane_coord(self):
874
+ '''xz plane firstly, z is at the '''
875
+ x=torch.arange(self.triplane_reso)
876
+ z=torch.arange(self.triplane_reso)
877
+ X,Z=torch.meshgrid(x,z,indexing='xy')
878
+ xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order
879
+
880
+ '''xy plane'''
881
+ x = torch.arange(self.triplane_reso)
882
+ y = torch.arange(self.triplane_reso)
883
+ X, Y = torch.meshgrid(x, y, indexing='xy')
884
+ xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order
885
+
886
+ '''yz plane'''
887
+ y = torch.arange(self.triplane_reso)
888
+ z = torch.arange(self.triplane_reso)
889
+ Y,Z = torch.meshgrid(y,z,indexing='xy')
890
+ yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1)
891
+
892
+ triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0)
893
+ triplane_coords=triplane_coords/(self.triplane_reso-1)
894
+ triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6)
895
+ self.triplane_coords=triplane_coords.float().cuda()
896
+
897
+ def get_vit_coords(self):
898
+ x=torch.arange(self.image_reso)
899
+ y=torch.arange(self.image_reso)
900
+ X,Y=torch.meshgrid(x,y,indexing='xy')
901
+ vit_coords=torch.cat([X[:,:,None],Y[:,:,None]],dim=-1)
902
+ self.vit_coords=vit_coords.float().cuda() #in x,y order
903
+
904
+ def compute_attn_mask(self,proj_coord,vit_coords,valid_frames,kernel_size=2.0):
905
+ '''
906
+ :param proj_coord: B,K,H,W,2
907
+ :param vit_coords: H,W,2
908
+ :return:
909
+ '''
910
+ B,K=proj_coord.shape[0:2]
911
+ vit_coords_expand=vit_coords[None,None,:,:,:].expand(B,K,-1,-1,-1)
912
+
913
+ proj_coord=proj_coord.view(B*K,proj_coord.shape[2]*proj_coord.shape[3],proj_coord.shape[4])
914
+ vit_coords_expand=vit_coords_expand.view(B*K,self.image_reso*self.image_reso,2)
915
+ attn_mask=self.kernel_func(torch.cdist(proj_coord,vit_coords_expand),sigma=float(kernel_size))
916
+ attn_mask=attn_mask.reshape(B,K,proj_coord.shape[1],vit_coords_expand.shape[1])
917
+ valid_expand=valid_frames[:,:,None,None]
918
+ attn_mask[valid_frames>0,:,:]=True
919
+ attn_mask=attn_mask.permute(0,2,1,3)
920
+ attn_mask=attn_mask.reshape(B,proj_coord.shape[1],K*vit_coords_expand.shape[1])
921
+ atten_index=torch.where(attn_mask[0,0]==False)
922
+ return attn_mask
923
+
924
+
925
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
926
+ '''
927
+ :param image_feat: B,C,K,16,16
928
+ :param proj_mat: B,K,4,4
929
+ :param valid_frames: B,K, true if have image, used to compute attn_mask for transformer
930
+ :return:
931
+ '''
932
+ image_feat=self.img_proj(image_feat)
933
+ batch_size=image_feat.shape[0] #K is number of frames
934
+ K=image_feat.shape[2]
935
+ triplane_coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,K,-1,-1,-1) #B,K,192,64,3
936
+ #print(torch.amin(triplane_coords),torch.amax(triplane_coords))
937
+ coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,K,triplane_coords.shape[2],triplane_coords.shape[3],1)).float().cuda()],dim=-1)
938
+ #print(coord_homo.shape,proj_mat.shape)
939
+ coord_inimg=torch.einsum('bjhwc,bjck->bjhwk',coord_homo,proj_mat.transpose(2,3))
940
+ x=coord_inimg[:,:,:,:,0]/coord_inimg[:,:,:,:,2]
941
+ y=coord_inimg[:,:,:,:,1]/coord_inimg[:,:,:,:,2]
942
+ x=x/(224.0-1.0)*(self.image_reso-1)
943
+ y=y/(224.0-1.0)*(self.image_reso-1)
944
+
945
+ xy=torch.cat([x[...,None],y[...,None]],dim=-1) #B,K,H,W,2
946
+ image_value=image_feat.view(image_feat.shape[0],image_feat.shape[1],-1).transpose(1,2)
947
+ triplane_query=triplane_feat.view(triplane_feat.shape[0],triplane_feat.shape[1],-1).transpose(1,2)
948
+ valid_frames=1.0-valid_frames.float()
949
+ attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,self.vit_coords,valid_frames),
950
+ self.n_heads,dim=0)
951
+
952
+ q=self.q(triplane_query)
953
+ k=self.k(image_value)
954
+ v=self.v(image_value)
955
+ #print(q.shape,k.shape,v.shape)
956
+
957
+ attn,_=self.attn(q,k,v,attn_mask=attn_mask)
958
+ #print(attn.shape)
959
+ output=self.out_proj(attn).transpose(1,2).reshape(batch_size,-1,triplane_feat.shape[2],triplane_feat.shape[3])
960
+ #print(output.shape)
961
+ return output
962
+
963
+
964
+ if __name__=="__main__":
965
+ # import sys
966
+ # sys.path.append("../..")
967
+ # from datasets.SingleView_dataset import Object_PartialPoints_Img
968
+ # from datasets.transforms import Aug_with_Tran
969
+ # #sampler=#Image_Vox_Local_Sampler_Pooling(reso=64,padding=0.1,out_channels=64,stride=4).cuda().float()
970
+ # sampler=Image_ExpandVox_attn_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=64,inner_channel=64
971
+ # ,out_channels=64,n_heads=8).cuda().float()
972
+ # # sampler=Image_Direct_AttenwMask_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
973
+ # # ,out_channels=64,n_heads=8).cuda().float()
974
+ # dataset_config = {
975
+ # "data_path": "/data1/haolin/datasets",
976
+ # "surface_size": 20000,
977
+ # "par_pc_size": 4096,
978
+ # "load_proj_mat": True,
979
+ # }
980
+ # transform = Aug_with_Tran()
981
+ # datasets = Object_PartialPoints_Img(dataset_config['data_path'], split_filename="val_par_img.json", split='val',
982
+ # transform=transform, sampling=False,
983
+ # num_samples=1024, return_surface=True, ret_sample=True,
984
+ # surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],
985
+ # surface_size=dataset_config['surface_size'],
986
+ # load_proj_mat=dataset_config['load_proj_mat'], load_image=True,
987
+ # load_org_img=False, load_triplane=True, replica=1)
988
+ #
989
+ # dataloader = torch.utils.data.DataLoader(
990
+ # datasets=datasets,
991
+ # batch_size=64,
992
+ # shuffle=True
993
+ # )
994
+ # iterator = dataloader.__iter__()
995
+ # data_batch = iterator.next()
996
+ # unflatten = torch.nn.Unflatten(1, (16, 16))
997
+ # image = data_batch['image'][:,:,:].cuda().float()
998
+ # #image=unflatten(image).permute(0,3,1,2)
999
+ # proj_mat = data_batch['proj_mat'].cuda().float()
1000
+ # triplane_feat=torch.randn((64,64,32*3,32)).cuda().float()
1001
+ # sampler(triplane_feat,image,proj_mat)
1002
+ # memory_usage=torch.cuda.max_memory_allocated() / MB
1003
+ # print("memory usage %f mb"%(memory_usage))
1004
+
1005
+
1006
+ import sys
1007
+ sys.path.append("../..")
1008
+ from datasets.SingleView_dataset import Object_PartialPoints_MultiImg
1009
+ from datasets.transforms import Aug_with_Tran
1010
+
1011
+ dataset_config = {
1012
+ "data_path": "/data1/haolin/datasets",
1013
+ "surface_size": 20000,
1014
+ "par_pc_size": 4096,
1015
+ "load_proj_mat": True,
1016
+ }
1017
+ transform = Aug_with_Tran()
1018
+ dataset = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename="train_par_img.json", split='train',
1019
+ transform=transform, sampling=False,
1020
+ num_samples=1024, return_surface=True, ret_sample=True,
1021
+ surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],
1022
+ surface_size=dataset_config['surface_size'],
1023
+ load_proj_mat=dataset_config['load_proj_mat'], load_image=True,
1024
+ load_org_img=False, load_triplane=True, replica=1)
1025
+
1026
+ dataloader = torch.utils.data.DataLoader(
1027
+ dataset=dataset,
1028
+ batch_size=10,
1029
+ shuffle=False
1030
+ )
1031
+ iterator = dataloader.__iter__()
1032
+ data_batch = iterator.next()
1033
+ #unflatten = torch.nn.Unflatten(2, (16, 16))
1034
+ image = data_batch['image'][:,:,:,:].cuda().float()
1035
+ #image=unflatten(image).permute(0,4,1,2,3)
1036
+ proj_mat = data_batch['proj_mat'].cuda().float()
1037
+ valid_frames = data_batch['valid_frames'].cuda().float()
1038
+ triplane_feat=torch.randn((10,128,32*3,32)).cuda().float()
1039
+
1040
+ # fusion_module=MultiImage_Fuse_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
1041
+ # ,out_channels=64,n_heads=8).cuda().float()
1042
+ fusion_module=MultiImage_Global_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
1043
+ ,out_channels=64,n_heads=8).cuda().float()
1044
+ fusion_module(triplane_feat,image,proj_mat,valid_frames)
1045
+ memory_usage=torch.cuda.max_memory_allocated() / MB
1046
+ print("memory usage %f mb"%(memory_usage))
models/modules/parpoints_encoder.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch_scatter import scatter_mean, scatter_max
5
+ from .unet import UNet
6
+ from .resnet_block import ResnetBlockFC
7
+ from .PointEMB import PointEmbed
8
+ import numpy as np
9
+
10
+ class ParPoint_Encoder(nn.Module):
11
+ ''' PointNet-based encoder network with ResNet blocks for each point.
12
+ Number of input points are fixed.
13
+
14
+ Args:
15
+ c_dim (int): dimension of latent code c
16
+ dim (int): input points dimension
17
+ hidden_dim (int): hidden dimension of the network
18
+ scatter_type (str): feature aggregation when doing local pooling
19
+ unet (bool): weather to use U-Net
20
+ unet_kwargs (str): U-Net parameters
21
+ plane_resolution (int): defined resolution for plane feature
22
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
23
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
24
+ n_blocks (int): number of blocks ResNetBlockFC layers
25
+ '''
26
+
27
+ def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', unet_kwargs=None,
28
+ plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
29
+ super().__init__()
30
+ self.c_dim = c_dim
31
+
32
+ self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
33
+ self.blocks = nn.ModuleList([
34
+ ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)
35
+ ])
36
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
37
+
38
+ self.actvn = nn.ReLU()
39
+ self.hidden_dim = hidden_dim
40
+
41
+ self.unet = UNet(unet_kwargs['output_dim'], in_channels=c_dim, **unet_kwargs)
42
+
43
+ self.reso_plane = plane_resolution
44
+ self.plane_type = plane_type
45
+ self.padding = padding
46
+
47
+ if scatter_type == 'max':
48
+ self.scatter = scatter_max
49
+ elif scatter_type == 'mean':
50
+ self.scatter = scatter_mean
51
+
52
+ # takes in "p": point cloud and "query": sdf_xyz
53
+ # sample plane features for unlabeled_query as well
54
+ def forward(self, p,point_emb): # , query2):
55
+ batch_size, T, D = p.size()
56
+ #print('origin',torch.amin(p[0],dim=0),torch.amax(p[0],dim=0))
57
+ # acquire the index for each point
58
+ coord = {}
59
+ index = {}
60
+ if 'xz' in self.plane_type:
61
+ coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
62
+ index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
63
+ if 'xy' in self.plane_type:
64
+ coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
65
+ index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
66
+ if 'yz' in self.plane_type:
67
+ coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
68
+ index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
69
+ net = self.fc_pos(point_emb)
70
+
71
+ net = self.blocks[0](net)
72
+ for block in self.blocks[1:]:
73
+ pooled = self.pool_local(coord, index, net)
74
+ net = torch.cat([net, pooled], dim=2)
75
+ net = block(net)
76
+
77
+ c = self.fc_c(net)
78
+ #print(c.shape)
79
+
80
+ fea = {}
81
+ # second_sum = 0
82
+ if 'xz' in self.plane_type:
83
+ fea['xz'] = self.generate_plane_features(p, c,
84
+ plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
85
+ if 'xy' in self.plane_type:
86
+ fea['xy'] = self.generate_plane_features(p, c, plane='xy')
87
+ if 'yz' in self.plane_type:
88
+ fea['yz'] = self.generate_plane_features(p, c, plane='yz')
89
+ cat_feature = torch.cat([fea['xz'], fea['xy'], fea['yz']],
90
+ dim=2) # concat at row dimension
91
+ #print(cat_feature.shape)
92
+ plane_feat=self.unet(cat_feature)
93
+
94
+ return plane_feat
95
+
96
+
97
+ def normalize_coordinate(self, p, padding=0.1, plane='xz'):
98
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
99
+
100
+ Args:
101
+ p (tensor): point
102
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
103
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
104
+ '''
105
+ if plane == 'xz':
106
+ xy = p[:, :, [0, 2]]
107
+ elif plane == 'xy':
108
+ xy = p[:, :, [0, 1]]
109
+ else:
110
+ xy = p[:, :, [1, 2]]
111
+ #print("origin",torch.amin(xy), torch.amax(xy))
112
+ xy=xy/2 #xy is originally -1 ~ 1
113
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
114
+ xy_new = xy_new + 0.5 # range (0, 1)
115
+ #print("scale",torch.amin(xy_new),torch.amax(xy_new))
116
+
117
+ # f there are outliers out of the range
118
+ if xy_new.max() >= 1:
119
+ xy_new[xy_new >= 1] = 1 - 10e-6
120
+ if xy_new.min() < 0:
121
+ xy_new[xy_new < 0] = 0.0
122
+ return xy_new
123
+
124
+ def coordinate2index(self, x, reso):
125
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
126
+ Corresponds to our 3D model
127
+
128
+ Args:
129
+ x (tensor): coordinate
130
+ reso (int): defined resolution
131
+ coord_type (str): coordinate type
132
+ '''
133
+ x = (x * reso).long()
134
+ index = x[:, :, 0] + reso * x[:, :, 1]
135
+ index = index[:, None, :]
136
+ return index
137
+
138
+ # xy is the normalized coordinates of the point cloud of each plane
139
+ # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
140
+ def pool_local(self, xy, index, c):
141
+ bs, fea_dim = c.size(0), c.size(2)
142
+ keys = xy.keys()
143
+
144
+ c_out = 0
145
+ for key in keys:
146
+ # scatter plane features from points
147
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2)
148
+ if self.scatter == scatter_max:
149
+ fea = fea[0]
150
+ # gather feature back to points
151
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
152
+ c_out += fea
153
+ return c_out.permute(0, 2, 1)
154
+
155
+ def generate_plane_features(self, p, c, plane='xz'):
156
+ # acquire indices of features in plane
157
+ xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
158
+ index = self.coordinate2index(xy, self.reso_plane)
159
+
160
+ # scatter plane features from points
161
+ fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2)
162
+ c = c.permute(0, 2, 1) # B x 512 x T
163
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
164
+ fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane,
165
+ self.reso_plane) # sparce matrix (B x 512 x reso x reso)
166
+ #print(fea_plane.shape)
167
+
168
+ return fea_plane
models/modules/point_transformer.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, einsum
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops import rearrange,repeat
5
+ from timm.models.layers import DropPath
6
+ from torch_cluster import fps
7
+ import numpy as np
8
+
9
+ def zero_module(module):
10
+ """
11
+ Zero out the parameters of a module and return it.
12
+ """
13
+ for p in module.parameters():
14
+ p.detach().zero_()
15
+ return module
16
+
17
+ class PositionalEmbedding(torch.nn.Module):
18
+ def __init__(self, num_channels, max_positions=10000, endpoint=False):
19
+ super().__init__()
20
+ self.num_channels = num_channels
21
+ self.max_positions = max_positions
22
+ self.endpoint = endpoint
23
+
24
+ def forward(self, x):
25
+ freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
26
+ freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
27
+ freqs = (1 / self.max_positions) ** freqs
28
+ x = x.ger(freqs.to(x.dtype))
29
+ x = torch.cat([x.cos(), x.sin()], dim=1)
30
+ return x
31
+
32
+ class CrossAttention(nn.Module):
33
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
34
+ super().__init__()
35
+ inner_dim = dim_head * heads
36
+
37
+ if context_dim is None:
38
+ context_dim = query_dim
39
+
40
+ self.scale = dim_head ** -0.5
41
+ self.heads = heads
42
+
43
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
44
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
45
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
46
+
47
+ self.to_out = nn.Sequential(
48
+ nn.Linear(inner_dim, query_dim),
49
+ nn.Dropout(dropout)
50
+ )
51
+
52
+ def forward(self, x, context=None, mask=None):
53
+ h = self.heads
54
+
55
+ q = self.to_q(x)
56
+
57
+ if context is None:
58
+ context = x
59
+
60
+ k = self.to_k(context)
61
+ v = self.to_v(context)
62
+
63
+ q, k, v = map(lambda t: rearrange(
64
+ t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
65
+
66
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
67
+
68
+ # attention, what we cannot get enough of
69
+ attn = sim.softmax(dim=-1)
70
+
71
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
72
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
73
+ return self.to_out(out)
74
+
75
+
76
+ class LayerScale(nn.Module):
77
+ def __init__(self, dim, init_values=1e-5, inplace=False):
78
+ super().__init__()
79
+ self.inplace = inplace
80
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
81
+
82
+ def forward(self, x):
83
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
84
+
85
+ class GEGLU(nn.Module):
86
+ def __init__(self, dim_in, dim_out):
87
+ super().__init__()
88
+ self.proj = nn.Linear(dim_in, dim_out * 2)
89
+
90
+ def forward(self, x):
91
+ x, gate = self.proj(x).chunk(2, dim=-1)
92
+ return x * F.gelu(gate)
93
+
94
+ class FeedForward(nn.Module):
95
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
96
+ super().__init__()
97
+ inner_dim = int(dim * mult)
98
+ if dim_out is None:
99
+ dim_out = dim
100
+
101
+ project_in = nn.Sequential(
102
+ nn.Linear(dim, inner_dim),
103
+ nn.GELU()
104
+ ) if not glu else GEGLU(dim, inner_dim)
105
+
106
+ self.net = nn.Sequential(
107
+ project_in,
108
+ nn.Dropout(dropout),
109
+ nn.Linear(inner_dim, dim_out)
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.net(x)
114
+
115
+ class AdaLayerNorm(nn.Module):
116
+ def __init__(self, n_embd):
117
+ super().__init__()
118
+
119
+ self.silu = nn.SiLU()
120
+ self.linear = nn.Linear(n_embd, n_embd*2)
121
+ self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False)
122
+
123
+ def forward(self, x, timestep):
124
+ emb = self.linear(timestep)
125
+ scale, shift = torch.chunk(emb, 2, dim=2)
126
+ x = self.layernorm(x) * (1 + scale) + shift
127
+ return x
128
+
129
+ class BasicTransformerBlock(nn.Module):
130
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
131
+ super().__init__()
132
+ self.attn1 = CrossAttention(
133
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
134
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
135
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
136
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
137
+ self.norm1 = AdaLayerNorm(dim)
138
+ self.norm2 = AdaLayerNorm(dim)
139
+ self.norm3 = AdaLayerNorm(dim)
140
+ self.checkpoint = checkpoint
141
+
142
+ init_values = 0
143
+ drop_path = 0.0
144
+
145
+
146
+ self.ls1 = LayerScale(
147
+ dim, init_values=init_values) if init_values else nn.Identity()
148
+ self.drop_path1 = DropPath(
149
+ drop_path) if drop_path > 0. else nn.Identity()
150
+
151
+ self.ls2 = LayerScale(
152
+ dim, init_values=init_values) if init_values else nn.Identity()
153
+ self.drop_path2 = DropPath(
154
+ drop_path) if drop_path > 0. else nn.Identity()
155
+
156
+ self.ls3 = LayerScale(
157
+ dim, init_values=init_values) if init_values else nn.Identity()
158
+ self.drop_path3 = DropPath(
159
+ drop_path) if drop_path > 0. else nn.Identity()
160
+
161
+ def forward(self, x, t, context=None):
162
+ x = self.drop_path1(self.ls1(self.attn1(self.norm1(x, t)))) + x
163
+ x = self.drop_path2(self.ls2(self.attn2(self.norm2(x, t), context=context))) + x
164
+ x = self.drop_path3(self.ls3(self.ff(self.norm3(x, t)))) + x
165
+ return x
166
+
167
+ class LatentArrayTransformer(nn.Module):
168
+ """
169
+ Transformer block for image-like data.
170
+ First, project the input (aka embedding)
171
+ and reshape to b, t, d.
172
+ Then apply standard transformer action.
173
+ Finally, reshape to image
174
+ """
175
+
176
+ def __init__(self, in_channels, t_channels, n_heads, d_head,
177
+ depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None,
178
+ block=BasicTransformerBlock):
179
+ super().__init__()
180
+ self.in_channels = in_channels
181
+ inner_dim = n_heads * d_head
182
+
183
+ self.t_channels = t_channels
184
+
185
+ self.proj_in = nn.Linear(in_channels, inner_dim, bias=False)
186
+
187
+ self.transformer_blocks = nn.ModuleList(
188
+ [block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
189
+ for _ in range(depth)]
190
+ )
191
+
192
+ self.norm = nn.LayerNorm(inner_dim)
193
+
194
+ if out_channels is None:
195
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False))
196
+ else:
197
+ self.num_cls = out_channels
198
+ self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False))
199
+
200
+ self.context_dim = context_dim
201
+
202
+ self.map_noise = PositionalEmbedding(t_channels)
203
+
204
+ self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim)
205
+ self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim)
206
+
207
+ # ###
208
+ # self.pos_emb = nn.Embedding(512, inner_dim)
209
+ # ###
210
+
211
+ def forward(self, x, t, cond, class_emb):
212
+
213
+ t_emb = self.map_noise(t)[:, None]
214
+ t_emb = F.silu(self.map_layer0(t_emb))
215
+ t_emb = F.silu(self.map_layer1(t_emb))
216
+
217
+ x = self.proj_in(x)
218
+ #print(class_emb.shape,t_emb.shape)
219
+ for block in self.transformer_blocks:
220
+ x = block(x, t_emb+class_emb[:,None,:], context=cond)
221
+
222
+ x = self.norm(x)
223
+
224
+ x = self.proj_out(x)
225
+ return x
226
+
227
+ class PointTransformer(nn.Module):
228
+ """
229
+ Transformer block for image-like data.
230
+ First, project the input (aka embedding)
231
+ and reshape to b, t, d.
232
+ Then apply standard transformer action.
233
+ Finally, reshape to image
234
+ """
235
+
236
+ def __init__(self, in_channels, t_channels, n_heads, d_head,
237
+ depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None,
238
+ block=BasicTransformerBlock):
239
+ super().__init__()
240
+ self.in_channels = in_channels
241
+ inner_dim = n_heads * d_head
242
+
243
+ self.t_channels = t_channels
244
+
245
+ self.proj_in = nn.Linear(in_channels, inner_dim, bias=False)
246
+
247
+ self.transformer_blocks = nn.ModuleList(
248
+ [block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
249
+ for _ in range(depth)]
250
+ )
251
+
252
+ self.norm = nn.LayerNorm(inner_dim)
253
+
254
+ if out_channels is None:
255
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False))
256
+ else:
257
+ self.num_cls = out_channels
258
+ self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False))
259
+
260
+ self.context_dim = context_dim
261
+
262
+ self.map_noise = PositionalEmbedding(t_channels)
263
+
264
+ self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim)
265
+ self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim)
266
+
267
+ # ###
268
+ # self.pos_emb = nn.Embedding(512, inner_dim)
269
+ # ###
270
+
271
+ def forward(self, x, t, cond=None):
272
+
273
+ t_emb = self.map_noise(t)[:, None]
274
+ t_emb = F.silu(self.map_layer0(t_emb))
275
+ t_emb = F.silu(self.map_layer1(t_emb))
276
+
277
+ x = self.proj_in(x)
278
+
279
+ for block in self.transformer_blocks:
280
+ x = block(x, t_emb, context=cond)
281
+
282
+ x = self.norm(x)
283
+
284
+ x = self.proj_out(x)
285
+ return x
286
+ def exists(val):
287
+ return val is not None
288
+
289
+ def default(val, d):
290
+ return val if exists(val) else d
291
+
292
+ def cache_fn(f):
293
+ cache = None
294
+ @wraps(f)
295
+ def cached_fn(*args, _cache = True, **kwargs):
296
+ if not _cache:
297
+ return f(*args, **kwargs)
298
+ nonlocal cache
299
+ if cache is not None:
300
+ return cache
301
+ cache = f(*args, **kwargs)
302
+ return cache
303
+ return cached_fn
304
+
305
+ class PreNorm(nn.Module):
306
+ def __init__(self, dim, fn, context_dim = None):
307
+ super().__init__()
308
+ self.fn = fn
309
+ self.norm = nn.LayerNorm(dim)
310
+ self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
311
+
312
+ def forward(self, x, **kwargs):
313
+ x = self.norm(x)
314
+
315
+ if exists(self.norm_context):
316
+ context = kwargs['context']
317
+ normed_context = self.norm_context(context)
318
+ kwargs.update(context = normed_context)
319
+
320
+ return self.fn(x, **kwargs)
321
+
322
+ class Attention(nn.Module):
323
+ def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
324
+ super().__init__()
325
+ inner_dim = dim_head * heads
326
+ context_dim = default(context_dim, query_dim)
327
+ self.scale = dim_head ** -0.5
328
+ self.heads = heads
329
+
330
+ self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
331
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
332
+ self.to_out = nn.Linear(inner_dim, query_dim)
333
+
334
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
335
+
336
+ def forward(self, x, context = None, mask = None):
337
+ h = self.heads
338
+
339
+ q = self.to_q(x)
340
+ context = default(context, x)
341
+ k, v = self.to_kv(context).chunk(2, dim = -1)
342
+
343
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
344
+
345
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
346
+
347
+ if exists(mask):
348
+ mask = rearrange(mask, 'b ... -> b (...)')
349
+ max_neg_value = -torch.finfo(sim.dtype).max
350
+ mask = repeat(mask, 'b j -> (b h) () j', h = h)
351
+ sim.masked_fill_(~mask, max_neg_value)
352
+
353
+ # attention, what we cannot get enough of
354
+ attn = sim.softmax(dim = -1)
355
+
356
+ out = einsum('b i j, b j d -> b i d', attn, v)
357
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
358
+ return self.drop_path(self.to_out(out))
359
+
360
+
361
+ class PointEmbed(nn.Module):
362
+ def __init__(self, hidden_dim=48, dim=128):
363
+ super().__init__()
364
+
365
+ assert hidden_dim % 6 == 0
366
+
367
+ self.embedding_dim = hidden_dim
368
+ e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
369
+ e = torch.stack([
370
+ torch.cat([e, torch.zeros(self.embedding_dim // 6),
371
+ torch.zeros(self.embedding_dim // 6)]),
372
+ torch.cat([torch.zeros(self.embedding_dim // 6), e,
373
+ torch.zeros(self.embedding_dim // 6)]),
374
+ torch.cat([torch.zeros(self.embedding_dim // 6),
375
+ torch.zeros(self.embedding_dim // 6), e]),
376
+ ])
377
+ self.register_buffer('basis', e) # 3 x 16
378
+
379
+ self.mlp = nn.Linear(self.embedding_dim + 3, dim)
380
+
381
+ @staticmethod
382
+ def embed(input, basis):
383
+ projections = torch.einsum(
384
+ 'bnd,de->bne', input, basis)
385
+ embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
386
+ return embeddings
387
+
388
+ def forward(self, input):
389
+ # input: B x N x 3
390
+ embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
391
+ return embed
392
+
393
+
394
+ class PointEncoder(nn.Module):
395
+ def __init__(self,
396
+ dim=512,
397
+ num_inputs = 2048,
398
+ num_latents = 512,
399
+ latent_dim = 512):
400
+ super().__init__()
401
+
402
+ self.num_inputs = num_inputs
403
+ self.num_latents = num_latents
404
+
405
+ self.cross_attend_blocks = nn.ModuleList([
406
+ PreNorm(dim, Attention(dim, dim, heads=1, dim_head=dim), context_dim=dim),
407
+ PreNorm(dim, FeedForward(dim))
408
+ ])
409
+
410
+ self.point_embed = PointEmbed(dim=dim)
411
+ self.proj=nn.Linear(dim,latent_dim)
412
+ def encode(self, pc):
413
+ # pc: B x N x 3
414
+ B, N, D = pc.shape
415
+ assert N == self.num_inputs
416
+
417
+ ###### fps
418
+ flattened = pc.view(B * N, D)
419
+
420
+ batch = torch.arange(B).to(pc.device)
421
+ batch = torch.repeat_interleave(batch, N)
422
+
423
+ pos = flattened
424
+
425
+ ratio = 1.0 * self.num_latents / self.num_inputs
426
+
427
+ idx = fps(pos, batch, ratio=ratio)
428
+
429
+ sampled_pc = pos[idx]
430
+ sampled_pc = sampled_pc.view(B, -1, 3)
431
+ ######
432
+
433
+ sampled_pc_embeddings = self.point_embed(sampled_pc)
434
+
435
+ pc_embeddings = self.point_embed(pc)
436
+
437
+ cross_attn, cross_ff = self.cross_attend_blocks
438
+
439
+ x = cross_attn(sampled_pc_embeddings, context=pc_embeddings, mask=None) + sampled_pc_embeddings
440
+ x = cross_ff(x) + x
441
+
442
+ return self.proj(x)
models/modules/pointnet2_backbone.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import sys
6
+ import os
7
+ from external.pointnet2.pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule
8
+ from .utils import zero_module
9
+ from .Positional_Embedding import PositionalEmbedding
10
+
11
+ class Pointnet2Encoder(nn.Module):
12
+ def __init__(self,input_feature_dim=0,npoints=[2048,1024,512,256],radius=[0.2,0.4,0.6,1.2],nsample=[64,32,16,8]):
13
+ super().__init__()
14
+ self.sa1 = PointnetSAModuleVotes(
15
+ npoint=npoints[0],
16
+ radius=radius[0],
17
+ nsample=nsample[0],
18
+ mlp=[input_feature_dim, 64, 64, 128],
19
+ use_xyz=True,
20
+ normalize_xyz=True
21
+ )
22
+
23
+ self.sa2 = PointnetSAModuleVotes(
24
+ npoint=npoints[1],
25
+ radius=radius[1],
26
+ nsample=nsample[1],
27
+ mlp=[128, 128, 128, 256],
28
+ use_xyz=True,
29
+ normalize_xyz=True
30
+ )
31
+
32
+ self.sa3 = PointnetSAModuleVotes(
33
+ npoint=npoints[2],
34
+ radius=radius[2],
35
+ nsample=nsample[2],
36
+ mlp=[256, 256, 256, 512],
37
+ use_xyz=True,
38
+ normalize_xyz=True
39
+ )
40
+
41
+ self.sa4 = PointnetSAModuleVotes(
42
+ npoint=npoints[3],
43
+ radius=radius[3],
44
+ nsample=nsample[3],
45
+ mlp=[512, 512, 512, 512],
46
+ use_xyz=True,
47
+ normalize_xyz=True
48
+ )
49
+ def _break_up_pc(self, pc):
50
+ xyz = pc[..., 0:3].contiguous()
51
+ features = (
52
+ pc[..., 3:].transpose(1, 2).contiguous()
53
+ if pc.size(-1) > 3 else None
54
+ )
55
+
56
+ return xyz, features
57
+ def forward(self,pointcloud,end_points=None):
58
+ if not end_points: end_points = {}
59
+ batch_size = pointcloud.shape[0]
60
+
61
+ xyz, features = self._break_up_pc(pointcloud)
62
+
63
+ end_points['org_xyz'] = xyz
64
+ # --------- 4 SET ABSTRACTION LAYERS ---------
65
+ xyz1, features1, _ = self.sa1(xyz, features)
66
+ end_points['sa1_xyz'] = xyz1
67
+ end_points['sa1_features'] = features1
68
+
69
+ xyz2, features2, _ = self.sa2(xyz1, features1) # this fps_inds is just 0,1,...,1023
70
+ end_points['sa2_xyz'] = xyz2
71
+ end_points['sa2_features'] = features2
72
+
73
+ xyz3, features3, _ = self.sa3(xyz2, features2) # this fps_inds is just 0,1,...,511
74
+ end_points['sa3_xyz'] = xyz3
75
+ end_points['sa3_features'] = features3
76
+ #print(xyz3.shape,features3.shape)
77
+ xyz4, features4, _ = self.sa4(xyz3, features3) # this fps_inds is just 0,1,...,255
78
+ end_points['sa4_xyz'] = xyz4
79
+ end_points['sa4_features'] = features4
80
+ #print(xyz4.shape,features4.shape)
81
+ return end_points
82
+
83
+
84
+
85
+ class PointUNet(nn.Module):
86
+ r"""
87
+ Backbone network for point cloud feature learning.
88
+ Based on Pointnet++ single-scale grouping network.
89
+
90
+ Parameters
91
+ ----------
92
+ input_feature_dim: int
93
+ Number of input channels in the feature descriptor for each point.
94
+ e.g. 3 for RGB.
95
+ """
96
+
97
+ def __init__(self):
98
+ super().__init__()
99
+
100
+ self.noisy_encoder=Pointnet2Encoder()
101
+ self.cond_encoder=Pointnet2Encoder()
102
+ self.fp1_cross = PointnetFPModule(mlp=[512 + 512, 512, 512])
103
+ self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512])
104
+ #self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512])
105
+ self.fp2_cross = PointnetFPModule(mlp=[512 + 512, 512, 256])
106
+ self.fp2 = PointnetFPModule(mlp=[256 + 256, 512, 256])
107
+ #self.fp2=PointnetFPModule(mlp=[512 + 256, 512, 256])
108
+ self.fp3_cross= PointnetFPModule(mlp=[256 + 256, 256, 128])
109
+ self.fp3 = PointnetFPModule(mlp=[128 + 128, 256, 128])
110
+ #self.fp3 = PointnetFPModule(mlp=[256 + 128, 256, 128])
111
+ self.fp4_cross=PointnetFPModule(mlp=[128+128, 128, 128])
112
+ self.fp4 = PointnetFPModule(mlp=[128, 128, 128])
113
+ #self.fp4 = PointnetFPModule(mlp=[128, 128, 128])
114
+
115
+ self.output_layer=nn.Sequential(
116
+ nn.LayerNorm(128),
117
+ zero_module(nn.Linear(in_features=128,out_features=3,bias=False))
118
+ )
119
+ self.t_emb_layer = PositionalEmbedding(256)
120
+ self.map_layer0 = nn.Linear(in_features=256, out_features=512)
121
+ self.map_layer1 = nn.Linear(in_features=512, out_features=512)
122
+
123
+ def forward(self, noise_points, t,cond_points):
124
+ r"""
125
+ Forward pass of the network
126
+
127
+ Parameters
128
+ ----------
129
+ pointcloud: Variable(torch.cuda.FloatTensor)
130
+ (B, N, 3 + input_feature_dim) tensor
131
+ Point cloud to run predicts on
132
+ Each point in the point-cloud MUST
133
+ be formated as (x, y, z, features...)
134
+
135
+ Returns
136
+ ----------
137
+ end_points: {XXX_xyz, XXX_features, XXX_inds}
138
+ XXX_xyz: float32 Tensor of shape (B,K,3)
139
+ XXX_features: float32 Tensor of shape (B,K,D)
140
+ XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1]
141
+ """
142
+ t_emb = self.t_emb_layer(t)
143
+ t_emb = F.silu(self.map_layer0(t_emb))
144
+ t_emb = F.silu(self.map_layer1(t_emb))#B,512
145
+ t_emb = t_emb[:, :, None] #B,512,K
146
+ noise_end_points=self.noisy_encoder(noise_points)
147
+ cond=self.cond_encoder(cond_points)
148
+ # --------- 2 FEATURE UPSAMPLING LAYERS --------
149
+ features = self.fp1_cross(noise_end_points['sa4_xyz'],cond['sa4_xyz'],noise_end_points['sa4_features']+t_emb,
150
+ cond['sa4_features'])
151
+ features = self.fp1(noise_end_points['sa3_xyz'], noise_end_points['sa4_xyz'], noise_end_points['sa3_features'],
152
+ features)
153
+ features = self.fp2_cross(noise_end_points['sa3_xyz'],cond['sa3_xyz'],features,
154
+ cond["sa3_features"])
155
+ features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'],
156
+ features)
157
+ features = self.fp3_cross(noise_end_points['sa2_xyz'],cond['sa2_xyz'],features,
158
+ cond['sa2_features'])
159
+ features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features)
160
+ features = self.fp4_cross(noise_end_points['sa1_xyz'],cond['sa1_xyz'],features,
161
+ cond['sa1_features'])
162
+ features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features)
163
+ features=features.transpose(1,2)
164
+
165
+ # features = self.fp1_cross(noise_end_points['sa4_xyz'], cond_end_points['sa4_xyz'],
166
+ # noise_end_points['sa4_features']+t_emb, cond_end_points['sa4_features'])
167
+ # features = self.fp1(noise_end_points['sa3_xyz'].clone(), noise_end_points['sa4_xyz'].clone(), noise_end_points['sa3_features'],
168
+ # features)
169
+ # features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'],
170
+ # features)
171
+ # features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features)
172
+ # features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features)
173
+ # features = features.transpose(1,2)
174
+ output_points=self.output_layer(features)
175
+
176
+ return output_points
177
+
178
+
179
+ if __name__ == '__main__':
180
+ net=PointUNet().cuda().float()
181
+ net=net.eval()
182
+ noise_points=torch.randn(16,4096,3).cuda().float()
183
+ cond_points=torch.randn(16,4096,3).cuda().float()
184
+ t=torch.randn(16).cuda().float()
185
+ cond_encoder=Pointnet2Encoder().cuda().float()
186
+
187
+ out = net(noise_points,cond_points)
188
+ print(out.shape)
models/modules/resnet_block.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # Resnet Blocks
6
+ class ResnetBlockFC(nn.Module):
7
+ ''' Fully connected ResNet Block class.
8
+ Args:
9
+ size_in (int): input dimension
10
+ size_out (int): output dimension
11
+ size_h (int): hidden dimension
12
+ '''
13
+
14
+ def __init__(self, size_in, size_out=None, size_h=None):
15
+ super().__init__()
16
+ # Attributes
17
+ if size_out is None:
18
+ size_out = size_in
19
+
20
+ if size_h is None:
21
+ size_h = min(size_in, size_out)
22
+
23
+ self.size_in = size_in
24
+ self.size_h = size_h
25
+ self.size_out = size_out
26
+ # Submodules
27
+ self.fc_0 = nn.Linear(size_in, size_h)
28
+ self.fc_1 = nn.Linear(size_h, size_out)
29
+ self.actvn = nn.ReLU()
30
+
31
+ if size_in == size_out:
32
+ self.shortcut = None
33
+ else:
34
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
35
+ # Initialization
36
+ nn.init.zeros_(self.fc_1.weight)
37
+
38
+ def forward(self, x):
39
+ net = self.fc_0(self.actvn(x))
40
+ dx = self.fc_1(self.actvn(net))
41
+
42
+ if self.shortcut is not None:
43
+ x_s = self.shortcut(x)
44
+ else:
45
+ x_s = x
46
+
47
+ return x_s + dx
models/modules/resunet.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .unet import RollOut_Conv
4
+ from .Positional_Embedding import PositionalEmbedding
5
+ import torch.nn.functional as F
6
+ from .utils import zero_module
7
+ from .image_sampler import MultiImage_Fuse_Sampler, MultiImage_Global_Sampler,MultiImage_TriFuse_Sampler
8
+
9
+ class ResidualConv_MultiImgAtten(nn.Module):
10
+ def __init__(self, input_dim, output_dim, stride, padding, reso=64,
11
+ vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
12
+ norm="batch"):
13
+ super(ResidualConv_MultiImgAtten, self).__init__()
14
+ self.use_attn=use_attn
15
+
16
+ if norm=="batch":
17
+ norm_layer=nn.BatchNorm2d
18
+ elif norm==None:
19
+ norm_layer=nn.Identity
20
+
21
+ self.conv_block = nn.Sequential(
22
+ norm_layer(input_dim),
23
+ nn.ReLU(),
24
+ nn.Conv2d(
25
+ input_dim, output_dim, kernel_size=3, padding=padding
26
+ )
27
+ )
28
+ self.out_layer=nn.Sequential(
29
+ norm_layer(output_dim),
30
+ nn.ReLU(),
31
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
32
+ )
33
+ self.conv_skip = nn.Sequential(
34
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
35
+ norm_layer(output_dim),
36
+ )
37
+ self.roll_out_conv=nn.Sequential(
38
+ norm_layer(output_dim),
39
+ nn.ReLU(),
40
+ RollOut_Conv(output_dim, output_dim),
41
+ )
42
+ if self.use_attn:
43
+ self.img_sampler = MultiImage_Fuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
44
+ img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
45
+ out_channels=output_dim,padding=triplane_padding)
46
+ self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
47
+
48
+ self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
49
+ self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
50
+ def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
51
+ t_emb = F.silu(self.map_layer0(t_emb))
52
+ t_emb = F.silu(self.map_layer1(t_emb))
53
+ t_emb = t_emb[:,:,None,None]
54
+
55
+ out=self.conv_block(x)+t_emb
56
+ out=self.out_layer(out)
57
+ feature=out+self.conv_skip(x)
58
+ feature = self.roll_out_conv(feature)
59
+ if self.use_attn:
60
+ feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
61
+ feature=self.down_conv(feature)
62
+
63
+ return feature
64
+
65
+ class ResidualConv_TriMultiImgAtten(nn.Module):
66
+ def __init__(self, input_dim, output_dim, stride, padding, reso=64,
67
+ vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
68
+ norm="batch"):
69
+ super(ResidualConv_TriMultiImgAtten, self).__init__()
70
+ self.use_attn=use_attn
71
+
72
+ if norm=="batch":
73
+ norm_layer=nn.BatchNorm2d
74
+ elif norm==None:
75
+ norm_layer=nn.Identity
76
+
77
+ self.conv_block = nn.Sequential(
78
+ norm_layer(input_dim),
79
+ nn.ReLU(),
80
+ nn.Conv2d(
81
+ input_dim, output_dim, kernel_size=3, padding=padding
82
+ )
83
+ )
84
+ self.out_layer=nn.Sequential(
85
+ norm_layer(output_dim),
86
+ nn.ReLU(),
87
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
88
+ )
89
+ self.conv_skip = nn.Sequential(
90
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
91
+ norm_layer(output_dim),
92
+ )
93
+ self.roll_out_conv=nn.Sequential(
94
+ norm_layer(output_dim),
95
+ nn.ReLU(),
96
+ RollOut_Conv(output_dim, output_dim),
97
+ )
98
+ if self.use_attn:
99
+ self.img_sampler = MultiImage_TriFuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
100
+ img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
101
+ out_channels=output_dim,max_nimg=5,padding=triplane_padding)
102
+ self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
103
+
104
+ self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
105
+ self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
106
+ def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
107
+ t_emb = F.silu(self.map_layer0(t_emb))
108
+ t_emb = F.silu(self.map_layer1(t_emb))
109
+ t_emb = t_emb[:,:,None,None]
110
+
111
+ out=self.conv_block(x)+t_emb
112
+ out=self.out_layer(out)
113
+ feature=out+self.conv_skip(x)
114
+ feature = self.roll_out_conv(feature)
115
+ if self.use_attn:
116
+ feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
117
+ feature=self.down_conv(feature)
118
+
119
+ return feature
120
+
121
+
122
+ class ResidualConv_GlobalAtten(nn.Module):
123
+ def __init__(self, input_dim, output_dim, stride, padding, reso=64,
124
+ vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
125
+ norm="batch"):
126
+ super(ResidualConv_GlobalAtten, self).__init__()
127
+ self.use_attn=use_attn
128
+
129
+ if norm=="batch":
130
+ norm_layer=nn.BatchNorm2d
131
+ elif norm==None:
132
+ norm_layer=nn.Identity
133
+
134
+ self.conv_block = nn.Sequential(
135
+ norm_layer(input_dim),
136
+ nn.ReLU(),
137
+ nn.Conv2d(
138
+ input_dim, output_dim, kernel_size=3, padding=padding
139
+ )
140
+ )
141
+ self.out_layer=nn.Sequential(
142
+ norm_layer(output_dim),
143
+ nn.ReLU(),
144
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
145
+ )
146
+ self.conv_skip = nn.Sequential(
147
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
148
+ norm_layer(output_dim),
149
+ )
150
+ self.roll_out_conv=nn.Sequential(
151
+ norm_layer(output_dim),
152
+ nn.ReLU(),
153
+ RollOut_Conv(output_dim, output_dim),
154
+ )
155
+ if self.use_attn:
156
+ self.img_sampler = MultiImage_Global_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
157
+ img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
158
+ out_channels=output_dim,max_nimg=5,padding=triplane_padding)
159
+ self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
160
+
161
+ self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
162
+ self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
163
+ def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
164
+ t_emb = F.silu(self.map_layer0(t_emb))
165
+ t_emb = F.silu(self.map_layer1(t_emb))
166
+ t_emb = t_emb[:,:,None,None]
167
+
168
+ out=self.conv_block(x)+t_emb
169
+ out=self.out_layer(out)
170
+ feature=out+self.conv_skip(x)
171
+ feature = self.roll_out_conv(feature)
172
+ if self.use_attn:
173
+ feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
174
+ feature=self.down_conv(feature)
175
+
176
+ return feature
177
+
178
+ class ResidualConv(nn.Module):
179
+ def __init__(self, input_dim, output_dim, stride, padding, t_input_dim=256):
180
+ super(ResidualConv, self).__init__()
181
+
182
+ self.conv_block = nn.Sequential(
183
+ nn.BatchNorm2d(input_dim),
184
+ nn.ReLU(),
185
+ nn.Conv2d(
186
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
187
+ ),
188
+ nn.BatchNorm2d(output_dim),
189
+ nn.ReLU(),
190
+ RollOut_Conv(output_dim,output_dim),
191
+ )
192
+ self.out_layer=nn.Sequential(
193
+ nn.BatchNorm2d(output_dim),
194
+ nn.ReLU(),
195
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
196
+ )
197
+ self.conv_skip = nn.Sequential(
198
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
199
+ nn.BatchNorm2d(output_dim),
200
+ )
201
+
202
+ self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
203
+ self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
204
+ def forward(self, x,t_emb):
205
+ t_emb = F.silu(self.map_layer0(t_emb))
206
+ t_emb = F.silu(self.map_layer1(t_emb))
207
+ t_emb = t_emb[:,:,None,None]
208
+
209
+ out=self.conv_block(x)+t_emb
210
+ out=self.out_layer(out)
211
+
212
+ return out + self.conv_skip(x)
213
+
214
+ class Upsample(nn.Module):
215
+ def __init__(self, input_dim, output_dim, kernel, stride):
216
+ super(Upsample, self).__init__()
217
+
218
+ self.upsample = nn.ConvTranspose2d(
219
+ input_dim, output_dim, kernel_size=kernel, stride=stride
220
+ )
221
+
222
+ def forward(self, x):
223
+ return self.upsample(x)
224
+
225
+
226
+
227
+ class ResUnet_Par_cond(nn.Module):
228
+ def __init__(self, channel, filters=[64, 128, 256, 512, 1024],output_channel=32,par_channel=32):
229
+ super(ResUnet_Par_cond, self).__init__()
230
+
231
+ self.input_layer = nn.Sequential(
232
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
233
+ nn.BatchNorm2d(filters[0]),
234
+ nn.ReLU(),
235
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
236
+ )
237
+ self.input_skip = nn.Sequential(
238
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
239
+ )
240
+
241
+ self.residual_conv_1 = ResidualConv(filters[0]+par_channel, filters[1], 2, 1)
242
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)
243
+ self.residual_conv_3 = ResidualConv(filters[2], filters[3], 2, 1)
244
+ self.bridge = ResidualConv(filters[3],filters[4],2,1)
245
+
246
+
247
+ self.upsample_1 = Upsample(filters[4], filters[4], 2, 2)
248
+ self.up_residual_conv1 = ResidualConv(filters[4] + filters[3], filters[3], 1, 1)
249
+
250
+ self.upsample_2 = Upsample(filters[3], filters[3], 2, 2)
251
+ self.up_residual_conv2 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)
252
+
253
+ self.upsample_3 = Upsample(filters[2], filters[2], 2, 2)
254
+ self.up_residual_conv3 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)
255
+
256
+ self.upsample_4 = Upsample(filters[1], filters[1], 2, 2)
257
+ self.up_residual_conv4 = ResidualConv(filters[1] + filters[0]+par_channel, filters[0], 1, 1)
258
+
259
+ self.output_layer = nn.Sequential(
260
+ #nn.LayerNorm(filters[0]),
261
+ nn.LayerNorm(64),#normalize along width dimension, usually it should normalize along channel dimension,
262
+ # I don't know why, but the finetuning performance increase significantly
263
+ zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)),
264
+ )
265
+ self.par_channel=par_channel
266
+ self.par_conv=nn.Sequential(
267
+ nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1),
268
+ )
269
+ self.t_emb_layer=PositionalEmbedding(256)
270
+ self.cat_emb=nn.Linear(
271
+ in_features=6,
272
+ out_features=256,
273
+ )
274
+
275
+ def forward(self, x,t,category_code,par_point_feat):
276
+ # Encode
277
+ t_emb=self.t_emb_layer(t)
278
+ cat_emb=self.cat_emb(category_code)
279
+ t_emb=t_emb+cat_emb
280
+ #print(t_emb.shape)
281
+ x1 = self.input_layer(x) + self.input_skip(x)
282
+ if par_point_feat is not None:
283
+ par_point_feat=self.par_conv(par_point_feat)
284
+ else:
285
+ bs,_,H,W=x1.shape
286
+ #print(x1.shape)
287
+ par_point_feat=torch.zeros((bs,self.par_channel,H,W)).float().to(x1.device)
288
+ x1 = torch.cat([x1, par_point_feat], dim=1)
289
+ x2 = self.residual_conv_1(x1,t_emb)
290
+ x3 = self.residual_conv_2(x2,t_emb)
291
+ # Bridge
292
+ x4 = self.residual_conv_3(x3,t_emb)
293
+ x5 = self.bridge(x4,t_emb)
294
+
295
+ x6=self.upsample_1(x5)
296
+ x6=torch.cat([x6,x4],dim=1)
297
+ x7=self.up_residual_conv1(x6,t_emb)
298
+
299
+ x7=self.upsample_2(x7)
300
+ x7=torch.cat([x7,x3],dim=1)
301
+ x8=self.up_residual_conv2(x7,t_emb)
302
+
303
+ x8 = self.upsample_3(x8)
304
+ x8 = torch.cat([x8, x2], dim=1)
305
+ #print(x8.shape)
306
+ x9 = self.up_residual_conv3(x8,t_emb)
307
+
308
+ x9 = self.upsample_4(x9)
309
+ x9 = torch.cat([x9, x1], dim=1)
310
+ x10 = self.up_residual_conv4(x9,t_emb)
311
+
312
+ output=self.output_layer(x10)
313
+
314
+ return output
315
+
316
+ class ResUnet_DirectAttenMultiImg_Cond(nn.Module):
317
+ def __init__(self, channel, filters=[64, 128, 256, 512, 1024],
318
+ img_in_channels=1024,vit_reso=16,output_channel=32,
319
+ use_par=False,par_channel=32,triplane_padding=0.1,norm='batch',
320
+ use_cat_embedding=False,
321
+ block_type="multiview_local"):
322
+ super(ResUnet_DirectAttenMultiImg_Cond, self).__init__()
323
+
324
+ if block_type == "multiview_local":
325
+ block=ResidualConv_MultiImgAtten
326
+ elif block_type =="multiview_global":
327
+ block=ResidualConv_GlobalAtten
328
+ elif block_type =="multiview_tri":
329
+ block=ResidualConv_TriMultiImgAtten
330
+ else:
331
+ raise NotImplementedError
332
+
333
+ if norm=="batch":
334
+ norm_layer=nn.BatchNorm2d
335
+ elif norm==None:
336
+ norm_layer=nn.Identity
337
+
338
+ self.use_cat_embedding=use_cat_embedding
339
+ self.input_layer = nn.Sequential(
340
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
341
+ norm_layer(filters[0]),
342
+ nn.ReLU(),
343
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
344
+ )
345
+ self.input_skip = nn.Sequential(
346
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
347
+ )
348
+ self.use_par=use_par
349
+ input_1_channels=filters[0]
350
+ if self.use_par:
351
+ self.par_conv = nn.Sequential(
352
+ nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1),
353
+ )
354
+ input_1_channels=filters[0]+par_channel
355
+ self.residual_conv_1 = block(input_1_channels, filters[1], 2, 1,reso=64
356
+ ,use_attn=False,triplane_padding=triplane_padding,norm=norm)
357
+ self.residual_conv_2 = block(filters[1], filters[2], 2, 1, reso=32,
358
+ use_attn=False,triplane_padding=triplane_padding,norm=norm)
359
+ self.residual_conv_3 = block(filters[2], filters[3], 2, 1,reso=16,
360
+ use_attn=False,triplane_padding=triplane_padding,norm=norm)
361
+ self.bridge = block(filters[3] , filters[4], 2, 1, reso=8
362
+ ,use_attn=False,triplane_padding=triplane_padding,norm=norm) #input reso is 8, output reso is 4
363
+
364
+
365
+ self.upsample_1 = Upsample(filters[4], filters[4], 2, 2)
366
+ self.up_residual_conv1 = block(filters[4] + filters[3], filters[3], 1, 1,reso=8,img_in_channels=img_in_channels,vit_reso=vit_reso,
367
+ use_attn=True,triplane_padding=triplane_padding,norm=norm)
368
+
369
+ self.upsample_2 = Upsample(filters[3], filters[3], 2, 2)
370
+ self.up_residual_conv2 = block(filters[3] + filters[2], filters[2], 1, 1,reso=16,img_in_channels=img_in_channels,vit_reso=vit_reso,
371
+ use_attn=True,triplane_padding=triplane_padding,norm=norm)
372
+
373
+ self.upsample_3 = Upsample(filters[2], filters[2], 2, 2)
374
+ self.up_residual_conv3 = block(filters[2] + filters[1], filters[1], 1, 1,reso=32,img_in_channels=img_in_channels,vit_reso=vit_reso,
375
+ use_attn=True,triplane_padding=triplane_padding,norm=norm)
376
+
377
+ self.upsample_4 = Upsample(filters[1], filters[1], 2, 2)
378
+ self.up_residual_conv4 = block(filters[1] + input_1_channels, filters[0], 1, 1, reso=64,
379
+ use_attn=False,triplane_padding=triplane_padding,norm=norm)
380
+
381
+ self.output_layer = nn.Sequential(
382
+ nn.LayerNorm(64), #normalize along width dimension, usually it should normalize along channel dimension,
383
+ # I don't know why, but the finetuning performance increase significantly
384
+ #nn.LayerNorm([filters[0], 192, 64]),
385
+ zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)),
386
+ )
387
+ self.t_emb_layer=PositionalEmbedding(256)
388
+ if use_cat_embedding:
389
+ self.cat_emb = nn.Linear(
390
+ in_features=6,
391
+ out_features=256,
392
+ )
393
+
394
+ def forward(self, x,t,image_emb,proj_mat,valid_frames,category_code,par_point_feat=None):
395
+ # Encode
396
+ t_emb=self.t_emb_layer(t)
397
+ if self.use_cat_embedding:
398
+ cat_emb=self.cat_emb(category_code)
399
+ t_emb=t_emb+cat_emb
400
+ x1 = self.input_layer(x) + self.input_skip(x)
401
+ if self.use_par:
402
+ par_point_feat=self.par_conv(par_point_feat)
403
+ x1 = torch.cat([x1, par_point_feat], dim=1)
404
+ x2 = self.residual_conv_1(x1,t_emb,image_emb,proj_mat,valid_frames)
405
+ x3 = self.residual_conv_2(x2,t_emb,image_emb,proj_mat,valid_frames)
406
+ x4 = self.residual_conv_3(x3,t_emb,image_emb,proj_mat,valid_frames)
407
+ x5 = self.bridge(x4,t_emb,image_emb,proj_mat,valid_frames)
408
+
409
+ x6=self.upsample_1(x5)
410
+ x6=torch.cat([x6,x4],dim=1)
411
+ x7=self.up_residual_conv1(x6,t_emb,image_emb,proj_mat,valid_frames)
412
+
413
+ x7=self.upsample_2(x7)
414
+ x7=torch.cat([x7,x3],dim=1)
415
+ x8=self.up_residual_conv2(x7,t_emb,image_emb,proj_mat,valid_frames)
416
+
417
+ x8 = self.upsample_3(x8)
418
+ x8 = torch.cat([x8, x2], dim=1)
419
+ #print(x8.shape)
420
+ x9 = self.up_residual_conv3(x8,t_emb,image_emb,proj_mat,valid_frames)
421
+
422
+ x9 = self.upsample_4(x9)
423
+ x9 = torch.cat([x9, x1], dim=1)
424
+ x10 = self.up_residual_conv4(x9,t_emb,image_emb,proj_mat,valid_frames)
425
+
426
+ output=self.output_layer(x10)
427
+
428
+ return output
429
+
430
+
431
+ if __name__=="__main__":
432
+ net=ResUnet(32,output_channel=32).float().cuda()
433
+ n_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
434
+ print("Model = %s" % str(net))
435
+ print('number of params (M): %.2f' % (n_parameters / 1.e6))
436
+ par_point_feat=torch.randn((10,32,64*3,64)).float().cuda()
437
+ input=torch.randn((10,32,64*3,64)).float().cuda()
438
+ t=torch.randn((10,1,1,1)).float().cuda()
439
+ output=net(input,t.flatten(),par_point_feat)
440
+ #print(output.shape)
models/modules/unet.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Codes are from:
3
+ https://github.com/jaxony/unet-pytorch/blob/master/model.py
4
+ '''
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.autograd import Variable
10
+ from collections import OrderedDict
11
+ from torch.nn import init
12
+ import numpy as np
13
+
14
+
15
+ def conv3x3(in_channels, out_channels, stride=1,
16
+ padding=1, bias=True, groups=1):
17
+ return nn.Conv2d(
18
+ in_channels,
19
+ out_channels,
20
+ kernel_size=3,
21
+ stride=stride,
22
+ padding=padding,
23
+ bias=bias,
24
+ groups=groups)
25
+
26
+
27
+ def upconv2x2(in_channels, out_channels, mode='transpose'):
28
+ if mode == 'transpose':
29
+ return nn.ConvTranspose2d(
30
+ in_channels,
31
+ out_channels,
32
+ kernel_size=2,
33
+ stride=2)
34
+ else:
35
+ # out_channels is always going to be the same
36
+ # as in_channels
37
+ return nn.Sequential(
38
+ nn.Upsample(mode='bilinear', scale_factor=2),
39
+ conv1x1(in_channels, out_channels))
40
+
41
+
42
+ def conv1x1(in_channels, out_channels, groups=1):
43
+ return nn.Conv2d(
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size=1,
47
+ groups=groups,
48
+ stride=1)
49
+
50
+ class RollOut_Conv(nn.Module):
51
+ def __init__(self,in_channels,out_channels):
52
+ super(RollOut_Conv,self).__init__()
53
+ #pass
54
+ self.in_channels=in_channels
55
+ self.out_channels=out_channels
56
+ self.conv = conv3x3(self.in_channels*3, self.out_channels)
57
+
58
+ def forward(self,row_features):
59
+ H,W=row_features.shape[2],row_features.shape[3]
60
+ H_per=H//3
61
+ xz_feature,xy_feature,yz_feature=torch.split(row_features,dim=2,split_size_or_sections=H_per)
62
+ xy_row_pool=torch.mean(xy_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
63
+ yz_col_pool=torch.mean(yz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
64
+ cat_xz_feat=torch.cat([xz_feature,xy_row_pool,yz_col_pool],dim=1)
65
+
66
+ xz_row_pool=torch.mean(xz_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
67
+ zy_feature=yz_feature.transpose(2,3) #switch z y axis, for reduced confusion
68
+ zy_col_pool=torch.mean(zy_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
69
+ cat_xy_feat=torch.cat([xy_feature,xz_row_pool,zy_col_pool],dim=1)
70
+
71
+ xz_col_pool=torch.mean(xz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
72
+ yx_feature=xy_feature.transpose(2,3)
73
+ yx_row_pool=torch.mean(yx_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
74
+ cat_yz_feat=torch.cat([yz_feature,yx_row_pool,xz_col_pool],dim=1)
75
+
76
+ fuse_row_feat=torch.cat([cat_xz_feat,cat_xy_feat,cat_yz_feat],dim=2) #concat at row dimension
77
+
78
+ x = self.conv(fuse_row_feat)
79
+
80
+ return x
81
+
82
+
83
+ class DownConv(nn.Module):
84
+ """
85
+ A helper Module that performs 2 convolutions and 1 MaxPool.
86
+ A ReLU activation follows each convolution.
87
+ """
88
+
89
+ def __init__(self, in_channels, out_channels, pooling=True):
90
+ super(DownConv, self).__init__()
91
+
92
+ self.in_channels = in_channels
93
+ self.out_channels = out_channels
94
+ self.pooling = pooling
95
+
96
+ self.conv1 = conv3x3(self.in_channels, self.out_channels)
97
+ self.Rollout_conv=RollOut_Conv(self.out_channels,self.out_channels)
98
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
99
+
100
+ if self.pooling:
101
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
102
+
103
+ def forward(self, x):
104
+ x = F.relu(self.conv1(x))
105
+ x = F.relu(self.Rollout_conv(x))
106
+ x = F.relu(self.conv2(x))
107
+ before_pool = x
108
+ if self.pooling:
109
+ x = self.pool(x)
110
+ return x, before_pool
111
+
112
+
113
+ class UpConv(nn.Module):
114
+ """
115
+ A helper Module that performs 2 convolutions and 1 UpConvolution.
116
+ A ReLU activation follows each convolution.
117
+ """
118
+
119
+ def __init__(self, in_channels, out_channels,
120
+ merge_mode='concat', up_mode='transpose'):
121
+ super(UpConv, self).__init__()
122
+
123
+ self.in_channels = in_channels
124
+ self.out_channels = out_channels
125
+ self.merge_mode = merge_mode
126
+ self.up_mode = up_mode
127
+
128
+ self.upconv = upconv2x2(self.in_channels, self.out_channels,
129
+ mode=self.up_mode)
130
+
131
+ if self.merge_mode == 'concat':
132
+ self.conv1 = conv3x3(
133
+ 2 * self.out_channels, self.out_channels)
134
+ else:
135
+ # num of input channels to conv2 is same
136
+ self.conv1 = conv3x3(self.out_channels, self.out_channels)
137
+ self.Rollout_conv = RollOut_Conv(self.out_channels, self.out_channels)
138
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
139
+
140
+ def forward(self, from_down, from_up):
141
+ """ Forward pass
142
+ Arguments:
143
+ from_down: tensor from the encoder pathway
144
+ from_up: upconv'd tensor from the decoder pathway
145
+ """
146
+ from_up = self.upconv(from_up)
147
+ if self.merge_mode == 'concat':
148
+ x = torch.cat((from_up, from_down), 1)
149
+ else:
150
+ x = from_up + from_down
151
+ x = F.relu(self.conv1(x))
152
+ x = F.relu(self.Rollout_conv(x))
153
+ x = F.relu(self.conv2(x))
154
+ return x
155
+
156
+
157
+ class UNet(nn.Module):
158
+ """ `UNet` class is based on https://arxiv.org/abs/1505.04597
159
+
160
+ The U-Net is a convolutional encoder-decoder neural network.
161
+ Contextual spatial information (from the decoding,
162
+ expansive pathway) about an input tensor is merged with
163
+ information representing the localization of details
164
+ (from the encoding, compressive pathway).
165
+
166
+ Modifications to the original paper:
167
+ (1) padding is used in 3x3 convolutions to prevent loss
168
+ of border pixels
169
+ (2) merging outputs does not require cropping due to (1)
170
+ (3) residual connections can be used by specifying
171
+ UNet(merge_mode='add')
172
+ (4) if non-parametric upsampling is used in the decoder
173
+ pathway (specified by upmode='upsample'), then an
174
+ additional 1x1 2d convolution occurs after upsampling
175
+ to reduce channel dimensionality by a factor of 2.
176
+ This channel halving happens with the convolution in
177
+ the tranpose convolution (specified by upmode='transpose')
178
+ """
179
+
180
+ def __init__(self, num_classes, in_channels=3, depth=5,
181
+ start_filts=64, up_mode='transpose',
182
+ merge_mode='concat', **kwargs):
183
+ """
184
+ Arguments:
185
+ in_channels: int, number of channels in the input tensor.
186
+ Default is 3 for RGB images.
187
+ depth: int, number of MaxPools in the U-Net.
188
+ start_filts: int, number of convolutional filters for the
189
+ first conv.
190
+ up_mode: string, type of upconvolution. Choices: 'transpose'
191
+ for transpose convolution or 'upsample' for nearest neighbour
192
+ upsampling.
193
+ """
194
+ super(UNet, self).__init__()
195
+
196
+ if up_mode in ('transpose', 'upsample'):
197
+ self.up_mode = up_mode
198
+ else:
199
+ raise ValueError("\"{}\" is not a valid mode for "
200
+ "upsampling. Only \"transpose\" and "
201
+ "\"upsample\" are allowed.".format(up_mode))
202
+
203
+ if merge_mode in ('concat', 'add'):
204
+ self.merge_mode = merge_mode
205
+ else:
206
+ raise ValueError("\"{}\" is not a valid mode for"
207
+ "merging up and down paths. "
208
+ "Only \"concat\" and "
209
+ "\"add\" are allowed.".format(up_mode))
210
+
211
+ # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
212
+ if self.up_mode == 'upsample' and self.merge_mode == 'add':
213
+ raise ValueError("up_mode \"upsample\" is incompatible "
214
+ "with merge_mode \"add\" at the moment "
215
+ "because it doesn't make sense to use "
216
+ "nearest neighbour to reduce "
217
+ "depth channels (by half).")
218
+
219
+ self.num_classes = num_classes
220
+ self.in_channels = in_channels
221
+ self.start_filts = start_filts
222
+ self.depth = depth
223
+
224
+ self.down_convs = []
225
+ self.up_convs = []
226
+
227
+ # create the encoder pathway and add to a list
228
+ for i in range(depth):
229
+ ins = self.in_channels if i == 0 else outs
230
+ outs = self.start_filts * (2 ** i)
231
+ pooling = True if i < depth - 1 else False
232
+
233
+ down_conv = DownConv(ins, outs, pooling=pooling)
234
+ self.down_convs.append(down_conv)
235
+
236
+ # create the decoder pathway and add to a list
237
+ # - careful! decoding only requires depth-1 blocks
238
+ for i in range(depth - 1):
239
+ ins = outs
240
+ outs = ins // 2
241
+ up_conv = UpConv(ins, outs, up_mode=up_mode,
242
+ merge_mode=merge_mode)
243
+ self.up_convs.append(up_conv)
244
+
245
+ # add the list of modules to current module
246
+ self.down_convs = nn.ModuleList(self.down_convs)
247
+ self.up_convs = nn.ModuleList(self.up_convs)
248
+ self.conv_final = conv1x1(outs, self.num_classes)
249
+
250
+ self.reset_params()
251
+
252
+ @staticmethod
253
+ def weight_init(m):
254
+ if isinstance(m, nn.Conv2d):
255
+ init.xavier_normal_(m.weight)
256
+ init.constant_(m.bias, 0)
257
+
258
+ def reset_params(self):
259
+ for i, m in enumerate(self.modules()):
260
+ self.weight_init(m)
261
+
262
+ def forward(self, feature_plane):
263
+ #cat_feature=torch.cat([feature_plane['xz'],feature_plane['xy'],feature_plane,feature_plane['yz']],dim=2) #concat at row dimension
264
+ x=feature_plane
265
+ encoder_outs = []
266
+ # encoder pathway, save outputs for merging
267
+ for i, module in enumerate(self.down_convs):
268
+ x, before_pool = module(x)
269
+ encoder_outs.append(before_pool)
270
+ for i, module in enumerate(self.up_convs):
271
+ before_pool = encoder_outs[-(i + 2)]
272
+ x = module(before_pool, x)
273
+
274
+ # No softmax is used. This means you need to use
275
+ # nn.CrossEntropyLoss is your training script,
276
+ # as this module includes a softmax already.
277
+ x = self.conv_final(x)
278
+ return x
279
+
280
+
281
+ if __name__ == "__main__":
282
+ # """
283
+ # testing
284
+ # """
285
+ # model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
286
+ # print(model)
287
+ # print(sum(p.numel() for p in model.parameters()))
288
+ #
289
+ # reso = 176
290
+ # x = np.zeros((1, 1, reso, reso))
291
+ # x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
292
+ # x = torch.FloatTensor(x)
293
+ #
294
+ # out = model(x)
295
+ # print('%f' % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso)))
296
+ #
297
+ # # loss = torch.sum(out)
298
+ # # loss.backward()
299
+ #roll_out_conv=RollOut_Conv(in_channels=32,out_channels=32).cuda().float()
300
+ model=UNet(32, depth=5, merge_mode='concat', in_channels=32, start_filts=32).cuda().float()
301
+ row_feature=torch.randn((10,32,128*3,128)).cuda().float()
302
+ output=model(row_feature)
303
+ #output_feature=roll_out_conv(row_feature)
304
+ #print(output_feature.shape)
models/modules/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def zero_module(module):
4
+ """
5
+ Zero out the parameters of a module and return it.
6
+ """
7
+ for p in module.parameters():
8
+ p.detach().zero_()
9
+ return module
10
+
11
+ class StackedRandomGenerator:
12
+ def __init__(self, device, seeds):
13
+ super().__init__()
14
+ self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
15
+
16
+ def randn(self, size, **kwargs):
17
+ assert size[0] == len(self.generators)
18
+ return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
19
+
20
+ def randn_like(self, input):
21
+ return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
22
+
23
+ def randint(self, *args, size, **kwargs):
24
+ assert size[0] == len(self.generators)
25
+ return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
output/put_checkpoints_here ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
process_scripts/augment_arkit_partial_point.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ import os
4
+ import trimesh
5
+ from sklearn.cluster import KMeans
6
+ import random
7
+ import glob
8
+ import tqdm
9
+ import argparse
10
+ import multiprocessing as mp
11
+ import sys
12
+ sys.path.append("..")
13
+ from datasets.taxonomy import arkit_category
14
+
15
+ parser=argparse.ArgumentParser()
16
+ parser.add_argument('--category',nargs="+",type=str)
17
+ parser.add_argument("--keyword",type=str,default="lowres") #augment only the low resolution points
18
+ parser.add_argument("--data_root",type=str,default="../data/other_data")
19
+ args=parser.parse_args()
20
+ category=args.category
21
+ if category[0]=="all":
22
+ category=arkit_category["all"]
23
+ kmeans=KMeans(
24
+ init="random",
25
+ n_clusters=20,
26
+ n_init=10,
27
+ max_iter=300,
28
+ random_state=42
29
+ )
30
+
31
+ def process_data(src_point_path,save_folder,keyword):
32
+ src_point_tri = trimesh.load(src_point_path)
33
+ src_point = np.asarray(src_point_tri.vertices)
34
+ kmeans.fit(src_point)
35
+ point_cluster_index = kmeans.labels_
36
+
37
+ '''choose 10~19 clusters to form the augmented new point'''
38
+ for i in range(10):
39
+ n_cluster = random.randint(14, 19) # 14,19 for lowres, 10,19 for highres
40
+ choose_cluster = np.random.choice(20, n_cluster, replace=False)
41
+ aug_point_list = []
42
+ for cluster_index in choose_cluster:
43
+ cluster_point = src_point[point_cluster_index == cluster_index]
44
+ aug_point_list.append(cluster_point)
45
+ aug_point = np.concatenate(aug_point_list, axis=0)
46
+ save_path = os.path.join(save_folder, "%s_partial_points_%d.ply" % (keyword, i + 1))
47
+ print("saving to %s"%(save_path))
48
+ aug_point_tri = trimesh.PointCloud(vertices=aug_point)
49
+ aug_point_tri.export(save_path)
50
+
51
+ pool=mp.Pool(10)
52
+ for cat in category[0:]:
53
+ keyword=args.keyword
54
+ point_dir = os.path.join(args.data_root,cat,"5_partial_points")
55
+ folder_list=os.listdir(point_dir)
56
+ for folder in tqdm.tqdm(folder_list[0:]):
57
+ folder_path=os.path.join(point_dir,folder)
58
+ src_point_path=os.path.join(point_dir,folder,"%s_partial_points_0.ply"%(keyword))
59
+ if os.path.exists(src_point_path)==False:
60
+ continue
61
+ save_folder=folder_path
62
+ pool.apply_async(process_data,(src_point_path,save_folder,keyword))
63
+ pool.close()
64
+ pool.join()
process_scripts/augment_synthetic_partial_points.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ import os
4
+ import trimesh
5
+ from sklearn.cluster import KMeans
6
+ import random
7
+ import glob
8
+ import tqdm
9
+ import multiprocessing as mp
10
+ import sys
11
+ sys.path.append("..")
12
+ from datasets.taxonomy import synthetic_category_combined
13
+
14
+ import argparse
15
+ parser=argparse.ArgumentParser()
16
+ parser.add_argument("--category",nargs="+",type=str)
17
+ parser.add_argument("--root_dir",type=str,default="../data/other_data")
18
+ args=parser.parse_args()
19
+ categories=args.category
20
+ if categories[0]=="all":
21
+ categories=synthetic_category_combined["all"]
22
+
23
+ kmeans=KMeans(
24
+ init="random",
25
+ n_clusters=7,
26
+ n_init=10,
27
+ max_iter=300,
28
+ random_state=42
29
+ )
30
+
31
+ def process_data(src_filepath,save_path):
32
+ #print("processing %s"%(src_filepath))
33
+ src_point_tri = trimesh.load(src_filepath)
34
+ src_point = np.asarray(src_point_tri.vertices)
35
+ kmeans.fit(src_point)
36
+ point_cluster_index = kmeans.labels_
37
+
38
+ n_cluster = random.randint(3, 6)
39
+ choose_cluster = np.random.choice(7, n_cluster, replace=False)
40
+ aug_point_list = []
41
+ for cluster_index in choose_cluster:
42
+ cluster_point = src_point[point_cluster_index == cluster_index]
43
+ aug_point_list.append(cluster_point)
44
+ aug_point = np.concatenate(aug_point_list, axis=0)
45
+ aug_point_tri = trimesh.PointCloud(vertices=aug_point)
46
+ print("saving to %s"%(save_path))
47
+ aug_point_tri.export(save_path)
48
+
49
+ pool=mp.Pool(10)
50
+ for cat in categories:
51
+ print("processing %s"%cat)
52
+ point_dir=os.path.join(args.root_dir,cat,"5_partial_points")
53
+ folder_list=os.listdir(point_dir)
54
+ for folder in folder_list[:]:
55
+ folder_path=os.path.join(point_dir,folder)
56
+ src_filelist=glob.glob(folder_path+"/partial_points_*.ply")
57
+ for src_filepath in src_filelist:
58
+ basename=os.path.basename(src_filepath)
59
+ save_path = os.path.join(point_dir, folder, "aug7_" + basename)
60
+ pool.apply_async(process_data,(src_filepath,save_path))
61
+ pool.close()
62
+ pool.join()
63
+
64
+
process_scripts/dist_export_triplane_features.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15002 --nproc_per_node=2 \
2
+ export_triplane_features.py \
3
+ --configs ../configs/train_triplane_vae.yaml \
4
+ --batch_size 10 \
5
+ --ae-pth ../output/ae/chair/best-checkpoint.pth \
6
+ --data-pth ../data \
7
+ --category arkit_chair 03001627 future_chair arkit_stool future_stool ABO_chair
8
+ #sub category
process_scripts/dist_extract_vit.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15000 --nproc_per_node=2 \
2
+ extract_img_vit_features.py \
3
+ --batch_size 24 \
4
+ --ckpt_path ../data/open_clip_pytorch_model.bin \
5
+ --category arkit_chair 03001627 future_chair arkit_stool future_stool ABO_chair #sub category
6
+ #--category 02871439 future_shelf ABO_shelf arkit_shelf \
process_scripts/export_triplane_features.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import sys
4
+ sys.path.append("..")
5
+ import numpy as np
6
+ import os
7
+ import torch
8
+
9
+ import trimesh
10
+
11
+ from datasets import Object_Occ,Scale_Shift_Rotate
12
+ from models import get_model
13
+ from pathlib import Path
14
+ import open3d as o3d
15
+ from configs.config_utils import CONFIG
16
+ import tqdm
17
+ from util import misc
18
+ from datasets.taxonomy import synthetic_arkit_category_combined
19
+
20
+ if __name__ == "__main__":
21
+
22
+ parser = argparse.ArgumentParser('', add_help=False)
23
+ parser.add_argument('--configs',type=str,required=True)
24
+ parser.add_argument('--ae-pth',type=str)
25
+ parser.add_argument("--category",nargs='+', type=str)
26
+ parser.add_argument('--world_size', default=1, type=int,
27
+ help='number of distributed processes')
28
+ parser.add_argument('--local_rank', default=-1, type=int)
29
+ parser.add_argument('--dist_on_itp', action='store_true')
30
+ parser.add_argument('--dist_url', default='env://',
31
+ help='url used to set up distributed training')
32
+ parser.add_argument('--device', default='cuda',
33
+ help='device to use for training / testing')
34
+ parser.add_argument("--batch_size", default=1, type=int)
35
+ parser.add_argument("--data-pth",default="../data",type=str)
36
+
37
+ args = parser.parse_args()
38
+ misc.init_distributed_mode(args)
39
+ device = torch.device(args.device)
40
+
41
+ config_path=args.configs
42
+ config=CONFIG(config_path)
43
+ dataset_config=config.config['dataset']
44
+ dataset_config['data_path']=args.data_pth
45
+ #transform = AxisScaling((0.75, 1.25), True)
46
+ transform=Scale_Shift_Rotate(rot_shift_surface=True,use_scale=True)
47
+ if len(args.category)==1 and args.category[0]=="all":
48
+ category=synthetic_arkit_category_combined["all"]
49
+ else:
50
+ category=args.category
51
+ train_dataset = Object_Occ(dataset_config['data_path'], split="train",
52
+ categories=category,
53
+ transform=transform, sampling=True,
54
+ num_samples=1024, return_surface=True,
55
+ surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1)
56
+ val_dataset = Object_Occ(dataset_config['data_path'], split="val",
57
+ categories=category,
58
+ transform=transform, sampling=True,
59
+ num_samples=1024, return_surface=True,
60
+ surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1)
61
+ num_tasks = misc.get_world_size()
62
+ global_rank = misc.get_rank()
63
+ train_sampler = torch.utils.data.DistributedSampler(
64
+ train_dataset, num_replicas=num_tasks, rank=global_rank,
65
+ shuffle=False) # shuffle=True to reduce monitor bias
66
+ val_sampler=torch.utils.data.DistributedSampler(
67
+ val_dataset, num_replicas=num_tasks, rank=global_rank,
68
+ shuffle=False) # shu
69
+ #dataset=val_dataset
70
+ batch_size=args.batch_size
71
+ train_dataloader=torch.utils.data.DataLoader(
72
+ train_dataset,sampler=train_sampler,
73
+ batch_size=batch_size,
74
+ num_workers=10,
75
+ shuffle=False,
76
+ drop_last=False,
77
+ )
78
+ val_dataloader = torch.utils.data.DataLoader(
79
+ val_dataset, sampler=val_sampler,
80
+ batch_size=batch_size,
81
+ num_workers=10,
82
+ shuffle=False,
83
+ drop_last=False,
84
+ )
85
+ dataloader_list=[train_dataloader,val_dataloader]
86
+ #dataloader_list=[val_dataloader]
87
+ output_dir=os.path.join(dataset_config['data_path'],"other_data")
88
+ #output_dir="/data1/haolin/datasets/ShapeNetV2_watertight"
89
+
90
+ model_config=config.config['model']
91
+ model=get_model(model_config)
92
+ model.load_state_dict(torch.load(args.ae_pth)['model'])
93
+ model.eval().float().to(device)
94
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
95
+
96
+ with torch.no_grad():
97
+ for e in range(5):
98
+ for dataloader in dataloader_list:
99
+ for data_iter_step, data_batch in tqdm.tqdm(enumerate(dataloader)):
100
+ surface = data_batch['surface'].to(device, non_blocking=True)
101
+ model_ids=data_batch['model_id']
102
+ tran_mats=data_batch['tran_mat']
103
+ categories=data_batch['category']
104
+ with torch.no_grad():
105
+ plane_feat,_,means,logvars=model.encode(surface)
106
+ plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode='bilinear')
107
+ vars=torch.exp(logvars)
108
+ means=torch.nn.functional.interpolate(means,scale_factor=0.5,mode="bilinear")
109
+ vars=torch.nn.functional.interpolate(vars,scale_factor=0.5,mode="bilinear")/4
110
+ sample_logvars=torch.log(vars)
111
+
112
+ for j in range(means.shape[0]):
113
+ #plane_dist=plane_feat[j].float().cpu().numpy()
114
+ mean=means[j].float().cpu().numpy()
115
+ logvar=sample_logvars[j].float().cpu().numpy()
116
+ tran_mat=tran_mats[j].float().cpu().numpy()
117
+
118
+ output_folder=os.path.join(output_dir,categories[j],'9_triplane_kl25_64',model_ids[j])
119
+ Path(output_folder).mkdir(parents=True, exist_ok=True)
120
+ exist_len=len(os.listdir(output_folder))
121
+ save_filepath=os.path.join(output_folder,"triplane_feat_%d.npz"%(exist_len))
122
+ np.savez_compressed(save_filepath,mean=mean,logvar=logvar,tran_mat=tran_mat)
process_scripts/extract_img_vit_features.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+ sys.path.append("..")
3
+ from util.simple_image_loader import Image_dataset
4
+ from torch.utils.data import DataLoader
5
+ import timm
6
+ import torch
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+ from transformers import DPTForDepthEstimation, DPTFeatureExtractor
10
+ import argparse
11
+ from util import misc
12
+ from datasets.taxonomy import synthetic_arkit_category_combined
13
+ parser=argparse.ArgumentParser()
14
+
15
+ parser.add_argument("--category",nargs="+",type=str)
16
+ parser.add_argument("--root_dir",type=str, default="../data")
17
+ parser.add_argument("--ckpt_path",type=str,default="../open_clip_pytorch_model.bin")
18
+ parser.add_argument("--batch_size",type=int,default=24)
19
+ parser.add_argument('--world_size', default=1, type=int,
20
+ help='number of distributed processes')
21
+ parser.add_argument('--local_rank', default=-1, type=int)
22
+ parser.add_argument('--dist_on_itp', action='store_true')
23
+ parser.add_argument('--dist_url', default='env://',
24
+ help='url used to set up distributed training')
25
+ args= parser.parse_args()
26
+ misc.init_distributed_mode(args)
27
+ category=args.category
28
+
29
+ #dataset=Image_dataset(categories=['03001627','ABO_chair','future_chair'])
30
+ if args.category[0]=="all":
31
+ category=synthetic_arkit_category_combined["all"]
32
+ print("loading dataset")
33
+ dataset=Image_dataset(dataset_folder=args.root_dir,categories=category,n_px=224)
34
+ num_tasks = misc.get_world_size()
35
+ global_rank = misc.get_rank()
36
+ sampler = torch.utils.data.DistributedSampler(
37
+ dataset, num_replicas=num_tasks, rank=global_rank,
38
+ shuffle=False) # shuffle=True to reduce monitor bias
39
+
40
+ dataloader=DataLoader(
41
+ dataset,
42
+ sampler=sampler,
43
+ batch_size=args.batch_size,
44
+ num_workers=4,
45
+ pin_memory=True,
46
+ drop_last=False
47
+ )
48
+ print("loading model")
49
+ VIT_MODEL = 'vit_huge_patch14_224_clip_laion2b'
50
+ model=timm.create_model(VIT_MODEL, pretrained=True,pretrained_cfg_overlay=dict(file=args.ckpt_path))
51
+ model=model.eval().float().cuda()
52
+ save_dir=os.path.join(args.root_dir,"other_data")
53
+ for idx,data_batch in enumerate(dataloader):
54
+ if idx%50==0:
55
+ print("{}/{}".format(dataloader.__len__(),idx))
56
+ images = data_batch["images"].cuda().float()
57
+ model_id= data_batch["model_id"]
58
+ image_name=data_batch["image_name"]
59
+ category=data_batch["category"]
60
+ with torch.no_grad():
61
+ #output=model(images,output_hidden_states=True)
62
+ output_features=model.forward_features(images)
63
+ #predict_depth=output.predicted_depth
64
+ #print(predict_depth.shape)
65
+ for j in range(output_features.shape[0]):
66
+ save_folder=os.path.join(save_dir,category[j],"7_img_features",model_id[j])
67
+ os.makedirs(save_folder,exist_ok=True)
68
+ save_path=os.path.join(save_folder,image_name[j]+".npz")
69
+ #print("saving to",save_path)
70
+ np.savez_compressed(save_path,img_features=output_features[j].detach().cpu().numpy().astype(np.float32))
71
+
72
+
73
+
process_scripts/generate_split_for_arkit.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import glob
4
+ import open3d as o3d
5
+ import json
6
+ import argparse
7
+ import glob
8
+
9
+ parser=argparse.ArgumentParser()
10
+ parser.add_argument("--cat",required=True,type=str,nargs="+")
11
+ parser.add_argument("--keyword",default="lowres",type=str)
12
+ parser.add_argument("--root_dir",type=str,default="../data")
13
+ args=parser.parse_args()
14
+
15
+ keyword=args.keyword
16
+ sdf_folder="occ_data"
17
+ other_folder="other_data"
18
+ data_dir=args.root_dir
19
+
20
+ align_dir=os.path.join(args.root_dir,"align_mat_all") # this alignment matrix is aligned from highres scan to lowres scan
21
+ # the alignment matrix is still under cleaning, not all the data have proper alignment matrix yet.
22
+ align_filelist=glob.glob(align_dir+"/*/*.txt")
23
+ valid_model_list=[]
24
+ for align_filepath in align_filelist:
25
+ if "-v" in align_filepath:
26
+ align_mat=np.loadtxt(align_filepath)
27
+ if align_mat.shape[0]!=4:
28
+ continue
29
+ model_id=os.path.basename(align_filepath).split("-")[0]
30
+ valid_model_list.append(model_id)
31
+
32
+ print("there are %d valid lowres models"%(len(valid_model_list)))
33
+
34
+ category_list=args.cat
35
+ for category in category_list:
36
+ train_path=os.path.join(data_dir,sdf_folder,category,"train.lst")
37
+ with open(train_path,'r') as f:
38
+ train_list=f.readlines()
39
+ train_list=[item.rstrip() for item in train_list]
40
+ if ".npz" in train_list[0]:
41
+ train_list=[item[:-4] for item in train_list]
42
+ val_path=os.path.join(data_dir,sdf_folder,category,"val.lst")
43
+ with open(val_path,'r') as f:
44
+ val_list=f.readlines()
45
+ val_list=[item.rstrip() for item in val_list]
46
+ if ".npz" in val_list[0]:
47
+ val_list=[item[:-4] for item in val_list]
48
+
49
+
50
+ sdf_dir=os.path.join(data_dir,sdf_folder,category)
51
+ filelist=os.listdir(sdf_dir)
52
+ model_id_list=[item[:-4] for item in filelist if ".npz" in item]
53
+
54
+ train_par_img_list=[]
55
+ val_par_img_list=[]
56
+ for model_id in model_id_list:
57
+ if model_id not in valid_model_list:
58
+ continue
59
+ image_dir=os.path.join(data_dir,other_folder,category,"6_images",model_id)
60
+ partial_dir=os.path.join(data_dir,other_folder,category,"5_partial_points",model_id)
61
+ if os.path.exists(image_dir)==False and os.path.exists(partial_dir)==False:
62
+ continue
63
+ if os.path.exists(image_dir):
64
+ image_list=glob.glob(image_dir+"/*.jpg")+glob.glob(image_dir+"/*.png")
65
+ image_list=[os.path.basename(image_path) for image_path in image_list]
66
+ else:
67
+ image_list=[]
68
+
69
+ if os.path.exists(partial_dir):
70
+ partial_list=glob.glob(partial_dir+"/%s_partial_points_*.ply"%(keyword))
71
+ else:
72
+ partial_list=[]
73
+ partial_valid_list=[]
74
+ for partial_filepath in partial_list:
75
+ par_o3d=o3d.io.read_point_cloud(partial_filepath)
76
+ par_xyz=np.asarray(par_o3d.points)
77
+ if par_xyz.shape[0]>2048:
78
+ partial_valid_list.append(os.path.basename(partial_filepath))
79
+ if model_id in val_list:
80
+ if "%s_partial_points_0.ply"%(keyword) in partial_valid_list:
81
+ partial_valid_list=["%s_partial_points_0.ply"%(keyword)]
82
+ else:
83
+ partial_valid_list=[]
84
+ if len(image_list)==0 and len(partial_valid_list)==0:
85
+ continue
86
+ ret_dict={
87
+ "model_id":model_id,
88
+ "image_filenames":image_list[:],
89
+ "partial_filenames":partial_valid_list[:]
90
+ }
91
+ if model_id in train_list:
92
+ train_par_img_list.append(ret_dict)
93
+ elif model_id in val_list:
94
+ val_par_img_list.append(ret_dict)
95
+
96
+ train_save_path=os.path.join(sdf_dir,"%s_train_par_img.json"%(keyword))
97
+ with open(train_save_path,'w') as f:
98
+ json.dump(train_par_img_list,f,indent=4)
99
+
100
+ val_save_path=os.path.join(sdf_dir,"%s_val_par_img.json"%(keyword))
101
+ with open(val_save_path,'w') as f:
102
+ json.dump(val_par_img_list,f,indent=4)
process_scripts/generate_split_for_synthetic_data.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+ import numpy as np
3
+ import glob
4
+ import open3d as o3d
5
+ import json
6
+ import argparse
7
+
8
+ parser=argparse.ArgumentParser()
9
+ parser.add_argument("--cat",required=True,type=str,nargs="+")
10
+ parser.add_argument("--root_dir",type=str,default="../data")
11
+ args=parser.parse_args()
12
+
13
+ sdf_folder="occ_data"
14
+ other_folder="other_folder"
15
+ data_dir=args.root_dir
16
+ category=args.cat
17
+ train_path=os.path.join(data_dir,sdf_folder,category,"train.lst")
18
+ with open(train_path,'r') as f:
19
+ train_list=f.readlines()
20
+ train_list=[item.rstrip() for item in train_list]
21
+ if ".npz" in train_list[0]:
22
+ train_list=[item[:-4] for item in train_list]
23
+ val_path=os.path.join(data_dir,sdf_folder,category,"val.lst")
24
+ with open(val_path,'r') as f:
25
+ val_list=f.readlines()
26
+ val_list=[item.rstrip() for item in val_list]
27
+ if ".npz" in val_list[0]:
28
+ val_list=[item[:-4] for item in val_list]
29
+
30
+ category_list=args.cat
31
+ for category in category_list:
32
+ sdf_dir=os.path.join(data_dir,sdf_folder,category)
33
+ filelist=os.listdir(sdf_dir)
34
+ model_id_list=[item[:-4] for item in filelist if ".npz" in item]
35
+
36
+ train_par_img_list=[]
37
+ val_par_img_list=[]
38
+ for model_id in model_id_list:
39
+ image_dir=os.path.join(data_dir,other_folder,category,"6_images",model_id)
40
+ partial_dir=os.path.join(data_dir,other_folder,category,"5_partial_points",model_id)
41
+ if os.path.exists(image_dir)==False and os.path.exists(partial_dir)==False:
42
+ continue
43
+ if os.path.exists(image_dir):
44
+ image_list=glob.glob(image_dir+"/*.jpg")+glob.glob(image_dir+"/*.png")
45
+ image_list=[os.path.basename(image_path) for image_path in image_list]
46
+ else:
47
+ image_list=[]
48
+
49
+ if os.path.exists(partial_dir):
50
+ partial_list=glob.glob(partial_dir+"/partial_points_*.ply")
51
+ else:
52
+ partial_list=[]
53
+ partial_valid_list=[]
54
+ for partial_filepath in partial_list:
55
+ par_o3d=o3d.io.read_point_cloud(partial_filepath)
56
+ par_xyz=np.asarray(par_o3d.points)
57
+ if par_xyz.shape[0]>2048:
58
+ partial_valid_list.append(os.path.basename(partial_filepath))
59
+ if len(image_list)==0 and len(partial_valid_list)==0:
60
+ continue
61
+ ret_dict={
62
+ "model_id":model_id,
63
+ "image_filenames":image_list[:],
64
+ "partial_filenames":partial_valid_list[:]
65
+ }
66
+ if model_id in train_list:
67
+ train_par_img_list.append(ret_dict)
68
+ elif model_id in val_list:
69
+ val_par_img_list.append(ret_dict)
70
+
71
+ #print(train_par_img_list)
72
+ train_save_path=os.path.join(sdf_dir,"train_par_img.json")
73
+ with open(train_save_path,'w') as f:
74
+ json.dump(train_par_img_list,f,indent=4)
75
+
76
+ val_save_path=os.path.join(sdf_dir,"val_par_img.json")
77
+ with open(val_save_path,'w') as f:
78
+ json.dump(val_par_img_list,f,indent=4)
process_scripts/unzip_all_data.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import argparse
4
+
5
+ parser = argparse.ArgumentParser("unzip the prepared data")
6
+ parser.add_argument("--occ_root", type=str, default="../data/occ_data")
7
+ parser.add_argument("--other_root", type=str,default="../data/other_data")
8
+ parser.add_argument("--unzip_occ",default=False,action="store_true")
9
+ parser.add_argument("--unzip_other",default=False,action="store_true")
10
+
11
+ args=parser.parse_args()
12
+ if args.unzip_occ:
13
+ filelist=os.listdir(args.occ_root)
14
+ for filename in filelist:
15
+ filepath=os.path.join(args.occ_root,filename)
16
+ if ".rar" in filename:
17
+ unrar_command="unrar x %s %s"%(filepath,args.occ_root)
18
+ os.system(unrar_command)
19
+ elif ".zip" in filename:
20
+ unzip_command="7z x %s -o%s"%(filepath,args.occ_root)
21
+ os.system(unzip_command)
22
+
23
+
24
+ if args.unzip_other:
25
+ category_list=os.listdir(args.other_root)
26
+ for category in category_list:
27
+ category_folder=os.path.join(args.other_root,category)
28
+ #print(category_folder)
29
+ rar_filelist=glob.glob(category_folder+"/*.rar")
30
+ zip_filelist=glob.glob(category_folder+"/*.zip")
31
+
32
+ for rar_filepath in rar_filelist:
33
+ unrar_command="unrar x %s %s"%(rar_filepath,category_folder)
34
+ os.system(unrar_command)
35
+ for zip_filepath in zip_filelist:
36
+ unzip_command="7z x %s -o%s"%(zip_filepath,category_folder)
37
+ os.system(unzip_command)
38
+