project-monai commited on
Commit
41c525d
·
verified ·
1 Parent(s): 3d672fd

Upload wholeBrainSeg_Large_UNEST_segmentation version 0.2.6

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/3DSlicer_use.png filter=lfs diff=lfs merge=lfs -text
37
+ docs/demo.png filter=lfs diff=lfs merge=lfs -text
38
+ docs/wholebrain.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
configs/inference.json ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os"
5
+ ],
6
+ "bundle_root": ".",
7
+ "output_dir": "$@bundle_root + '/eval'",
8
+ "dataset_dir": "$@bundle_root + '/dataset/images'",
9
+ "datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.nii.gz')))",
10
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
11
+ "network_def": {
12
+ "_target_": "scripts.networks.unest_base_patch_4.UNesT",
13
+ "in_channels": 1,
14
+ "out_channels": 133,
15
+ "patch_size": 4,
16
+ "depths": [
17
+ 2,
18
+ 2,
19
+ 8
20
+ ],
21
+ "embed_dim": [
22
+ 128,
23
+ 256,
24
+ 512
25
+ ],
26
+ "num_heads": [
27
+ 4,
28
+ 8,
29
+ 16
30
+ ]
31
+ },
32
+ "network": "$@network_def.to(@device)",
33
+ "preprocessing": {
34
+ "_target_": "Compose",
35
+ "transforms": [
36
+ {
37
+ "_target_": "LoadImaged",
38
+ "keys": "image"
39
+ },
40
+ {
41
+ "_target_": "EnsureChannelFirstd",
42
+ "keys": "image"
43
+ },
44
+ {
45
+ "_target_": "NormalizeIntensityd",
46
+ "keys": "image",
47
+ "nonzero": "True",
48
+ "channel_wise": "True"
49
+ },
50
+ {
51
+ "_target_": "EnsureTyped",
52
+ "keys": "image"
53
+ }
54
+ ]
55
+ },
56
+ "dataset": {
57
+ "_target_": "Dataset",
58
+ "data": "$[{'image': i} for i in @datalist]",
59
+ "transform": "@preprocessing"
60
+ },
61
+ "dataloader": {
62
+ "_target_": "DataLoader",
63
+ "dataset": "@dataset",
64
+ "batch_size": 1,
65
+ "shuffle": false,
66
+ "num_workers": 4
67
+ },
68
+ "inferer": {
69
+ "_target_": "SlidingWindowInferer",
70
+ "roi_size": [
71
+ 96,
72
+ 96,
73
+ 96
74
+ ],
75
+ "sw_batch_size": 4,
76
+ "overlap": 0.7
77
+ },
78
+ "postprocessing": {
79
+ "_target_": "Compose",
80
+ "transforms": [
81
+ {
82
+ "_target_": "Activationsd",
83
+ "keys": "pred",
84
+ "softmax": true
85
+ },
86
+ {
87
+ "_target_": "Invertd",
88
+ "keys": "pred",
89
+ "transform": "@preprocessing",
90
+ "orig_keys": "image",
91
+ "meta_key_postfix": "meta_dict",
92
+ "nearest_interp": false,
93
+ "to_tensor": true
94
+ },
95
+ {
96
+ "_target_": "AsDiscreted",
97
+ "keys": "pred",
98
+ "argmax": true
99
+ },
100
+ {
101
+ "_target_": "SaveImaged",
102
+ "keys": "pred",
103
+ "meta_keys": "pred_meta_dict",
104
+ "output_dir": "@output_dir"
105
+ }
106
+ ]
107
+ },
108
+ "handlers": [
109
+ {
110
+ "_target_": "CheckpointLoader",
111
+ "load_path": "$@bundle_root + '/models/model.pt'",
112
+ "load_dict": {
113
+ "model": "@network"
114
+ },
115
+ "strict": "True"
116
+ },
117
+ {
118
+ "_target_": "StatsHandler",
119
+ "iteration_log": false
120
+ }
121
+ ],
122
+ "evaluator": {
123
+ "_target_": "SupervisedEvaluator",
124
+ "device": "@device",
125
+ "val_data_loader": "@dataloader",
126
+ "network": "@network",
127
+ "inferer": "@inferer",
128
+ "postprocessing": "@postprocessing",
129
+ "val_handlers": "@handlers",
130
+ "amp": false
131
+ },
132
+ "evaluating": [
133
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
134
135
+ ]
136
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
3
+ "version": "0.2.6",
4
+ "changelog": {
5
+ "0.2.6": "update to huggingface hosting",
6
+ "0.2.5": "update large files",
7
+ "0.2.4": "fix black 24.1 format error",
8
+ "0.2.3": "fix PYTHONPATH in readme.md",
9
+ "0.2.2": "add name tag",
10
+ "0.2.1": "fix license Copyright error",
11
+ "0.2.0": "update license files",
12
+ "0.1.2": "Add training support for whole brain segmentation, users can use active learning in the MONAI Label",
13
+ "0.1.1": "Fix dimension according to MONAI 1.0 and fix readme file",
14
+ "0.1.0": "complete the model package"
15
+ },
16
+ "monai_version": "1.4.0",
17
+ "pytorch_version": "2.4.0",
18
+ "numpy_version": "1.24.4",
19
+ "required_packages_version": {
20
+ "nibabel": "5.2.1",
21
+ "pytorch-ignite": "0.4.11",
22
+ "einops": "0.7.0",
23
+ "fire": "0.6.0",
24
+ "timm": "0.6.7",
25
+ "torchvision": "0.19.0",
26
+ "tensorboard": "2.17.0"
27
+ },
28
+ "supported_apps": {},
29
+ "name": "Whole brain large UNEST segmentation",
30
+ "task": "Whole Brain Segmentation",
31
+ "description": "A 3D transformer-based model for whole brain segmentation from T1W MRI image",
32
+ "authors": "Vanderbilt University + MONAI team",
33
+ "copyright": "Copyright (c) MONAI Consortium",
34
+ "data_source": "",
35
+ "data_type": "nibabel",
36
+ "image_classes": "single channel data, intensity scaled to [0, 1]",
37
+ "label_classes": "133 Classes",
38
+ "pred_classes": "133 Classes",
39
+ "eval_metrics": {
40
+ "mean_dice": 0.71
41
+ },
42
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
43
+ "references": [
44
+ "Xin, et al. Characterizing Renal Structures with 3D Block Aggregate Transformers. arXiv preprint arXiv:2203.02430 (2022). https://arxiv.org/pdf/2203.02430.pdf"
45
+ ],
46
+ "network_data_format": {
47
+ "inputs": {
48
+ "image": {
49
+ "type": "image",
50
+ "format": "hounsfield",
51
+ "modality": "MRI",
52
+ "num_channels": 1,
53
+ "spatial_shape": [
54
+ 96,
55
+ 96,
56
+ 96
57
+ ],
58
+ "dtype": "float32",
59
+ "value_range": [
60
+ 0,
61
+ 1
62
+ ],
63
+ "is_patch_data": true,
64
+ "channel_def": {
65
+ "0": "image"
66
+ }
67
+ }
68
+ },
69
+ "outputs": {
70
+ "pred": {
71
+ "type": "image",
72
+ "format": "segmentation",
73
+ "num_channels": 133,
74
+ "spatial_shape": [
75
+ 96,
76
+ 96,
77
+ 96
78
+ ],
79
+ "dtype": "float32",
80
+ "value_range": [
81
+ 0,
82
+ 1
83
+ ],
84
+ "is_patch_data": true,
85
+ "channel_def": {
86
+ "0": "background",
87
+ "1": "3rd-Ventricle",
88
+ "2": "4th-Ventricle",
89
+ "3": "Right-Accumbens-Area",
90
+ "4": "Left-Accumbens-Area",
91
+ "5": "Right-Amygdala",
92
+ "6": "Left-Amygdala",
93
+ "7": "Brain-Stem",
94
+ "8": "Right-Caudate",
95
+ "9": "Left-Caudate",
96
+ "10": "Right-Cerebellum-Exterior",
97
+ "11": "Left-Cerebellum-Exterior",
98
+ "12": "Right-Cerebellum-White-Matter",
99
+ "13": "Left-Cerebellum-White-Matter",
100
+ "14": "Right-Cerebral-White-Matter",
101
+ "15": "Left-Cerebral-White-Matter",
102
+ "16": "Right-Hippocampus",
103
+ "17": "Left-Hippocampus",
104
+ "18": "Right-Inf-Lat-Vent",
105
+ "19": "Left-Inf-Lat-Vent",
106
+ "20": "Right-Lateral-Ventricle",
107
+ "21": "Left-Lateral-Ventricle",
108
+ "22": "Right-Pallidum",
109
+ "23": "Left-Pallidum",
110
+ "24": "Right-Putamen",
111
+ "25": "Left-Putamen",
112
+ "26": "Right-Thalamus-Proper",
113
+ "27": "Left-Thalamus-Proper",
114
+ "28": "Right-Ventral-DC",
115
+ "29": "Left-Ventral-DC",
116
+ "30": "Cerebellar-Vermal-Lobules-I-V",
117
+ "31": "Cerebellar-Vermal-Lobules-VI-VII",
118
+ "32": "Cerebellar-Vermal-Lobules-VIII-X",
119
+ "33": "Left-Basal-Forebrain",
120
+ "34": "Right-Basal-Forebrain",
121
+ "35": "Right-ACgG--anterior-cingulate-gyrus",
122
+ "36": "Left-ACgG--anterior-cingulate-gyrus",
123
+ "37": "Right-AIns--anterior-insula",
124
+ "38": "Left-AIns--anterior-insula",
125
+ "39": "Right-AOrG--anterior-orbital-gyrus",
126
+ "40": "Left-AOrG--anterior-orbital-gyrus",
127
+ "41": "Right-AnG---angular-gyrus",
128
+ "42": "Left-AnG---angular-gyrus",
129
+ "43": "Right-Calc--calcarine-cortex",
130
+ "44": "Left-Calc--calcarine-cortex",
131
+ "45": "Right-CO----central-operculum",
132
+ "46": "Left-CO----central-operculum",
133
+ "47": "Right-Cun---cuneus",
134
+ "48": "Left-Cun---cuneus",
135
+ "49": "Right-Ent---entorhinal-area",
136
+ "50": "Left-Ent---entorhinal-area",
137
+ "51": "Right-FO----frontal-operculum",
138
+ "52": "Left-FO----frontal-operculum",
139
+ "53": "Right-FRP---frontal-pole",
140
+ "54": "Left-FRP---frontal-pole",
141
+ "55": "Right-FuG---fusiform-gyrus ",
142
+ "56": "Left-FuG---fusiform-gyrus",
143
+ "57": "Right-GRe---gyrus-rectus",
144
+ "58": "Left-GRe---gyrus-rectus",
145
+ "59": "Right-IOG---inferior-occipital-gyrus",
146
+ "60": "Left-IOG---inferior-occipital-gyrus",
147
+ "61": "Right-ITG---inferior-temporal-gyrus",
148
+ "62": "Left-ITG---inferior-temporal-gyrus",
149
+ "63": "Right-LiG---lingual-gyrus",
150
+ "64": "Left-LiG---lingual-gyrus",
151
+ "65": "Right-LOrG--lateral-orbital-gyrus",
152
+ "66": "Left-LOrG--lateral-orbital-gyrus",
153
+ "67": "Right-MCgG--middle-cingulate-gyrus",
154
+ "68": "Left-MCgG--middle-cingulate-gyrus",
155
+ "69": "Right-MFC---medial-frontal-cortex",
156
+ "70": "Left-MFC---medial-frontal-cortex",
157
+ "71": "Right-MFG---middle-frontal-gyrus",
158
+ "72": "Left-MFG---middle-frontal-gyrus",
159
+ "73": "Right-MOG---middle-occipital-gyrus",
160
+ "74": "Left-MOG---middle-occipital-gyrus",
161
+ "75": "Right-MOrG--medial-orbital-gyrus",
162
+ "76": "Left-MOrG--medial-orbital-gyrus",
163
+ "77": "Right-MPoG--postcentral-gyrus",
164
+ "78": "Left-MPoG--postcentral-gyrus",
165
+ "79": "Right-MPrG--precentral-gyrus",
166
+ "80": "Left-MPrG--precentral-gyrus",
167
+ "81": "Right-MSFG--superior-frontal-gyrus",
168
+ "82": "Left-MSFG--superior-frontal-gyrus",
169
+ "83": "Right-MTG---middle-temporal-gyrus",
170
+ "84": "Left-MTG---middle-temporal-gyrus",
171
+ "85": "Right-OCP---occipital-pole",
172
+ "86": "Left-OCP---occipital-pole",
173
+ "87": "Right-OFuG--occipital-fusiform-gyrus",
174
+ "88": "Left-OFuG--occipital-fusiform-gyrus",
175
+ "89": "Right-OpIFG-opercular-part-of-the-IFG",
176
+ "90": "Left-OpIFG-opercular-part-of-the-IFG",
177
+ "91": "Right-OrIFG-orbital-part-of-the-IFG",
178
+ "92": "Left-OrIFG-orbital-part-of-the-IFG",
179
+ "93": "Right-PCgG--posterior-cingulate-gyrus",
180
+ "94": "Left-PCgG--posterior-cingulate-gyrus",
181
+ "95": "Right-PCu---precuneus",
182
+ "96": "Left-PCu---precuneus",
183
+ "97": "Right-PHG---parahippocampal-gyrus",
184
+ "98": "Left-PHG---parahippocampal-gyrus",
185
+ "99": "Right-PIns--posterior-insula",
186
+ "100": "Left-PIns--posterior-insula",
187
+ "101": "Right-PO----parietal-operculum",
188
+ "102": "Left-PO----parietal-operculum",
189
+ "103": "Right-PoG---postcentral-gyrus",
190
+ "104": "Left-PoG---postcentral-gyrus",
191
+ "105": "Right-POrG--posterior-orbital-gyrus",
192
+ "106": "Left-POrG--posterior-orbital-gyrus",
193
+ "107": "Right-PP----planum-polare",
194
+ "108": "Left-PP----planum-polare",
195
+ "109": "Right-PrG---precentral-gyrus",
196
+ "110": "Left-PrG---precentral-gyrus",
197
+ "111": "Right-PT----planum-temporale",
198
+ "112": "Left-PT----planum-temporale",
199
+ "113": "Right-SCA---subcallosal-area",
200
+ "114": "Left-SCA---subcallosal-area",
201
+ "115": "Right-SFG---superior-frontal-gyrus",
202
+ "116": "Left-SFG---superior-frontal-gyrus",
203
+ "117": "Right-SMC---supplementary-motor-cortex",
204
+ "118": "Left-SMC---supplementary-motor-cortex",
205
+ "119": "Right-SMG---supramarginal-gyrus",
206
+ "120": "Left-SMG---supramarginal-gyrus",
207
+ "121": "Right-SOG---superior-occipital-gyrus",
208
+ "122": "Left-SOG---superior-occipital-gyrus",
209
+ "123": "Right-SPL---superior-parietal-lobule",
210
+ "124": "Left-SPL---superior-parietal-lobule",
211
+ "125": "Right-STG---superior-temporal-gyrus",
212
+ "126": "Left-STG---superior-temporal-gyrus",
213
+ "127": "Right-TMP---temporal-pole",
214
+ "128": "Left-TMP---temporal-pole",
215
+ "129": "Right-TrIFG-triangular-part-of-the-IFG",
216
+ "130": "Left-TrIFG-triangular-part-of-the-IFG",
217
+ "131": "Right-TTG---transverse-temporal-gyrus",
218
+ "132": "Left-TTG---transverse-temporal-gyrus"
219
+ }
220
+ }
221
+ }
222
+ }
223
+ }
configs/multi_gpu_train.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "$torch.device(f'cuda:{dist.get_rank()}')",
3
+ "network": {
4
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
5
+ "module": "$@network_def.to(@device)",
6
+ "device_ids": [
7
+ "@device"
8
+ ]
9
+ },
10
+ "train#sampler": {
11
+ "_target_": "DistributedSampler",
12
+ "dataset": "@train#dataset",
13
+ "even_divisible": true,
14
+ "shuffle": true
15
+ },
16
+ "train#dataloader#sampler": "@train#sampler",
17
+ "train#dataloader#shuffle": false,
18
+ "train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]",
19
+ "validate#sampler": {
20
+ "_target_": "DistributedSampler",
21
+ "dataset": "@validate#dataset",
22
+ "even_divisible": false,
23
+ "shuffle": false
24
+ },
25
+ "validate#dataloader#sampler": "@validate#sampler",
26
+ "validate#evaluator#val_handlers": "$None if dist.get_rank() > 0 else @validate#handlers",
27
+ "training": [
28
+ "$import torch.distributed as dist",
29
+ "$dist.init_process_group(backend='nccl')",
30
+ "$torch.cuda.set_device(@device)",
31
+ "$monai.utils.set_determinism(seed=123)",
32
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
33
+ "$@train#trainer.run()",
34
+ "$dist.destroy_process_group()"
35
+ ]
36
+ }
configs/train.json ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os",
5
+ "$import ignite"
6
+ ],
7
+ "bundle_root": ".",
8
+ "ckpt_dir": "$@bundle_root + '/models'",
9
+ "output_dir": "$@bundle_root + '/eval'",
10
+ "dataset_dir": "$@bundle_root + '/dataset/brain'",
11
+ "images": "$list(sorted(glob.glob(@dataset_dir + '/images/*.nii.gz')))",
12
+ "labels": "$list(sorted(glob.glob(@dataset_dir + '/labels/*.nii.gz')))",
13
+ "val_interval": 5,
14
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
15
+ "network_def": {
16
+ "_target_": "scripts.networks.unest_base_patch_4.UNesT",
17
+ "in_channels": 1,
18
+ "out_channels": 133,
19
+ "patch_size": 4,
20
+ "depths": [
21
+ 2,
22
+ 2,
23
+ 8
24
+ ],
25
+ "embed_dim": [
26
+ 128,
27
+ 256,
28
+ 512
29
+ ],
30
+ "num_heads": [
31
+ 4,
32
+ 8,
33
+ 16
34
+ ]
35
+ },
36
+ "network": "$@network_def.to(@device)",
37
+ "loss": {
38
+ "_target_": "DiceCELoss",
39
+ "to_onehot_y": true,
40
+ "softmax": true,
41
+ "squared_pred": true,
42
+ "batch": true
43
+ },
44
+ "optimizer": {
45
+ "_target_": "torch.optim.Adam",
46
+ "params": "[email protected]()",
47
+ "lr": 0.0001
48
+ },
49
+ "train": {
50
+ "deterministic_transforms": [
51
+ {
52
+ "_target_": "LoadImaged",
53
+ "keys": [
54
+ "image",
55
+ "label"
56
+ ]
57
+ },
58
+ {
59
+ "_target_": "EnsureChannelFirstd",
60
+ "keys": [
61
+ "image",
62
+ "label"
63
+ ]
64
+ },
65
+ {
66
+ "_target_": "EnsureTyped",
67
+ "keys": [
68
+ "image",
69
+ "label"
70
+ ]
71
+ }
72
+ ],
73
+ "random_transforms": [
74
+ {
75
+ "_target_": "RandSpatialCropd",
76
+ "keys": [
77
+ "image",
78
+ "label"
79
+ ],
80
+ "roi_size": [
81
+ 96,
82
+ 96,
83
+ 96
84
+ ],
85
+ "random_size": false
86
+ },
87
+ {
88
+ "_target_": "RandFlipd",
89
+ "keys": [
90
+ "image",
91
+ "label"
92
+ ],
93
+ "spatial_axis": [
94
+ 0
95
+ ],
96
+ "prob": 0.1
97
+ },
98
+ {
99
+ "_target_": "RandFlipd",
100
+ "keys": [
101
+ "image",
102
+ "label"
103
+ ],
104
+ "spatial_axis": [
105
+ 1
106
+ ],
107
+ "prob": 0.1
108
+ },
109
+ {
110
+ "_target_": "RandFlipd",
111
+ "keys": [
112
+ "image",
113
+ "label"
114
+ ],
115
+ "spatial_axis": [
116
+ 2
117
+ ],
118
+ "prob": 0.1
119
+ },
120
+ {
121
+ "_target_": "RandRotate90d",
122
+ "keys": [
123
+ "image",
124
+ "label"
125
+ ],
126
+ "max_k": 3,
127
+ "prob": 0.1
128
+ },
129
+ {
130
+ "_target_": "NormalizeIntensityd",
131
+ "keys": "image",
132
+ "nonzero": true,
133
+ "channel_wise": true
134
+ }
135
+ ],
136
+ "preprocessing": {
137
+ "_target_": "Compose",
138
+ "transforms": "$@train#deterministic_transforms + @train#random_transforms"
139
+ },
140
+ "dataset": {
141
+ "_target_": "CacheDataset",
142
+ "data": "$[{'image': i, 'label': l} for i, l in zip(@images[:-2], @labels[:-2])]",
143
+ "transform": "@train#preprocessing",
144
+ "cache_rate": 1.0,
145
+ "num_workers": 2
146
+ },
147
+ "dataloader": {
148
+ "_target_": "DataLoader",
149
+ "dataset": "@train#dataset",
150
+ "batch_size": 1,
151
+ "shuffle": true,
152
+ "num_workers": 1
153
+ },
154
+ "inferer": {
155
+ "_target_": "SimpleInferer"
156
+ },
157
+ "postprocessing": {
158
+ "_target_": "Compose",
159
+ "transforms": [
160
+ {
161
+ "_target_": "Activationsd",
162
+ "keys": "pred",
163
+ "softmax": true
164
+ },
165
+ {
166
+ "_target_": "AsDiscreted",
167
+ "keys": [
168
+ "pred",
169
+ "label"
170
+ ],
171
+ "argmax": [
172
+ true,
173
+ false
174
+ ],
175
+ "to_onehot": 133
176
+ }
177
+ ]
178
+ },
179
+ "handlers": [
180
+ {
181
+ "_target_": "ValidationHandler",
182
+ "validator": "@validate#evaluator",
183
+ "epoch_level": true,
184
+ "interval": "@val_interval"
185
+ },
186
+ {
187
+ "_target_": "StatsHandler",
188
+ "tag_name": "train_loss",
189
+ "output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
190
+ },
191
+ {
192
+ "_target_": "TensorBoardStatsHandler",
193
+ "log_dir": "@output_dir",
194
+ "tag_name": "train_loss",
195
+ "output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
196
+ }
197
+ ],
198
+ "key_metric": {
199
+ "train_accuracy": {
200
+ "_target_": "ignite.metrics.Accuracy",
201
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
202
+ }
203
+ },
204
+ "trainer": {
205
+ "_target_": "SupervisedTrainer",
206
+ "max_epochs": 2000,
207
+ "device": "@device",
208
+ "train_data_loader": "@train#dataloader",
209
+ "network": "@network",
210
+ "loss_function": "@loss",
211
+ "optimizer": "@optimizer",
212
+ "inferer": "@train#inferer",
213
+ "postprocessing": "@train#postprocessing",
214
+ "key_train_metric": "@train#key_metric",
215
+ "train_handlers": "@train#handlers",
216
+ "amp": true
217
+ }
218
+ },
219
+ "validate": {
220
+ "preprocessing": {
221
+ "_target_": "Compose",
222
+ "transforms": "%train#deterministic_transforms"
223
+ },
224
+ "dataset": {
225
+ "_target_": "CacheDataset",
226
+ "data": "$[{'image': i, 'label': l} for i, l in zip(@images[-2:], @labels[-2:])]",
227
+ "transform": "@validate#preprocessing",
228
+ "cache_rate": 1.0
229
+ },
230
+ "dataloader": {
231
+ "_target_": "DataLoader",
232
+ "dataset": "@validate#dataset",
233
+ "batch_size": 2,
234
+ "shuffle": false,
235
+ "num_workers": 1
236
+ },
237
+ "inferer": {
238
+ "_target_": "SlidingWindowInferer",
239
+ "roi_size": [
240
+ 96,
241
+ 96,
242
+ 96
243
+ ],
244
+ "sw_batch_size": 4,
245
+ "overlap": 0.5
246
+ },
247
+ "postprocessing": "%train#postprocessing",
248
+ "handlers": [
249
+ {
250
+ "_target_": "StatsHandler",
251
+ "iteration_log": false
252
+ },
253
+ {
254
+ "_target_": "TensorBoardStatsHandler",
255
+ "log_dir": "@output_dir",
256
+ "iteration_log": false
257
+ },
258
+ {
259
+ "_target_": "CheckpointSaver",
260
+ "save_dir": "@ckpt_dir",
261
+ "save_dict": {
262
+ "model": "@network"
263
+ },
264
+ "save_key_metric": true,
265
+ "key_metric_filename": "model.pt"
266
+ }
267
+ ],
268
+ "key_metric": {
269
+ "val_mean_dice": {
270
+ "_target_": "MeanDice",
271
+ "include_background": false,
272
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
273
+ }
274
+ },
275
+ "additional_metrics": {
276
+ "val_accuracy": {
277
+ "_target_": "ignite.metrics.Accuracy",
278
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
279
+ }
280
+ },
281
+ "evaluator": {
282
+ "_target_": "SupervisedEvaluator",
283
+ "device": "@device",
284
+ "val_data_loader": "@validate#dataloader",
285
+ "network": "@network",
286
+ "inferer": "@validate#inferer",
287
+ "postprocessing": "@validate#postprocessing",
288
+ "key_val_metric": "@validate#key_metric",
289
+ "additional_metrics": "@validate#additional_metrics",
290
+ "val_handlers": "@validate#handlers",
291
+ "amp": true
292
+ }
293
+ },
294
+ "training": [
295
+ "$monai.utils.set_determinism(seed=123)",
296
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
297
+ "$@train#trainer.run()"
298
+ ]
299
+ }
docs/3DSlicer_use.png ADDED

