Upload wholeBrainSeg_Large_UNEST_segmentation version 0.2.6
Browse files- .gitattributes +3 -0
- LICENSE +201 -0
- configs/inference.json +136 -0
- configs/logging.conf +21 -0
- configs/metadata.json +223 -0
- configs/multi_gpu_train.json +36 -0
- configs/train.json +299 -0
- docs/3DSlicer_use.png +3 -0
- docs/README.md +189 -0
- docs/demo.png +3 -0
- docs/training.png +0 -0
- docs/unest.png +0 -0
- docs/wholebrain.png +3 -0
- models/model.pt +3 -0
- scripts/__init__.py +10 -0
- scripts/networks/__init__.py +10 -0
- scripts/networks/nest/__init__.py +16 -0
- scripts/networks/nest/utils.py +481 -0
- scripts/networks/nest_transformer_3D.py +489 -0
- scripts/networks/patchEmbed3D.py +190 -0
- scripts/networks/unest_base_patch_4.py +249 -0
- scripts/networks/unest_block.py +245 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
docs/3DSlicer_use.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
docs/demo.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
docs/wholebrain.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
configs/inference.json
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"imports": [
|
3 |
+
"$import glob",
|
4 |
+
"$import os"
|
5 |
+
],
|
6 |
+
"bundle_root": ".",
|
7 |
+
"output_dir": "$@bundle_root + '/eval'",
|
8 |
+
"dataset_dir": "$@bundle_root + '/dataset/images'",
|
9 |
+
"datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.nii.gz')))",
|
10 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
11 |
+
"network_def": {
|
12 |
+
"_target_": "scripts.networks.unest_base_patch_4.UNesT",
|
13 |
+
"in_channels": 1,
|
14 |
+
"out_channels": 133,
|
15 |
+
"patch_size": 4,
|
16 |
+
"depths": [
|
17 |
+
2,
|
18 |
+
2,
|
19 |
+
8
|
20 |
+
],
|
21 |
+
"embed_dim": [
|
22 |
+
128,
|
23 |
+
256,
|
24 |
+
512
|
25 |
+
],
|
26 |
+
"num_heads": [
|
27 |
+
4,
|
28 |
+
8,
|
29 |
+
16
|
30 |
+
]
|
31 |
+
},
|
32 |
+
"network": "$@network_def.to(@device)",
|
33 |
+
"preprocessing": {
|
34 |
+
"_target_": "Compose",
|
35 |
+
"transforms": [
|
36 |
+
{
|
37 |
+
"_target_": "LoadImaged",
|
38 |
+
"keys": "image"
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"_target_": "EnsureChannelFirstd",
|
42 |
+
"keys": "image"
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"_target_": "NormalizeIntensityd",
|
46 |
+
"keys": "image",
|
47 |
+
"nonzero": "True",
|
48 |
+
"channel_wise": "True"
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"_target_": "EnsureTyped",
|
52 |
+
"keys": "image"
|
53 |
+
}
|
54 |
+
]
|
55 |
+
},
|
56 |
+
"dataset": {
|
57 |
+
"_target_": "Dataset",
|
58 |
+
"data": "$[{'image': i} for i in @datalist]",
|
59 |
+
"transform": "@preprocessing"
|
60 |
+
},
|
61 |
+
"dataloader": {
|
62 |
+
"_target_": "DataLoader",
|
63 |
+
"dataset": "@dataset",
|
64 |
+
"batch_size": 1,
|
65 |
+
"shuffle": false,
|
66 |
+
"num_workers": 4
|
67 |
+
},
|
68 |
+
"inferer": {
|
69 |
+
"_target_": "SlidingWindowInferer",
|
70 |
+
"roi_size": [
|
71 |
+
96,
|
72 |
+
96,
|
73 |
+
96
|
74 |
+
],
|
75 |
+
"sw_batch_size": 4,
|
76 |
+
"overlap": 0.7
|
77 |
+
},
|
78 |
+
"postprocessing": {
|
79 |
+
"_target_": "Compose",
|
80 |
+
"transforms": [
|
81 |
+
{
|
82 |
+
"_target_": "Activationsd",
|
83 |
+
"keys": "pred",
|
84 |
+
"softmax": true
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"_target_": "Invertd",
|
88 |
+
"keys": "pred",
|
89 |
+
"transform": "@preprocessing",
|
90 |
+
"orig_keys": "image",
|
91 |
+
"meta_key_postfix": "meta_dict",
|
92 |
+
"nearest_interp": false,
|
93 |
+
"to_tensor": true
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"_target_": "AsDiscreted",
|
97 |
+
"keys": "pred",
|
98 |
+
"argmax": true
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"_target_": "SaveImaged",
|
102 |
+
"keys": "pred",
|
103 |
+
"meta_keys": "pred_meta_dict",
|
104 |
+
"output_dir": "@output_dir"
|
105 |
+
}
|
106 |
+
]
|
107 |
+
},
|
108 |
+
"handlers": [
|
109 |
+
{
|
110 |
+
"_target_": "CheckpointLoader",
|
111 |
+
"load_path": "$@bundle_root + '/models/model.pt'",
|
112 |
+
"load_dict": {
|
113 |
+
"model": "@network"
|
114 |
+
},
|
115 |
+
"strict": "True"
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"_target_": "StatsHandler",
|
119 |
+
"iteration_log": false
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"evaluator": {
|
123 |
+
"_target_": "SupervisedEvaluator",
|
124 |
+
"device": "@device",
|
125 |
+
"val_data_loader": "@dataloader",
|
126 |
+
"network": "@network",
|
127 |
+
"inferer": "@inferer",
|
128 |
+
"postprocessing": "@postprocessing",
|
129 |
+
"val_handlers": "@handlers",
|
130 |
+
"amp": false
|
131 |
+
},
|
132 |
+
"evaluating": [
|
133 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
|
134 |
+
"[email protected]()"
|
135 |
+
]
|
136 |
+
}
|
configs/logging.conf
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[loggers]
|
2 |
+
keys=root
|
3 |
+
|
4 |
+
[handlers]
|
5 |
+
keys=consoleHandler
|
6 |
+
|
7 |
+
[formatters]
|
8 |
+
keys=fullFormatter
|
9 |
+
|
10 |
+
[logger_root]
|
11 |
+
level=INFO
|
12 |
+
handlers=consoleHandler
|
13 |
+
|
14 |
+
[handler_consoleHandler]
|
15 |
+
class=StreamHandler
|
16 |
+
level=INFO
|
17 |
+
formatter=fullFormatter
|
18 |
+
args=(sys.stdout,)
|
19 |
+
|
20 |
+
[formatter_fullFormatter]
|
21 |
+
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
configs/metadata.json
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
|
3 |
+
"version": "0.2.6",
|
4 |
+
"changelog": {
|
5 |
+
"0.2.6": "update to huggingface hosting",
|
6 |
+
"0.2.5": "update large files",
|
7 |
+
"0.2.4": "fix black 24.1 format error",
|
8 |
+
"0.2.3": "fix PYTHONPATH in readme.md",
|
9 |
+
"0.2.2": "add name tag",
|
10 |
+
"0.2.1": "fix license Copyright error",
|
11 |
+
"0.2.0": "update license files",
|
12 |
+
"0.1.2": "Add training support for whole brain segmentation, users can use active learning in the MONAI Label",
|
13 |
+
"0.1.1": "Fix dimension according to MONAI 1.0 and fix readme file",
|
14 |
+
"0.1.0": "complete the model package"
|
15 |
+
},
|
16 |
+
"monai_version": "1.4.0",
|
17 |
+
"pytorch_version": "2.4.0",
|
18 |
+
"numpy_version": "1.24.4",
|
19 |
+
"required_packages_version": {
|
20 |
+
"nibabel": "5.2.1",
|
21 |
+
"pytorch-ignite": "0.4.11",
|
22 |
+
"einops": "0.7.0",
|
23 |
+
"fire": "0.6.0",
|
24 |
+
"timm": "0.6.7",
|
25 |
+
"torchvision": "0.19.0",
|
26 |
+
"tensorboard": "2.17.0"
|
27 |
+
},
|
28 |
+
"supported_apps": {},
|
29 |
+
"name": "Whole brain large UNEST segmentation",
|
30 |
+
"task": "Whole Brain Segmentation",
|
31 |
+
"description": "A 3D transformer-based model for whole brain segmentation from T1W MRI image",
|
32 |
+
"authors": "Vanderbilt University + MONAI team",
|
33 |
+
"copyright": "Copyright (c) MONAI Consortium",
|
34 |
+
"data_source": "",
|
35 |
+
"data_type": "nibabel",
|
36 |
+
"image_classes": "single channel data, intensity scaled to [0, 1]",
|
37 |
+
"label_classes": "133 Classes",
|
38 |
+
"pred_classes": "133 Classes",
|
39 |
+
"eval_metrics": {
|
40 |
+
"mean_dice": 0.71
|
41 |
+
},
|
42 |
+
"intended_use": "This is an example, not to be used for diagnostic purposes",
|
43 |
+
"references": [
|
44 |
+
"Xin, et al. Characterizing Renal Structures with 3D Block Aggregate Transformers. arXiv preprint arXiv:2203.02430 (2022). https://arxiv.org/pdf/2203.02430.pdf"
|
45 |
+
],
|
46 |
+
"network_data_format": {
|
47 |
+
"inputs": {
|
48 |
+
"image": {
|
49 |
+
"type": "image",
|
50 |
+
"format": "hounsfield",
|
51 |
+
"modality": "MRI",
|
52 |
+
"num_channels": 1,
|
53 |
+
"spatial_shape": [
|
54 |
+
96,
|
55 |
+
96,
|
56 |
+
96
|
57 |
+
],
|
58 |
+
"dtype": "float32",
|
59 |
+
"value_range": [
|
60 |
+
0,
|
61 |
+
1
|
62 |
+
],
|
63 |
+
"is_patch_data": true,
|
64 |
+
"channel_def": {
|
65 |
+
"0": "image"
|
66 |
+
}
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"outputs": {
|
70 |
+
"pred": {
|
71 |
+
"type": "image",
|
72 |
+
"format": "segmentation",
|
73 |
+
"num_channels": 133,
|
74 |
+
"spatial_shape": [
|
75 |
+
96,
|
76 |
+
96,
|
77 |
+
96
|
78 |
+
],
|
79 |
+
"dtype": "float32",
|
80 |
+
"value_range": [
|
81 |
+
0,
|
82 |
+
1
|
83 |
+
],
|
84 |
+
"is_patch_data": true,
|
85 |
+
"channel_def": {
|
86 |
+
"0": "background",
|
87 |
+
"1": "3rd-Ventricle",
|
88 |
+
"2": "4th-Ventricle",
|
89 |
+
"3": "Right-Accumbens-Area",
|
90 |
+
"4": "Left-Accumbens-Area",
|
91 |
+
"5": "Right-Amygdala",
|
92 |
+
"6": "Left-Amygdala",
|
93 |
+
"7": "Brain-Stem",
|
94 |
+
"8": "Right-Caudate",
|
95 |
+
"9": "Left-Caudate",
|
96 |
+
"10": "Right-Cerebellum-Exterior",
|
97 |
+
"11": "Left-Cerebellum-Exterior",
|
98 |
+
"12": "Right-Cerebellum-White-Matter",
|
99 |
+
"13": "Left-Cerebellum-White-Matter",
|
100 |
+
"14": "Right-Cerebral-White-Matter",
|
101 |
+
"15": "Left-Cerebral-White-Matter",
|
102 |
+
"16": "Right-Hippocampus",
|
103 |
+
"17": "Left-Hippocampus",
|
104 |
+
"18": "Right-Inf-Lat-Vent",
|
105 |
+
"19": "Left-Inf-Lat-Vent",
|
106 |
+
"20": "Right-Lateral-Ventricle",
|
107 |
+
"21": "Left-Lateral-Ventricle",
|
108 |
+
"22": "Right-Pallidum",
|
109 |
+
"23": "Left-Pallidum",
|
110 |
+
"24": "Right-Putamen",
|
111 |
+
"25": "Left-Putamen",
|
112 |
+
"26": "Right-Thalamus-Proper",
|
113 |
+
"27": "Left-Thalamus-Proper",
|
114 |
+
"28": "Right-Ventral-DC",
|
115 |
+
"29": "Left-Ventral-DC",
|
116 |
+
"30": "Cerebellar-Vermal-Lobules-I-V",
|
117 |
+
"31": "Cerebellar-Vermal-Lobules-VI-VII",
|
118 |
+
"32": "Cerebellar-Vermal-Lobules-VIII-X",
|
119 |
+
"33": "Left-Basal-Forebrain",
|
120 |
+
"34": "Right-Basal-Forebrain",
|
121 |
+
"35": "Right-ACgG--anterior-cingulate-gyrus",
|
122 |
+
"36": "Left-ACgG--anterior-cingulate-gyrus",
|
123 |
+
"37": "Right-AIns--anterior-insula",
|
124 |
+
"38": "Left-AIns--anterior-insula",
|
125 |
+
"39": "Right-AOrG--anterior-orbital-gyrus",
|
126 |
+
"40": "Left-AOrG--anterior-orbital-gyrus",
|
127 |
+
"41": "Right-AnG---angular-gyrus",
|
128 |
+
"42": "Left-AnG---angular-gyrus",
|
129 |
+
"43": "Right-Calc--calcarine-cortex",
|
130 |
+
"44": "Left-Calc--calcarine-cortex",
|
131 |
+
"45": "Right-CO----central-operculum",
|
132 |
+
"46": "Left-CO----central-operculum",
|
133 |
+
"47": "Right-Cun---cuneus",
|
134 |
+
"48": "Left-Cun---cuneus",
|
135 |
+
"49": "Right-Ent---entorhinal-area",
|
136 |
+
"50": "Left-Ent---entorhinal-area",
|
137 |
+
"51": "Right-FO----frontal-operculum",
|
138 |
+
"52": "Left-FO----frontal-operculum",
|
139 |
+
"53": "Right-FRP---frontal-pole",
|
140 |
+
"54": "Left-FRP---frontal-pole",
|
141 |
+
"55": "Right-FuG---fusiform-gyrus ",
|
142 |
+
"56": "Left-FuG---fusiform-gyrus",
|
143 |
+
"57": "Right-GRe---gyrus-rectus",
|
144 |
+
"58": "Left-GRe---gyrus-rectus",
|
145 |
+
"59": "Right-IOG---inferior-occipital-gyrus",
|
146 |
+
"60": "Left-IOG---inferior-occipital-gyrus",
|
147 |
+
"61": "Right-ITG---inferior-temporal-gyrus",
|
148 |
+
"62": "Left-ITG---inferior-temporal-gyrus",
|
149 |
+
"63": "Right-LiG---lingual-gyrus",
|
150 |
+
"64": "Left-LiG---lingual-gyrus",
|
151 |
+
"65": "Right-LOrG--lateral-orbital-gyrus",
|
152 |
+
"66": "Left-LOrG--lateral-orbital-gyrus",
|
153 |
+
"67": "Right-MCgG--middle-cingulate-gyrus",
|
154 |
+
"68": "Left-MCgG--middle-cingulate-gyrus",
|
155 |
+
"69": "Right-MFC---medial-frontal-cortex",
|
156 |
+
"70": "Left-MFC---medial-frontal-cortex",
|
157 |
+
"71": "Right-MFG---middle-frontal-gyrus",
|
158 |
+
"72": "Left-MFG---middle-frontal-gyrus",
|
159 |
+
"73": "Right-MOG---middle-occipital-gyrus",
|
160 |
+
"74": "Left-MOG---middle-occipital-gyrus",
|
161 |
+
"75": "Right-MOrG--medial-orbital-gyrus",
|
162 |
+
"76": "Left-MOrG--medial-orbital-gyrus",
|
163 |
+
"77": "Right-MPoG--postcentral-gyrus",
|
164 |
+
"78": "Left-MPoG--postcentral-gyrus",
|
165 |
+
"79": "Right-MPrG--precentral-gyrus",
|
166 |
+
"80": "Left-MPrG--precentral-gyrus",
|
167 |
+
"81": "Right-MSFG--superior-frontal-gyrus",
|
168 |
+
"82": "Left-MSFG--superior-frontal-gyrus",
|
169 |
+
"83": "Right-MTG---middle-temporal-gyrus",
|
170 |
+
"84": "Left-MTG---middle-temporal-gyrus",
|
171 |
+
"85": "Right-OCP---occipital-pole",
|
172 |
+
"86": "Left-OCP---occipital-pole",
|
173 |
+
"87": "Right-OFuG--occipital-fusiform-gyrus",
|
174 |
+
"88": "Left-OFuG--occipital-fusiform-gyrus",
|
175 |
+
"89": "Right-OpIFG-opercular-part-of-the-IFG",
|
176 |
+
"90": "Left-OpIFG-opercular-part-of-the-IFG",
|
177 |
+
"91": "Right-OrIFG-orbital-part-of-the-IFG",
|
178 |
+
"92": "Left-OrIFG-orbital-part-of-the-IFG",
|
179 |
+
"93": "Right-PCgG--posterior-cingulate-gyrus",
|
180 |
+
"94": "Left-PCgG--posterior-cingulate-gyrus",
|
181 |
+
"95": "Right-PCu---precuneus",
|
182 |
+
"96": "Left-PCu---precuneus",
|
183 |
+
"97": "Right-PHG---parahippocampal-gyrus",
|
184 |
+
"98": "Left-PHG---parahippocampal-gyrus",
|
185 |
+
"99": "Right-PIns--posterior-insula",
|
186 |
+
"100": "Left-PIns--posterior-insula",
|
187 |
+
"101": "Right-PO----parietal-operculum",
|
188 |
+
"102": "Left-PO----parietal-operculum",
|
189 |
+
"103": "Right-PoG---postcentral-gyrus",
|
190 |
+
"104": "Left-PoG---postcentral-gyrus",
|
191 |
+
"105": "Right-POrG--posterior-orbital-gyrus",
|
192 |
+
"106": "Left-POrG--posterior-orbital-gyrus",
|
193 |
+
"107": "Right-PP----planum-polare",
|
194 |
+
"108": "Left-PP----planum-polare",
|
195 |
+
"109": "Right-PrG---precentral-gyrus",
|
196 |
+
"110": "Left-PrG---precentral-gyrus",
|
197 |
+
"111": "Right-PT----planum-temporale",
|
198 |
+
"112": "Left-PT----planum-temporale",
|
199 |
+
"113": "Right-SCA---subcallosal-area",
|
200 |
+
"114": "Left-SCA---subcallosal-area",
|
201 |
+
"115": "Right-SFG---superior-frontal-gyrus",
|
202 |
+
"116": "Left-SFG---superior-frontal-gyrus",
|
203 |
+
"117": "Right-SMC---supplementary-motor-cortex",
|
204 |
+
"118": "Left-SMC---supplementary-motor-cortex",
|
205 |
+
"119": "Right-SMG---supramarginal-gyrus",
|
206 |
+
"120": "Left-SMG---supramarginal-gyrus",
|
207 |
+
"121": "Right-SOG---superior-occipital-gyrus",
|
208 |
+
"122": "Left-SOG---superior-occipital-gyrus",
|
209 |
+
"123": "Right-SPL---superior-parietal-lobule",
|
210 |
+
"124": "Left-SPL---superior-parietal-lobule",
|
211 |
+
"125": "Right-STG---superior-temporal-gyrus",
|
212 |
+
"126": "Left-STG---superior-temporal-gyrus",
|
213 |
+
"127": "Right-TMP---temporal-pole",
|
214 |
+
"128": "Left-TMP---temporal-pole",
|
215 |
+
"129": "Right-TrIFG-triangular-part-of-the-IFG",
|
216 |
+
"130": "Left-TrIFG-triangular-part-of-the-IFG",
|
217 |
+
"131": "Right-TTG---transverse-temporal-gyrus",
|
218 |
+
"132": "Left-TTG---transverse-temporal-gyrus"
|
219 |
+
}
|
220 |
+
}
|
221 |
+
}
|
222 |
+
}
|
223 |
+
}
|
configs/multi_gpu_train.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"device": "$torch.device(f'cuda:{dist.get_rank()}')",
|
3 |
+
"network": {
|
4 |
+
"_target_": "torch.nn.parallel.DistributedDataParallel",
|
5 |
+
"module": "$@network_def.to(@device)",
|
6 |
+
"device_ids": [
|
7 |
+
"@device"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
"train#sampler": {
|
11 |
+
"_target_": "DistributedSampler",
|
12 |
+
"dataset": "@train#dataset",
|
13 |
+
"even_divisible": true,
|
14 |
+
"shuffle": true
|
15 |
+
},
|
16 |
+
"train#dataloader#sampler": "@train#sampler",
|
17 |
+
"train#dataloader#shuffle": false,
|
18 |
+
"train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]",
|
19 |
+
"validate#sampler": {
|
20 |
+
"_target_": "DistributedSampler",
|
21 |
+
"dataset": "@validate#dataset",
|
22 |
+
"even_divisible": false,
|
23 |
+
"shuffle": false
|
24 |
+
},
|
25 |
+
"validate#dataloader#sampler": "@validate#sampler",
|
26 |
+
"validate#evaluator#val_handlers": "$None if dist.get_rank() > 0 else @validate#handlers",
|
27 |
+
"training": [
|
28 |
+
"$import torch.distributed as dist",
|
29 |
+
"$dist.init_process_group(backend='nccl')",
|
30 |
+
"$torch.cuda.set_device(@device)",
|
31 |
+
"$monai.utils.set_determinism(seed=123)",
|
32 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
|
33 |
+
"$@train#trainer.run()",
|
34 |
+
"$dist.destroy_process_group()"
|
35 |
+
]
|
36 |
+
}
|
configs/train.json
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"imports": [
|
3 |
+
"$import glob",
|
4 |
+
"$import os",
|
5 |
+
"$import ignite"
|
6 |
+
],
|
7 |
+
"bundle_root": ".",
|
8 |
+
"ckpt_dir": "$@bundle_root + '/models'",
|
9 |
+
"output_dir": "$@bundle_root + '/eval'",
|
10 |
+
"dataset_dir": "$@bundle_root + '/dataset/brain'",
|
11 |
+
"images": "$list(sorted(glob.glob(@dataset_dir + '/images/*.nii.gz')))",
|
12 |
+
"labels": "$list(sorted(glob.glob(@dataset_dir + '/labels/*.nii.gz')))",
|
13 |
+
"val_interval": 5,
|
14 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
15 |
+
"network_def": {
|
16 |
+
"_target_": "scripts.networks.unest_base_patch_4.UNesT",
|
17 |
+
"in_channels": 1,
|
18 |
+
"out_channels": 133,
|
19 |
+
"patch_size": 4,
|
20 |
+
"depths": [
|
21 |
+
2,
|
22 |
+
2,
|
23 |
+
8
|
24 |
+
],
|
25 |
+
"embed_dim": [
|
26 |
+
128,
|
27 |
+
256,
|
28 |
+
512
|
29 |
+
],
|
30 |
+
"num_heads": [
|
31 |
+
4,
|
32 |
+
8,
|
33 |
+
16
|
34 |
+
]
|
35 |
+
},
|
36 |
+
"network": "$@network_def.to(@device)",
|
37 |
+
"loss": {
|
38 |
+
"_target_": "DiceCELoss",
|
39 |
+
"to_onehot_y": true,
|
40 |
+
"softmax": true,
|
41 |
+
"squared_pred": true,
|
42 |
+
"batch": true
|
43 |
+
},
|
44 |
+
"optimizer": {
|
45 |
+
"_target_": "torch.optim.Adam",
|
46 |
+
"params": "[email protected]()",
|
47 |
+
"lr": 0.0001
|
48 |
+
},
|
49 |
+
"train": {
|
50 |
+
"deterministic_transforms": [
|
51 |
+
{
|
52 |
+
"_target_": "LoadImaged",
|
53 |
+
"keys": [
|
54 |
+
"image",
|
55 |
+
"label"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"_target_": "EnsureChannelFirstd",
|
60 |
+
"keys": [
|
61 |
+
"image",
|
62 |
+
"label"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"_target_": "EnsureTyped",
|
67 |
+
"keys": [
|
68 |
+
"image",
|
69 |
+
"label"
|
70 |
+
]
|
71 |
+
}
|
72 |
+
],
|
73 |
+
"random_transforms": [
|
74 |
+
{
|
75 |
+
"_target_": "RandSpatialCropd",
|
76 |
+
"keys": [
|
77 |
+
"image",
|
78 |
+
"label"
|
79 |
+
],
|
80 |
+
"roi_size": [
|
81 |
+
96,
|
82 |
+
96,
|
83 |
+
96
|
84 |
+
],
|
85 |
+
"random_size": false
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"_target_": "RandFlipd",
|
89 |
+
"keys": [
|
90 |
+
"image",
|
91 |
+
"label"
|
92 |
+
],
|
93 |
+
"spatial_axis": [
|
94 |
+
0
|
95 |
+
],
|
96 |
+
"prob": 0.1
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"_target_": "RandFlipd",
|
100 |
+
"keys": [
|
101 |
+
"image",
|
102 |
+
"label"
|
103 |
+
],
|
104 |
+
"spatial_axis": [
|
105 |
+
1
|
106 |
+
],
|
107 |
+
"prob": 0.1
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"_target_": "RandFlipd",
|
111 |
+
"keys": [
|
112 |
+
"image",
|
113 |
+
"label"
|
114 |
+
],
|
115 |
+
"spatial_axis": [
|
116 |
+
2
|
117 |
+
],
|
118 |
+
"prob": 0.1
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"_target_": "RandRotate90d",
|
122 |
+
"keys": [
|
123 |
+
"image",
|
124 |
+
"label"
|
125 |
+
],
|
126 |
+
"max_k": 3,
|
127 |
+
"prob": 0.1
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"_target_": "NormalizeIntensityd",
|
131 |
+
"keys": "image",
|
132 |
+
"nonzero": true,
|
133 |
+
"channel_wise": true
|
134 |
+
}
|
135 |
+
],
|
136 |
+
"preprocessing": {
|
137 |
+
"_target_": "Compose",
|
138 |
+
"transforms": "$@train#deterministic_transforms + @train#random_transforms"
|
139 |
+
},
|
140 |
+
"dataset": {
|
141 |
+
"_target_": "CacheDataset",
|
142 |
+
"data": "$[{'image': i, 'label': l} for i, l in zip(@images[:-2], @labels[:-2])]",
|
143 |
+
"transform": "@train#preprocessing",
|
144 |
+
"cache_rate": 1.0,
|
145 |
+
"num_workers": 2
|
146 |
+
},
|
147 |
+
"dataloader": {
|
148 |
+
"_target_": "DataLoader",
|
149 |
+
"dataset": "@train#dataset",
|
150 |
+
"batch_size": 1,
|
151 |
+
"shuffle": true,
|
152 |
+
"num_workers": 1
|
153 |
+
},
|
154 |
+
"inferer": {
|
155 |
+
"_target_": "SimpleInferer"
|
156 |
+
},
|
157 |
+
"postprocessing": {
|
158 |
+
"_target_": "Compose",
|
159 |
+
"transforms": [
|
160 |
+
{
|
161 |
+
"_target_": "Activationsd",
|
162 |
+
"keys": "pred",
|
163 |
+
"softmax": true
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"_target_": "AsDiscreted",
|
167 |
+
"keys": [
|
168 |
+
"pred",
|
169 |
+
"label"
|
170 |
+
],
|
171 |
+
"argmax": [
|
172 |
+
true,
|
173 |
+
false
|
174 |
+
],
|
175 |
+
"to_onehot": 133
|
176 |
+
}
|
177 |
+
]
|
178 |
+
},
|
179 |
+
"handlers": [
|
180 |
+
{
|
181 |
+
"_target_": "ValidationHandler",
|
182 |
+
"validator": "@validate#evaluator",
|
183 |
+
"epoch_level": true,
|
184 |
+
"interval": "@val_interval"
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"_target_": "StatsHandler",
|
188 |
+
"tag_name": "train_loss",
|
189 |
+
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"_target_": "TensorBoardStatsHandler",
|
193 |
+
"log_dir": "@output_dir",
|
194 |
+
"tag_name": "train_loss",
|
195 |
+
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
|
196 |
+
}
|
197 |
+
],
|
198 |
+
"key_metric": {
|
199 |
+
"train_accuracy": {
|
200 |
+
"_target_": "ignite.metrics.Accuracy",
|
201 |
+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
|
202 |
+
}
|
203 |
+
},
|
204 |
+
"trainer": {
|
205 |
+
"_target_": "SupervisedTrainer",
|
206 |
+
"max_epochs": 2000,
|
207 |
+
"device": "@device",
|
208 |
+
"train_data_loader": "@train#dataloader",
|
209 |
+
"network": "@network",
|
210 |
+
"loss_function": "@loss",
|
211 |
+
"optimizer": "@optimizer",
|
212 |
+
"inferer": "@train#inferer",
|
213 |
+
"postprocessing": "@train#postprocessing",
|
214 |
+
"key_train_metric": "@train#key_metric",
|
215 |
+
"train_handlers": "@train#handlers",
|
216 |
+
"amp": true
|
217 |
+
}
|
218 |
+
},
|
219 |
+
"validate": {
|
220 |
+
"preprocessing": {
|
221 |
+
"_target_": "Compose",
|
222 |
+
"transforms": "%train#deterministic_transforms"
|
223 |
+
},
|
224 |
+
"dataset": {
|
225 |
+
"_target_": "CacheDataset",
|
226 |
+
"data": "$[{'image': i, 'label': l} for i, l in zip(@images[-2:], @labels[-2:])]",
|
227 |
+
"transform": "@validate#preprocessing",
|
228 |
+
"cache_rate": 1.0
|
229 |
+
},
|
230 |
+
"dataloader": {
|
231 |
+
"_target_": "DataLoader",
|
232 |
+
"dataset": "@validate#dataset",
|
233 |
+
"batch_size": 2,
|
234 |
+
"shuffle": false,
|
235 |
+
"num_workers": 1
|
236 |
+
},
|
237 |
+
"inferer": {
|
238 |
+
"_target_": "SlidingWindowInferer",
|
239 |
+
"roi_size": [
|
240 |
+
96,
|
241 |
+
96,
|
242 |
+
96
|
243 |
+
],
|
244 |
+
"sw_batch_size": 4,
|
245 |
+
"overlap": 0.5
|
246 |
+
},
|
247 |
+
"postprocessing": "%train#postprocessing",
|
248 |
+
"handlers": [
|
249 |
+
{
|
250 |
+
"_target_": "StatsHandler",
|
251 |
+
"iteration_log": false
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"_target_": "TensorBoardStatsHandler",
|
255 |
+
"log_dir": "@output_dir",
|
256 |
+
"iteration_log": false
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"_target_": "CheckpointSaver",
|
260 |
+
"save_dir": "@ckpt_dir",
|
261 |
+
"save_dict": {
|
262 |
+
"model": "@network"
|
263 |
+
},
|
264 |
+
"save_key_metric": true,
|
265 |
+
"key_metric_filename": "model.pt"
|
266 |
+
}
|
267 |
+
],
|
268 |
+
"key_metric": {
|
269 |
+
"val_mean_dice": {
|
270 |
+
"_target_": "MeanDice",
|
271 |
+
"include_background": false,
|
272 |
+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
|
273 |
+
}
|
274 |
+
},
|
275 |
+
"additional_metrics": {
|
276 |
+
"val_accuracy": {
|
277 |
+
"_target_": "ignite.metrics.Accuracy",
|
278 |
+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
|
279 |
+
}
|
280 |
+
},
|
281 |
+
"evaluator": {
|
282 |
+
"_target_": "SupervisedEvaluator",
|
283 |
+
"device": "@device",
|
284 |
+
"val_data_loader": "@validate#dataloader",
|
285 |
+
"network": "@network",
|
286 |
+
"inferer": "@validate#inferer",
|
287 |
+
"postprocessing": "@validate#postprocessing",
|
288 |
+
"key_val_metric": "@validate#key_metric",
|
289 |
+
"additional_metrics": "@validate#additional_metrics",
|
290 |
+
"val_handlers": "@validate#handlers",
|
291 |
+
"amp": true
|
292 |
+
}
|
293 |
+
},
|
294 |
+
"training": [
|
295 |
+
"$monai.utils.set_determinism(seed=123)",
|
296 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
|
297 |
+
"$@train#trainer.run()"
|
298 |
+
]
|
299 |
+
}
|
docs/3DSlicer_use.png
ADDED
![]() |
Git LFS Details
|
docs/README.md
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description
|
2 |
+
Detailed whole brain segmentation is an essential quantitative technique in medical image analysis, which provides a non-invasive way of measuring brain regions from a clinical acquired structural magnetic resonance imaging (MRI).
|
3 |
+
We provide the pre-trained model for training and inferencing whole brain segmentation with 133 structures.
|
4 |
+
Training pipeline is provided to support active learning in MONAI Label and training with bundle.
|
5 |
+
|
6 |
+
A tutorial and release of model for whole brain segmentation using the 3D transformer-based segmentation model UNEST.
|
7 |
+
|
8 |
+
Authors:
|
9 |
+
Xin Yu ([email protected])
|
10 |
+
|
11 |
+
Yinchi Zhou ([email protected]) | Yucheng Tang ([email protected])
|
12 |
+
|
13 |
+
<p align="center">
|
14 |
+
-------------------------------------------------------------------------------------
|
15 |
+
</p>
|
16 |
+
|
17 |
+
 <br>
|
18 |
+
<p align="center">
|
19 |
+
Fig.1 - The demonstration of T1w MRI images registered in MNI space and the whole brain segmentation labels with 133 classes</p>
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
# Model Overview
|
24 |
+
A pre-trained UNEST base model [1] for volumetric (3D) whole brain segmentation with T1w MR images.
|
25 |
+
To leverage information across embedded sequences, ”shifted window” transformers
|
26 |
+
are proposed for dense predictions and modeling multi-scale features. However, these
|
27 |
+
attempts that aim to complicate the self-attention range often yield high computation
|
28 |
+
complexity and data inefficiency. Inspired by the aggregation function in the nested
|
29 |
+
ViT, we propose a new design of a 3D U-shape medical segmentation model with
|
30 |
+
Nested Transformers (UNesT) hierarchically with the 3D block aggregation function,
|
31 |
+
that learn locality behaviors for small structures or small dataset. This design retains
|
32 |
+
the original global self-attention mechanism and achieves information communication
|
33 |
+
across patches by stacking transformer encoders hierarchically.
|
34 |
+
|
35 |
+
 <br>
|
36 |
+
<p align="center">
|
37 |
+
Fig.2 - The network architecture of UNEST Base model
|
38 |
+
</p>
|
39 |
+
|
40 |
+
|
41 |
+
## Data
|
42 |
+
The training data is from the Vanderbilt University and Vanderbilt University Medical Center with public released OASIS and CANDI datsets.
|
43 |
+
Training and testing data are MRI T1-weighted (T1w) 3D volumes coming from 3 different sites. There are a total of 133 classes in the whole brain segmentation task.
|
44 |
+
Among 50 T1w MRI scans from Open Access Series on Imaging Studies (OASIS) (Marcus et al., 2007) dataset, 45 scans are used for training and the other 5 for validation.
|
45 |
+
The testing cohort contains Colin27 T1w scan (Aubert-Broche et al., 2006) and 13 T1w MRI scans from the Child and Adolescent Neuro Development Initiative (CANDI)
|
46 |
+
(Kennedy et al., 2012). All data are registered to the MNI space using the MNI305 (Evans et al., 1993) template and preprocessed follow the method in (Huo et al., 2019). Input images are randomly cropped to the size of 96 × 96 × 96.
|
47 |
+
|
48 |
+
### Important
|
49 |
+
|
50 |
+
The brain MRI images for training are registered to Affine registration from the target image to the MNI305 template using NiftyReg.
|
51 |
+
The data should be in the MNI305 space before inference.
|
52 |
+
|
53 |
+
If your images are already in MNI space, skip the registration step.
|
54 |
+
|
55 |
+
You could use any resitration tool to register image to MNI space. Here is an example using ants.
|
56 |
+
Registration to MNI Space: Sample suggestion. E.g., use ANTS or other tools for registering T1 MRI image to MNI305 Space.
|
57 |
+
|
58 |
+
```
|
59 |
+
pip install antspyx
|
60 |
+
|
61 |
+
#Sample ANTS registration
|
62 |
+
|
63 |
+
import ants
|
64 |
+
import sys
|
65 |
+
import os
|
66 |
+
|
67 |
+
fixed_image = ants.image_read('<fixed_image_path>')
|
68 |
+
moving_image = ants.image_read('<moving_image_path>')
|
69 |
+
transform = ants.registration(fixed_image,moving_image,'Affine')
|
70 |
+
|
71 |
+
reg3t = ants.apply_transforms(fixed_image,moving_image,transform['fwdtransforms'][0])
|
72 |
+
ants.image_write(reg3t,output_image_path)
|
73 |
+
```
|
74 |
+
|
75 |
+
## Training configuration
|
76 |
+
The training and inference was performed with at least one 24GB-memory GPU.
|
77 |
+
|
78 |
+
Actual Model Input: 96 x 96 x 96
|
79 |
+
|
80 |
+
## Input and output formats
|
81 |
+
Input: 1 channel T1w MRI image in MNI305 Space.
|
82 |
+
|
83 |
+
|
84 |
+
## commands example
|
85 |
+
Download trained checkpoint model to ./model/model.pt:
|
86 |
+
|
87 |
+
|
88 |
+
Add scripts component: To run the workflow with customized components, PYTHONPATH should be revised to include the path to the customized component:
|
89 |
+
|
90 |
+
```
|
91 |
+
export PYTHONPATH=$PYTHONPATH: '<path to the bundle root dir>/'
|
92 |
+
```
|
93 |
+
|
94 |
+
Execute Training:
|
95 |
+
|
96 |
+
```
|
97 |
+
python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf
|
98 |
+
```
|
99 |
+
|
100 |
+
Execute inference:
|
101 |
+
|
102 |
+
```
|
103 |
+
python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
|
104 |
+
```
|
105 |
+
|
106 |
+
|
107 |
+
## More examples output
|
108 |
+
 <br>
|
109 |
+
<p align="center">
|
110 |
+
Fig.3 - The output prediction comparison with variant and ground truth
|
111 |
+
</p>
|
112 |
+
|
113 |
+
## Training/Validation Benchmarking
|
114 |
+
A graph showing the training accuracy for fine-tuning 600 epochs.
|
115 |
+
|
116 |
+
 <br>
|
117 |
+
|
118 |
+
With 10 fine-tuned labels, the training process converges fast.
|
119 |
+
|
120 |
+
## Complete ROI of the whole brain segmentation
|
121 |
+
133 brain structures are segmented.
|
122 |
+
|
123 |
+
| #1 | #2 | #3 | #4 |
|
124 |
+
| :------------ | :---------- | :-------- | :-------- |
|
125 |
+
| 0: background | 1 : 3rd-Ventricle | 2 : 4th-Ventricle | 3 : Right-Accumbens-Area |
|
126 |
+
| 4 : Left-Accumbens-Area | 5 : Right-Amygdala | 6 : Left-Amygdala | 7 : Brain-Stem |
|
127 |
+
| 8 : Right-Caudate | 9 : Left-Caudate | 10 : Right-Cerebellum-Exterior | 11 : Left-Cerebellum-Exterior |
|
128 |
+
| 12 : Right-Cerebellum-White-Matter | 13 : Left-Cerebellum-White-Matter | 14 : Right-Cerebral-White-Matter | 15 : Left-Cerebral-White-Matter |
|
129 |
+
| 16 : Right-Hippocampus | 17 : Left-Hippocampus | 18 : Right-Inf-Lat-Vent | 19 : Left-Inf-Lat-Vent |
|
130 |
+
| 20 : Right-Lateral-Ventricle | 21 : Left-Lateral-Ventricle | 22 : Right-Pallidum | 23 : Left-Pallidum |
|
131 |
+
| 24 : Right-Putamen | 25 : Left-Putamen | 26 : Right-Thalamus-Proper | 27 : Left-Thalamus-Proper |
|
132 |
+
| 28 : Right-Ventral-DC | 29 : Left-Ventral-DC | 30 : Cerebellar-Vermal-Lobules-I-V | 31 : Cerebellar-Vermal-Lobules-VI-VII |
|
133 |
+
| 32 : Cerebellar-Vermal-Lobules-VIII-X | 33 : Left-Basal-Forebrain | 34 : Right-Basal-Forebrain | 35 : Right-ACgG--anterior-cingulate-gyrus |
|
134 |
+
| 36 : Left-ACgG--anterior-cingulate-gyrus | 37 : Right-AIns--anterior-insula | 38 : Left-AIns--anterior-insula | 39 : Right-AOrG--anterior-orbital-gyrus |
|
135 |
+
| 40 : Left-AOrG--anterior-orbital-gyrus | 41 : Right-AnG---angular-gyrus | 42 : Left-AnG---angular-gyrus | 43 : Right-Calc--calcarine-cortex |
|
136 |
+
| 44 : Left-Calc--calcarine-cortex | 45 : Right-CO----central-operculum | 46 : Left-CO----central-operculum | 47 : Right-Cun---cuneus |
|
137 |
+
| 48 : Left-Cun---cuneus | 49 : Right-Ent---entorhinal-area | 50 : Left-Ent---entorhinal-area | 51 : Right-FO----frontal-operculum |
|
138 |
+
| 52 : Left-FO----frontal-operculum | 53 : Right-FRP---frontal-pole | 54 : Left-FRP---frontal-pole | 55 : Right-FuG---fusiform-gyrus |
|
139 |
+
| 56 : Left-FuG---fusiform-gyrus | 57 : Right-GRe---gyrus-rectus | 58 : Left-GRe---gyrus-rectus | 59 : Right-IOG---inferior-occipital-gyrus ,
|
140 |
+
| 60 : Left-IOG---inferior-occipital-gyrus | 61 : Right-ITG---inferior-temporal-gyrus | 62 : Left-ITG---inferior-temporal-gyrus | 63 : Right-LiG---lingual-gyrus |
|
141 |
+
| 64 : Left-LiG---lingual-gyrus | 65 : Right-LOrG--lateral-orbital-gyrus | 66 : Left-LOrG--lateral-orbital-gyrus | 67 : Right-MCgG--middle-cingulate-gyrus |
|
142 |
+
| 68 : Left-MCgG--middle-cingulate-gyrus | 69 : Right-MFC---medial-frontal-cortex | 70 : Left-MFC---medial-frontal-cortex | 71 : Right-MFG---middle-frontal-gyrus |
|
143 |
+
| 72 : Left-MFG---middle-frontal-gyrus | 73 : Right-MOG---middle-occipital-gyrus | 74 : Left-MOG---middle-occipital-gyrus | 75 : Right-MOrG--medial-orbital-gyrus |
|
144 |
+
| 76 : Left-MOrG--medial-orbital-gyrus | 77 : Right-MPoG--postcentral-gyrus | 78 : Left-MPoG--postcentral-gyrus | 79 : Right-MPrG--precentral-gyrus |
|
145 |
+
| 80 : Left-MPrG--precentral-gyrus | 81 : Right-MSFG--superior-frontal-gyrus | 82 : Left-MSFG--superior-frontal-gyrus | 83 : Right-MTG---middle-temporal-gyrus |
|
146 |
+
| 84 : Left-MTG---middle-temporal-gyrus | 85 : Right-OCP---occipital-pole | 86 : Left-OCP---occipital-pole | 87 : Right-OFuG--occipital-fusiform-gyrus |
|
147 |
+
| 88 : Left-OFuG--occipital-fusiform-gyrus | 89 : Right-OpIFG-opercular-part-of-the-IFG | 90 : Left-OpIFG-opercular-part-of-the-IFG | 91 : Right-OrIFG-orbital-part-of-the-IFG |
|
148 |
+
| 92 : Left-OrIFG-orbital-part-of-the-IFG | 93 : Right-PCgG--posterior-cingulate-gyrus | 94 : Left-PCgG--posterior-cingulate-gyrus | 95 : Right-PCu---precuneus |
|
149 |
+
| 96 : Left-PCu---precuneus | 97 : Right-PHG---parahippocampal-gyrus | 98 : Left-PHG---parahippocampal-gyrus | 99 : Right-PIns--posterior-insula |
|
150 |
+
| 100 : Left-PIns--posterior-insula | 101 : Right-PO----parietal-operculum | 102 : Left-PO----parietal-operculum | 103 : Right-PoG---postcentral-gyrus |
|
151 |
+
| 104 : Left-PoG---postcentral-gyrus | 105 : Right-POrG--posterior-orbital-gyrus | 106 : Left-POrG--posterior-orbital-gyrus | 107 : Right-PP----planum-polare |
|
152 |
+
| 108 : Left-PP----planum-polare | 109 : Right-PrG---precentral-gyrus | 110 : Left-PrG---precentral-gyrus | 111 : Right-PT----planum-temporale |
|
153 |
+
| 112 : Left-PT----planum-temporale | 113 : Right-SCA---subcallosal-area | 114 : Left-SCA---subcallosal-area | 115 : Right-SFG---superior-frontal-gyrus |
|
154 |
+
| 116 : Left-SFG---superior-frontal-gyrus | 117 : Right-SMC---supplementary-motor-cortex | 118 : Left-SMC---supplementary-motor-cortex | 119 : Right-SMG---supramarginal-gyrus |
|
155 |
+
| 120 : Left-SMG---supramarginal-gyrus | 121 : Right-SOG---superior-occipital-gyrus | 122 : Left-SOG---superior-occipital-gyrus | 123 : Right-SPL---superior-parietal-lobule |
|
156 |
+
| 124 : Left-SPL---superior-parietal-lobule | 125 : Right-STG---superior-temporal-gyrus | 126 : Left-STG---superior-temporal-gyrus | 127 : Right-TMP---temporal-pole |
|
157 |
+
| 128 : Left-TMP---temporal-pole | 129 : Right-TrIFG-triangular-part-of-the-IFG | 130 : Left-TrIFG-triangular-part-of-the-IFG | 131 : Right-TTG---transverse-temporal-gyrus |
|
158 |
+
| 132 : Left-TTG---transverse-temporal-gyrus |
|
159 |
+
|
160 |
+
|
161 |
+
## Bundle Integration in MONAI Lable
|
162 |
+
The inference and training pipleine can be easily used by the MONAI Label server and 3D Slicer for fast labeling T1w MRI images in MNI space.
|
163 |
+
|
164 |
+
 <br>
|
165 |
+
|
166 |
+
# Disclaimer
|
167 |
+
This is an example, not to be used for diagnostic purposes.
|
168 |
+
|
169 |
+
# References
|
170 |
+
[1] Yu, Xin, Yinchi Zhou, Yucheng Tang et al. Characterizing Renal Structures with 3D Block Aggregate Transformers. arXiv preprint arXiv:2203.02430 (2022). https://arxiv.org/pdf/2203.02430.pdf
|
171 |
+
|
172 |
+
[2] Zizhao Zhang et al. Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding. AAAI Conference on Artificial Intelligence (AAAI) 2022
|
173 |
+
|
174 |
+
[3] Huo, Yuankai, et al. 3D whole brain segmentation using spatially localized atlas network tiles. NeuroImage 194 (2019): 105-119.
|
175 |
+
|
176 |
+
# License
|
177 |
+
Copyright (c) MONAI Consortium
|
178 |
+
|
179 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
180 |
+
you may not use this file except in compliance with the License.
|
181 |
+
You may obtain a copy of the License at
|
182 |
+
|
183 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
184 |
+
|
185 |
+
Unless required by applicable law or agreed to in writing, software
|
186 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
187 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
188 |
+
See the License for the specific language governing permissions and
|
189 |
+
limitations under the License.
|
docs/demo.png
ADDED
![]() |
Git LFS Details
|
docs/training.png
ADDED
![]() |
docs/unest.png
ADDED
![]() |
docs/wholebrain.png
ADDED
![]() |
Git LFS Details
|
models/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79a52ccd77bc35d05410f39788a1b063af3eb3b809b42241335c18aed27ec422
|
3 |
+
size 348901503
|
scripts/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
scripts/networks/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
scripts/networks/nest/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from .utils import (
|
3 |
+
Conv3dSame,
|
4 |
+
DropPath,
|
5 |
+
Linear,
|
6 |
+
Mlp,
|
7 |
+
_assert,
|
8 |
+
conv3d_same,
|
9 |
+
create_conv3d,
|
10 |
+
create_pool3d,
|
11 |
+
get_padding,
|
12 |
+
get_same_padding,
|
13 |
+
pad_same,
|
14 |
+
to_ntuple,
|
15 |
+
trunc_normal_,
|
16 |
+
)
|
scripts/networks/nest/utils.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
|
4 |
+
import collections.abc
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from itertools import repeat
|
8 |
+
from typing import List, Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
try:
|
15 |
+
from torch import _assert
|
16 |
+
except ImportError:
|
17 |
+
|
18 |
+
def _assert(condition: bool, message: str):
|
19 |
+
assert condition, message
|
20 |
+
|
21 |
+
|
22 |
+
def drop_block_2d(
|
23 |
+
x,
|
24 |
+
drop_prob: float = 0.1,
|
25 |
+
block_size: int = 7,
|
26 |
+
gamma_scale: float = 1.0,
|
27 |
+
with_noise: bool = False,
|
28 |
+
inplace: bool = False,
|
29 |
+
batchwise: bool = False,
|
30 |
+
):
|
31 |
+
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
32 |
+
|
33 |
+
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
34 |
+
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
35 |
+
"""
|
36 |
+
b, c, h, w = x.shape
|
37 |
+
total_size = w * h
|
38 |
+
clipped_block_size = min(block_size, min(w, h))
|
39 |
+
# seed_drop_rate, the gamma parameter
|
40 |
+
gamma = gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((w - block_size + 1) * (h - block_size + 1))
|
41 |
+
|
42 |
+
# Forces the block to be inside the feature map.
|
43 |
+
w_i, h_i = torch.meshgrid(torch.arange(w).to(x.device), torch.arange(h).to(x.device))
|
44 |
+
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < w - (clipped_block_size - 1) // 2)) & (
|
45 |
+
(h_i >= clipped_block_size // 2) & (h_i < h - (clipped_block_size - 1) // 2)
|
46 |
+
)
|
47 |
+
valid_block = torch.reshape(valid_block, (1, 1, h, w)).to(dtype=x.dtype)
|
48 |
+
|
49 |
+
if batchwise:
|
50 |
+
# one mask for whole batch, quite a bit faster
|
51 |
+
uniform_noise = torch.rand((1, c, h, w), dtype=x.dtype, device=x.device)
|
52 |
+
else:
|
53 |
+
uniform_noise = torch.rand_like(x)
|
54 |
+
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
55 |
+
block_mask = -F.max_pool2d(
|
56 |
+
-block_mask, kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2 # block_size,
|
57 |
+
)
|
58 |
+
|
59 |
+
if with_noise:
|
60 |
+
normal_noise = torch.randn((1, c, h, w), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
61 |
+
if inplace:
|
62 |
+
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
63 |
+
else:
|
64 |
+
x = x * block_mask + normal_noise * (1 - block_mask)
|
65 |
+
else:
|
66 |
+
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
|
67 |
+
if inplace:
|
68 |
+
x.mul_(block_mask * normalize_scale)
|
69 |
+
else:
|
70 |
+
x = x * block_mask * normalize_scale
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
def drop_block_fast_2d(
|
75 |
+
x: torch.Tensor,
|
76 |
+
drop_prob: float = 0.1,
|
77 |
+
block_size: int = 7,
|
78 |
+
gamma_scale: float = 1.0,
|
79 |
+
with_noise: bool = False,
|
80 |
+
inplace: bool = False,
|
81 |
+
):
|
82 |
+
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
83 |
+
|
84 |
+
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
85 |
+
block mask at edges.
|
86 |
+
"""
|
87 |
+
b, c, h, w = x.shape
|
88 |
+
total_size = w * h
|
89 |
+
clipped_block_size = min(block_size, min(w, h))
|
90 |
+
gamma = gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((w - block_size + 1) * (h - block_size + 1))
|
91 |
+
|
92 |
+
block_mask = torch.empty_like(x).bernoulli_(gamma)
|
93 |
+
block_mask = F.max_pool2d(
|
94 |
+
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2
|
95 |
+
)
|
96 |
+
|
97 |
+
if with_noise:
|
98 |
+
normal_noise = torch.empty_like(x).normal_()
|
99 |
+
if inplace:
|
100 |
+
x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
|
101 |
+
else:
|
102 |
+
x = x * (1.0 - block_mask) + normal_noise * block_mask
|
103 |
+
else:
|
104 |
+
block_mask = 1 - block_mask
|
105 |
+
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
|
106 |
+
if inplace:
|
107 |
+
x.mul_(block_mask * normalize_scale)
|
108 |
+
else:
|
109 |
+
x = x * block_mask * normalize_scale
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class DropBlock2d(nn.Module):
|
114 |
+
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
|
115 |
+
|
116 |
+
def __init__(
|
117 |
+
self, drop_prob=0.1, block_size=7, gamma_scale=1.0, with_noise=False, inplace=False, batchwise=False, fast=True
|
118 |
+
):
|
119 |
+
super(DropBlock2d, self).__init__()
|
120 |
+
self.drop_prob = drop_prob
|
121 |
+
self.gamma_scale = gamma_scale
|
122 |
+
self.block_size = block_size
|
123 |
+
self.with_noise = with_noise
|
124 |
+
self.inplace = inplace
|
125 |
+
self.batchwise = batchwise
|
126 |
+
self.fast = fast # FIXME finish comparisons of fast vs not
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
if not self.training or not self.drop_prob:
|
130 |
+
return x
|
131 |
+
if self.fast:
|
132 |
+
return drop_block_fast_2d(
|
133 |
+
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
return drop_block_2d(
|
137 |
+
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
|
142 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
143 |
+
|
144 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
145 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
146 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
147 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
148 |
+
'survival rate' as the argument.
|
149 |
+
|
150 |
+
"""
|
151 |
+
if drop_prob == 0.0 or not training:
|
152 |
+
return x
|
153 |
+
keep_prob = 1 - drop_prob
|
154 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
155 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
156 |
+
if keep_prob > 0.0 and scale_by_keep:
|
157 |
+
random_tensor.div_(keep_prob)
|
158 |
+
return x * random_tensor
|
159 |
+
|
160 |
+
|
161 |
+
class DropPath(nn.Module):
|
162 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
163 |
+
|
164 |
+
def __init__(self, drop_prob=None, scale_by_keep=True):
|
165 |
+
super(DropPath, self).__init__()
|
166 |
+
self.drop_prob = drop_prob
|
167 |
+
self.scale_by_keep = scale_by_keep
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
171 |
+
|
172 |
+
|
173 |
+
def create_conv3d(in_channels, out_channels, kernel_size, **kwargs):
|
174 |
+
"""Select a 2d convolution implementation based on arguments
|
175 |
+
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv3d, or CondConv2d.
|
176 |
+
|
177 |
+
Used extensively by EfficientNet, MobileNetv3 and related networks.
|
178 |
+
"""
|
179 |
+
|
180 |
+
depthwise = kwargs.pop("depthwise", False)
|
181 |
+
# for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
|
182 |
+
groups = in_channels if depthwise else kwargs.pop("groups", 1)
|
183 |
+
|
184 |
+
m = create_conv3d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
|
185 |
+
return m
|
186 |
+
|
187 |
+
|
188 |
+
def conv3d_same(
|
189 |
+
x,
|
190 |
+
weight: torch.Tensor,
|
191 |
+
bias: Optional[torch.Tensor] = None,
|
192 |
+
stride: Tuple[int, int] = (1, 1, 1),
|
193 |
+
padding: Tuple[int, int] = (0, 0, 0),
|
194 |
+
dilation: Tuple[int, int] = (1, 1, 1),
|
195 |
+
groups: int = 1,
|
196 |
+
):
|
197 |
+
x = pad_same(x, weight.shape[-3:], stride, dilation)
|
198 |
+
return F.conv3d(x, weight, bias, stride, (0, 0, 0), dilation, groups)
|
199 |
+
|
200 |
+
|
201 |
+
class Conv3dSame(nn.Conv2d):
|
202 |
+
"""Tensorflow like 'SAME' convolution wrapper for 2D convolutions"""
|
203 |
+
|
204 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
205 |
+
super(Conv3dSame, self).__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
return conv3d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
209 |
+
|
210 |
+
|
211 |
+
def create_conv3d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
212 |
+
padding = kwargs.pop("padding", "")
|
213 |
+
kwargs.setdefault("bias", False)
|
214 |
+
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
215 |
+
if is_dynamic:
|
216 |
+
return Conv3dSame(in_chs, out_chs, kernel_size, **kwargs)
|
217 |
+
else:
|
218 |
+
return nn.Conv3d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
219 |
+
|
220 |
+
|
221 |
+
# Calculate symmetric padding for a convolution
|
222 |
+
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
|
223 |
+
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
224 |
+
return padding
|
225 |
+
|
226 |
+
|
227 |
+
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
|
228 |
+
def get_same_padding(x: int, k: int, s: int, d: int):
|
229 |
+
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
|
230 |
+
|
231 |
+
|
232 |
+
# Can SAME padding for given args be done statically?
|
233 |
+
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
|
234 |
+
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
235 |
+
|
236 |
+
|
237 |
+
# Dynamically pad input x with 'SAME' padding for conv with specified args
|
238 |
+
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1, 1), value: float = 0):
|
239 |
+
id, ih, iw = x.size()[-3:]
|
240 |
+
pad_d, pad_h, pad_w = (
|
241 |
+
get_same_padding(id, k[0], s[0], d[0]),
|
242 |
+
get_same_padding(ih, k[1], s[1], d[1]),
|
243 |
+
get_same_padding(iw, k[2], s[2], d[2]),
|
244 |
+
)
|
245 |
+
if pad_d > 0 or pad_h > 0 or pad_w > 0:
|
246 |
+
x = F.pad(
|
247 |
+
x,
|
248 |
+
[pad_d // 2, pad_d - pad_d // 2, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
|
249 |
+
value=value,
|
250 |
+
)
|
251 |
+
return x
|
252 |
+
|
253 |
+
|
254 |
+
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
255 |
+
dynamic = False
|
256 |
+
if isinstance(padding, str):
|
257 |
+
# for any string padding, the padding will be calculated for you, one of three ways
|
258 |
+
padding = padding.lower()
|
259 |
+
if padding == "same":
|
260 |
+
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
261 |
+
if is_static_pad(kernel_size, **kwargs):
|
262 |
+
# static case, no extra overhead
|
263 |
+
padding = get_padding(kernel_size, **kwargs)
|
264 |
+
else:
|
265 |
+
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
266 |
+
padding = 0
|
267 |
+
dynamic = True
|
268 |
+
elif padding == "valid":
|
269 |
+
# 'VALID' padding, same as padding=0
|
270 |
+
padding = 0
|
271 |
+
else:
|
272 |
+
# Default to PyTorch style 'same'-ish symmetric padding
|
273 |
+
padding = get_padding(kernel_size, **kwargs)
|
274 |
+
return padding, dynamic
|
275 |
+
|
276 |
+
|
277 |
+
# From PyTorch internals
|
278 |
+
def _ntuple(n):
|
279 |
+
def parse(x):
|
280 |
+
if isinstance(x, collections.abc.Iterable):
|
281 |
+
return x
|
282 |
+
return tuple(repeat(x, n))
|
283 |
+
|
284 |
+
return parse
|
285 |
+
|
286 |
+
|
287 |
+
to_1tuple = _ntuple(1)
|
288 |
+
to_2tuple = _ntuple(2)
|
289 |
+
to_3tuple = _ntuple(3)
|
290 |
+
to_4tuple = _ntuple(4)
|
291 |
+
to_ntuple = _ntuple
|
292 |
+
|
293 |
+
|
294 |
+
def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
|
295 |
+
min_value = min_value or divisor
|
296 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
297 |
+
# Make sure that round down does not go down by more than 10%.
|
298 |
+
if new_v < round_limit * v:
|
299 |
+
new_v += divisor
|
300 |
+
return new_v
|
301 |
+
|
302 |
+
|
303 |
+
class Linear(nn.Linear):
|
304 |
+
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
|
305 |
+
|
306 |
+
Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
|
307 |
+
weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
|
308 |
+
"""
|
309 |
+
|
310 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
311 |
+
if torch.jit.is_scripting():
|
312 |
+
bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
|
313 |
+
return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
|
314 |
+
else:
|
315 |
+
return F.linear(input, self.weight, self.bias)
|
316 |
+
|
317 |
+
|
318 |
+
class Mlp(nn.Module):
|
319 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
320 |
+
|
321 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
322 |
+
super().__init__()
|
323 |
+
out_features = out_features or in_features
|
324 |
+
hidden_features = hidden_features or in_features
|
325 |
+
drop_probs = to_2tuple(drop)
|
326 |
+
|
327 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
328 |
+
self.act = act_layer()
|
329 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
330 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
331 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
332 |
+
|
333 |
+
def forward(self, x):
|
334 |
+
x = self.fc1(x)
|
335 |
+
x = self.act(x)
|
336 |
+
x = self.drop1(x)
|
337 |
+
x = self.fc2(x)
|
338 |
+
x = self.drop2(x)
|
339 |
+
return x
|
340 |
+
|
341 |
+
|
342 |
+
def avg_pool3d_same(
|
343 |
+
x,
|
344 |
+
kernel_size: List[int],
|
345 |
+
stride: List[int],
|
346 |
+
padding: List[int] = (0, 0, 0),
|
347 |
+
ceil_mode: bool = False,
|
348 |
+
count_include_pad: bool = True,
|
349 |
+
):
|
350 |
+
# FIXME how to deal with count_include_pad vs not for external padding?
|
351 |
+
x = pad_same(x, kernel_size, stride)
|
352 |
+
return F.avg_pool3d(x, kernel_size, stride, (0, 0, 0), ceil_mode, count_include_pad)
|
353 |
+
|
354 |
+
|
355 |
+
class AvgPool3dSame(nn.AvgPool2d):
|
356 |
+
"""Tensorflow like 'SAME' wrapper for 2D average pooling"""
|
357 |
+
|
358 |
+
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
359 |
+
kernel_size = to_2tuple(kernel_size)
|
360 |
+
stride = to_2tuple(stride)
|
361 |
+
super(AvgPool3dSame, self).__init__(kernel_size, stride, (0, 0, 0), ceil_mode, count_include_pad)
|
362 |
+
|
363 |
+
def forward(self, x):
|
364 |
+
x = pad_same(x, self.kernel_size, self.stride)
|
365 |
+
return F.avg_pool3d(x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
|
366 |
+
|
367 |
+
|
368 |
+
def max_pool3d_same(
|
369 |
+
x,
|
370 |
+
kernel_size: List[int],
|
371 |
+
stride: List[int],
|
372 |
+
padding: List[int] = (0, 0, 0),
|
373 |
+
dilation: List[int] = (1, 1, 1),
|
374 |
+
ceil_mode: bool = False,
|
375 |
+
):
|
376 |
+
x = pad_same(x, kernel_size, stride, value=-float("inf"))
|
377 |
+
return F.max_pool3d(x, kernel_size, stride, (0, 0, 0), dilation, ceil_mode)
|
378 |
+
|
379 |
+
|
380 |
+
class MaxPool3dSame(nn.MaxPool2d):
|
381 |
+
"""Tensorflow like 'SAME' wrapper for 3D max pooling"""
|
382 |
+
|
383 |
+
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
|
384 |
+
kernel_size = to_2tuple(kernel_size)
|
385 |
+
stride = to_2tuple(stride)
|
386 |
+
dilation = to_2tuple(dilation)
|
387 |
+
super(MaxPool3dSame, self).__init__(kernel_size, stride, (0, 0, 0), dilation, ceil_mode)
|
388 |
+
|
389 |
+
def forward(self, x):
|
390 |
+
x = pad_same(x, self.kernel_size, self.stride, value=-float("inf"))
|
391 |
+
return F.max_pool3d(x, self.kernel_size, self.stride, (0, 0, 0), self.dilation, self.ceil_mode)
|
392 |
+
|
393 |
+
|
394 |
+
def create_pool3d(pool_type, kernel_size, stride=None, **kwargs):
|
395 |
+
stride = stride or kernel_size
|
396 |
+
padding = kwargs.pop("padding", "")
|
397 |
+
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
|
398 |
+
if is_dynamic:
|
399 |
+
if pool_type == "avg":
|
400 |
+
return AvgPool3dSame(kernel_size, stride=stride, **kwargs)
|
401 |
+
elif pool_type == "max":
|
402 |
+
return MaxPool3dSame(kernel_size, stride=stride, **kwargs)
|
403 |
+
else:
|
404 |
+
raise AssertionError()
|
405 |
+
|
406 |
+
# assert False, f"Unsupported pool type {pool_type}"
|
407 |
+
else:
|
408 |
+
if pool_type == "avg":
|
409 |
+
return nn.AvgPool3d(kernel_size, stride=stride, padding=padding, **kwargs)
|
410 |
+
elif pool_type == "max":
|
411 |
+
return nn.MaxPool3d(kernel_size, stride=stride, padding=padding, **kwargs)
|
412 |
+
else:
|
413 |
+
raise AssertionError()
|
414 |
+
|
415 |
+
# assert False, f"Unsupported pool type {pool_type}"
|
416 |
+
|
417 |
+
|
418 |
+
def _float_to_int(x: float) -> int:
|
419 |
+
"""
|
420 |
+
Symbolic tracing helper to substitute for inbuilt `int`.
|
421 |
+
Hint: Inbuilt `int` can't accept an argument of type `Proxy`
|
422 |
+
"""
|
423 |
+
return int(x)
|
424 |
+
|
425 |
+
|
426 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
427 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
428 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
429 |
+
def norm_cdf(x):
|
430 |
+
# Computes standard normal cumulative distribution function
|
431 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
432 |
+
|
433 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
434 |
+
warnings.warn(
|
435 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
436 |
+
"The distribution of values may be incorrect.",
|
437 |
+
stacklevel=2,
|
438 |
+
)
|
439 |
+
|
440 |
+
with torch.no_grad():
|
441 |
+
# Values are generated by using a truncated uniform distribution and
|
442 |
+
# then using the inverse CDF for the normal distribution.
|
443 |
+
# Get upper and lower cdf values
|
444 |
+
l = norm_cdf((a - mean) / std)
|
445 |
+
u = norm_cdf((b - mean) / std)
|
446 |
+
|
447 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
448 |
+
# [2l-1, 2u-1].
|
449 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
450 |
+
|
451 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
452 |
+
# standard normal
|
453 |
+
tensor.erfinv_()
|
454 |
+
|
455 |
+
# Transform to proper mean, std
|
456 |
+
tensor.mul_(std * math.sqrt(2.0))
|
457 |
+
tensor.add_(mean)
|
458 |
+
|
459 |
+
# Clamp to ensure it's in the proper range
|
460 |
+
tensor.clamp_(min=a, max=b)
|
461 |
+
return tensor
|
462 |
+
|
463 |
+
|
464 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
465 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
466 |
+
normal distribution. The values are effectively drawn from the
|
467 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
468 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
469 |
+
the bounds. The method used for generating the random values works
|
470 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
471 |
+
Args:
|
472 |
+
tensor: an n-dimensional `torch.Tensor`
|
473 |
+
mean: the mean of the normal distribution
|
474 |
+
std: the standard deviation of the normal distribution
|
475 |
+
a: the minimum cutoff value
|
476 |
+
b: the maximum cutoff value
|
477 |
+
Examples:
|
478 |
+
>>> w = torch.empty(3, 5)
|
479 |
+
>>> nn.init.trunc_normal_(w)
|
480 |
+
"""
|
481 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
scripts/networks/nest_transformer_3D.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# =========================================================================
|
4 |
+
# Adapted from https://github.com/google-research/nested-transformer.
|
5 |
+
# which has the following license...
|
6 |
+
# https://github.com/pytorch/vision/blob/main/LICENSE
|
7 |
+
#
|
8 |
+
# BSD 3-Clause License
|
9 |
+
|
10 |
+
|
11 |
+
# Redistribution and use in source and binary forms, with or without
|
12 |
+
# modification, are permitted provided that the following conditions are met:
|
13 |
+
|
14 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
15 |
+
# list of conditions and the following disclaimer.
|
16 |
+
|
17 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
18 |
+
# this list of conditions and the following disclaimer in the documentation
|
19 |
+
# and/or other materials provided with the distribution.
|
20 |
+
|
21 |
+
# * Neither the name of the copyright holder nor the names of its
|
22 |
+
# contributors may be used to endorse or promote products derived from
|
23 |
+
# this software without specific prior written permission.
|
24 |
+
|
25 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
26 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
27 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
28 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
29 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
30 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
31 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
32 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
33 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
34 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
35 |
+
|
36 |
+
""" Nested Transformer (NesT) in PyTorch
|
37 |
+
A PyTorch implement of Aggregating Nested Transformers as described in:
|
38 |
+
'Aggregating Nested Transformers'
|
39 |
+
- https://arxiv.org/abs/2105.12723
|
40 |
+
The official Jax code is released and available at https://github.com/google-research/nested-transformer.
|
41 |
+
The weights have been converted with convert/convert_nest_flax.py
|
42 |
+
Acknowledgments:
|
43 |
+
* The paper authors for sharing their research, code, and model weights
|
44 |
+
* Ross Wightman's existing code off which I based this
|
45 |
+
Copyright 2021 Alexander Soare
|
46 |
+
|
47 |
+
"""
|
48 |
+
|
49 |
+
import collections.abc
|
50 |
+
import logging
|
51 |
+
import math
|
52 |
+
from functools import partial
|
53 |
+
from typing import Callable, Sequence
|
54 |
+
|
55 |
+
import torch
|
56 |
+
import torch.nn.functional as F
|
57 |
+
from torch import nn
|
58 |
+
|
59 |
+
from .nest import DropPath, Mlp, _assert, create_conv3d, create_pool3d, to_ntuple, trunc_normal_
|
60 |
+
from .patchEmbed3D import PatchEmbed3D
|
61 |
+
|
62 |
+
_logger = logging.getLogger(__name__)
|
63 |
+
|
64 |
+
|
65 |
+
class Attention(nn.Module):
|
66 |
+
"""
|
67 |
+
This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
|
68 |
+
an extra "image block" dim
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
|
72 |
+
super().__init__()
|
73 |
+
self.num_heads = num_heads
|
74 |
+
head_dim = dim // num_heads
|
75 |
+
self.scale = head_dim**-0.5
|
76 |
+
|
77 |
+
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
|
78 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
79 |
+
self.proj = nn.Linear(dim, dim)
|
80 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
"""
|
84 |
+
x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
|
85 |
+
"""
|
86 |
+
b, t, n, c = x.shape
|
87 |
+
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
|
88 |
+
qkv = self.qkv(x).reshape(b, t, n, 3, self.num_heads, c // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
89 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
90 |
+
|
91 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
92 |
+
attn = attn.softmax(dim=-1)
|
93 |
+
attn = self.attn_drop(attn)
|
94 |
+
|
95 |
+
x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(b, t, n, c)
|
96 |
+
x = self.proj(x)
|
97 |
+
x = self.proj_drop(x)
|
98 |
+
return x # (B, T, N, C)
|
99 |
+
|
100 |
+
|
101 |
+
class TransformerLayer(nn.Module):
|
102 |
+
"""
|
103 |
+
This is much like `.vision_transformer.Block` but:
|
104 |
+
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
|
105 |
+
- Uses modified Attention layer that handles the "block" dimension
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
dim,
|
111 |
+
num_heads,
|
112 |
+
mlp_ratio=4.0,
|
113 |
+
qkv_bias=False,
|
114 |
+
drop=0.0,
|
115 |
+
attn_drop=0.0,
|
116 |
+
drop_path=0.0,
|
117 |
+
act_layer=nn.GELU,
|
118 |
+
norm_layer=nn.LayerNorm,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
self.norm1 = norm_layer(dim)
|
122 |
+
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
123 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
124 |
+
self.norm2 = norm_layer(dim)
|
125 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
126 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
y = self.norm1(x)
|
130 |
+
x = x + self.drop_path(self.attn(y))
|
131 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
132 |
+
return x
|
133 |
+
|
134 |
+
|
135 |
+
class ConvPool(nn.Module):
|
136 |
+
def __init__(self, in_channels, out_channels, norm_layer, pad_type=""):
|
137 |
+
super().__init__()
|
138 |
+
self.conv = create_conv3d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True)
|
139 |
+
self.norm = norm_layer(out_channels)
|
140 |
+
self.pool = create_pool3d("max", kernel_size=3, stride=2, padding=pad_type)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
"""
|
144 |
+
x is expected to have shape (B, C, D, H, W)
|
145 |
+
"""
|
146 |
+
_assert(x.shape[-3] % 2 == 0, "BlockAggregation requires even input spatial dims")
|
147 |
+
_assert(x.shape[-2] % 2 == 0, "BlockAggregation requires even input spatial dims")
|
148 |
+
_assert(x.shape[-1] % 2 == 0, "BlockAggregation requires even input spatial dims")
|
149 |
+
|
150 |
+
# print('In ConvPool x : {}'.format(x.shape))
|
151 |
+
x = self.conv(x)
|
152 |
+
# Layer norm done over channel dim only
|
153 |
+
x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
|
154 |
+
x = self.pool(x)
|
155 |
+
return x # (B, C, D//2, H//2, W//2)
|
156 |
+
|
157 |
+
|
158 |
+
def blockify(x, block_size: int):
|
159 |
+
"""image to blocks
|
160 |
+
Args:
|
161 |
+
x (Tensor): with shape (B, D, H, W, C)
|
162 |
+
block_size (int): edge length of a single square block in units of D, H, W
|
163 |
+
"""
|
164 |
+
b, d, h, w, c = x.shape
|
165 |
+
_assert(d % block_size == 0, "`block_size` must divide input depth evenly")
|
166 |
+
_assert(h % block_size == 0, "`block_size` must divide input height evenly")
|
167 |
+
_assert(w % block_size == 0, "`block_size` must divide input width evenly")
|
168 |
+
grid_depth = d // block_size
|
169 |
+
grid_height = h // block_size
|
170 |
+
grid_width = w // block_size
|
171 |
+
x = x.reshape(b, grid_depth, block_size, grid_height, block_size, grid_width, block_size, c)
|
172 |
+
|
173 |
+
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape(
|
174 |
+
b, grid_depth * grid_height * grid_width, -1, c
|
175 |
+
) # shape [2, 512, 27, 128]
|
176 |
+
|
177 |
+
return x # (B, T, N, C)
|
178 |
+
|
179 |
+
|
180 |
+
# @register_notrace_function # reason: int receives Proxy
|
181 |
+
def deblockify(x, block_size: int):
|
182 |
+
"""blocks to image
|
183 |
+
Args:
|
184 |
+
x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
|
185 |
+
block_size (int): edge length of a single square block in units of desired D, H, W
|
186 |
+
"""
|
187 |
+
b, t, _, c = x.shape
|
188 |
+
grid_size = round(math.pow(t, 1 / 3))
|
189 |
+
depth = height = width = grid_size * block_size
|
190 |
+
x = x.reshape(b, grid_size, grid_size, grid_size, block_size, block_size, block_size, c)
|
191 |
+
|
192 |
+
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(b, depth, height, width, c)
|
193 |
+
|
194 |
+
return x # (B, D, H, W, C)
|
195 |
+
|
196 |
+
|
197 |
+
class NestLevel(nn.Module):
|
198 |
+
"""Single hierarchical level of a Nested Transformer"""
|
199 |
+
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
num_blocks,
|
203 |
+
block_size,
|
204 |
+
seq_length,
|
205 |
+
num_heads,
|
206 |
+
depth,
|
207 |
+
embed_dim,
|
208 |
+
prev_embed_dim=None,
|
209 |
+
mlp_ratio=4.0,
|
210 |
+
qkv_bias=True,
|
211 |
+
drop_rate=0.0,
|
212 |
+
attn_drop_rate=0.0,
|
213 |
+
drop_path_rates: Sequence[int] = (),
|
214 |
+
norm_layer=None,
|
215 |
+
act_layer=None,
|
216 |
+
pad_type="",
|
217 |
+
):
|
218 |
+
super().__init__()
|
219 |
+
self.block_size = block_size
|
220 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
|
221 |
+
|
222 |
+
if prev_embed_dim is not None:
|
223 |
+
self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type)
|
224 |
+
else:
|
225 |
+
self.pool = nn.Identity()
|
226 |
+
|
227 |
+
# Transformer encoder
|
228 |
+
if len(drop_path_rates):
|
229 |
+
assert len(drop_path_rates) == depth, "Must provide as many drop path rates as there are transformer layers"
|
230 |
+
self.transformer_encoder = nn.Sequential(
|
231 |
+
*[
|
232 |
+
TransformerLayer(
|
233 |
+
dim=embed_dim,
|
234 |
+
num_heads=num_heads,
|
235 |
+
mlp_ratio=mlp_ratio,
|
236 |
+
qkv_bias=qkv_bias,
|
237 |
+
drop=drop_rate,
|
238 |
+
attn_drop=attn_drop_rate,
|
239 |
+
drop_path=drop_path_rates[i],
|
240 |
+
norm_layer=norm_layer,
|
241 |
+
act_layer=act_layer,
|
242 |
+
)
|
243 |
+
for i in range(depth)
|
244 |
+
]
|
245 |
+
)
|
246 |
+
|
247 |
+
def forward(self, x):
|
248 |
+
"""
|
249 |
+
expects x as (B, C, D, H, W)
|
250 |
+
"""
|
251 |
+
x = self.pool(x)
|
252 |
+
x = x.permute(0, 2, 3, 4, 1) # (B, H', W', C), switch to channels last for transformer
|
253 |
+
|
254 |
+
x = blockify(x, self.block_size) # (B, T, N, C')
|
255 |
+
x = x + self.pos_embed
|
256 |
+
|
257 |
+
x = self.transformer_encoder(x) # (B, ,T, N, C')
|
258 |
+
|
259 |
+
x = deblockify(x, self.block_size) # (B, D', H', W', C') [2, 24, 24, 24, 128]
|
260 |
+
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
|
261 |
+
return x.permute(0, 4, 1, 2, 3) # (B, C, D', H', W')
|
262 |
+
|
263 |
+
|
264 |
+
class NestTransformer3D(nn.Module):
|
265 |
+
"""Nested Transformer (NesT)
|
266 |
+
A PyTorch impl of : `Aggregating Nested Transformers`
|
267 |
+
- https://arxiv.org/abs/2105.12723
|
268 |
+
"""
|
269 |
+
|
270 |
+
def __init__(
|
271 |
+
self,
|
272 |
+
img_size=96,
|
273 |
+
in_chans=1,
|
274 |
+
patch_size=2,
|
275 |
+
num_levels=3,
|
276 |
+
embed_dims=(128, 256, 512),
|
277 |
+
num_heads=(4, 8, 16),
|
278 |
+
depths=(2, 2, 20),
|
279 |
+
num_classes=1000,
|
280 |
+
mlp_ratio=4.0,
|
281 |
+
qkv_bias=True,
|
282 |
+
drop_rate=0.0,
|
283 |
+
attn_drop_rate=0.0,
|
284 |
+
drop_path_rate=0.5,
|
285 |
+
norm_layer=None,
|
286 |
+
act_layer=None,
|
287 |
+
pad_type="",
|
288 |
+
weight_init="",
|
289 |
+
global_pool="avg",
|
290 |
+
):
|
291 |
+
"""
|
292 |
+
Args:
|
293 |
+
img_size (int, tuple): input image size
|
294 |
+
in_chans (int): number of input channels
|
295 |
+
patch_size (int): patch size
|
296 |
+
num_levels (int): number of block hierarchies (T_d in the paper)
|
297 |
+
embed_dims (int, tuple): embedding dimensions of each level
|
298 |
+
num_heads (int, tuple): number of attention heads for each level
|
299 |
+
depths (int, tuple): number of transformer layers for each level
|
300 |
+
num_classes (int): number of classes for classification head
|
301 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim for MLP of transformer layers
|
302 |
+
qkv_bias (bool): enable bias for qkv if True
|
303 |
+
drop_rate (float): dropout rate for MLP of transformer layers, MSA final projection layer, and classifier
|
304 |
+
attn_drop_rate (float): attention dropout rate
|
305 |
+
drop_path_rate (float): stochastic depth rate
|
306 |
+
norm_layer: (nn.Module): normalization layer for transformer layers
|
307 |
+
act_layer: (nn.Module): activation layer in MLP of transformer layers
|
308 |
+
pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
|
309 |
+
weight_init: (str): weight init scheme
|
310 |
+
global_pool: (str): type of pooling operation to apply to final feature map
|
311 |
+
Notes:
|
312 |
+
- Default values follow NesT-B from the original Jax code.
|
313 |
+
- `embed_dims`, `num_heads`, `depths` should be ints or tuples with length `num_levels`.
|
314 |
+
- For those following the paper, Table A1 may have errors!
|
315 |
+
- https://github.com/google-research/nested-transformer/issues/2
|
316 |
+
"""
|
317 |
+
super().__init__()
|
318 |
+
|
319 |
+
for param_name in ["embed_dims", "num_heads", "depths"]:
|
320 |
+
param_value = locals()[param_name]
|
321 |
+
if isinstance(param_value, collections.abc.Sequence):
|
322 |
+
assert len(param_value) == num_levels, f"Require `len({param_name}) == num_levels`"
|
323 |
+
|
324 |
+
embed_dims = to_ntuple(num_levels)(embed_dims)
|
325 |
+
num_heads = to_ntuple(num_levels)(num_heads)
|
326 |
+
depths = to_ntuple(num_levels)(depths)
|
327 |
+
self.num_classes = num_classes
|
328 |
+
self.num_features = embed_dims[-1]
|
329 |
+
self.feature_info = []
|
330 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
331 |
+
act_layer = act_layer or nn.GELU
|
332 |
+
self.drop_rate = drop_rate
|
333 |
+
self.num_levels = num_levels
|
334 |
+
if isinstance(img_size, collections.abc.Sequence):
|
335 |
+
assert img_size[0] == img_size[1], "Model only handles square inputs"
|
336 |
+
img_size = img_size[0]
|
337 |
+
assert img_size % patch_size == 0, "`patch_size` must divide `img_size` evenly"
|
338 |
+
self.patch_size = patch_size
|
339 |
+
|
340 |
+
# Number of blocks at each level
|
341 |
+
self.num_blocks = (8 ** torch.arange(num_levels)).flip(0).tolist()
|
342 |
+
assert (img_size // patch_size) % round(
|
343 |
+
math.pow(self.num_blocks[0], 1 / 3)
|
344 |
+
) == 0, "First level blocks don't fit evenly. Check `img_size`, `patch_size`, and `num_levels`"
|
345 |
+
|
346 |
+
# Block edge size in units of patches
|
347 |
+
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
|
348 |
+
# number of blocks along edge of image
|
349 |
+
self.block_size = int((img_size // patch_size) // round(math.pow(self.num_blocks[0], 1 / 3)))
|
350 |
+
|
351 |
+
# Patch embedding
|
352 |
+
self.patch_embed = PatchEmbed3D(
|
353 |
+
img_size=[img_size, img_size, img_size],
|
354 |
+
patch_size=[patch_size, patch_size, patch_size],
|
355 |
+
in_chans=in_chans,
|
356 |
+
embed_dim=embed_dims[0],
|
357 |
+
)
|
358 |
+
self.num_patches = self.patch_embed.num_patches
|
359 |
+
self.seq_length = self.num_patches // self.num_blocks[0]
|
360 |
+
# Build up each hierarchical level
|
361 |
+
levels = []
|
362 |
+
|
363 |
+
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
364 |
+
prev_dim = None
|
365 |
+
curr_stride = 4
|
366 |
+
for i in range(len(self.num_blocks)):
|
367 |
+
dim = embed_dims[i]
|
368 |
+
levels.append(
|
369 |
+
NestLevel(
|
370 |
+
self.num_blocks[i],
|
371 |
+
self.block_size,
|
372 |
+
self.seq_length,
|
373 |
+
num_heads[i],
|
374 |
+
depths[i],
|
375 |
+
dim,
|
376 |
+
prev_dim,
|
377 |
+
mlp_ratio,
|
378 |
+
qkv_bias,
|
379 |
+
drop_rate,
|
380 |
+
attn_drop_rate,
|
381 |
+
dp_rates[i],
|
382 |
+
norm_layer,
|
383 |
+
act_layer,
|
384 |
+
pad_type=pad_type,
|
385 |
+
)
|
386 |
+
)
|
387 |
+
self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f"levels.{i}")]
|
388 |
+
prev_dim = dim
|
389 |
+
curr_stride *= 2
|
390 |
+
|
391 |
+
self.levels = nn.ModuleList([levels[i] for i in range(num_levels)])
|
392 |
+
|
393 |
+
# Final normalization layer
|
394 |
+
self.norm = norm_layer(embed_dims[-1])
|
395 |
+
|
396 |
+
self.init_weights(weight_init)
|
397 |
+
|
398 |
+
def init_weights(self, mode=""):
|
399 |
+
assert mode in ("nlhb", "")
|
400 |
+
head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
401 |
+
for level in self.levels:
|
402 |
+
trunc_normal_(level.pos_embed, std=0.02, a=-2, b=2)
|
403 |
+
named_apply(partial(_init_nest_weights, head_bias=head_bias), self)
|
404 |
+
|
405 |
+
@torch.jit.ignore
|
406 |
+
def no_weight_decay(self):
|
407 |
+
return {f"level.{i}.pos_embed" for i in range(len(self.levels))}
|
408 |
+
|
409 |
+
def get_classifier(self):
|
410 |
+
return self.head
|
411 |
+
|
412 |
+
def forward_features(self, x):
|
413 |
+
"""x shape (B, C, D, H, W)"""
|
414 |
+
x = self.patch_embed(x)
|
415 |
+
|
416 |
+
hidden_states_out = [x]
|
417 |
+
|
418 |
+
for _, level in enumerate(self.levels):
|
419 |
+
x = level(x)
|
420 |
+
hidden_states_out.append(x)
|
421 |
+
# Layer norm done over channel dim only (to NDHWC and back)
|
422 |
+
x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
|
423 |
+
return x, hidden_states_out
|
424 |
+
|
425 |
+
def forward(self, x):
|
426 |
+
"""x shape (B, C, D, H, W)"""
|
427 |
+
x = self.forward_features(x)
|
428 |
+
|
429 |
+
if self.drop_rate > 0.0:
|
430 |
+
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
431 |
+
return x
|
432 |
+
|
433 |
+
|
434 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
435 |
+
if not depth_first and include_root:
|
436 |
+
fn(module=module, name=name)
|
437 |
+
for child_name, child_module in module.named_children():
|
438 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
439 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
440 |
+
if depth_first and include_root:
|
441 |
+
fn(module=module, name=name)
|
442 |
+
return module
|
443 |
+
|
444 |
+
|
445 |
+
def _init_nest_weights(module: nn.Module, name: str = "", head_bias: float = 0.0):
|
446 |
+
"""NesT weight initialization
|
447 |
+
Can replicate Jax implementation. Otherwise follows vision_transformer.py
|
448 |
+
"""
|
449 |
+
if isinstance(module, nn.Linear):
|
450 |
+
if name.startswith("head"):
|
451 |
+
trunc_normal_(module.weight, std=0.02, a=-2, b=2)
|
452 |
+
nn.init.constant_(module.bias, head_bias)
|
453 |
+
else:
|
454 |
+
trunc_normal_(module.weight, std=0.02, a=-2, b=2)
|
455 |
+
if module.bias is not None:
|
456 |
+
nn.init.zeros_(module.bias)
|
457 |
+
elif isinstance(module, nn.Conv2d):
|
458 |
+
trunc_normal_(module.weight, std=0.02, a=-2, b=2)
|
459 |
+
if module.bias is not None:
|
460 |
+
nn.init.zeros_(module.bias)
|
461 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
462 |
+
nn.init.zeros_(module.bias)
|
463 |
+
nn.init.ones_(module.weight)
|
464 |
+
|
465 |
+
|
466 |
+
def resize_pos_embed(posemb, posemb_new):
|
467 |
+
"""
|
468 |
+
Rescale the grid of position embeddings when loading from state_dict
|
469 |
+
Expected shape of position embeddings is (1, T, N, C), and considers only square images
|
470 |
+
"""
|
471 |
+
_logger.info("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
|
472 |
+
seq_length_old = posemb.shape[2]
|
473 |
+
num_blocks_new, seq_length_new = posemb_new.shape[1:3]
|
474 |
+
size_new = int(math.sqrt(num_blocks_new * seq_length_new))
|
475 |
+
# First change to (1, C, H, W)
|
476 |
+
posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2)
|
477 |
+
posemb = F.interpolate(posemb, size=[size_new, size_new], mode="bicubic", align_corners=False)
|
478 |
+
# Now change to new (1, T, N, C)
|
479 |
+
posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new)))
|
480 |
+
return posemb
|
481 |
+
|
482 |
+
|
483 |
+
def checkpoint_filter_fn(state_dict, model):
|
484 |
+
"""resize positional embeddings of pretrained weights"""
|
485 |
+
pos_embed_keys = [k for k in state_dict.keys() if k.startswith("pos_embed_")]
|
486 |
+
for k in pos_embed_keys:
|
487 |
+
if state_dict[k].shape != getattr(model, k).shape:
|
488 |
+
state_dict[k] = resize_pos_embed(state_dict[k], getattr(model, k))
|
489 |
+
return state_dict
|
scripts/networks/patchEmbed3D.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2020 - 2021 MONAI Consortium
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
# Unless required by applicable law or agreed to in writing, software
|
9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11 |
+
# See the License for the specific language governing permissions and
|
12 |
+
# limitations under the License.
|
13 |
+
|
14 |
+
|
15 |
+
import math
|
16 |
+
from typing import Sequence, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from monai.utils import optional_import
|
22 |
+
|
23 |
+
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
24 |
+
|
25 |
+
|
26 |
+
class PatchEmbeddingBlock(nn.Module):
|
27 |
+
"""
|
28 |
+
A patch embedding block, based on: "Dosovitskiy et al.,
|
29 |
+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
in_channels: int,
|
35 |
+
img_size: Tuple[int, int, int],
|
36 |
+
patch_size: Tuple[int, int, int],
|
37 |
+
hidden_size: int,
|
38 |
+
num_heads: int,
|
39 |
+
pos_embed: str,
|
40 |
+
dropout_rate: float = 0.0,
|
41 |
+
) -> None:
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
in_channels: dimension of input channels.
|
45 |
+
img_size: dimension of input image.
|
46 |
+
patch_size: dimension of patch size.
|
47 |
+
hidden_size: dimension of hidden layer.
|
48 |
+
num_heads: number of attention heads.
|
49 |
+
pos_embed: position embedding layer type.
|
50 |
+
dropout_rate: faction of the input units to drop.
|
51 |
+
|
52 |
+
"""
|
53 |
+
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
if not (0 <= dropout_rate <= 1):
|
57 |
+
raise AssertionError("dropout_rate should be between 0 and 1.")
|
58 |
+
|
59 |
+
if hidden_size % num_heads != 0:
|
60 |
+
raise AssertionError("hidden size should be divisible by num_heads.")
|
61 |
+
|
62 |
+
for m, p in zip(img_size, patch_size):
|
63 |
+
if m < p:
|
64 |
+
raise AssertionError("patch_size should be smaller than img_size.")
|
65 |
+
|
66 |
+
if pos_embed not in ["conv", "perceptron"]:
|
67 |
+
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
|
68 |
+
|
69 |
+
if pos_embed == "perceptron":
|
70 |
+
if img_size[0] % patch_size[0] != 0:
|
71 |
+
raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.")
|
72 |
+
|
73 |
+
self.n_patches = (
|
74 |
+
(img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2])
|
75 |
+
)
|
76 |
+
self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2]
|
77 |
+
|
78 |
+
self.pos_embed = pos_embed
|
79 |
+
self.patch_embeddings: Union[nn.Conv3d, nn.Sequential]
|
80 |
+
if self.pos_embed == "conv":
|
81 |
+
self.patch_embeddings = nn.Conv3d(
|
82 |
+
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
|
83 |
+
)
|
84 |
+
elif self.pos_embed == "perceptron":
|
85 |
+
self.patch_embeddings = nn.Sequential(
|
86 |
+
Rearrange(
|
87 |
+
"b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)",
|
88 |
+
p1=patch_size[0],
|
89 |
+
p2=patch_size[1],
|
90 |
+
p3=patch_size[2],
|
91 |
+
),
|
92 |
+
nn.Linear(self.patch_dim, hidden_size),
|
93 |
+
)
|
94 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
|
95 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
96 |
+
self.dropout = nn.Dropout(dropout_rate)
|
97 |
+
self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
|
98 |
+
self.apply(self._init_weights)
|
99 |
+
|
100 |
+
def _init_weights(self, m):
|
101 |
+
if isinstance(m, nn.Linear):
|
102 |
+
self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
|
103 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
104 |
+
nn.init.constant_(m.bias, 0)
|
105 |
+
elif isinstance(m, nn.LayerNorm):
|
106 |
+
nn.init.constant_(m.bias, 0)
|
107 |
+
nn.init.constant_(m.weight, 1.0)
|
108 |
+
|
109 |
+
def trunc_normal_(self, tensor, mean, std, a, b):
|
110 |
+
# From PyTorch official master until it's in a few official releases - RW
|
111 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
112 |
+
def norm_cdf(x):
|
113 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
114 |
+
|
115 |
+
with torch.no_grad():
|
116 |
+
l = norm_cdf((a - mean) / std)
|
117 |
+
u = norm_cdf((b - mean) / std)
|
118 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
119 |
+
tensor.erfinv_()
|
120 |
+
tensor.mul_(std * math.sqrt(2.0))
|
121 |
+
tensor.add_(mean)
|
122 |
+
tensor.clamp_(min=a, max=b)
|
123 |
+
return tensor
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
if self.pos_embed == "conv":
|
127 |
+
x = self.patch_embeddings(x)
|
128 |
+
x = x.flatten(2)
|
129 |
+
x = x.transpose(-1, -2)
|
130 |
+
elif self.pos_embed == "perceptron":
|
131 |
+
x = self.patch_embeddings(x)
|
132 |
+
embeddings = x + self.position_embeddings
|
133 |
+
embeddings = self.dropout(embeddings)
|
134 |
+
return embeddings
|
135 |
+
|
136 |
+
|
137 |
+
class PatchEmbed3D(nn.Module):
|
138 |
+
"""Video to Patch Embedding.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
patch_size (int): Patch token size. Default: (2,4,4).
|
142 |
+
in_chans (int): Number of input video channels. Default: 3.
|
143 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
144 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
145 |
+
"""
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
img_size: Sequence[int] = (96, 96, 96),
|
150 |
+
patch_size=(4, 4, 4),
|
151 |
+
in_chans: int = 1,
|
152 |
+
embed_dim: int = 96,
|
153 |
+
norm_layer=None,
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
self.patch_size = patch_size
|
157 |
+
|
158 |
+
self.in_chans = in_chans
|
159 |
+
self.embed_dim = embed_dim
|
160 |
+
|
161 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
|
162 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
163 |
+
|
164 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
165 |
+
|
166 |
+
if norm_layer is not None:
|
167 |
+
self.norm = norm_layer(embed_dim)
|
168 |
+
else:
|
169 |
+
self.norm = None
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
"""Forward function."""
|
173 |
+
# padding
|
174 |
+
_, _, d, h, w = x.size()
|
175 |
+
if w % self.patch_size[2] != 0:
|
176 |
+
x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))
|
177 |
+
if h % self.patch_size[1] != 0:
|
178 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))
|
179 |
+
if d % self.patch_size[0] != 0:
|
180 |
+
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))
|
181 |
+
|
182 |
+
x = self.proj(x) # B C D Wh Ww
|
183 |
+
if self.norm is not None:
|
184 |
+
d, wh, ww = x.size(2), x.size(3), x.size(4)
|
185 |
+
x = x.flatten(2).transpose(1, 2)
|
186 |
+
x = self.norm(x)
|
187 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)
|
188 |
+
# pdb.set_trace()
|
189 |
+
|
190 |
+
return x
|
scripts/networks/unest_base_patch_4.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# limitations under the License.
|
3 |
+
"""
|
4 |
+
The 3D NEST transformer based segmentation model
|
5 |
+
|
6 |
+
MASI Lab, Vanderbilty University
|
7 |
+
|
8 |
+
|
9 |
+
Authors: Xin Yu, Yinchi Zhou, Yucheng Tang, Bennett Landman
|
10 |
+
|
11 |
+
|
12 |
+
The NEST code is partly from
|
13 |
+
|
14 |
+
Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and
|
15 |
+
Interpretable Visual Understanding
|
16 |
+
https://arxiv.org/pdf/2105.12723.pdf
|
17 |
+
|
18 |
+
"""
|
19 |
+
from typing import Sequence, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from monai.networks.blocks import Convolution
|
24 |
+
from monai.networks.blocks.dynunet_block import UnetOutBlock
|
25 |
+
from scripts.networks.nest_transformer_3D import NestTransformer3D
|
26 |
+
from scripts.networks.unest_block import UNesTBlock, UNesTConvBlock, UNestUpBlock
|
27 |
+
|
28 |
+
|
29 |
+
class UNesT(nn.Module):
|
30 |
+
"""
|
31 |
+
UNesT model implementation
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
in_channels: int,
|
37 |
+
out_channels: int,
|
38 |
+
img_size: Sequence[int] = (96, 96, 96),
|
39 |
+
feature_size: int = 16,
|
40 |
+
patch_size: int = 2,
|
41 |
+
depths: Sequence[int] = (2, 2, 2, 2),
|
42 |
+
num_heads: Sequence[int] = (3, 6, 12, 24),
|
43 |
+
embed_dim: Sequence[int] = (128, 256, 512),
|
44 |
+
window_size: Sequence[int] = (7, 7, 7),
|
45 |
+
norm_name: Union[Tuple, str] = "instance",
|
46 |
+
conv_block: bool = False,
|
47 |
+
res_block: bool = True,
|
48 |
+
dropout_rate: float = 0.0,
|
49 |
+
) -> None:
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
in_channels: dimension of input channels.
|
53 |
+
out_channels: dimension of output channels.
|
54 |
+
img_size: dimension of input image.
|
55 |
+
feature_size: dimension of network feature size.
|
56 |
+
hidden_size: dimension of hidden layer.
|
57 |
+
mlp_dim: dimension of feedforward layer.
|
58 |
+
num_heads: number of attention heads.
|
59 |
+
pos_embed: position embedding layer type.
|
60 |
+
norm_name: feature normalization type and arguments.
|
61 |
+
conv_block: bool argument to determine if convolutional block is used.
|
62 |
+
res_block: bool argument to determine if residual block is used.
|
63 |
+
dropout_rate: faction of the input units to drop.
|
64 |
+
|
65 |
+
Examples:
|
66 |
+
|
67 |
+
# for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm
|
68 |
+
>>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')
|
69 |
+
|
70 |
+
# for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm
|
71 |
+
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')
|
72 |
+
|
73 |
+
"""
|
74 |
+
|
75 |
+
super().__init__()
|
76 |
+
if not (0 <= dropout_rate <= 1):
|
77 |
+
raise AssertionError("dropout_rate should be between 0 and 1.")
|
78 |
+
self.embed_dim = embed_dim
|
79 |
+
self.nestViT = NestTransformer3D(
|
80 |
+
img_size=96,
|
81 |
+
in_chans=1,
|
82 |
+
patch_size=patch_size,
|
83 |
+
num_levels=3,
|
84 |
+
embed_dims=embed_dim,
|
85 |
+
num_heads=num_heads,
|
86 |
+
depths=depths,
|
87 |
+
num_classes=1000,
|
88 |
+
mlp_ratio=4.0,
|
89 |
+
qkv_bias=True,
|
90 |
+
drop_rate=0.0,
|
91 |
+
attn_drop_rate=0.0,
|
92 |
+
drop_path_rate=0.5,
|
93 |
+
norm_layer=None,
|
94 |
+
act_layer=None,
|
95 |
+
pad_type="",
|
96 |
+
weight_init="",
|
97 |
+
global_pool="avg",
|
98 |
+
)
|
99 |
+
self.encoder1 = UNesTConvBlock(
|
100 |
+
spatial_dims=3,
|
101 |
+
in_channels=1,
|
102 |
+
out_channels=feature_size * 2,
|
103 |
+
kernel_size=3,
|
104 |
+
stride=1,
|
105 |
+
norm_name=norm_name,
|
106 |
+
res_block=res_block,
|
107 |
+
)
|
108 |
+
self.encoder2 = UNestUpBlock(
|
109 |
+
spatial_dims=3,
|
110 |
+
in_channels=self.embed_dim[0],
|
111 |
+
out_channels=feature_size * 4,
|
112 |
+
num_layer=1,
|
113 |
+
kernel_size=3,
|
114 |
+
stride=1,
|
115 |
+
upsample_kernel_size=2,
|
116 |
+
norm_name=norm_name,
|
117 |
+
conv_block=False,
|
118 |
+
res_block=False,
|
119 |
+
)
|
120 |
+
|
121 |
+
self.encoder3 = UNesTConvBlock(
|
122 |
+
spatial_dims=3,
|
123 |
+
in_channels=self.embed_dim[0],
|
124 |
+
out_channels=8 * feature_size,
|
125 |
+
kernel_size=3,
|
126 |
+
stride=1,
|
127 |
+
norm_name=norm_name,
|
128 |
+
res_block=res_block,
|
129 |
+
)
|
130 |
+
self.encoder4 = UNesTConvBlock(
|
131 |
+
spatial_dims=3,
|
132 |
+
in_channels=self.embed_dim[1],
|
133 |
+
out_channels=16 * feature_size,
|
134 |
+
kernel_size=3,
|
135 |
+
stride=1,
|
136 |
+
norm_name=norm_name,
|
137 |
+
res_block=res_block,
|
138 |
+
)
|
139 |
+
self.decoder5 = UNesTBlock(
|
140 |
+
spatial_dims=3,
|
141 |
+
in_channels=2 * self.embed_dim[2],
|
142 |
+
out_channels=feature_size * 32,
|
143 |
+
stride=1,
|
144 |
+
kernel_size=3,
|
145 |
+
upsample_kernel_size=2,
|
146 |
+
norm_name=norm_name,
|
147 |
+
res_block=res_block,
|
148 |
+
)
|
149 |
+
self.decoder4 = UNesTBlock(
|
150 |
+
spatial_dims=3,
|
151 |
+
in_channels=self.embed_dim[2],
|
152 |
+
out_channels=feature_size * 16,
|
153 |
+
stride=1,
|
154 |
+
kernel_size=3,
|
155 |
+
upsample_kernel_size=2,
|
156 |
+
norm_name=norm_name,
|
157 |
+
res_block=res_block,
|
158 |
+
)
|
159 |
+
self.decoder3 = UNesTBlock(
|
160 |
+
spatial_dims=3,
|
161 |
+
in_channels=feature_size * 16,
|
162 |
+
out_channels=feature_size * 8,
|
163 |
+
stride=1,
|
164 |
+
kernel_size=3,
|
165 |
+
upsample_kernel_size=2,
|
166 |
+
norm_name=norm_name,
|
167 |
+
res_block=res_block,
|
168 |
+
)
|
169 |
+
self.decoder2 = UNesTBlock(
|
170 |
+
spatial_dims=3,
|
171 |
+
in_channels=feature_size * 8,
|
172 |
+
out_channels=feature_size * 4,
|
173 |
+
stride=1,
|
174 |
+
kernel_size=3,
|
175 |
+
upsample_kernel_size=2,
|
176 |
+
norm_name=norm_name,
|
177 |
+
res_block=res_block,
|
178 |
+
)
|
179 |
+
self.decoder1 = UNesTBlock(
|
180 |
+
spatial_dims=3,
|
181 |
+
in_channels=feature_size * 4,
|
182 |
+
out_channels=feature_size * 2,
|
183 |
+
stride=1,
|
184 |
+
kernel_size=3,
|
185 |
+
upsample_kernel_size=2,
|
186 |
+
norm_name=norm_name,
|
187 |
+
res_block=res_block,
|
188 |
+
)
|
189 |
+
self.encoder10 = Convolution(
|
190 |
+
spatial_dims=3,
|
191 |
+
in_channels=32 * feature_size,
|
192 |
+
out_channels=64 * feature_size,
|
193 |
+
strides=2,
|
194 |
+
adn_ordering="ADN",
|
195 |
+
dropout=0.0,
|
196 |
+
)
|
197 |
+
self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) # type: ignore
|
198 |
+
|
199 |
+
def proj_feat(self, x, hidden_size, feat_size):
|
200 |
+
x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
|
201 |
+
x = x.permute(0, 4, 1, 2, 3).contiguous()
|
202 |
+
return x
|
203 |
+
|
204 |
+
def load_from(self, weights):
|
205 |
+
with torch.no_grad():
|
206 |
+
# copy weights from patch embedding
|
207 |
+
for i in weights["state_dict"]:
|
208 |
+
print(i)
|
209 |
+
self.vit.patch_embedding.position_embeddings.copy_(
|
210 |
+
weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"]
|
211 |
+
)
|
212 |
+
self.vit.patch_embedding.cls_token.copy_(
|
213 |
+
weights["state_dict"]["module.transformer.patch_embedding.cls_token"]
|
214 |
+
)
|
215 |
+
self.vit.patch_embedding.patch_embeddings[1].weight.copy_(
|
216 |
+
weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.weight"]
|
217 |
+
)
|
218 |
+
self.vit.patch_embedding.patch_embeddings[1].bias.copy_(
|
219 |
+
weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.bias"]
|
220 |
+
)
|
221 |
+
|
222 |
+
# copy weights from encoding blocks (default: num of blocks: 12)
|
223 |
+
for bname, block in self.vit.blocks.named_children():
|
224 |
+
print(block)
|
225 |
+
block.loadFrom(weights, n_block=bname)
|
226 |
+
# last norm layer of transformer
|
227 |
+
self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"])
|
228 |
+
self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"])
|
229 |
+
|
230 |
+
def forward(self, x_in):
|
231 |
+
x, hidden_states_out = self.nestViT(x_in)
|
232 |
+
enc0 = self.encoder1(x_in) # 2, 32, 96, 96, 96
|
233 |
+
x1 = hidden_states_out[0] # 2, 128, 24, 24, 24 2, 128, 12, 12, 12
|
234 |
+
enc1 = self.encoder2(x1) # 2, 64, 48, 48, 48 torch.Size([2, 64, 24, 24, 24])
|
235 |
+
x2 = hidden_states_out[1] # 2, 128, 24, 24, 24
|
236 |
+
enc2 = self.encoder3(x2) # 2, 128, 24, 24, 24 torch.Size([2, 128, 12, 12, 12])
|
237 |
+
x3 = hidden_states_out[2] # 2, 256, 12, 12, 12 torch.Size([2, 256, 6, 6, 6])
|
238 |
+
enc3 = self.encoder4(x3) # 2, 256, 12, 12, 12 torch.Size([2, 256, 6, 6, 6])
|
239 |
+
x4 = hidden_states_out[3]
|
240 |
+
enc4 = x4 # 2, 512, 6, 6, 6 torch.Size([2, 512, 3, 3, 3])
|
241 |
+
dec4 = x # 2, 512, 6, 6, 6 torch.Size([2, 512, 3, 3, 3])
|
242 |
+
dec4 = self.encoder10(dec4) # 2, 1024, 3, 3, 3 torch.Size([2, 1024, 2, 2, 2])
|
243 |
+
dec3 = self.decoder5(dec4, enc4) # 2, 512, 6, 6, 6
|
244 |
+
dec2 = self.decoder4(dec3, enc3) # 2, 256, 12, 12, 12
|
245 |
+
dec1 = self.decoder3(dec2, enc2) # 2, 128, 24, 24, 24
|
246 |
+
dec0 = self.decoder2(dec1, enc1) # 2, 64, 48, 48, 48
|
247 |
+
out = self.decoder1(dec0, enc0) # 2, 32, 96, 96, 96
|
248 |
+
logits = self.out(out)
|
249 |
+
return logits
|
scripts/networks/unest_block.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
from typing import Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer
|
8 |
+
|
9 |
+
|
10 |
+
class UNesTBlock(nn.Module):
|
11 |
+
""" """
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
spatial_dims: int,
|
16 |
+
in_channels: int,
|
17 |
+
out_channels: int, # type: ignore
|
18 |
+
kernel_size: Union[Sequence[int], int],
|
19 |
+
stride: Union[Sequence[int], int],
|
20 |
+
upsample_kernel_size: Union[Sequence[int], int],
|
21 |
+
norm_name: Union[Tuple, str],
|
22 |
+
res_block: bool = False,
|
23 |
+
) -> None:
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
spatial_dims: number of spatial dimensions.
|
27 |
+
in_channels: number of input channels.
|
28 |
+
out_channels: number of output channels.
|
29 |
+
kernel_size: convolution kernel size.
|
30 |
+
stride: convolution stride.
|
31 |
+
upsample_kernel_size: convolution kernel size for transposed convolution layers.
|
32 |
+
norm_name: feature normalization type and arguments.
|
33 |
+
res_block: bool argument to determine if residual block is used.
|
34 |
+
|
35 |
+
"""
|
36 |
+
|
37 |
+
super(UNesTBlock, self).__init__()
|
38 |
+
upsample_stride = upsample_kernel_size
|
39 |
+
self.transp_conv = get_conv_layer(
|
40 |
+
spatial_dims,
|
41 |
+
in_channels,
|
42 |
+
out_channels,
|
43 |
+
kernel_size=upsample_kernel_size,
|
44 |
+
stride=upsample_stride,
|
45 |
+
conv_only=True,
|
46 |
+
is_transposed=True,
|
47 |
+
)
|
48 |
+
|
49 |
+
if res_block:
|
50 |
+
self.conv_block = UnetResBlock(
|
51 |
+
spatial_dims,
|
52 |
+
out_channels + out_channels,
|
53 |
+
out_channels,
|
54 |
+
kernel_size=kernel_size,
|
55 |
+
stride=1,
|
56 |
+
norm_name=norm_name,
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
self.conv_block = UnetBasicBlock( # type: ignore
|
60 |
+
spatial_dims,
|
61 |
+
out_channels + out_channels,
|
62 |
+
out_channels,
|
63 |
+
kernel_size=kernel_size,
|
64 |
+
stride=1,
|
65 |
+
norm_name=norm_name,
|
66 |
+
)
|
67 |
+
|
68 |
+
def forward(self, inp, skip):
|
69 |
+
# number of channels for skip should equals to out_channels
|
70 |
+
out = self.transp_conv(inp)
|
71 |
+
# print(out.shape)
|
72 |
+
# print(skip.shape)
|
73 |
+
out = torch.cat((out, skip), dim=1)
|
74 |
+
out = self.conv_block(out)
|
75 |
+
return out
|
76 |
+
|
77 |
+
|
78 |
+
class UNestUpBlock(nn.Module):
|
79 |
+
""" """
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
spatial_dims: int,
|
84 |
+
in_channels: int,
|
85 |
+
out_channels: int,
|
86 |
+
num_layer: int,
|
87 |
+
kernel_size: Union[Sequence[int], int],
|
88 |
+
stride: Union[Sequence[int], int],
|
89 |
+
upsample_kernel_size: Union[Sequence[int], int],
|
90 |
+
norm_name: Union[Tuple, str],
|
91 |
+
conv_block: bool = False,
|
92 |
+
res_block: bool = False,
|
93 |
+
) -> None:
|
94 |
+
"""
|
95 |
+
Args:
|
96 |
+
spatial_dims: number of spatial dimensions.
|
97 |
+
in_channels: number of input channels.
|
98 |
+
out_channels: number of output channels.
|
99 |
+
num_layer: number of upsampling blocks.
|
100 |
+
kernel_size: convolution kernel size.
|
101 |
+
stride: convolution stride.
|
102 |
+
upsample_kernel_size: convolution kernel size for transposed convolution layers.
|
103 |
+
norm_name: feature normalization type and arguments.
|
104 |
+
conv_block: bool argument to determine if convolutional block is used.
|
105 |
+
res_block: bool argument to determine if residual block is used.
|
106 |
+
|
107 |
+
"""
|
108 |
+
|
109 |
+
super().__init__()
|
110 |
+
|
111 |
+
upsample_stride = upsample_kernel_size
|
112 |
+
self.transp_conv_init = get_conv_layer(
|
113 |
+
spatial_dims,
|
114 |
+
in_channels,
|
115 |
+
out_channels,
|
116 |
+
kernel_size=upsample_kernel_size,
|
117 |
+
stride=upsample_stride,
|
118 |
+
conv_only=True,
|
119 |
+
is_transposed=True,
|
120 |
+
)
|
121 |
+
if conv_block:
|
122 |
+
if res_block:
|
123 |
+
self.blocks = nn.ModuleList(
|
124 |
+
[
|
125 |
+
nn.Sequential(
|
126 |
+
get_conv_layer(
|
127 |
+
spatial_dims,
|
128 |
+
out_channels,
|
129 |
+
out_channels,
|
130 |
+
kernel_size=upsample_kernel_size,
|
131 |
+
stride=upsample_stride,
|
132 |
+
conv_only=True,
|
133 |
+
is_transposed=True,
|
134 |
+
),
|
135 |
+
UnetResBlock(
|
136 |
+
spatial_dims=3,
|
137 |
+
in_channels=out_channels,
|
138 |
+
out_channels=out_channels,
|
139 |
+
kernel_size=kernel_size,
|
140 |
+
stride=stride,
|
141 |
+
norm_name=norm_name,
|
142 |
+
),
|
143 |
+
)
|
144 |
+
for i in range(num_layer)
|
145 |
+
]
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
self.blocks = nn.ModuleList(
|
149 |
+
[
|
150 |
+
nn.Sequential(
|
151 |
+
get_conv_layer(
|
152 |
+
spatial_dims,
|
153 |
+
out_channels,
|
154 |
+
out_channels,
|
155 |
+
kernel_size=upsample_kernel_size,
|
156 |
+
stride=upsample_stride,
|
157 |
+
conv_only=True,
|
158 |
+
is_transposed=True,
|
159 |
+
),
|
160 |
+
UnetBasicBlock(
|
161 |
+
spatial_dims=3,
|
162 |
+
in_channels=out_channels,
|
163 |
+
out_channels=out_channels,
|
164 |
+
kernel_size=kernel_size,
|
165 |
+
stride=stride,
|
166 |
+
norm_name=norm_name,
|
167 |
+
),
|
168 |
+
)
|
169 |
+
for i in range(num_layer)
|
170 |
+
]
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
self.blocks = nn.ModuleList(
|
174 |
+
[
|
175 |
+
get_conv_layer(
|
176 |
+
spatial_dims,
|
177 |
+
out_channels,
|
178 |
+
out_channels,
|
179 |
+
kernel_size=1,
|
180 |
+
stride=1,
|
181 |
+
conv_only=True,
|
182 |
+
is_transposed=True,
|
183 |
+
)
|
184 |
+
for i in range(num_layer)
|
185 |
+
]
|
186 |
+
)
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
x = self.transp_conv_init(x)
|
190 |
+
for blk in self.blocks:
|
191 |
+
x = blk(x)
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class UNesTConvBlock(nn.Module):
|
196 |
+
"""
|
197 |
+
UNesT block with skip connections
|
198 |
+
"""
|
199 |
+
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
spatial_dims: int,
|
203 |
+
in_channels: int,
|
204 |
+
out_channels: int,
|
205 |
+
kernel_size: Union[Sequence[int], int],
|
206 |
+
stride: Union[Sequence[int], int],
|
207 |
+
norm_name: Union[Tuple, str],
|
208 |
+
res_block: bool = False,
|
209 |
+
) -> None:
|
210 |
+
"""
|
211 |
+
Args:
|
212 |
+
spatial_dims: number of spatial dimensions.
|
213 |
+
in_channels: number of input channels.
|
214 |
+
out_channels: number of output channels.
|
215 |
+
kernel_size: convolution kernel size.
|
216 |
+
stride: convolution stride.
|
217 |
+
norm_name: feature normalization type and arguments.
|
218 |
+
res_block: bool argument to determine if residual block is used.
|
219 |
+
|
220 |
+
"""
|
221 |
+
|
222 |
+
super().__init__()
|
223 |
+
|
224 |
+
if res_block:
|
225 |
+
self.layer = UnetResBlock(
|
226 |
+
spatial_dims=spatial_dims,
|
227 |
+
in_channels=in_channels,
|
228 |
+
out_channels=out_channels,
|
229 |
+
kernel_size=kernel_size,
|
230 |
+
stride=stride,
|
231 |
+
norm_name=norm_name,
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
self.layer = UnetBasicBlock( # type: ignore
|
235 |
+
spatial_dims=spatial_dims,
|
236 |
+
in_channels=in_channels,
|
237 |
+
out_channels=out_channels,
|
238 |
+
kernel_size=kernel_size,
|
239 |
+
stride=stride,
|
240 |
+
norm_name=norm_name,
|
241 |
+
)
|
242 |
+
|
243 |
+
def forward(self, inp):
|
244 |
+
out = self.layer(inp)
|
245 |
+
return out
|