Git LFS Details

  • SHA256: d6fe205d20ef8895b8ac2420e5c7682091ce11341027846e93b735c97ceba6b2
  • Pointer size: 131 Bytes
  • Size of remote file: 609 kB
docs/README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description
2
+ Detailed whole brain segmentation is an essential quantitative technique in medical image analysis, which provides a non-invasive way of measuring brain regions from a clinical acquired structural magnetic resonance imaging (MRI).
3
+ We provide the pre-trained model for training and inferencing whole brain segmentation with 133 structures.
4
+ Training pipeline is provided to support active learning in MONAI Label and training with bundle.
5
+
6
+ A tutorial and release of model for whole brain segmentation using the 3D transformer-based segmentation model UNEST.
7
+
8
+ Authors:
9
+ Xin Yu ([email protected])
10
+
11
+ Yinchi Zhou ([email protected]) | Yucheng Tang ([email protected])
12
+
13
+ <p align="center">
14
+ -------------------------------------------------------------------------------------
15
+ </p>
16
+
17
+ ![](./demo.png) <br>
18
+ <p align="center">
19
+ Fig.1 - The demonstration of T1w MRI images registered in MNI space and the whole brain segmentation labels with 133 classes</p>
20
+
21
+
22
+
23
+ # Model Overview
24
+ A pre-trained UNEST base model [1] for volumetric (3D) whole brain segmentation with T1w MR images.
25
+ To leverage information across embedded sequences, ”shifted window” transformers
26
+ are proposed for dense predictions and modeling multi-scale features. However, these
27
+ attempts that aim to complicate the self-attention range often yield high computation
28
+ complexity and data inefficiency. Inspired by the aggregation function in the nested
29
+ ViT, we propose a new design of a 3D U-shape medical segmentation model with
30
+ Nested Transformers (UNesT) hierarchically with the 3D block aggregation function,
31
+ that learn locality behaviors for small structures or small dataset. This design retains
32
+ the original global self-attention mechanism and achieves information communication
33
+ across patches by stacking transformer encoders hierarchically.
34
+
35
+ ![](./unest.png) <br>
36
+ <p align="center">
37
+ Fig.2 - The network architecture of UNEST Base model
38
+ </p>
39
+
40
+
41
+ ## Data
42
+ The training data is from the Vanderbilt University and Vanderbilt University Medical Center with public released OASIS and CANDI datsets.
43
+ Training and testing data are MRI T1-weighted (T1w) 3D volumes coming from 3 different sites. There are a total of 133 classes in the whole brain segmentation task.
44
+ Among 50 T1w MRI scans from Open Access Series on Imaging Studies (OASIS) (Marcus et al., 2007) dataset, 45 scans are used for training and the other 5 for validation.
45
+ The testing cohort contains Colin27 T1w scan (Aubert-Broche et al., 2006) and 13 T1w MRI scans from the Child and Adolescent Neuro Development Initiative (CANDI)
46
+ (Kennedy et al., 2012). All data are registered to the MNI space using the MNI305 (Evans et al., 1993) template and preprocessed follow the method in (Huo et al., 2019). Input images are randomly cropped to the size of 96 × 96 × 96.
47
+
48
+ ### Important
49
+
50
+ The brain MRI images for training are registered to Affine registration from the target image to the MNI305 template using NiftyReg.
51
+ The data should be in the MNI305 space before inference.
52
+
53
+ If your images are already in MNI space, skip the registration step.
54
+
55
+ You could use any resitration tool to register image to MNI space. Here is an example using ants.
56
+ Registration to MNI Space: Sample suggestion. E.g., use ANTS or other tools for registering T1 MRI image to MNI305 Space.
57
+
58
+ ```
59
+ pip install antspyx
60
+
61
+ #Sample ANTS registration
62
+
63
+ import ants
64
+ import sys
65
+ import os
66
+
67
+ fixed_image = ants.image_read('<fixed_image_path>')
68
+ moving_image = ants.image_read('<moving_image_path>')
69
+ transform = ants.registration(fixed_image,moving_image,'Affine')
70
+
71
+ reg3t = ants.apply_transforms(fixed_image,moving_image,transform['fwdtransforms'][0])
72
+ ants.image_write(reg3t,output_image_path)
73
+ ```
74
+
75
+ ## Training configuration
76
+ The training and inference was performed with at least one 24GB-memory GPU.
77
+
78
+ Actual Model Input: 96 x 96 x 96
79
+
80
+ ## Input and output formats
81
+ Input: 1 channel T1w MRI image in MNI305 Space.
82
+
83
+
84
+ ## commands example
85
+ Download trained checkpoint model to ./model/model.pt:
86
+
87
+
88
+ Add scripts component: To run the workflow with customized components, PYTHONPATH should be revised to include the path to the customized component:
89
+
90
+ ```
91
+ export PYTHONPATH=$PYTHONPATH: '<path to the bundle root dir>/'
92
+ ```
93
+
94
+ Execute Training:
95
+
96
+ ```
97
+ python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf
98
+ ```
99
+
100
+ Execute inference:
101
+
102
+ ```
103
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
104
+ ```
105
+
106
+
107
+ ## More examples output
108
+ ![](./wholebrain.png) <br>
109
+ <p align="center">
110
+ Fig.3 - The output prediction comparison with variant and ground truth
111
+ </p>
112
+
113
+ ## Training/Validation Benchmarking
114
+ A graph showing the training accuracy for fine-tuning 600 epochs.
115
+
116
+ ![](./training.png) <br>
117
+
118
+ With 10 fine-tuned labels, the training process converges fast.
119
+
120
+ ## Complete ROI of the whole brain segmentation
121
+ 133 brain structures are segmented.
122
+
123
+ | #1 | #2 | #3 | #4 |
124
+ | :------------ | :---------- | :-------- | :-------- |
125
+ | 0: background | 1 : 3rd-Ventricle | 2 : 4th-Ventricle | 3 : Right-Accumbens-Area |
126
+ | 4 : Left-Accumbens-Area | 5 : Right-Amygdala | 6 : Left-Amygdala | 7 : Brain-Stem |
127
+ | 8 : Right-Caudate | 9 : Left-Caudate | 10 : Right-Cerebellum-Exterior | 11 : Left-Cerebellum-Exterior |
128
+ | 12 : Right-Cerebellum-White-Matter | 13 : Left-Cerebellum-White-Matter | 14 : Right-Cerebral-White-Matter | 15 : Left-Cerebral-White-Matter |
129
+ | 16 : Right-Hippocampus | 17 : Left-Hippocampus | 18 : Right-Inf-Lat-Vent | 19 : Left-Inf-Lat-Vent |
130
+ | 20 : Right-Lateral-Ventricle | 21 : Left-Lateral-Ventricle | 22 : Right-Pallidum | 23 : Left-Pallidum |
131
+ | 24 : Right-Putamen | 25 : Left-Putamen | 26 : Right-Thalamus-Proper | 27 : Left-Thalamus-Proper |
132
+ | 28 : Right-Ventral-DC | 29 : Left-Ventral-DC | 30 : Cerebellar-Vermal-Lobules-I-V | 31 : Cerebellar-Vermal-Lobules-VI-VII |
133
+ | 32 : Cerebellar-Vermal-Lobules-VIII-X | 33 : Left-Basal-Forebrain | 34 : Right-Basal-Forebrain | 35 : Right-ACgG--anterior-cingulate-gyrus |
134
+ | 36 : Left-ACgG--anterior-cingulate-gyrus | 37 : Right-AIns--anterior-insula | 38 : Left-AIns--anterior-insula | 39 : Right-AOrG--anterior-orbital-gyrus |
135
+ | 40 : Left-AOrG--anterior-orbital-gyrus | 41 : Right-AnG---angular-gyrus | 42 : Left-AnG---angular-gyrus | 43 : Right-Calc--calcarine-cortex |
136
+ | 44 : Left-Calc--calcarine-cortex | 45 : Right-CO----central-operculum | 46 : Left-CO----central-operculum | 47 : Right-Cun---cuneus |
137
+ | 48 : Left-Cun---cuneus | 49 : Right-Ent---entorhinal-area | 50 : Left-Ent---entorhinal-area | 51 : Right-FO----frontal-operculum |
138
+ | 52 : Left-FO----frontal-operculum | 53 : Right-FRP---frontal-pole | 54 : Left-FRP---frontal-pole | 55 : Right-FuG---fusiform-gyrus |
139
+ | 56 : Left-FuG---fusiform-gyrus | 57 : Right-GRe---gyrus-rectus | 58 : Left-GRe---gyrus-rectus | 59 : Right-IOG---inferior-occipital-gyrus ,
140
+ | 60 : Left-IOG---inferior-occipital-gyrus | 61 : Right-ITG---inferior-temporal-gyrus | 62 : Left-ITG---inferior-temporal-gyrus | 63 : Right-LiG---lingual-gyrus |
141
+ | 64 : Left-LiG---lingual-gyrus | 65 : Right-LOrG--lateral-orbital-gyrus | 66 : Left-LOrG--lateral-orbital-gyrus | 67 : Right-MCgG--middle-cingulate-gyrus |
142
+ | 68 : Left-MCgG--middle-cingulate-gyrus | 69 : Right-MFC---medial-frontal-cortex | 70 : Left-MFC---medial-frontal-cortex | 71 : Right-MFG---middle-frontal-gyrus |
143
+ | 72 : Left-MFG---middle-frontal-gyrus | 73 : Right-MOG---middle-occipital-gyrus | 74 : Left-MOG---middle-occipital-gyrus | 75 : Right-MOrG--medial-orbital-gyrus |
144
+ | 76 : Left-MOrG--medial-orbital-gyrus | 77 : Right-MPoG--postcentral-gyrus | 78 : Left-MPoG--postcentral-gyrus | 79 : Right-MPrG--precentral-gyrus |
145
+ | 80 : Left-MPrG--precentral-gyrus | 81 : Right-MSFG--superior-frontal-gyrus | 82 : Left-MSFG--superior-frontal-gyrus | 83 : Right-MTG---middle-temporal-gyrus |
146
+ | 84 : Left-MTG---middle-temporal-gyrus | 85 : Right-OCP---occipital-pole | 86 : Left-OCP---occipital-pole | 87 : Right-OFuG--occipital-fusiform-gyrus |
147
+ | 88 : Left-OFuG--occipital-fusiform-gyrus | 89 : Right-OpIFG-opercular-part-of-the-IFG | 90 : Left-OpIFG-opercular-part-of-the-IFG | 91 : Right-OrIFG-orbital-part-of-the-IFG |
148
+ | 92 : Left-OrIFG-orbital-part-of-the-IFG | 93 : Right-PCgG--posterior-cingulate-gyrus | 94 : Left-PCgG--posterior-cingulate-gyrus | 95 : Right-PCu---precuneus |
149
+ | 96 : Left-PCu---precuneus | 97 : Right-PHG---parahippocampal-gyrus | 98 : Left-PHG---parahippocampal-gyrus | 99 : Right-PIns--posterior-insula |
150
+ | 100 : Left-PIns--posterior-insula | 101 : Right-PO----parietal-operculum | 102 : Left-PO----parietal-operculum | 103 : Right-PoG---postcentral-gyrus |
151
+ | 104 : Left-PoG---postcentral-gyrus | 105 : Right-POrG--posterior-orbital-gyrus | 106 : Left-POrG--posterior-orbital-gyrus | 107 : Right-PP----planum-polare |
152
+ | 108 : Left-PP----planum-polare | 109 : Right-PrG---precentral-gyrus | 110 : Left-PrG---precentral-gyrus | 111 : Right-PT----planum-temporale |
153
+ | 112 : Left-PT----planum-temporale | 113 : Right-SCA---subcallosal-area | 114 : Left-SCA---subcallosal-area | 115 : Right-SFG---superior-frontal-gyrus |
154
+ | 116 : Left-SFG---superior-frontal-gyrus | 117 : Right-SMC---supplementary-motor-cortex | 118 : Left-SMC---supplementary-motor-cortex | 119 : Right-SMG---supramarginal-gyrus |
155
+ | 120 : Left-SMG---supramarginal-gyrus | 121 : Right-SOG---superior-occipital-gyrus | 122 : Left-SOG---superior-occipital-gyrus | 123 : Right-SPL---superior-parietal-lobule |
156
+ | 124 : Left-SPL---superior-parietal-lobule | 125 : Right-STG---superior-temporal-gyrus | 126 : Left-STG---superior-temporal-gyrus | 127 : Right-TMP---temporal-pole |
157
+ | 128 : Left-TMP---temporal-pole | 129 : Right-TrIFG-triangular-part-of-the-IFG | 130 : Left-TrIFG-triangular-part-of-the-IFG | 131 : Right-TTG---transverse-temporal-gyrus |
158
+ | 132 : Left-TTG---transverse-temporal-gyrus |
159
+
160
+
161
+ ## Bundle Integration in MONAI Lable
162
+ The inference and training pipleine can be easily used by the MONAI Label server and 3D Slicer for fast labeling T1w MRI images in MNI space.
163
+
164
+ ![](./3DSlicer_use.png) <br>
165
+
166
+ # Disclaimer
167
+ This is an example, not to be used for diagnostic purposes.
168
+
169
+ # References
170
+ [1] Yu, Xin, Yinchi Zhou, Yucheng Tang et al. Characterizing Renal Structures with 3D Block Aggregate Transformers. arXiv preprint arXiv:2203.02430 (2022). https://arxiv.org/pdf/2203.02430.pdf
171
+
172
+ [2] Zizhao Zhang et al. Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding. AAAI Conference on Artificial Intelligence (AAAI) 2022
173
+
174
+ [3] Huo, Yuankai, et al. 3D whole brain segmentation using spatially localized atlas network tiles. NeuroImage 194 (2019): 105-119.
175
+
176
+ # License
177
+ Copyright (c) MONAI Consortium
178
+
179
+ Licensed under the Apache License, Version 2.0 (the "License");
180
+ you may not use this file except in compliance with the License.
181
+ You may obtain a copy of the License at
182
+
183
+ http://www.apache.org/licenses/LICENSE-2.0
184
+
185
+ Unless required by applicable law or agreed to in writing, software
186
+ distributed under the License is distributed on an "AS IS" BASIS,
187
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
188
+ See the License for the specific language governing permissions and
189
+ limitations under the License.
docs/demo.png ADDED

Git LFS Details

  • SHA256: 92aae2d9b2901de18b445d6e6efdf48b6c3d8bb5e66ee55c3fde152e13f952f7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
docs/training.png ADDED
docs/unest.png ADDED
docs/wholebrain.png ADDED

Git LFS Details

  • SHA256: bb2e981296ea8f1ae12ab4e7cda15c3694ea78d151287879bdfd257f1ca7c587
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79a52ccd77bc35d05410f39788a1b063af3eb3b809b42241335c18aed27ec422
3
+ size 348901503
scripts/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/networks/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/networks/nest/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from .utils import (
3
+ Conv3dSame,
4
+ DropPath,
5
+ Linear,
6
+ Mlp,
7
+ _assert,
8
+ conv3d_same,
9
+ create_conv3d,
10
+ create_pool3d,
11
+ get_padding,
12
+ get_same_padding,
13
+ pad_same,
14
+ to_ntuple,
15
+ trunc_normal_,
16
+ )
scripts/networks/nest/utils.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import collections.abc
5
+ import math
6
+ import warnings
7
+ from itertools import repeat
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ from torch import _assert
16
+ except ImportError:
17
+
18
+ def _assert(condition: bool, message: str):
19
+ assert condition, message
20
+
21
+
22
+ def drop_block_2d(
23
+ x,
24
+ drop_prob: float = 0.1,
25
+ block_size: int = 7,
26
+ gamma_scale: float = 1.0,
27
+ with_noise: bool = False,
28
+ inplace: bool = False,
29
+ batchwise: bool = False,
30
+ ):
31
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
32
+
33
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
34
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
35
+ """
36
+ b, c, h, w = x.shape
37
+ total_size = w * h
38
+ clipped_block_size = min(block_size, min(w, h))
39
+ # seed_drop_rate, the gamma parameter
40
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((w - block_size + 1) * (h - block_size + 1))
41
+
42
+ # Forces the block to be inside the feature map.
43
+ w_i, h_i = torch.meshgrid(torch.arange(w).to(x.device), torch.arange(h).to(x.device))
44
+ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < w - (clipped_block_size - 1) // 2)) & (
45
+ (h_i >= clipped_block_size // 2) & (h_i < h - (clipped_block_size - 1) // 2)
46
+ )
47
+ valid_block = torch.reshape(valid_block, (1, 1, h, w)).to(dtype=x.dtype)
48
+
49
+ if batchwise:
50
+ # one mask for whole batch, quite a bit faster
51
+ uniform_noise = torch.rand((1, c, h, w), dtype=x.dtype, device=x.device)
52
+ else:
53
+ uniform_noise = torch.rand_like(x)
54
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
55
+ block_mask = -F.max_pool2d(
56
+ -block_mask, kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2 # block_size,
57
+ )
58
+
59
+ if with_noise:
60
+ normal_noise = torch.randn((1, c, h, w), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
61
+ if inplace:
62
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
63
+ else:
64
+ x = x * block_mask + normal_noise * (1 - block_mask)
65
+ else:
66
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
67
+ if inplace:
68
+ x.mul_(block_mask * normalize_scale)
69
+ else:
70
+ x = x * block_mask * normalize_scale
71
+ return x
72
+
73
+
74
+ def drop_block_fast_2d(
75
+ x: torch.Tensor,
76
+ drop_prob: float = 0.1,
77
+ block_size: int = 7,
78
+ gamma_scale: float = 1.0,
79
+ with_noise: bool = False,
80
+ inplace: bool = False,
81
+ ):
82
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
83
+
84
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
85
+ block mask at edges.
86
+ """
87
+ b, c, h, w = x.shape
88
+ total_size = w * h
89
+ clipped_block_size = min(block_size, min(w, h))
90
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((w - block_size + 1) * (h - block_size + 1))
91
+
92
+ block_mask = torch.empty_like(x).bernoulli_(gamma)
93
+ block_mask = F.max_pool2d(
94
+ block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2
95
+ )
96
+
97
+ if with_noise:
98
+ normal_noise = torch.empty_like(x).normal_()
99
+ if inplace:
100
+ x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
101
+ else:
102
+ x = x * (1.0 - block_mask) + normal_noise * block_mask
103
+ else:
104
+ block_mask = 1 - block_mask
105
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
106
+ if inplace:
107
+ x.mul_(block_mask * normalize_scale)
108
+ else:
109
+ x = x * block_mask * normalize_scale
110
+ return x
111
+
112
+
113
+ class DropBlock2d(nn.Module):
114
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
115
+
116
+ def __init__(
117
+ self, drop_prob=0.1, block_size=7, gamma_scale=1.0, with_noise=False, inplace=False, batchwise=False, fast=True
118
+ ):
119
+ super(DropBlock2d, self).__init__()
120
+ self.drop_prob = drop_prob
121
+ self.gamma_scale = gamma_scale
122
+ self.block_size = block_size
123
+ self.with_noise = with_noise
124
+ self.inplace = inplace
125
+ self.batchwise = batchwise
126
+ self.fast = fast # FIXME finish comparisons of fast vs not
127
+
128
+ def forward(self, x):
129
+ if not self.training or not self.drop_prob:
130
+ return x
131
+ if self.fast:
132
+ return drop_block_fast_2d(
133
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace
134
+ )
135
+ else:
136
+ return drop_block_2d(
137
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise
138
+ )
139
+
140
+
141
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
142
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
143
+
144
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
145
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
146
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
147
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
148
+ 'survival rate' as the argument.
149
+
150
+ """
151
+ if drop_prob == 0.0 or not training:
152
+ return x
153
+ keep_prob = 1 - drop_prob
154
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
155
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
156
+ if keep_prob > 0.0 and scale_by_keep:
157
+ random_tensor.div_(keep_prob)
158
+ return x * random_tensor
159
+
160
+
161
+ class DropPath(nn.Module):
162
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
163
+
164
+ def __init__(self, drop_prob=None, scale_by_keep=True):
165
+ super(DropPath, self).__init__()
166
+ self.drop_prob = drop_prob
167
+ self.scale_by_keep = scale_by_keep
168
+
169
+ def forward(self, x):
170
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
171
+
172
+
173
+ def create_conv3d(in_channels, out_channels, kernel_size, **kwargs):
174
+ """Select a 2d convolution implementation based on arguments
175
+ Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv3d, or CondConv2d.
176
+
177
+ Used extensively by EfficientNet, MobileNetv3 and related networks.
178
+ """
179
+
180
+ depthwise = kwargs.pop("depthwise", False)
181
+ # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
182
+ groups = in_channels if depthwise else kwargs.pop("groups", 1)
183
+
184
+ m = create_conv3d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
185
+ return m
186
+
187
+
188
+ def conv3d_same(
189
+ x,
190
+ weight: torch.Tensor,
191
+ bias: Optional[torch.Tensor] = None,
192
+ stride: Tuple[int, int] = (1, 1, 1),
193
+ padding: Tuple[int, int] = (0, 0, 0),
194
+ dilation: Tuple[int, int] = (1, 1, 1),
195
+ groups: int = 1,
196
+ ):
197
+ x = pad_same(x, weight.shape[-3:], stride, dilation)
198
+ return F.conv3d(x, weight, bias, stride, (0, 0, 0), dilation, groups)
199
+
200
+
201
+ class Conv3dSame(nn.Conv2d):
202
+ """Tensorflow like 'SAME' convolution wrapper for 2D convolutions"""
203
+
204
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
205
+ super(Conv3dSame, self).__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
206
+
207
+ def forward(self, x):
208
+ return conv3d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
209
+
210
+
211
+ def create_conv3d_pad(in_chs, out_chs, kernel_size, **kwargs):
212
+ padding = kwargs.pop("padding", "")
213
+ kwargs.setdefault("bias", False)
214
+ padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
215
+ if is_dynamic:
216
+ return Conv3dSame(in_chs, out_chs, kernel_size, **kwargs)
217
+ else:
218
+ return nn.Conv3d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
219
+
220
+
221
+ # Calculate symmetric padding for a convolution
222
+ def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
223
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
224
+ return padding
225
+
226
+
227
+ # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
228
+ def get_same_padding(x: int, k: int, s: int, d: int):
229
+ return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
230
+
231
+
232
+ # Can SAME padding for given args be done statically?
233
+ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
234
+ return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
235
+
236
+
237
+ # Dynamically pad input x with 'SAME' padding for conv with specified args
238
+ def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1, 1), value: float = 0):
239
+ id, ih, iw = x.size()[-3:]
240
+ pad_d, pad_h, pad_w = (
241
+ get_same_padding(id, k[0], s[0], d[0]),
242
+ get_same_padding(ih, k[1], s[1], d[1]),
243
+ get_same_padding(iw, k[2], s[2], d[2]),
244
+ )
245
+ if pad_d > 0 or pad_h > 0 or pad_w > 0:
246
+ x = F.pad(
247
+ x,
248
+ [pad_d // 2, pad_d - pad_d // 2, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
249
+ value=value,
250
+ )
251
+ return x
252
+
253
+
254
+ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
255
+ dynamic = False
256
+ if isinstance(padding, str):
257
+ # for any string padding, the padding will be calculated for you, one of three ways
258
+ padding = padding.lower()
259
+ if padding == "same":
260
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
261
+ if is_static_pad(kernel_size, **kwargs):
262
+ # static case, no extra overhead
263
+ padding = get_padding(kernel_size, **kwargs)
264
+ else:
265
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
266
+ padding = 0
267
+ dynamic = True
268
+ elif padding == "valid":
269
+ # 'VALID' padding, same as padding=0
270
+ padding = 0
271
+ else:
272
+ # Default to PyTorch style 'same'-ish symmetric padding
273
+ padding = get_padding(kernel_size, **kwargs)
274
+ return padding, dynamic
275
+
276
+
277
+ # From PyTorch internals
278
+ def _ntuple(n):
279
+ def parse(x):
280
+ if isinstance(x, collections.abc.Iterable):
281
+ return x
282
+ return tuple(repeat(x, n))
283
+
284
+ return parse
285
+
286
+
287
+ to_1tuple = _ntuple(1)
288
+ to_2tuple = _ntuple(2)
289
+ to_3tuple = _ntuple(3)
290
+ to_4tuple = _ntuple(4)
291
+ to_ntuple = _ntuple
292
+
293
+
294
+ def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
295
+ min_value = min_value or divisor
296
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
297
+ # Make sure that round down does not go down by more than 10%.
298
+ if new_v < round_limit * v:
299
+ new_v += divisor
300
+ return new_v
301
+
302
+
303
+ class Linear(nn.Linear):
304
+ r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
305
+
306
+ Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
307
+ weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
308
+ """
309
+
310
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
311
+ if torch.jit.is_scripting():
312
+ bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
313
+ return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
314
+ else:
315
+ return F.linear(input, self.weight, self.bias)
316
+
317
+
318
+ class Mlp(nn.Module):
319
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
320
+
321
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
322
+ super().__init__()
323
+ out_features = out_features or in_features
324
+ hidden_features = hidden_features or in_features
325
+ drop_probs = to_2tuple(drop)
326
+
327
+ self.fc1 = nn.Linear(in_features, hidden_features)
328
+ self.act = act_layer()
329
+ self.drop1 = nn.Dropout(drop_probs[0])
330
+ self.fc2 = nn.Linear(hidden_features, out_features)
331
+ self.drop2 = nn.Dropout(drop_probs[1])
332
+
333
+ def forward(self, x):
334
+ x = self.fc1(x)
335
+ x = self.act(x)
336
+ x = self.drop1(x)
337
+ x = self.fc2(x)
338
+ x = self.drop2(x)
339
+ return x
340
+
341
+
342
+ def avg_pool3d_same(
343
+ x,
344
+ kernel_size: List[int],
345
+ stride: List[int],
346
+ padding: List[int] = (0, 0, 0),
347
+ ceil_mode: bool = False,
348
+ count_include_pad: bool = True,
349
+ ):
350
+ # FIXME how to deal with count_include_pad vs not for external padding?
351
+ x = pad_same(x, kernel_size, stride)
352
+ return F.avg_pool3d(x, kernel_size, stride, (0, 0, 0), ceil_mode, count_include_pad)
353
+
354
+
355
+ class AvgPool3dSame(nn.AvgPool2d):
356
+ """Tensorflow like 'SAME' wrapper for 2D average pooling"""
357
+
358
+ def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
359
+ kernel_size = to_2tuple(kernel_size)
360
+ stride = to_2tuple(stride)
361
+ super(AvgPool3dSame, self).__init__(kernel_size, stride, (0, 0, 0), ceil_mode, count_include_pad)
362
+
363
+ def forward(self, x):
364
+ x = pad_same(x, self.kernel_size, self.stride)
365
+ return F.avg_pool3d(x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
366
+
367
+
368
+ def max_pool3d_same(
369
+ x,
370
+ kernel_size: List[int],
371
+ stride: List[int],
372
+ padding: List[int] = (0, 0, 0),
373
+ dilation: List[int] = (1, 1, 1),
374
+ ceil_mode: bool = False,
375
+ ):
376
+ x = pad_same(x, kernel_size, stride, value=-float("inf"))
377
+ return F.max_pool3d(x, kernel_size, stride, (0, 0, 0), dilation, ceil_mode)
378
+
379
+
380
+ class MaxPool3dSame(nn.MaxPool2d):
381
+ """Tensorflow like 'SAME' wrapper for 3D max pooling"""
382
+
383
+ def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
384
+ kernel_size = to_2tuple(kernel_size)
385
+ stride = to_2tuple(stride)
386
+ dilation = to_2tuple(dilation)
387
+ super(MaxPool3dSame, self).__init__(kernel_size, stride, (0, 0, 0), dilation, ceil_mode)
388
+
389
+ def forward(self, x):
390
+ x = pad_same(x, self.kernel_size, self.stride, value=-float("inf"))
391
+ return F.max_pool3d(x, self.kernel_size, self.stride, (0, 0, 0), self.dilation, self.ceil_mode)
392
+
393
+
394
+ def create_pool3d(pool_type, kernel_size, stride=None, **kwargs):
395
+ stride = stride or kernel_size
396
+ padding = kwargs.pop("padding", "")
397
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
398
+ if is_dynamic:
399
+ if pool_type == "avg":
400
+ return AvgPool3dSame(kernel_size, stride=stride, **kwargs)
401
+ elif pool_type == "max":
402
+ return MaxPool3dSame(kernel_size, stride=stride, **kwargs)
403
+ else:
404
+ raise AssertionError()
405
+
406
+ # assert False, f"Unsupported pool type {pool_type}"
407
+ else:
408
+ if pool_type == "avg":
409
+ return nn.AvgPool3d(kernel_size, stride=stride, padding=padding, **kwargs)
410
+ elif pool_type == "max":
411
+ return nn.MaxPool3d(kernel_size, stride=stride, padding=padding, **kwargs)
412
+ else:
413
+ raise AssertionError()
414
+
415
+ # assert False, f"Unsupported pool type {pool_type}"
416
+
417
+
418
+ def _float_to_int(x: float) -> int:
419
+ """
420
+ Symbolic tracing helper to substitute for inbuilt `int`.
421
+ Hint: Inbuilt `int` can't accept an argument of type `Proxy`
422
+ """
423
+ return int(x)
424
+
425
+
426
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
427
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
428
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
429
+ def norm_cdf(x):
430
+ # Computes standard normal cumulative distribution function
431
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
432
+
433
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
434
+ warnings.warn(
435
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
436
+ "The distribution of values may be incorrect.",
437
+ stacklevel=2,
438
+ )
439
+
440
+ with torch.no_grad():
441
+ # Values are generated by using a truncated uniform distribution and
442
+ # then using the inverse CDF for the normal distribution.
443
+ # Get upper and lower cdf values
444
+ l = norm_cdf((a - mean) / std)
445
+ u = norm_cdf((b - mean) / std)
446
+
447
+ # Uniformly fill tensor with values from [l, u], then translate to
448
+ # [2l-1, 2u-1].
449
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
450
+
451
+ # Use inverse cdf transform for normal distribution to get truncated
452
+ # standard normal
453
+ tensor.erfinv_()
454
+
455
+ # Transform to proper mean, std
456
+ tensor.mul_(std * math.sqrt(2.0))
457
+ tensor.add_(mean)
458
+
459
+ # Clamp to ensure it's in the proper range
460
+ tensor.clamp_(min=a, max=b)
461
+ return tensor
462
+
463
+
464
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
465
+ r"""Fills the input Tensor with values drawn from a truncated
466
+ normal distribution. The values are effectively drawn from the
467
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
468
+ with values outside :math:`[a, b]` redrawn until they are within
469
+ the bounds. The method used for generating the random values works
470
+ best when :math:`a \leq \text{mean} \leq b`.
471
+ Args:
472
+ tensor: an n-dimensional `torch.Tensor`
473
+ mean: the mean of the normal distribution
474
+ std: the standard deviation of the normal distribution
475
+ a: the minimum cutoff value
476
+ b: the maximum cutoff value
477
+ Examples:
478
+ >>> w = torch.empty(3, 5)
479
+ >>> nn.init.trunc_normal_(w)
480
+ """
481
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
scripts/networks/nest_transformer_3D.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # =========================================================================
4
+ # Adapted from https://github.com/google-research/nested-transformer.
5
+ # which has the following license...
6
+ # https://github.com/pytorch/vision/blob/main/LICENSE
7
+ #
8
+ # BSD 3-Clause License
9
+
10
+
11
+ # Redistribution and use in source and binary forms, with or without
12
+ # modification, are permitted provided that the following conditions are met:
13
+
14
+ # * Redistributions of source code must retain the above copyright notice, this
15
+ # list of conditions and the following disclaimer.
16
+
17
+ # * Redistributions in binary form must reproduce the above copyright notice,
18
+ # this list of conditions and the following disclaimer in the documentation
19
+ # and/or other materials provided with the distribution.
20
+
21
+ # * Neither the name of the copyright holder nor the names of its
22
+ # contributors may be used to endorse or promote products derived from
23
+ # this software without specific prior written permission.
24
+
25
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+
36
+ """ Nested Transformer (NesT) in PyTorch
37
+ A PyTorch implement of Aggregating Nested Transformers as described in:
38
+ 'Aggregating Nested Transformers'
39
+ - https://arxiv.org/abs/2105.12723
40
+ The official Jax code is released and available at https://github.com/google-research/nested-transformer.
41
+ The weights have been converted with convert/convert_nest_flax.py
42
+ Acknowledgments:
43
+ * The paper authors for sharing their research, code, and model weights
44
+ * Ross Wightman's existing code off which I based this
45
+ Copyright 2021 Alexander Soare
46
+
47
+ """
48
+
49
+ import collections.abc
50
+ import logging
51
+ import math
52
+ from functools import partial
53
+ from typing import Callable, Sequence
54
+
55
+ import torch
56
+ import torch.nn.functional as F
57
+ from torch import nn
58
+
59
+ from .nest import DropPath, Mlp, _assert, create_conv3d, create_pool3d, to_ntuple, trunc_normal_
60
+ from .patchEmbed3D import PatchEmbed3D
61
+
62
+ _logger = logging.getLogger(__name__)
63
+
64
+
65
+ class Attention(nn.Module):
66
+ """
67
+ This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
68
+ an extra "image block" dim
69
+ """
70
+
71
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
72
+ super().__init__()
73
+ self.num_heads = num_heads
74
+ head_dim = dim // num_heads
75
+ self.scale = head_dim**-0.5
76
+
77
+ self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
78
+ self.attn_drop = nn.Dropout(attn_drop)
79
+ self.proj = nn.Linear(dim, dim)
80
+ self.proj_drop = nn.Dropout(proj_drop)
81
+
82
+ def forward(self, x):
83
+ """
84
+ x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
85
+ """
86
+ b, t, n, c = x.shape
87
+ # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
88
+ qkv = self.qkv(x).reshape(b, t, n, 3, self.num_heads, c // self.num_heads).permute(3, 0, 4, 1, 2, 5)
89
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
90
+
91
+ attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
92
+ attn = attn.softmax(dim=-1)
93
+ attn = self.attn_drop(attn)
94
+
95
+ x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(b, t, n, c)
96
+ x = self.proj(x)
97
+ x = self.proj_drop(x)
98
+ return x # (B, T, N, C)
99
+
100
+
101
+ class TransformerLayer(nn.Module):
102
+ """
103
+ This is much like `.vision_transformer.Block` but:
104
+ - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
105
+ - Uses modified Attention layer that handles the "block" dimension
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim,
111
+ num_heads,
112
+ mlp_ratio=4.0,
113
+ qkv_bias=False,
114
+ drop=0.0,
115
+ attn_drop=0.0,
116
+ drop_path=0.0,
117
+ act_layer=nn.GELU,
118
+ norm_layer=nn.LayerNorm,
119
+ ):
120
+ super().__init__()
121
+ self.norm1 = norm_layer(dim)
122
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
123
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
124
+ self.norm2 = norm_layer(dim)
125
+ mlp_hidden_dim = int(dim * mlp_ratio)
126
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
127
+
128
+ def forward(self, x):
129
+ y = self.norm1(x)
130
+ x = x + self.drop_path(self.attn(y))
131
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
132
+ return x
133
+
134
+
135
+ class ConvPool(nn.Module):
136
+ def __init__(self, in_channels, out_channels, norm_layer, pad_type=""):
137
+ super().__init__()
138
+ self.conv = create_conv3d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True)
139
+ self.norm = norm_layer(out_channels)
140
+ self.pool = create_pool3d("max", kernel_size=3, stride=2, padding=pad_type)
141
+
142
+ def forward(self, x):
143
+ """
144
+ x is expected to have shape (B, C, D, H, W)
145
+ """
146
+ _assert(x.shape[-3] % 2 == 0, "BlockAggregation requires even input spatial dims")
147
+ _assert(x.shape[-2] % 2 == 0, "BlockAggregation requires even input spatial dims")
148
+ _assert(x.shape[-1] % 2 == 0, "BlockAggregation requires even input spatial dims")
149
+
150
+ # print('In ConvPool x : {}'.format(x.shape))
151
+ x = self.conv(x)
152
+ # Layer norm done over channel dim only
153
+ x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
154
+ x = self.pool(x)
155
+ return x # (B, C, D//2, H//2, W//2)
156
+
157
+
158
+ def blockify(x, block_size: int):
159
+ """image to blocks
160
+ Args:
161
+ x (Tensor): with shape (B, D, H, W, C)
162
+ block_size (int): edge length of a single square block in units of D, H, W
163
+ """
164
+ b, d, h, w, c = x.shape
165
+ _assert(d % block_size == 0, "`block_size` must divide input depth evenly")
166
+ _assert(h % block_size == 0, "`block_size` must divide input height evenly")
167
+ _assert(w % block_size == 0, "`block_size` must divide input width evenly")
168
+ grid_depth = d // block_size
169
+ grid_height = h // block_size
170
+ grid_width = w // block_size
171
+ x = x.reshape(b, grid_depth, block_size, grid_height, block_size, grid_width, block_size, c)
172
+
173
+ x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape(
174
+ b, grid_depth * grid_height * grid_width, -1, c
175
+ ) # shape [2, 512, 27, 128]
176
+
177
+ return x # (B, T, N, C)
178
+
179
+
180
+ # @register_notrace_function # reason: int receives Proxy
181
+ def deblockify(x, block_size: int):
182
+ """blocks to image
183
+ Args:
184
+ x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
185
+ block_size (int): edge length of a single square block in units of desired D, H, W
186
+ """
187
+ b, t, _, c = x.shape
188
+ grid_size = round(math.pow(t, 1 / 3))
189
+ depth = height = width = grid_size * block_size
190
+ x = x.reshape(b, grid_size, grid_size, grid_size, block_size, block_size, block_size, c)
191
+
192
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(b, depth, height, width, c)
193
+
194
+ return x # (B, D, H, W, C)
195
+
196
+
197
+ class NestLevel(nn.Module):
198
+ """Single hierarchical level of a Nested Transformer"""
199
+
200
+ def __init__(
201
+ self,
202
+ num_blocks,
203
+ block_size,
204
+ seq_length,
205
+ num_heads,
206
+ depth,
207
+ embed_dim,
208
+ prev_embed_dim=None,
209
+ mlp_ratio=4.0,
210
+ qkv_bias=True,
211
+ drop_rate=0.0,
212
+ attn_drop_rate=0.0,
213
+ drop_path_rates: Sequence[int] = (),
214
+ norm_layer=None,
215
+ act_layer=None,
216
+ pad_type="",
217
+ ):
218
+ super().__init__()
219
+ self.block_size = block_size
220
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
221
+
222
+ if prev_embed_dim is not None:
223
+ self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type)
224
+ else:
225
+ self.pool = nn.Identity()
226
+
227
+ # Transformer encoder
228
+ if len(drop_path_rates):
229
+ assert len(drop_path_rates) == depth, "Must provide as many drop path rates as there are transformer layers"
230
+ self.transformer_encoder = nn.Sequential(
231
+ *[
232
+ TransformerLayer(
233
+ dim=embed_dim,
234
+ num_heads=num_heads,
235
+ mlp_ratio=mlp_ratio,
236
+ qkv_bias=qkv_bias,
237
+ drop=drop_rate,
238
+ attn_drop=attn_drop_rate,
239
+ drop_path=drop_path_rates[i],
240
+ norm_layer=norm_layer,
241
+ act_layer=act_layer,
242
+ )
243
+ for i in range(depth)
244
+ ]
245
+ )
246
+
247
+ def forward(self, x):
248
+ """
249
+ expects x as (B, C, D, H, W)
250
+ """
251
+ x = self.pool(x)
252
+ x = x.permute(0, 2, 3, 4, 1) # (B, H', W', C), switch to channels last for transformer
253
+
254
+ x = blockify(x, self.block_size) # (B, T, N, C')
255
+ x = x + self.pos_embed
256
+
257
+ x = self.transformer_encoder(x) # (B, ,T, N, C')
258
+
259
+ x = deblockify(x, self.block_size) # (B, D', H', W', C') [2, 24, 24, 24, 128]
260
+ # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
261
+ return x.permute(0, 4, 1, 2, 3) # (B, C, D', H', W')
262
+
263
+
264
+ class NestTransformer3D(nn.Module):
265
+ """Nested Transformer (NesT)
266
+ A PyTorch impl of : `Aggregating Nested Transformers`
267
+ - https://arxiv.org/abs/2105.12723
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ img_size=96,
273
+ in_chans=1,
274
+ patch_size=2,
275
+ num_levels=3,
276
+ embed_dims=(128, 256, 512),
277
+ num_heads=(4, 8, 16),
278
+ depths=(2, 2, 20),
279
+ num_classes=1000,
280
+ mlp_ratio=4.0,
281
+ qkv_bias=True,
282
+ drop_rate=0.0,
283
+ attn_drop_rate=0.0,
284
+ drop_path_rate=0.5,
285
+ norm_layer=None,
286
+ act_layer=None,
287
+ pad_type="",
288
+ weight_init="",
289
+ global_pool="avg",
290
+ ):
291
+ """
292
+ Args:
293
+ img_size (int, tuple): input image size
294
+ in_chans (int): number of input channels
295
+ patch_size (int): patch size
296
+ num_levels (int): number of block hierarchies (T_d in the paper)
297
+ embed_dims (int, tuple): embedding dimensions of each level
298
+ num_heads (int, tuple): number of attention heads for each level
299
+ depths (int, tuple): number of transformer layers for each level
300
+ num_classes (int): number of classes for classification head
301
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim for MLP of transformer layers
302
+ qkv_bias (bool): enable bias for qkv if True
303
+ drop_rate (float): dropout rate for MLP of transformer layers, MSA final projection layer, and classifier
304
+ attn_drop_rate (float): attention dropout rate
305
+ drop_path_rate (float): stochastic depth rate
306
+ norm_layer: (nn.Module): normalization layer for transformer layers
307
+ act_layer: (nn.Module): activation layer in MLP of transformer layers
308
+ pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
309
+ weight_init: (str): weight init scheme
310
+ global_pool: (str): type of pooling operation to apply to final feature map
311
+ Notes:
312
+ - Default values follow NesT-B from the original Jax code.
313
+ - `embed_dims`, `num_heads`, `depths` should be ints or tuples with length `num_levels`.
314
+ - For those following the paper, Table A1 may have errors!
315
+ - https://github.com/google-research/nested-transformer/issues/2
316
+ """
317
+ super().__init__()
318
+
319
+ for param_name in ["embed_dims", "num_heads", "depths"]:
320
+ param_value = locals()[param_name]
321
+ if isinstance(param_value, collections.abc.Sequence):
322
+ assert len(param_value) == num_levels, f"Require `len({param_name}) == num_levels`"
323
+
324
+ embed_dims = to_ntuple(num_levels)(embed_dims)
325
+ num_heads = to_ntuple(num_levels)(num_heads)
326
+ depths = to_ntuple(num_levels)(depths)
327
+ self.num_classes = num_classes
328
+ self.num_features = embed_dims[-1]
329
+ self.feature_info = []
330
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
331
+ act_layer = act_layer or nn.GELU
332
+ self.drop_rate = drop_rate
333
+ self.num_levels = num_levels
334
+ if isinstance(img_size, collections.abc.Sequence):
335
+ assert img_size[0] == img_size[1], "Model only handles square inputs"
336
+ img_size = img_size[0]
337
+ assert img_size % patch_size == 0, "`patch_size` must divide `img_size` evenly"
338
+ self.patch_size = patch_size
339
+
340
+ # Number of blocks at each level
341
+ self.num_blocks = (8 ** torch.arange(num_levels)).flip(0).tolist()
342
+ assert (img_size // patch_size) % round(
343
+ math.pow(self.num_blocks[0], 1 / 3)
344
+ ) == 0, "First level blocks don't fit evenly. Check `img_size`, `patch_size`, and `num_levels`"
345
+
346
+ # Block edge size in units of patches
347
+ # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
348
+ # number of blocks along edge of image
349
+ self.block_size = int((img_size // patch_size) // round(math.pow(self.num_blocks[0], 1 / 3)))
350
+
351
+ # Patch embedding
352
+ self.patch_embed = PatchEmbed3D(
353
+ img_size=[img_size, img_size, img_size],
354
+ patch_size=[patch_size, patch_size, patch_size],
355
+ in_chans=in_chans,
356
+ embed_dim=embed_dims[0],
357
+ )
358
+ self.num_patches = self.patch_embed.num_patches
359
+ self.seq_length = self.num_patches // self.num_blocks[0]
360
+ # Build up each hierarchical level
361
+ levels = []
362
+
363
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
364
+ prev_dim = None
365
+ curr_stride = 4
366
+ for i in range(len(self.num_blocks)):
367
+ dim = embed_dims[i]
368
+ levels.append(
369
+ NestLevel(
370
+ self.num_blocks[i],
371
+ self.block_size,
372
+ self.seq_length,
373
+ num_heads[i],
374
+ depths[i],
375
+ dim,
376
+ prev_dim,
377
+ mlp_ratio,
378
+ qkv_bias,
379
+ drop_rate,
380
+ attn_drop_rate,
381
+ dp_rates[i],
382
+ norm_layer,
383
+ act_layer,
384
+ pad_type=pad_type,
385
+ )
386
+ )
387
+ self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f"levels.{i}")]
388
+ prev_dim = dim
389
+ curr_stride *= 2
390
+
391
+ self.levels = nn.ModuleList([levels[i] for i in range(num_levels)])
392
+
393
+ # Final normalization layer
394
+ self.norm = norm_layer(embed_dims[-1])
395
+
396
+ self.init_weights(weight_init)
397
+
398
+ def init_weights(self, mode=""):
399
+ assert mode in ("nlhb", "")
400
+ head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
401
+ for level in self.levels:
402
+ trunc_normal_(level.pos_embed, std=0.02, a=-2, b=2)
403
+ named_apply(partial(_init_nest_weights, head_bias=head_bias), self)
404
+
405
+ @torch.jit.ignore
406
+ def no_weight_decay(self):
407
+ return {f"level.{i}.pos_embed" for i in range(len(self.levels))}
408
+
409
+ def get_classifier(self):
410
+ return self.head
411
+
412
+ def forward_features(self, x):
413
+ """x shape (B, C, D, H, W)"""
414
+ x = self.patch_embed(x)
415
+
416
+ hidden_states_out = [x]
417
+
418
+ for _, level in enumerate(self.levels):
419
+ x = level(x)
420
+ hidden_states_out.append(x)
421
+ # Layer norm done over channel dim only (to NDHWC and back)
422
+ x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
423
+ return x, hidden_states_out
424
+
425
+ def forward(self, x):
426
+ """x shape (B, C, D, H, W)"""
427
+ x = self.forward_features(x)
428
+
429
+ if self.drop_rate > 0.0:
430
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
431
+ return x
432
+
433
+
434
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
435
+ if not depth_first and include_root:
436
+ fn(module=module, name=name)
437
+ for child_name, child_module in module.named_children():
438
+ child_name = ".".join((name, child_name)) if name else child_name
439
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
440
+ if depth_first and include_root:
441
+ fn(module=module, name=name)
442
+ return module
443
+
444
+
445
+ def _init_nest_weights(module: nn.Module, name: str = "", head_bias: float = 0.0):
446
+ """NesT weight initialization
447
+ Can replicate Jax implementation. Otherwise follows vision_transformer.py
448
+ """
449
+ if isinstance(module, nn.Linear):
450
+ if name.startswith("head"):
451
+ trunc_normal_(module.weight, std=0.02, a=-2, b=2)
452
+ nn.init.constant_(module.bias, head_bias)
453
+ else:
454
+ trunc_normal_(module.weight, std=0.02, a=-2, b=2)
455
+ if module.bias is not None:
456
+ nn.init.zeros_(module.bias)
457
+ elif isinstance(module, nn.Conv2d):
458
+ trunc_normal_(module.weight, std=0.02, a=-2, b=2)
459
+ if module.bias is not None:
460
+ nn.init.zeros_(module.bias)
461
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
462
+ nn.init.zeros_(module.bias)
463
+ nn.init.ones_(module.weight)
464
+
465
+
466
+ def resize_pos_embed(posemb, posemb_new):
467
+ """
468
+ Rescale the grid of position embeddings when loading from state_dict
469
+ Expected shape of position embeddings is (1, T, N, C), and considers only square images
470
+ """
471
+ _logger.info("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
472
+ seq_length_old = posemb.shape[2]
473
+ num_blocks_new, seq_length_new = posemb_new.shape[1:3]
474
+ size_new = int(math.sqrt(num_blocks_new * seq_length_new))
475
+ # First change to (1, C, H, W)
476
+ posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2)
477
+ posemb = F.interpolate(posemb, size=[size_new, size_new], mode="bicubic", align_corners=False)
478
+ # Now change to new (1, T, N, C)
479
+ posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new)))
480
+ return posemb
481
+
482
+
483
+ def checkpoint_filter_fn(state_dict, model):
484
+ """resize positional embeddings of pretrained weights"""
485
+ pos_embed_keys = [k for k in state_dict.keys() if k.startswith("pos_embed_")]
486
+ for k in pos_embed_keys:
487
+ if state_dict[k].shape != getattr(model, k).shape:
488
+ state_dict[k] = resize_pos_embed(state_dict[k], getattr(model, k))
489
+ return state_dict
scripts/networks/patchEmbed3D.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020 - 2021 MONAI Consortium
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+
15
+ import math
16
+ from typing import Sequence, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from monai.utils import optional_import
22
+
23
+ Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
24
+
25
+
26
+ class PatchEmbeddingBlock(nn.Module):
27
+ """
28
+ A patch embedding block, based on: "Dosovitskiy et al.,
29
+ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ img_size: Tuple[int, int, int],
36
+ patch_size: Tuple[int, int, int],
37
+ hidden_size: int,
38
+ num_heads: int,
39
+ pos_embed: str,
40
+ dropout_rate: float = 0.0,
41
+ ) -> None:
42
+ """
43
+ Args:
44
+ in_channels: dimension of input channels.
45
+ img_size: dimension of input image.
46
+ patch_size: dimension of patch size.
47
+ hidden_size: dimension of hidden layer.
48
+ num_heads: number of attention heads.
49
+ pos_embed: position embedding layer type.
50
+ dropout_rate: faction of the input units to drop.
51
+
52
+ """
53
+
54
+ super().__init__()
55
+
56
+ if not (0 <= dropout_rate <= 1):
57
+ raise AssertionError("dropout_rate should be between 0 and 1.")
58
+
59
+ if hidden_size % num_heads != 0:
60
+ raise AssertionError("hidden size should be divisible by num_heads.")
61
+
62
+ for m, p in zip(img_size, patch_size):
63
+ if m < p:
64
+ raise AssertionError("patch_size should be smaller than img_size.")
65
+
66
+ if pos_embed not in ["conv", "perceptron"]:
67
+ raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
68
+
69
+ if pos_embed == "perceptron":
70
+ if img_size[0] % patch_size[0] != 0:
71
+ raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.")
72
+
73
+ self.n_patches = (
74
+ (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2])
75
+ )
76
+ self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2]
77
+
78
+ self.pos_embed = pos_embed
79
+ self.patch_embeddings: Union[nn.Conv3d, nn.Sequential]
80
+ if self.pos_embed == "conv":
81
+ self.patch_embeddings = nn.Conv3d(
82
+ in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
83
+ )
84
+ elif self.pos_embed == "perceptron":
85
+ self.patch_embeddings = nn.Sequential(
86
+ Rearrange(
87
+ "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)",
88
+ p1=patch_size[0],
89
+ p2=patch_size[1],
90
+ p3=patch_size[2],
91
+ ),
92
+ nn.Linear(self.patch_dim, hidden_size),
93
+ )
94
+ self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
95
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
96
+ self.dropout = nn.Dropout(dropout_rate)
97
+ self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
98
+ self.apply(self._init_weights)
99
+
100
+ def _init_weights(self, m):
101
+ if isinstance(m, nn.Linear):
102
+ self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
103
+ if isinstance(m, nn.Linear) and m.bias is not None:
104
+ nn.init.constant_(m.bias, 0)
105
+ elif isinstance(m, nn.LayerNorm):
106
+ nn.init.constant_(m.bias, 0)
107
+ nn.init.constant_(m.weight, 1.0)
108
+
109
+ def trunc_normal_(self, tensor, mean, std, a, b):
110
+ # From PyTorch official master until it's in a few official releases - RW
111
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
112
+ def norm_cdf(x):
113
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
114
+
115
+ with torch.no_grad():
116
+ l = norm_cdf((a - mean) / std)
117
+ u = norm_cdf((b - mean) / std)
118
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
119
+ tensor.erfinv_()
120
+ tensor.mul_(std * math.sqrt(2.0))
121
+ tensor.add_(mean)
122
+ tensor.clamp_(min=a, max=b)
123
+ return tensor
124
+
125
+ def forward(self, x):
126
+ if self.pos_embed == "conv":
127
+ x = self.patch_embeddings(x)
128
+ x = x.flatten(2)
129
+ x = x.transpose(-1, -2)
130
+ elif self.pos_embed == "perceptron":
131
+ x = self.patch_embeddings(x)
132
+ embeddings = x + self.position_embeddings
133
+ embeddings = self.dropout(embeddings)
134
+ return embeddings
135
+
136
+
137
+ class PatchEmbed3D(nn.Module):
138
+ """Video to Patch Embedding.
139
+
140
+ Args:
141
+ patch_size (int): Patch token size. Default: (2,4,4).
142
+ in_chans (int): Number of input video channels. Default: 3.
143
+ embed_dim (int): Number of linear projection output channels. Default: 96.
144
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ img_size: Sequence[int] = (96, 96, 96),
150
+ patch_size=(4, 4, 4),
151
+ in_chans: int = 1,
152
+ embed_dim: int = 96,
153
+ norm_layer=None,
154
+ ):
155
+ super().__init__()
156
+ self.patch_size = patch_size
157
+
158
+ self.in_chans = in_chans
159
+ self.embed_dim = embed_dim
160
+
161
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
162
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
163
+
164
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
165
+
166
+ if norm_layer is not None:
167
+ self.norm = norm_layer(embed_dim)
168
+ else:
169
+ self.norm = None
170
+
171
+ def forward(self, x):
172
+ """Forward function."""
173
+ # padding
174
+ _, _, d, h, w = x.size()
175
+ if w % self.patch_size[2] != 0:
176
+ x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))
177
+ if h % self.patch_size[1] != 0:
178
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))
179
+ if d % self.patch_size[0] != 0:
180
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))
181
+
182
+ x = self.proj(x) # B C D Wh Ww
183
+ if self.norm is not None:
184
+ d, wh, ww = x.size(2), x.size(3), x.size(4)
185
+ x = x.flatten(2).transpose(1, 2)
186
+ x = self.norm(x)
187
+ x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)
188
+ # pdb.set_trace()
189
+
190
+ return x
scripts/networks/unest_base_patch_4.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # limitations under the License.
3
+ """
4
+ The 3D NEST transformer based segmentation model
5
+
6
+ MASI Lab, Vanderbilty University
7
+
8
+
9
+ Authors: Xin Yu, Yinchi Zhou, Yucheng Tang, Bennett Landman
10
+
11
+
12
+ The NEST code is partly from
13
+
14
+ Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and
15
+ Interpretable Visual Understanding
16
+ https://arxiv.org/pdf/2105.12723.pdf
17
+
18
+ """
19
+ from typing import Sequence, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from monai.networks.blocks import Convolution
24
+ from monai.networks.blocks.dynunet_block import UnetOutBlock
25
+ from scripts.networks.nest_transformer_3D import NestTransformer3D
26
+ from scripts.networks.unest_block import UNesTBlock, UNesTConvBlock, UNestUpBlock
27
+
28
+
29
+ class UNesT(nn.Module):
30
+ """
31
+ UNesT model implementation
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ in_channels: int,
37
+ out_channels: int,
38
+ img_size: Sequence[int] = (96, 96, 96),
39
+ feature_size: int = 16,
40
+ patch_size: int = 2,
41
+ depths: Sequence[int] = (2, 2, 2, 2),
42
+ num_heads: Sequence[int] = (3, 6, 12, 24),
43
+ embed_dim: Sequence[int] = (128, 256, 512),
44
+ window_size: Sequence[int] = (7, 7, 7),
45
+ norm_name: Union[Tuple, str] = "instance",
46
+ conv_block: bool = False,
47
+ res_block: bool = True,
48
+ dropout_rate: float = 0.0,
49
+ ) -> None:
50
+ """
51
+ Args:
52
+ in_channels: dimension of input channels.
53
+ out_channels: dimension of output channels.
54
+ img_size: dimension of input image.
55
+ feature_size: dimension of network feature size.
56
+ hidden_size: dimension of hidden layer.
57
+ mlp_dim: dimension of feedforward layer.
58
+ num_heads: number of attention heads.
59
+ pos_embed: position embedding layer type.
60
+ norm_name: feature normalization type and arguments.
61
+ conv_block: bool argument to determine if convolutional block is used.
62
+ res_block: bool argument to determine if residual block is used.
63
+ dropout_rate: faction of the input units to drop.
64
+
65
+ Examples:
66
+
67
+ # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm
68
+ >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')
69
+
70
+ # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm
71
+ >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')
72
+
73
+ """
74
+
75
+ super().__init__()
76
+ if not (0 <= dropout_rate <= 1):
77
+ raise AssertionError("dropout_rate should be between 0 and 1.")
78
+ self.embed_dim = embed_dim
79
+ self.nestViT = NestTransformer3D(
80
+ img_size=96,
81
+ in_chans=1,
82
+ patch_size=patch_size,
83
+ num_levels=3,
84
+ embed_dims=embed_dim,
85
+ num_heads=num_heads,
86
+ depths=depths,
87
+ num_classes=1000,
88
+ mlp_ratio=4.0,
89
+ qkv_bias=True,
90
+ drop_rate=0.0,
91
+ attn_drop_rate=0.0,
92
+ drop_path_rate=0.5,
93
+ norm_layer=None,
94
+ act_layer=None,
95
+ pad_type="",
96
+ weight_init="",
97
+ global_pool="avg",
98
+ )
99
+ self.encoder1 = UNesTConvBlock(
100
+ spatial_dims=3,
101
+ in_channels=1,
102
+ out_channels=feature_size * 2,
103
+ kernel_size=3,
104
+ stride=1,
105
+ norm_name=norm_name,
106
+ res_block=res_block,
107
+ )
108
+ self.encoder2 = UNestUpBlock(
109
+ spatial_dims=3,
110
+ in_channels=self.embed_dim[0],
111
+ out_channels=feature_size * 4,
112
+ num_layer=1,
113
+ kernel_size=3,
114
+ stride=1,
115
+ upsample_kernel_size=2,
116
+ norm_name=norm_name,
117
+ conv_block=False,
118
+ res_block=False,
119
+ )
120
+
121
+ self.encoder3 = UNesTConvBlock(
122
+ spatial_dims=3,
123
+ in_channels=self.embed_dim[0],
124
+ out_channels=8 * feature_size,
125
+ kernel_size=3,
126
+ stride=1,
127
+ norm_name=norm_name,
128
+ res_block=res_block,
129
+ )
130
+ self.encoder4 = UNesTConvBlock(
131
+ spatial_dims=3,
132
+ in_channels=self.embed_dim[1],
133
+ out_channels=16 * feature_size,
134
+ kernel_size=3,
135
+ stride=1,
136
+ norm_name=norm_name,
137
+ res_block=res_block,
138
+ )
139
+ self.decoder5 = UNesTBlock(
140
+ spatial_dims=3,
141
+ in_channels=2 * self.embed_dim[2],
142
+ out_channels=feature_size * 32,
143
+ stride=1,
144
+ kernel_size=3,
145
+ upsample_kernel_size=2,
146
+ norm_name=norm_name,
147
+ res_block=res_block,
148
+ )
149
+ self.decoder4 = UNesTBlock(
150
+ spatial_dims=3,
151
+ in_channels=self.embed_dim[2],
152
+ out_channels=feature_size * 16,
153
+ stride=1,
154
+ kernel_size=3,
155
+ upsample_kernel_size=2,
156
+ norm_name=norm_name,
157
+ res_block=res_block,
158
+ )
159
+ self.decoder3 = UNesTBlock(
160
+ spatial_dims=3,
161
+ in_channels=feature_size * 16,
162
+ out_channels=feature_size * 8,
163
+ stride=1,
164
+ kernel_size=3,
165
+ upsample_kernel_size=2,
166
+ norm_name=norm_name,
167
+ res_block=res_block,
168
+ )
169
+ self.decoder2 = UNesTBlock(
170
+ spatial_dims=3,
171
+ in_channels=feature_size * 8,
172
+ out_channels=feature_size * 4,
173
+ stride=1,
174
+ kernel_size=3,
175
+ upsample_kernel_size=2,
176
+ norm_name=norm_name,
177
+ res_block=res_block,
178
+ )
179
+ self.decoder1 = UNesTBlock(
180
+ spatial_dims=3,
181
+ in_channels=feature_size * 4,
182
+ out_channels=feature_size * 2,
183
+ stride=1,
184
+ kernel_size=3,
185
+ upsample_kernel_size=2,
186
+ norm_name=norm_name,
187
+ res_block=res_block,
188
+ )
189
+ self.encoder10 = Convolution(
190
+ spatial_dims=3,
191
+ in_channels=32 * feature_size,
192
+ out_channels=64 * feature_size,
193
+ strides=2,
194
+ adn_ordering="ADN",
195
+ dropout=0.0,
196
+ )
197
+ self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) # type: ignore
198
+
199
+ def proj_feat(self, x, hidden_size, feat_size):
200
+ x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
201
+ x = x.permute(0, 4, 1, 2, 3).contiguous()
202
+ return x
203
+
204
+ def load_from(self, weights):
205
+ with torch.no_grad():
206
+ # copy weights from patch embedding
207
+ for i in weights["state_dict"]:
208
+ print(i)
209
+ self.vit.patch_embedding.position_embeddings.copy_(
210
+ weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"]
211
+ )
212
+ self.vit.patch_embedding.cls_token.copy_(
213
+ weights["state_dict"]["module.transformer.patch_embedding.cls_token"]
214
+ )
215
+ self.vit.patch_embedding.patch_embeddings[1].weight.copy_(
216
+ weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.weight"]
217
+ )
218
+ self.vit.patch_embedding.patch_embeddings[1].bias.copy_(
219
+ weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.bias"]
220
+ )
221
+
222
+ # copy weights from encoding blocks (default: num of blocks: 12)
223
+ for bname, block in self.vit.blocks.named_children():
224
+ print(block)
225
+ block.loadFrom(weights, n_block=bname)
226
+ # last norm layer of transformer
227
+ self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"])
228
+ self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"])
229
+
230
+ def forward(self, x_in):
231
+ x, hidden_states_out = self.nestViT(x_in)
232
+ enc0 = self.encoder1(x_in) # 2, 32, 96, 96, 96
233
+ x1 = hidden_states_out[0] # 2, 128, 24, 24, 24 2, 128, 12, 12, 12
234
+ enc1 = self.encoder2(x1) # 2, 64, 48, 48, 48 torch.Size([2, 64, 24, 24, 24])
235
+ x2 = hidden_states_out[1] # 2, 128, 24, 24, 24
236
+ enc2 = self.encoder3(x2) # 2, 128, 24, 24, 24 torch.Size([2, 128, 12, 12, 12])
237
+ x3 = hidden_states_out[2] # 2, 256, 12, 12, 12 torch.Size([2, 256, 6, 6, 6])
238
+ enc3 = self.encoder4(x3) # 2, 256, 12, 12, 12 torch.Size([2, 256, 6, 6, 6])
239
+ x4 = hidden_states_out[3]
240
+ enc4 = x4 # 2, 512, 6, 6, 6 torch.Size([2, 512, 3, 3, 3])
241
+ dec4 = x # 2, 512, 6, 6, 6 torch.Size([2, 512, 3, 3, 3])
242
+ dec4 = self.encoder10(dec4) # 2, 1024, 3, 3, 3 torch.Size([2, 1024, 2, 2, 2])
243
+ dec3 = self.decoder5(dec4, enc4) # 2, 512, 6, 6, 6
244
+ dec2 = self.decoder4(dec3, enc3) # 2, 256, 12, 12, 12
245
+ dec1 = self.decoder3(dec2, enc2) # 2, 128, 24, 24, 24
246
+ dec0 = self.decoder2(dec1, enc1) # 2, 64, 48, 48, 48
247
+ out = self.decoder1(dec0, enc0) # 2, 32, 96, 96, 96
248
+ logits = self.out(out)
249
+ return logits
scripts/networks/unest_block.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer
8
+
9
+
10
+ class UNesTBlock(nn.Module):
11
+ """ """
12
+
13
+ def __init__(
14
+ self,
15
+ spatial_dims: int,
16
+ in_channels: int,
17
+ out_channels: int, # type: ignore
18
+ kernel_size: Union[Sequence[int], int],
19
+ stride: Union[Sequence[int], int],
20
+ upsample_kernel_size: Union[Sequence[int], int],
21
+ norm_name: Union[Tuple, str],
22
+ res_block: bool = False,
23
+ ) -> None:
24
+ """
25
+ Args:
26
+ spatial_dims: number of spatial dimensions.
27
+ in_channels: number of input channels.
28
+ out_channels: number of output channels.
29
+ kernel_size: convolution kernel size.
30
+ stride: convolution stride.
31
+ upsample_kernel_size: convolution kernel size for transposed convolution layers.
32
+ norm_name: feature normalization type and arguments.
33
+ res_block: bool argument to determine if residual block is used.
34
+
35
+ """
36
+
37
+ super(UNesTBlock, self).__init__()
38
+ upsample_stride = upsample_kernel_size
39
+ self.transp_conv = get_conv_layer(
40
+ spatial_dims,
41
+ in_channels,
42
+ out_channels,
43
+ kernel_size=upsample_kernel_size,
44
+ stride=upsample_stride,
45
+ conv_only=True,
46
+ is_transposed=True,
47
+ )
48
+
49
+ if res_block:
50
+ self.conv_block = UnetResBlock(
51
+ spatial_dims,
52
+ out_channels + out_channels,
53
+ out_channels,
54
+ kernel_size=kernel_size,
55
+ stride=1,
56
+ norm_name=norm_name,
57
+ )
58
+ else:
59
+ self.conv_block = UnetBasicBlock( # type: ignore
60
+ spatial_dims,
61
+ out_channels + out_channels,
62
+ out_channels,
63
+ kernel_size=kernel_size,
64
+ stride=1,
65
+ norm_name=norm_name,
66
+ )
67
+
68
+ def forward(self, inp, skip):
69
+ # number of channels for skip should equals to out_channels
70
+ out = self.transp_conv(inp)
71
+ # print(out.shape)
72
+ # print(skip.shape)
73
+ out = torch.cat((out, skip), dim=1)
74
+ out = self.conv_block(out)
75
+ return out
76
+
77
+
78
+ class UNestUpBlock(nn.Module):
79
+ """ """
80
+
81
+ def __init__(
82
+ self,
83
+ spatial_dims: int,
84
+ in_channels: int,
85
+ out_channels: int,
86
+ num_layer: int,
87
+ kernel_size: Union[Sequence[int], int],
88
+ stride: Union[Sequence[int], int],
89
+ upsample_kernel_size: Union[Sequence[int], int],
90
+ norm_name: Union[Tuple, str],
91
+ conv_block: bool = False,
92
+ res_block: bool = False,
93
+ ) -> None:
94
+ """
95
+ Args:
96
+ spatial_dims: number of spatial dimensions.
97
+ in_channels: number of input channels.
98
+ out_channels: number of output channels.
99
+ num_layer: number of upsampling blocks.
100
+ kernel_size: convolution kernel size.
101
+ stride: convolution stride.
102
+ upsample_kernel_size: convolution kernel size for transposed convolution layers.
103
+ norm_name: feature normalization type and arguments.
104
+ conv_block: bool argument to determine if convolutional block is used.
105
+ res_block: bool argument to determine if residual block is used.
106
+
107
+ """
108
+
109
+ super().__init__()
110
+
111
+ upsample_stride = upsample_kernel_size
112
+ self.transp_conv_init = get_conv_layer(
113
+ spatial_dims,
114
+ in_channels,
115
+ out_channels,
116
+ kernel_size=upsample_kernel_size,
117
+ stride=upsample_stride,
118
+ conv_only=True,
119
+ is_transposed=True,
120
+ )
121
+ if conv_block:
122
+ if res_block:
123
+ self.blocks = nn.ModuleList(
124
+ [
125
+ nn.Sequential(
126
+ get_conv_layer(
127
+ spatial_dims,
128
+ out_channels,
129
+ out_channels,
130
+ kernel_size=upsample_kernel_size,
131
+ stride=upsample_stride,
132
+ conv_only=True,
133
+ is_transposed=True,
134
+ ),
135
+ UnetResBlock(
136
+ spatial_dims=3,
137
+ in_channels=out_channels,
138
+ out_channels=out_channels,
139
+ kernel_size=kernel_size,
140
+ stride=stride,
141
+ norm_name=norm_name,
142
+ ),
143
+ )
144
+ for i in range(num_layer)
145
+ ]
146
+ )
147
+ else:
148
+ self.blocks = nn.ModuleList(
149
+ [
150
+ nn.Sequential(
151
+ get_conv_layer(
152
+ spatial_dims,
153
+ out_channels,
154
+ out_channels,
155
+ kernel_size=upsample_kernel_size,
156
+ stride=upsample_stride,
157
+ conv_only=True,
158
+ is_transposed=True,
159
+ ),
160
+ UnetBasicBlock(
161
+ spatial_dims=3,
162
+ in_channels=out_channels,
163
+ out_channels=out_channels,
164
+ kernel_size=kernel_size,
165
+ stride=stride,
166
+ norm_name=norm_name,
167
+ ),
168
+ )
169
+ for i in range(num_layer)
170
+ ]
171
+ )
172
+ else:
173
+ self.blocks = nn.ModuleList(
174
+ [
175
+ get_conv_layer(
176
+ spatial_dims,
177
+ out_channels,
178
+ out_channels,
179
+ kernel_size=1,
180
+ stride=1,
181
+ conv_only=True,
182
+ is_transposed=True,
183
+ )
184
+ for i in range(num_layer)
185
+ ]
186
+ )
187
+
188
+ def forward(self, x):
189
+ x = self.transp_conv_init(x)
190
+ for blk in self.blocks:
191
+ x = blk(x)
192
+ return x
193
+
194
+
195
+ class UNesTConvBlock(nn.Module):
196
+ """
197
+ UNesT block with skip connections
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ spatial_dims: int,
203
+ in_channels: int,
204
+ out_channels: int,
205
+ kernel_size: Union[Sequence[int], int],
206
+ stride: Union[Sequence[int], int],
207
+ norm_name: Union[Tuple, str],
208
+ res_block: bool = False,
209
+ ) -> None:
210
+ """
211
+ Args:
212
+ spatial_dims: number of spatial dimensions.
213
+ in_channels: number of input channels.
214
+ out_channels: number of output channels.
215
+ kernel_size: convolution kernel size.
216
+ stride: convolution stride.
217
+ norm_name: feature normalization type and arguments.
218
+ res_block: bool argument to determine if residual block is used.
219
+
220
+ """
221
+
222
+ super().__init__()
223
+
224
+ if res_block:
225
+ self.layer = UnetResBlock(
226
+ spatial_dims=spatial_dims,
227
+ in_channels=in_channels,
228
+ out_channels=out_channels,
229
+ kernel_size=kernel_size,
230
+ stride=stride,
231
+ norm_name=norm_name,
232
+ )
233
+ else:
234
+ self.layer = UnetBasicBlock( # type: ignore
235
+ spatial_dims=spatial_dims,
236
+ in_channels=in_channels,
237
+ out_channels=out_channels,
238
+ kernel_size=kernel_size,
239
+ stride=stride,
240
+ norm_name=norm_name,
241
+ )
242
+
243
+ def forward(self, inp):
244
+ out = self.layer(inp)
245
+ return out