Venn commited on
Commit
02aa18d
·
verified ·
1 Parent(s): 594437e

Upload maisi_ct_generative version 1.0.0

Browse files
LICENSE ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Code License
2
+
3
+ This license applies to all files except the model weights in the directory.
4
+
5
+ Apache License
6
+ Version 2.0, January 2004
7
+ http://www.apache.org/licenses/
8
+
9
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10
+
11
+ 1. Definitions.
12
+
13
+ "License" shall mean the terms and conditions for use, reproduction,
14
+ and distribution as defined by Sections 1 through 9 of this document.
15
+
16
+ "Licensor" shall mean the copyright owner or entity authorized by
17
+ the copyright owner that is granting the License.
18
+
19
+ "Legal Entity" shall mean the union of the acting entity and all
20
+ other entities that control, are controlled by, or are under common
21
+ control with that entity. For the purposes of this definition,
22
+ "control" means (i) the power, direct or indirect, to cause the
23
+ direction or management of such entity, whether by contract or
24
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
25
+ outstanding shares, or (iii) beneficial ownership of such entity.
26
+
27
+ "You" (or "Your") shall mean an individual or Legal Entity
28
+ exercising permissions granted by this License.
29
+
30
+ "Source" form shall mean the preferred form for making modifications,
31
+ including but not limited to software source code, documentation
32
+ source, and configuration files.
33
+
34
+ "Object" form shall mean any form resulting from mechanical
35
+ transformation or translation of a Source form, including but
36
+ not limited to compiled object code, generated documentation,
37
+ and conversions to other media types.
38
+
39
+ "Work" shall mean the work of authorship, whether in Source or
40
+ Object form, made available under the License, as indicated by a
41
+ copyright notice that is included in or attached to the work
42
+ (an example is provided in the Appendix below).
43
+
44
+ "Derivative Works" shall mean any work, whether in Source or Object
45
+ form, that is based on (or derived from) the Work and for which the
46
+ editorial revisions, annotations, elaborations, or other modifications
47
+ represent, as a whole, an original work of authorship. For the purposes
48
+ of this License, Derivative Works shall not include works that remain
49
+ separable from, or merely link (or bind by name) to the interfaces of,
50
+ the Work and Derivative Works thereof.
51
+
52
+ "Contribution" shall mean any work of authorship, including
53
+ the original version of the Work and any modifications or additions
54
+ to that Work or Derivative Works thereof, that is intentionally
55
+ submitted to Licensor for inclusion in the Work by the copyright owner
56
+ or by an individual or Legal Entity authorized to submit on behalf of
57
+ the copyright owner. For the purposes of this definition, "submitted"
58
+ means any form of electronic, verbal, or written communication sent
59
+ to the Licensor or its representatives, including but not limited to
60
+ communication on electronic mailing lists, source code control systems,
61
+ and issue tracking systems that are managed by, or on behalf of, the
62
+ Licensor for the purpose of discussing and improving the Work, but
63
+ excluding communication that is conspicuously marked or otherwise
64
+ designated in writing by the copyright owner as "Not a Contribution."
65
+
66
+ "Contributor" shall mean Licensor and any individual or Legal Entity
67
+ on behalf of whom a Contribution has been received by Licensor and
68
+ subsequently incorporated within the Work.
69
+
70
+ 2. Grant of Copyright License. Subject to the terms and conditions of
71
+ this License, each Contributor hereby grants to You a perpetual,
72
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
73
+ copyright license to reproduce, prepare Derivative Works of,
74
+ publicly display, publicly perform, sublicense, and distribute the
75
+ Work and such Derivative Works in Source or Object form.
76
+
77
+ 3. Grant of Patent License. Subject to the terms and conditions of
78
+ this License, each Contributor hereby grants to You a perpetual,
79
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80
+ (except as stated in this section) patent license to make, have made,
81
+ use, offer to sell, sell, import, and otherwise transfer the Work,
82
+ where such license applies only to those patent claims licensable
83
+ by such Contributor that are necessarily infringed by their
84
+ Contribution(s) alone or by combination of their Contribution(s)
85
+ with the Work to which such Contribution(s) was submitted. If You
86
+ institute patent litigation against any entity (including a
87
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
88
+ or a Contribution incorporated within the Work constitutes direct
89
+ or contributory patent infringement, then any patent licenses
90
+ granted to You under this License for that Work shall terminate
91
+ as of the date such litigation is filed.
92
+
93
+ 4. Redistribution. You may reproduce and distribute copies of the
94
+ Work or Derivative Works thereof in any medium, with or without
95
+ modifications, and in Source or Object form, provided that You
96
+ meet the following conditions:
97
+
98
+ (a) You must give any other recipients of the Work or
99
+ Derivative Works a copy of this License; and
100
+
101
+ (b) You must cause any modified files to carry prominent notices
102
+ stating that You changed the files; and
103
+
104
+ (c) You must retain, in the Source form of any Derivative Works
105
+ that You distribute, all copyright, patent, trademark, and
106
+ attribution notices from the Source form of the Work,
107
+ excluding those notices that do not pertain to any part of
108
+ the Derivative Works; and
109
+
110
+ (d) If the Work includes a "NOTICE" text file as part of its
111
+ distribution, then any Derivative Works that You distribute must
112
+ include a readable copy of the attribution notices contained
113
+ within such NOTICE file, excluding those notices that do not
114
+ pertain to any part of the Derivative Works, in at least one
115
+ of the following places: within a NOTICE text file distributed
116
+ as part of the Derivative Works; within the Source form or
117
+ documentation, if provided along with the Derivative Works; or,
118
+ within a display generated by the Derivative Works, if and
119
+ wherever such third-party notices normally appear. The contents
120
+ of the NOTICE file are for informational purposes only and
121
+ do not modify the License. You may add Your own attribution
122
+ notices within Derivative Works that You distribute, alongside
123
+ or as an addendum to the NOTICE text from the Work, provided
124
+ that such additional attribution notices cannot be construed
125
+ as modifying the License.
126
+
127
+ You may add Your own copyright statement to Your modifications and
128
+ may provide additional or different license terms and conditions
129
+ for use, reproduction, or distribution of Your modifications, or
130
+ for any such Derivative Works as a whole, provided Your use,
131
+ reproduction, and distribution of the Work otherwise complies with
132
+ the conditions stated in this License.
133
+
134
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
135
+ any Contribution intentionally submitted for inclusion in the Work
136
+ by You to the Licensor shall be under the terms and conditions of
137
+ this License, without any additional terms or conditions.
138
+ Notwithstanding the above, nothing herein shall supersede or modify
139
+ the terms of any separate license agreement you may have executed
140
+ with Licensor regarding such Contributions.
141
+
142
+ 6. Trademarks. This License does not grant permission to use the trade
143
+ names, trademarks, service marks, or product names of the Licensor,
144
+ except as required for reasonable and customary use in describing the
145
+ origin of the Work and reproducing the content of the NOTICE file.
146
+
147
+ 7. Disclaimer of Warranty. Unless required by applicable law or
148
+ agreed to in writing, Licensor provides the Work (and each
149
+ Contributor provides its Contributions) on an "AS IS" BASIS,
150
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
151
+ implied, including, without limitation, any warranties or conditions
152
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
153
+ PARTICULAR PURPOSE. You are solely responsible for determining the
154
+ appropriateness of using or redistributing the Work and assume any
155
+ risks associated with Your exercise of permissions under this License.
156
+
157
+ 8. Limitation of Liability. In no event and under no legal theory,
158
+ whether in tort (including negligence), contract, or otherwise,
159
+ unless required by applicable law (such as deliberate and grossly
160
+ negligent acts) or agreed to in writing, shall any Contributor be
161
+ liable to You for damages, including any direct, indirect, special,
162
+ incidental, or consequential damages of any character arising as a
163
+ result of this License or out of the use or inability to use the
164
+ Work (including but not limited to damages for loss of goodwill,
165
+ work stoppage, computer failure or malfunction, or any and all
166
+ other commercial damages or losses), even if such Contributor
167
+ has been advised of the possibility of such damages.
168
+
169
+ 9. Accepting Warranty or Additional Liability. While redistributing
170
+ the Work or Derivative Works thereof, You may choose to offer,
171
+ and charge a fee for, acceptance of support, warranty, indemnity,
172
+ or other liability obligations and/or rights consistent with this
173
+ License. However, in accepting such obligations, You may act only
174
+ on Your own behalf and on Your sole responsibility, not on behalf
175
+ of any other Contributor, and only if You agree to indemnify,
176
+ defend, and hold each Contributor harmless for any liability
177
+ incurred by, or claims asserted against, such Contributor by reason
178
+ of your accepting any such warranty or additional liability.
179
+
180
+ END OF TERMS AND CONDITIONS
181
+
182
+ APPENDIX: How to apply the Apache License to your work.
183
+
184
+ To apply the Apache License to your work, attach the following
185
+ boilerplate notice, with the fields enclosed by brackets "[]"
186
+ replaced with your own identifying information. (Don't include
187
+ the brackets!) The text should be enclosed in the appropriate
188
+ comment syntax for the file format. We also recommend that a
189
+ file or class name and description of purpose be included on the
190
+ same "printed page" as the copyright notice for easier
191
+ identification within third-party archives.
192
+
193
+ Copyright [yyyy] [name of copyright owner]
194
+
195
+ Licensed under the Apache License, Version 2.0 (the "License");
196
+ you may not use this file except in compliance with the License.
197
+ You may obtain a copy of the License at
198
+
199
+ http://www.apache.org/licenses/LICENSE-2.0
200
+
201
+ Unless required by applicable law or agreed to in writing, software
202
+ distributed under the License is distributed on an "AS IS" BASIS,
203
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
204
+ See the License for the specific language governing permissions and
205
+ limitations under the License.
206
+
207
+ ------------------------------------------------------------------------------
208
+
209
+ Model Weights License
210
+
211
+ This license applies to model weights in the directory.
212
+
213
+ NVIDIA License
214
+
215
+ 1. Definitions
216
+
217
+ “Licensor” means any person or entity that distributes its Work.
218
+ “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
219
+ The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
220
+ Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
221
+
222
+ 2. License Grant
223
+
224
+ 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
225
+
226
+ 3. Limitations
227
+
228
+ 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
229
+
230
+ 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
231
+
232
+ 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
233
+
234
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
235
+
236
+ 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
237
+
238
+ 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
239
+
240
+ 4. Disclaimer of Warranty.
241
+
242
+ THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
243
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
244
+
245
+ 5. Limitation of Liability.
246
+
247
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
configs/all_anatomy_size_condtions.json ADDED
The diff for this file is too large to render. See raw diff
 
configs/candidate_masks_flexible_size_and_spacing_3000.json ADDED
The diff for this file is too large to render. See raw diff
 
configs/image_median_statistics.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "liver": {
3
+ "min_median": -14.0,
4
+ "max_median": 1000.0,
5
+ "percentile_0_5": 9.530000000000001,
6
+ "percentile_99_5": 162.0,
7
+ "sigma_6_low": -21.596463547885904,
8
+ "sigma_6_high": 156.27881534763367,
9
+ "sigma_12_low": -110.53410299564568,
10
+ "sigma_12_high": 245.21645479539342
11
+ },
12
+ "spleen": {
13
+ "min_median": -69.0,
14
+ "max_median": 1000.0,
15
+ "percentile_0_5": 16.925000000000004,
16
+ "percentile_99_5": 184.07500000000073,
17
+ "sigma_6_low": -43.133891656525165,
18
+ "sigma_6_high": 177.40494997185993,
19
+ "sigma_12_low": -153.4033124707177,
20
+ "sigma_12_high": 287.6743707860525
21
+ },
22
+ "pancreas": {
23
+ "min_median": -124.0,
24
+ "max_median": 1000.0,
25
+ "percentile_0_5": -29.0,
26
+ "percentile_99_5": 145.92000000000007,
27
+ "sigma_6_low": -56.59382515620725,
28
+ "sigma_6_high": 149.50627399318438,
29
+ "sigma_12_low": -159.64387473090306,
30
+ "sigma_12_high": 252.5563235678802
31
+ },
32
+ "kidney": {
33
+ "min_median": -165.5,
34
+ "max_median": 819.0,
35
+ "percentile_0_5": -40.0,
36
+ "percentile_99_5": 254.61999999999898,
37
+ "sigma_6_low": -130.56375604853028,
38
+ "sigma_6_high": 267.28163511081016,
39
+ "sigma_12_low": -329.4864516282005,
40
+ "sigma_12_high": 466.20433069048045
41
+ },
42
+ "lung": {
43
+ "min_median": -1000.0,
44
+ "max_median": 65.0,
45
+ "percentile_0_5": -937.0,
46
+ "percentile_99_5": -366.9500000000007,
47
+ "sigma_6_low": -1088.5583843889117,
48
+ "sigma_6_high": -551.8503346949108,
49
+ "sigma_12_low": -1356.912409235912,
50
+ "sigma_12_high": -283.4963098479103
51
+ },
52
+ "bone": {
53
+ "min_median": 77.5,
54
+ "max_median": 1000.0,
55
+ "percentile_0_5": 136.45499999999998,
56
+ "percentile_99_5": 551.6350000000002,
57
+ "sigma_6_low": 71.39901958080469,
58
+ "sigma_6_high": 471.9957615639765,
59
+ "sigma_12_low": -128.8993514107812,
60
+ "sigma_12_high": 672.2941325555623
61
+ },
62
+ "brain": {
63
+ "min_median": -1000.0,
64
+ "max_median": 238.0,
65
+ "percentile_0_5": -951.0,
66
+ "percentile_99_5": 126.25,
67
+ "sigma_6_low": -304.8208236135867,
68
+ "sigma_6_high": 369.5118535139189,
69
+ "sigma_12_low": -641.9871621773394,
70
+ "sigma_12_high": 706.6781920776717
71
+ }
72
+ }
configs/inference.json ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import torch",
4
+ "$from pathlib import Path",
5
+ "$import scripts"
6
+ ],
7
+ "bundle_root": ".",
8
+ "model_dir": "$@bundle_root + '/models'",
9
+ "output_dir": "$@bundle_root + '/output'",
10
+ "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
11
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
12
+ "trained_autoencoder_path": "$@model_dir + '/autoencoder.pt'",
13
+ "trained_diffusion_path": "$@model_dir + '/diffusion_unet.pt'",
14
+ "trained_controlnet_path": "$@model_dir + '/controlnet.pt'",
15
+ "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'",
16
+ "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'",
17
+ "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'",
18
+ "all_mask_files_json": "$@bundle_root + '/configs/candidate_masks_flexible_size_and_spacing_3000.json'",
19
+ "all_anatomy_size_condtions_json": "$@bundle_root + '/configs/all_anatomy_size_condtions.json'",
20
+ "label_dict_json": "$@bundle_root + '/configs/label_dict.json'",
21
+ "label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'",
22
+ "real_img_median_statistics_file": "$@bundle_root + '/configs/image_median_statistics.json'",
23
+ "num_output_samples": 1,
24
+ "body_region": [],
25
+ "anatomy_list": [
26
+ "liver"
27
+ ],
28
+ "modality": "ct",
29
+ "controllable_anatomy_size": [],
30
+ "num_inference_steps": 30,
31
+ "mask_generation_num_inference_steps": 1000,
32
+ "random_seed": null,
33
+ "spatial_dims": 3,
34
+ "image_channels": 1,
35
+ "latent_channels": 4,
36
+ "output_size_xy": 512,
37
+ "output_size_z": 512,
38
+ "output_size": [
39
+ "@output_size_xy",
40
+ "@output_size_xy",
41
+ "@output_size_z"
42
+ ],
43
+ "image_output_ext": ".nii.gz",
44
+ "label_output_ext": ".nii.gz",
45
+ "spacing_xy": 1.0,
46
+ "spacing_z": 1.0,
47
+ "spacing": [
48
+ "@spacing_xy",
49
+ "@spacing_xy",
50
+ "@spacing_z"
51
+ ],
52
+ "latent_shape": [
53
+ "@latent_channels",
54
+ "$@output_size[0]//4",
55
+ "$@output_size[1]//4",
56
+ "$@output_size[2]//4"
57
+ ],
58
+ "mask_generation_latent_shape": [
59
+ 4,
60
+ 64,
61
+ 64,
62
+ 64
63
+ ],
64
+ "autoencoder_sliding_window_infer_size": [
65
+ 80,
66
+ 80,
67
+ 80
68
+ ],
69
+ "autoencoder_sliding_window_infer_overlap": 0.4,
70
+ "autoencoder_def": {
71
+ "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
72
+ "spatial_dims": "@spatial_dims",
73
+ "in_channels": "@image_channels",
74
+ "out_channels": "@image_channels",
75
+ "latent_channels": "@latent_channels",
76
+ "num_channels": [
77
+ 64,
78
+ 128,
79
+ 256
80
+ ],
81
+ "num_res_blocks": [
82
+ 2,
83
+ 2,
84
+ 2
85
+ ],
86
+ "norm_num_groups": 32,
87
+ "norm_eps": 1e-06,
88
+ "attention_levels": [
89
+ false,
90
+ false,
91
+ false
92
+ ],
93
+ "with_encoder_nonlocal_attn": false,
94
+ "with_decoder_nonlocal_attn": false,
95
+ "use_checkpointing": false,
96
+ "use_convtranspose": false,
97
+ "norm_float16": true,
98
+ "num_splits": 2,
99
+ "dim_split": 1
100
+ },
101
+ "diffusion_unet_def": {
102
+ "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
103
+ "spatial_dims": "@spatial_dims",
104
+ "in_channels": "@latent_channels",
105
+ "out_channels": "@latent_channels",
106
+ "num_channels": [
107
+ 64,
108
+ 128,
109
+ 256,
110
+ 512
111
+ ],
112
+ "attention_levels": [
113
+ false,
114
+ false,
115
+ true,
116
+ true
117
+ ],
118
+ "num_head_channels": [
119
+ 0,
120
+ 0,
121
+ 32,
122
+ 32
123
+ ],
124
+ "num_res_blocks": 2,
125
+ "use_flash_attention": true,
126
+ "include_top_region_index_input": false,
127
+ "include_bottom_region_index_input": false,
128
+ "include_spacing_input": true,
129
+ "num_class_embeds": 128,
130
+ "resblock_updown": true,
131
+ "include_fc": true
132
+ },
133
+ "controlnet_def": {
134
+ "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
135
+ "spatial_dims": "@spatial_dims",
136
+ "in_channels": "@latent_channels",
137
+ "num_channels": [
138
+ 64,
139
+ 128,
140
+ 256,
141
+ 512
142
+ ],
143
+ "attention_levels": [
144
+ false,
145
+ false,
146
+ true,
147
+ true
148
+ ],
149
+ "num_head_channels": [
150
+ 0,
151
+ 0,
152
+ 32,
153
+ 32
154
+ ],
155
+ "num_res_blocks": 2,
156
+ "use_flash_attention": true,
157
+ "conditioning_embedding_in_channels": 8,
158
+ "conditioning_embedding_num_channels": [
159
+ 8,
160
+ 32,
161
+ 64
162
+ ],
163
+ "num_class_embeds": 128,
164
+ "resblock_updown": true,
165
+ "include_fc": true
166
+ },
167
+ "mask_generation_autoencoder_def": {
168
+ "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
169
+ "spatial_dims": "@spatial_dims",
170
+ "in_channels": 8,
171
+ "out_channels": 125,
172
+ "latent_channels": "@latent_channels",
173
+ "num_channels": [
174
+ 32,
175
+ 64,
176
+ 128
177
+ ],
178
+ "num_res_blocks": [
179
+ 1,
180
+ 2,
181
+ 2
182
+ ],
183
+ "norm_num_groups": 32,
184
+ "norm_eps": 1e-06,
185
+ "attention_levels": [
186
+ false,
187
+ false,
188
+ false
189
+ ],
190
+ "with_encoder_nonlocal_attn": false,
191
+ "with_decoder_nonlocal_attn": false,
192
+ "use_flash_attention": false,
193
+ "use_checkpointing": true,
194
+ "use_convtranspose": true,
195
+ "norm_float16": true,
196
+ "num_splits": 8,
197
+ "dim_split": 1
198
+ },
199
+ "mask_generation_diffusion_def": {
200
+ "_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
201
+ "spatial_dims": "@spatial_dims",
202
+ "in_channels": "@latent_channels",
203
+ "out_channels": "@latent_channels",
204
+ "channels": [
205
+ 64,
206
+ 128,
207
+ 256,
208
+ 512
209
+ ],
210
+ "attention_levels": [
211
+ false,
212
+ false,
213
+ true,
214
+ true
215
+ ],
216
+ "num_head_channels": [
217
+ 0,
218
+ 0,
219
+ 32,
220
+ 32
221
+ ],
222
+ "num_res_blocks": 2,
223
+ "use_flash_attention": true,
224
+ "with_conditioning": true,
225
+ "upcast_attention": true,
226
+ "cross_attention_dim": 10
227
+ },
228
+ "autoencoder": "$@autoencoder_def.to(@device)",
229
+ "checkpoint_autoencoder": "$torch.load(@trained_autoencoder_path, weights_only=True)",
230
+ "load_autoencoder": "[email protected]_state_dict(@checkpoint_autoencoder)",
231
+ "diffusion_unet": "$@diffusion_unet_def.to(@device)",
232
+ "checkpoint_diffusion_unet": "$torch.load(@trained_diffusion_path, weights_only=False)",
233
+ "load_diffusion": "$@diffusion_unet.load_state_dict(@checkpoint_diffusion_unet['unet_state_dict'])",
234
+ "controlnet": "$@controlnet_def.to(@device)",
235
+ "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @diffusion_unet.state_dict())",
236
+ "checkpoint_controlnet": "$torch.load(@trained_controlnet_path, weights_only=False)",
237
+ "load_controlnet": "[email protected]_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)",
238
+ "scale_factor": "$@checkpoint_diffusion_unet['scale_factor'].to(@device)",
239
+ "mask_generation_autoencoder": "$@mask_generation_autoencoder_def.to(@device)",
240
+ "checkpoint_mask_generation_autoencoder": "$torch.load(@trained_mask_generation_autoencoder_path, weights_only=True)",
241
+ "load_mask_generation_autoencoder": "$@mask_generation_autoencoder.load_state_dict(@checkpoint_mask_generation_autoencoder, strict=True)",
242
+ "mask_generation_diffusion_unet": "$@mask_generation_diffusion_def.to(@device)",
243
+ "checkpoint_mask_generation_diffusion_unet": "$torch.load(@trained_mask_generation_diffusion_path, weights_only=True)",
244
+ "load_mask_generation_diffusion": "$@mask_generation_diffusion_unet.load_state_dict(@checkpoint_mask_generation_diffusion_unet['unet_state_dict'], strict=True)",
245
+ "mask_generation_scale_factor": "$@checkpoint_mask_generation_diffusion_unet['scale_factor']",
246
+ "noise_scheduler": {
247
+ "_target_": "scripts.rectified_flow.RFlowScheduler",
248
+ "num_train_timesteps": 1000,
249
+ "use_discrete_timesteps": false,
250
+ "use_timestep_transform": true,
251
+ "sample_method": "uniform"
252
+ },
253
+ "mask_generation_noise_scheduler": {
254
+ "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
255
+ "num_train_timesteps": 1000,
256
+ "beta_start": 0.0015,
257
+ "beta_end": 0.0195,
258
+ "schedule": "scaled_linear_beta",
259
+ "clip_sample": false
260
+ },
261
+ "check_input": "$scripts.sample.check_input(@body_region,@anatomy_list,@label_dict_json,@output_size,@spacing,@controllable_anatomy_size)",
262
+ "ldm_sampler": {
263
+ "_target_": "scripts.sample.LDMSampler",
264
+ "_requires_": [
265
+ "@create_output_dir",
266
+ "@load_diffusion",
267
+ "@load_autoencoder",
268
+ "@copy_controlnet_state",
269
+ "@load_controlnet",
270
+ "@load_mask_generation_autoencoder",
271
+ "@load_mask_generation_diffusion",
272
+ "@check_input"
273
+ ],
274
+ "body_region": "@body_region",
275
+ "anatomy_list": "@anatomy_list",
276
+ "modality": "@modality",
277
+ "all_mask_files_json": "@all_mask_files_json",
278
+ "all_anatomy_size_condtions_json": "@all_anatomy_size_condtions_json",
279
+ "all_mask_files_base_dir": "@all_mask_files_base_dir",
280
+ "label_dict_json": "@label_dict_json",
281
+ "label_dict_remap_json": "@label_dict_remap_json",
282
+ "autoencoder": "@autoencoder",
283
+ "diffusion_unet": "@diffusion_unet",
284
+ "controlnet": "@controlnet",
285
+ "scale_factor": "@scale_factor",
286
+ "noise_scheduler": "@noise_scheduler",
287
+ "mask_generation_autoencoder": "@mask_generation_autoencoder",
288
+ "mask_generation_diffusion_unet": "@mask_generation_diffusion_unet",
289
+ "mask_generation_scale_factor": "@mask_generation_scale_factor",
290
+ "mask_generation_noise_scheduler": "@mask_generation_noise_scheduler",
291
+ "controllable_anatomy_size": "@controllable_anatomy_size",
292
+ "image_output_ext": "@image_output_ext",
293
+ "label_output_ext": "@label_output_ext",
294
+ "real_img_median_statistics": "@real_img_median_statistics_file",
295
+ "device": "@device",
296
+ "latent_shape": "@latent_shape",
297
+ "mask_generation_latent_shape": "@mask_generation_latent_shape",
298
+ "output_size": "@output_size",
299
+ "spacing": "@spacing",
300
+ "output_dir": "@output_dir",
301
+ "num_inference_steps": "@num_inference_steps",
302
+ "mask_generation_num_inference_steps": "@mask_generation_num_inference_steps",
303
+ "random_seed": "@random_seed",
304
+ "autoencoder_sliding_window_infer_size": "@autoencoder_sliding_window_infer_size",
305
+ "autoencoder_sliding_window_infer_overlap": "@autoencoder_sliding_window_infer_overlap"
306
+ },
307
+ "run": [
308
+ "$monai.utils.set_determinism(seed=@random_seed)",
309
+ "$@ldm_sampler.sample_multiple_images(@num_output_samples)"
310
+ ],
311
+ "evaluator": null
312
+ }
configs/inference_trt.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "+imports": [
3
+ "$from monai.networks import trt_compile"
4
+ ],
5
+ "c_trt_args": {
6
+ "export_args": {
7
+ "dynamo": "$False",
8
+ "report": "$True"
9
+ },
10
+ "output_lists": [
11
+ [
12
+ -1
13
+ ],
14
+ []
15
+ ]
16
+ },
17
+ "controlnet": "$trt_compile(@controlnet_def.to(@device), @trained_controlnet_path, @c_trt_args)",
18
+ "diffusion_unet": "$trt_compile(@diffusion_unet_def.to(@device), @trained_diffusion_path)"
19
+ }
configs/integration_test_masks.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "bottom_region_index": [
4
+ 0,
5
+ 0,
6
+ 0,
7
+ 1
8
+ ],
9
+ "dim": [
10
+ 512,
11
+ 512,
12
+ 512
13
+ ],
14
+ "label_list": [
15
+ 1,
16
+ 3,
17
+ 4,
18
+ 5,
19
+ 6,
20
+ 7,
21
+ 8,
22
+ 9,
23
+ 10,
24
+ 11,
25
+ 12,
26
+ 13,
27
+ 14,
28
+ 15,
29
+ 17,
30
+ 19,
31
+ 25,
32
+ 28,
33
+ 29,
34
+ 31,
35
+ 32,
36
+ 33,
37
+ 34,
38
+ 35,
39
+ 36,
40
+ 37,
41
+ 38,
42
+ 39,
43
+ 40,
44
+ 41,
45
+ 42,
46
+ 58,
47
+ 59,
48
+ 60,
49
+ 61,
50
+ 62,
51
+ 69,
52
+ 70,
53
+ 71,
54
+ 72,
55
+ 73,
56
+ 74,
57
+ 81,
58
+ 82,
59
+ 83,
60
+ 84,
61
+ 85,
62
+ 86,
63
+ 93,
64
+ 94,
65
+ 95,
66
+ 96,
67
+ 97,
68
+ 98,
69
+ 99,
70
+ 100,
71
+ 101,
72
+ 102,
73
+ 103,
74
+ 104,
75
+ 105,
76
+ 106,
77
+ 107,
78
+ 114,
79
+ 115,
80
+ 118,
81
+ 121,
82
+ 122,
83
+ 127
84
+ ],
85
+ "pseudo_label_filename": "./IntegrationTest-AbdomenCT.nii.gz",
86
+ "spacing": [
87
+ 1.0,
88
+ 1.0,
89
+ 1.0
90
+ ],
91
+ "top_region_index": [
92
+ 0,
93
+ 1,
94
+ 0,
95
+ 0
96
+ ]
97
+ }
98
+ ]
configs/label_dict.json ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "liver": 1,
3
+ "dummy1": 2,
4
+ "spleen": 3,
5
+ "pancreas": 4,
6
+ "right kidney": 5,
7
+ "aorta": 6,
8
+ "inferior vena cava": 7,
9
+ "right adrenal gland": 8,
10
+ "left adrenal gland": 9,
11
+ "gallbladder": 10,
12
+ "esophagus": 11,
13
+ "stomach": 12,
14
+ "duodenum": 13,
15
+ "left kidney": 14,
16
+ "bladder": 15,
17
+ "dummy2": 16,
18
+ "portal vein and splenic vein": 17,
19
+ "dummy3": 18,
20
+ "small bowel": 19,
21
+ "dummy4": 20,
22
+ "dummy5": 21,
23
+ "brain": 22,
24
+ "lung tumor": 23,
25
+ "pancreatic tumor": 24,
26
+ "hepatic vessel": 25,
27
+ "hepatic tumor": 26,
28
+ "colon cancer primaries": 27,
29
+ "left lung upper lobe": 28,
30
+ "left lung lower lobe": 29,
31
+ "right lung upper lobe": 30,
32
+ "right lung middle lobe": 31,
33
+ "right lung lower lobe": 32,
34
+ "vertebrae L5": 33,
35
+ "vertebrae L4": 34,
36
+ "vertebrae L3": 35,
37
+ "vertebrae L2": 36,
38
+ "vertebrae L1": 37,
39
+ "vertebrae T12": 38,
40
+ "vertebrae T11": 39,
41
+ "vertebrae T10": 40,
42
+ "vertebrae T9": 41,
43
+ "vertebrae T8": 42,
44
+ "vertebrae T7": 43,
45
+ "vertebrae T6": 44,
46
+ "vertebrae T5": 45,
47
+ "vertebrae T4": 46,
48
+ "vertebrae T3": 47,
49
+ "vertebrae T2": 48,
50
+ "vertebrae T1": 49,
51
+ "vertebrae C7": 50,
52
+ "vertebrae C6": 51,
53
+ "vertebrae C5": 52,
54
+ "vertebrae C4": 53,
55
+ "vertebrae C3": 54,
56
+ "vertebrae C2": 55,
57
+ "vertebrae C1": 56,
58
+ "trachea": 57,
59
+ "left iliac artery": 58,
60
+ "right iliac artery": 59,
61
+ "left iliac vena": 60,
62
+ "right iliac vena": 61,
63
+ "colon": 62,
64
+ "left rib 1": 63,
65
+ "left rib 2": 64,
66
+ "left rib 3": 65,
67
+ "left rib 4": 66,
68
+ "left rib 5": 67,
69
+ "left rib 6": 68,
70
+ "left rib 7": 69,
71
+ "left rib 8": 70,
72
+ "left rib 9": 71,
73
+ "left rib 10": 72,
74
+ "left rib 11": 73,
75
+ "left rib 12": 74,
76
+ "right rib 1": 75,
77
+ "right rib 2": 76,
78
+ "right rib 3": 77,
79
+ "right rib 4": 78,
80
+ "right rib 5": 79,
81
+ "right rib 6": 80,
82
+ "right rib 7": 81,
83
+ "right rib 8": 82,
84
+ "right rib 9": 83,
85
+ "right rib 10": 84,
86
+ "right rib 11": 85,
87
+ "right rib 12": 86,
88
+ "left humerus": 87,
89
+ "right humerus": 88,
90
+ "left scapula": 89,
91
+ "right scapula": 90,
92
+ "left clavicula": 91,
93
+ "right clavicula": 92,
94
+ "left femur": 93,
95
+ "right femur": 94,
96
+ "left hip": 95,
97
+ "right hip": 96,
98
+ "sacrum": 97,
99
+ "left gluteus maximus": 98,
100
+ "right gluteus maximus": 99,
101
+ "left gluteus medius": 100,
102
+ "right gluteus medius": 101,
103
+ "left gluteus minimus": 102,
104
+ "right gluteus minimus": 103,
105
+ "left autochthon": 104,
106
+ "right autochthon": 105,
107
+ "left iliopsoas": 106,
108
+ "right iliopsoas": 107,
109
+ "left atrial appendage": 108,
110
+ "brachiocephalic trunk": 109,
111
+ "left brachiocephalic vein": 110,
112
+ "right brachiocephalic vein": 111,
113
+ "left common carotid artery": 112,
114
+ "right common carotid artery": 113,
115
+ "costal cartilages": 114,
116
+ "heart": 115,
117
+ "left kidney cyst": 116,
118
+ "right kidney cyst": 117,
119
+ "prostate": 118,
120
+ "pulmonary vein": 119,
121
+ "skull": 120,
122
+ "spinal cord": 121,
123
+ "sternum": 122,
124
+ "left subclavian artery": 123,
125
+ "right subclavian artery": 124,
126
+ "superior vena cava": 125,
127
+ "thyroid gland": 126,
128
+ "vertebrae S1": 127,
129
+ "bone lesion": 128,
130
+ "dummy6": 129,
131
+ "dummy7": 130,
132
+ "dummy8": 131,
133
+ "airway": 132
134
+ }
configs/label_dict_124_to_132.json ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "background": [
3
+ 0,
4
+ 0
5
+ ],
6
+ "liver": [
7
+ 1,
8
+ 1
9
+ ],
10
+ "spleen": [
11
+ 2,
12
+ 3
13
+ ],
14
+ "pancreas": [
15
+ 3,
16
+ 4
17
+ ],
18
+ "right kidney": [
19
+ 4,
20
+ 5
21
+ ],
22
+ "aorta": [
23
+ 5,
24
+ 6
25
+ ],
26
+ "inferior vena cava": [
27
+ 6,
28
+ 7
29
+ ],
30
+ "right adrenal gland": [
31
+ 7,
32
+ 8
33
+ ],
34
+ "left adrenal gland": [
35
+ 8,
36
+ 9
37
+ ],
38
+ "gallbladder": [
39
+ 9,
40
+ 10
41
+ ],
42
+ "esophagus": [
43
+ 10,
44
+ 11
45
+ ],
46
+ "stomach": [
47
+ 11,
48
+ 12
49
+ ],
50
+ "duodenum": [
51
+ 12,
52
+ 13
53
+ ],
54
+ "left kidney": [
55
+ 13,
56
+ 14
57
+ ],
58
+ "bladder": [
59
+ 14,
60
+ 15
61
+ ],
62
+ "portal vein and splenic vein": [
63
+ 15,
64
+ 17
65
+ ],
66
+ "small bowel": [
67
+ 16,
68
+ 19
69
+ ],
70
+ "brain": [
71
+ 17,
72
+ 22
73
+ ],
74
+ "lung tumor": [
75
+ 18,
76
+ 23
77
+ ],
78
+ "pancreatic tumor": [
79
+ 19,
80
+ 24
81
+ ],
82
+ "hepatic vessel": [
83
+ 20,
84
+ 25
85
+ ],
86
+ "hepatic tumor": [
87
+ 21,
88
+ 26
89
+ ],
90
+ "colon cancer primaries": [
91
+ 22,
92
+ 27
93
+ ],
94
+ "left lung upper lobe": [
95
+ 23,
96
+ 28
97
+ ],
98
+ "left lung lower lobe": [
99
+ 24,
100
+ 29
101
+ ],
102
+ "right lung upper lobe": [
103
+ 25,
104
+ 30
105
+ ],
106
+ "right lung middle lobe": [
107
+ 26,
108
+ 31
109
+ ],
110
+ "right lung lower lobe": [
111
+ 27,
112
+ 32
113
+ ],
114
+ "vertebrae L5": [
115
+ 28,
116
+ 33
117
+ ],
118
+ "vertebrae L4": [
119
+ 29,
120
+ 34
121
+ ],
122
+ "vertebrae L3": [
123
+ 30,
124
+ 35
125
+ ],
126
+ "vertebrae L2": [
127
+ 31,
128
+ 36
129
+ ],
130
+ "vertebrae L1": [
131
+ 32,
132
+ 37
133
+ ],
134
+ "vertebrae T12": [
135
+ 33,
136
+ 38
137
+ ],
138
+ "vertebrae T11": [
139
+ 34,
140
+ 39
141
+ ],
142
+ "vertebrae T10": [
143
+ 35,
144
+ 40
145
+ ],
146
+ "vertebrae T9": [
147
+ 36,
148
+ 41
149
+ ],
150
+ "vertebrae T8": [
151
+ 37,
152
+ 42
153
+ ],
154
+ "vertebrae T7": [
155
+ 38,
156
+ 43
157
+ ],
158
+ "vertebrae T6": [
159
+ 39,
160
+ 44
161
+ ],
162
+ "vertebrae T5": [
163
+ 40,
164
+ 45
165
+ ],
166
+ "vertebrae T4": [
167
+ 41,
168
+ 46
169
+ ],
170
+ "vertebrae T3": [
171
+ 42,
172
+ 47
173
+ ],
174
+ "vertebrae T2": [
175
+ 43,
176
+ 48
177
+ ],
178
+ "vertebrae T1": [
179
+ 44,
180
+ 49
181
+ ],
182
+ "vertebrae C7": [
183
+ 45,
184
+ 50
185
+ ],
186
+ "vertebrae C6": [
187
+ 46,
188
+ 51
189
+ ],
190
+ "vertebrae C5": [
191
+ 47,
192
+ 52
193
+ ],
194
+ "vertebrae C4": [
195
+ 48,
196
+ 53
197
+ ],
198
+ "vertebrae C3": [
199
+ 49,
200
+ 54
201
+ ],
202
+ "vertebrae C2": [
203
+ 50,
204
+ 55
205
+ ],
206
+ "vertebrae C1": [
207
+ 51,
208
+ 56
209
+ ],
210
+ "trachea": [
211
+ 52,
212
+ 57
213
+ ],
214
+ "left iliac artery": [
215
+ 53,
216
+ 58
217
+ ],
218
+ "right iliac artery": [
219
+ 54,
220
+ 59
221
+ ],
222
+ "left iliac vena": [
223
+ 55,
224
+ 60
225
+ ],
226
+ "right iliac vena": [
227
+ 56,
228
+ 61
229
+ ],
230
+ "colon": [
231
+ 57,
232
+ 62
233
+ ],
234
+ "left rib 1": [
235
+ 58,
236
+ 63
237
+ ],
238
+ "left rib 2": [
239
+ 59,
240
+ 64
241
+ ],
242
+ "left rib 3": [
243
+ 60,
244
+ 65
245
+ ],
246
+ "left rib 4": [
247
+ 61,
248
+ 66
249
+ ],
250
+ "left rib 5": [
251
+ 62,
252
+ 67
253
+ ],
254
+ "left rib 6": [
255
+ 63,
256
+ 68
257
+ ],
258
+ "left rib 7": [
259
+ 64,
260
+ 69
261
+ ],
262
+ "left rib 8": [
263
+ 65,
264
+ 70
265
+ ],
266
+ "left rib 9": [
267
+ 66,
268
+ 71
269
+ ],
270
+ "left rib 10": [
271
+ 67,
272
+ 72
273
+ ],
274
+ "left rib 11": [
275
+ 68,
276
+ 73
277
+ ],
278
+ "left rib 12": [
279
+ 69,
280
+ 74
281
+ ],
282
+ "right rib 1": [
283
+ 70,
284
+ 75
285
+ ],
286
+ "right rib 2": [
287
+ 71,
288
+ 76
289
+ ],
290
+ "right rib 3": [
291
+ 72,
292
+ 77
293
+ ],
294
+ "right rib 4": [
295
+ 73,
296
+ 78
297
+ ],
298
+ "right rib 5": [
299
+ 74,
300
+ 79
301
+ ],
302
+ "right rib 6": [
303
+ 75,
304
+ 80
305
+ ],
306
+ "right rib 7": [
307
+ 76,
308
+ 81
309
+ ],
310
+ "right rib 8": [
311
+ 77,
312
+ 82
313
+ ],
314
+ "right rib 9": [
315
+ 78,
316
+ 83
317
+ ],
318
+ "right rib 10": [
319
+ 79,
320
+ 84
321
+ ],
322
+ "right rib 11": [
323
+ 80,
324
+ 85
325
+ ],
326
+ "right rib 12": [
327
+ 81,
328
+ 86
329
+ ],
330
+ "left humerus": [
331
+ 82,
332
+ 87
333
+ ],
334
+ "right humerus": [
335
+ 83,
336
+ 88
337
+ ],
338
+ "left scapula": [
339
+ 84,
340
+ 89
341
+ ],
342
+ "right scapula": [
343
+ 85,
344
+ 90
345
+ ],
346
+ "left clavicula": [
347
+ 86,
348
+ 91
349
+ ],
350
+ "right clavicula": [
351
+ 87,
352
+ 92
353
+ ],
354
+ "left femur": [
355
+ 88,
356
+ 93
357
+ ],
358
+ "right femur": [
359
+ 89,
360
+ 94
361
+ ],
362
+ "left hip": [
363
+ 90,
364
+ 95
365
+ ],
366
+ "right hip": [
367
+ 91,
368
+ 96
369
+ ],
370
+ "sacrum": [
371
+ 92,
372
+ 97
373
+ ],
374
+ "left gluteus maximus": [
375
+ 93,
376
+ 98
377
+ ],
378
+ "right gluteus maximus": [
379
+ 94,
380
+ 99
381
+ ],
382
+ "left gluteus medius": [
383
+ 95,
384
+ 100
385
+ ],
386
+ "right gluteus medius": [
387
+ 96,
388
+ 101
389
+ ],
390
+ "left gluteus minimus": [
391
+ 97,
392
+ 102
393
+ ],
394
+ "right gluteus minimus": [
395
+ 98,
396
+ 103
397
+ ],
398
+ "left autochthon": [
399
+ 99,
400
+ 104
401
+ ],
402
+ "right autochthon": [
403
+ 100,
404
+ 105
405
+ ],
406
+ "left iliopsoas": [
407
+ 101,
408
+ 106
409
+ ],
410
+ "right iliopsoas": [
411
+ 102,
412
+ 107
413
+ ],
414
+ "left atrial appendage": [
415
+ 103,
416
+ 108
417
+ ],
418
+ "brachiocephalic trunk": [
419
+ 104,
420
+ 109
421
+ ],
422
+ "left brachiocephalic vein": [
423
+ 105,
424
+ 110
425
+ ],
426
+ "right brachiocephalic vein": [
427
+ 106,
428
+ 111
429
+ ],
430
+ "left common carotid artery": [
431
+ 107,
432
+ 112
433
+ ],
434
+ "right common carotid artery": [
435
+ 108,
436
+ 113
437
+ ],
438
+ "costal cartilages": [
439
+ 109,
440
+ 114
441
+ ],
442
+ "heart": [
443
+ 110,
444
+ 115
445
+ ],
446
+ "prostate": [
447
+ 111,
448
+ 118
449
+ ],
450
+ "pulmonary vein": [
451
+ 112,
452
+ 119
453
+ ],
454
+ "skull": [
455
+ 113,
456
+ 120
457
+ ],
458
+ "spinal cord": [
459
+ 114,
460
+ 121
461
+ ],
462
+ "sternum": [
463
+ 115,
464
+ 122
465
+ ],
466
+ "left subclavian artery": [
467
+ 116,
468
+ 123
469
+ ],
470
+ "right subclavian artery": [
471
+ 117,
472
+ 124
473
+ ],
474
+ "superior vena cava": [
475
+ 118,
476
+ 125
477
+ ],
478
+ "thyroid gland": [
479
+ 119,
480
+ 126
481
+ ],
482
+ "vertebrae S1": [
483
+ 120,
484
+ 127
485
+ ],
486
+ "bone lesion": [
487
+ 121,
488
+ 128
489
+ ],
490
+ "kidney mass": [
491
+ 122,
492
+ 129
493
+ ],
494
+ "airway": [
495
+ 123,
496
+ 132
497
+ ],
498
+ "body": [
499
+ 124,
500
+ 200
501
+ ]
502
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20240318.json",
3
+ "version": "1.0.0",
4
+ "changelog": {
5
+ "1.0.0": "accelerated maisi, inference only, is not compartible with previous maisi diffusion model weights",
6
+ "0.4.6": "add TensorRT support",
7
+ "0.4.5": "update README",
8
+ "0.4.4": "update issue for IgniteInfo",
9
+ "0.4.3": "remove download large files, add weights_only when loading weights and add label_dict to large files",
10
+ "0.4.2": "update train.json to fix finetune ckpt bug",
11
+ "0.4.1": "update large files",
12
+ "0.4.0": "update to use monai 1.4, model ckpt updated, rm GenerativeAI repo, add quality check",
13
+ "0.3.6": "first oss version"
14
+ },
15
+ "monai_version": "1.4.0",
16
+ "pytorch_version": "2.4.0",
17
+ "numpy_version": "1.24.4",
18
+ "optional_packages_version": {
19
+ "fire": "0.6.0",
20
+ "nibabel": "5.2.1",
21
+ "tqdm": "4.66.4"
22
+ },
23
+ "supported_apps": {
24
+ "maisi-nim": ""
25
+ },
26
+ "name": "CT image latent diffusion generation",
27
+ "task": "CT image synthesis",
28
+ "description": "A generative model for creating 3D CT from Gaussian noise",
29
+ "authors": "MONAI team",
30
+ "copyright": "Copyright (c) MONAI Consortium",
31
+ "data_source": "http://medicaldecathlon.com/",
32
+ "data_type": "nibabel",
33
+ "image_classes": "Flair brain MRI with 1.1x1.1x1.1 mm voxel size",
34
+ "eval_metrics": {},
35
+ "intended_use": "This is a research tool/prototype and not to be used clinically",
36
+ "references": [],
37
+ "autoencoder_data_format": {
38
+ "inputs": {
39
+ "image": {
40
+ "type": "feature",
41
+ "format": "image",
42
+ "num_channels": 4,
43
+ "spatial_shape": [
44
+ 128,
45
+ 128,
46
+ 128
47
+ ],
48
+ "dtype": "float16",
49
+ "value_range": [
50
+ 0,
51
+ 1
52
+ ],
53
+ "is_patch_data": true
54
+ },
55
+ "body_region": {
56
+ "type": "array",
57
+ "value_range": [
58
+ "head",
59
+ "abdomen",
60
+ "chest/thorax",
61
+ "pelvis/lower"
62
+ ]
63
+ },
64
+ "anatomy_list": {
65
+ "type": "array",
66
+ "value_range": [
67
+ "liver",
68
+ "spleen",
69
+ "pancreas",
70
+ "right kidney",
71
+ "aorta",
72
+ "inferior vena cava",
73
+ "right adrenal gland",
74
+ "left adrenal gland",
75
+ "gallbladder",
76
+ "esophagus",
77
+ "stomach",
78
+ "duodenum",
79
+ "left kidney",
80
+ "bladder",
81
+ "portal vein and splenic vein",
82
+ "small bowel",
83
+ "brain",
84
+ "lung tumor",
85
+ "pancreatic tumor",
86
+ "hepatic vessel",
87
+ "hepatic tumor",
88
+ "colon cancer primaries",
89
+ "left lung upper lobe",
90
+ "left lung lower lobe",
91
+ "right lung upper lobe",
92
+ "right lung middle lobe",
93
+ "right lung lower lobe",
94
+ "vertebrae L5",
95
+ "vertebrae L4",
96
+ "vertebrae L3",
97
+ "vertebrae L2",
98
+ "vertebrae L1",
99
+ "vertebrae T12",
100
+ "vertebrae T11",
101
+ "vertebrae T10",
102
+ "vertebrae T9",
103
+ "vertebrae T8",
104
+ "vertebrae T7",
105
+ "vertebrae T6",
106
+ "vertebrae T5",
107
+ "vertebrae T4",
108
+ "vertebrae T3",
109
+ "vertebrae T2",
110
+ "vertebrae T1",
111
+ "vertebrae C7",
112
+ "vertebrae C6",
113
+ "vertebrae C5",
114
+ "vertebrae C4",
115
+ "vertebrae C3",
116
+ "vertebrae C2",
117
+ "vertebrae C1",
118
+ "trachea",
119
+ "left iliac artery",
120
+ "right iliac artery",
121
+ "left iliac vena",
122
+ "right iliac vena",
123
+ "colon",
124
+ "left rib 1",
125
+ "left rib 2",
126
+ "left rib 3",
127
+ "left rib 4",
128
+ "left rib 5",
129
+ "left rib 6",
130
+ "left rib 7",
131
+ "left rib 8",
132
+ "left rib 9",
133
+ "left rib 10",
134
+ "left rib 11",
135
+ "left rib 12",
136
+ "right rib 1",
137
+ "right rib 2",
138
+ "right rib 3",
139
+ "right rib 4",
140
+ "right rib 5",
141
+ "right rib 6",
142
+ "right rib 7",
143
+ "right rib 8",
144
+ "right rib 9",
145
+ "right rib 10",
146
+ "right rib 11",
147
+ "right rib 12",
148
+ "left humerus",
149
+ "right humerus",
150
+ "left scapula",
151
+ "right scapula",
152
+ "left clavicula",
153
+ "right clavicula",
154
+ "left femur",
155
+ "right femur",
156
+ "left hip",
157
+ "right hip",
158
+ "sacrum",
159
+ "left gluteus maximus",
160
+ "right gluteus maximus",
161
+ "left gluteus medius",
162
+ "right gluteus medius",
163
+ "left gluteus minimus",
164
+ "right gluteus minimus",
165
+ "left autochthon",
166
+ "right autochthon",
167
+ "left iliopsoas",
168
+ "right iliopsoas",
169
+ "left atrial appendage",
170
+ "brachiocephalic trunk",
171
+ "left brachiocephalic vein",
172
+ "right brachiocephalic vein",
173
+ "left common carotid artery",
174
+ "right common carotid artery",
175
+ "costal cartilages",
176
+ "heart",
177
+ "left kidney cyst",
178
+ "right kidney cyst",
179
+ "prostate",
180
+ "pulmonary vein",
181
+ "skull",
182
+ "spinal cord",
183
+ "sternum",
184
+ "left subclavian artery",
185
+ "right subclavian artery",
186
+ "superior vena cava",
187
+ "thyroid gland",
188
+ "vertebrae S1",
189
+ "bone lesion",
190
+ "airway"
191
+ ]
192
+ }
193
+ },
194
+ "outputs": {
195
+ "pred": {
196
+ "type": "image",
197
+ "format": "image",
198
+ "num_channels": 1,
199
+ "spatial_shape": [
200
+ 512,
201
+ 512,
202
+ 512
203
+ ],
204
+ "dtype": "float16",
205
+ "value_range": [
206
+ 0,
207
+ 1
208
+ ],
209
+ "is_patch_data": true,
210
+ "channel_def": {
211
+ "0": "image"
212
+ }
213
+ }
214
+ }
215
+ },
216
+ "generator_data_format": {
217
+ "inputs": {
218
+ "latent": {
219
+ "type": "noise",
220
+ "format": "image",
221
+ "num_channels": 4,
222
+ "spatial_shape": [
223
+ 128,
224
+ 128,
225
+ 128
226
+ ],
227
+ "dtype": "float16",
228
+ "value_range": [
229
+ 0,
230
+ 1
231
+ ],
232
+ "is_patch_data": true
233
+ },
234
+ "condition": {
235
+ "type": "timesteps",
236
+ "format": "timesteps",
237
+ "num_channels": 1,
238
+ "spatial_shape": [],
239
+ "dtype": "long",
240
+ "value_range": [
241
+ 0,
242
+ 1000
243
+ ],
244
+ "is_patch_data": false
245
+ }
246
+ },
247
+ "outputs": {
248
+ "pred": {
249
+ "type": "feature",
250
+ "format": "image",
251
+ "num_channels": 4,
252
+ "spatial_shape": [
253
+ 128,
254
+ 128,
255
+ 128
256
+ ],
257
+ "dtype": "float16",
258
+ "value_range": [
259
+ 0,
260
+ 1
261
+ ],
262
+ "is_patch_data": true,
263
+ "channel_def": {
264
+ "0": "image"
265
+ }
266
+ }
267
+ }
268
+ }
269
+ }
configs/multi_gpu_train.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])",
3
+ "use_tensorboard": "$dist.get_rank() == 0",
4
+ "controlnet": {
5
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
6
+ "module": "$@controlnet_def.to(@device)",
7
+ "find_unused_parameters": true,
8
+ "device_ids": [
9
+ "@device"
10
+ ]
11
+ },
12
+ "load_controlnet": "[email protected]_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)",
13
+ "train#sampler": {
14
+ "_target_": "DistributedSampler",
15
+ "dataset": "@train#dataset",
16
+ "even_divisible": true,
17
+ "shuffle": true
18
+ },
19
+ "train#dataloader#sampler": "@train#sampler",
20
+ "train#dataloader#shuffle": false,
21
+ "train#trainer#train_handlers": "$@train#handlers[: -1 if dist.get_rank() > 0 else None]",
22
+ "initialize": [
23
+ "$import torch.distributed as dist",
24
+ "$dist.is_initialized() or dist.init_process_group(backend='nccl')",
25
+ "$torch.cuda.set_device(@device)",
26
+ "$monai.utils.set_determinism(seed=123)"
27
+ ],
28
+ "run": [
29
+ "$@train#trainer.run()"
30
+ ],
31
+ "finalize": [
32
+ "$dist.is_initialized() and dist.destroy_process_group()"
33
+ ]
34
+ }
configs/train.json ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os",
5
+ "$import scripts",
6
+ "$import ignite"
7
+ ],
8
+ "bundle_root": ".",
9
+ "ckpt_dir": "$@bundle_root + '/models'",
10
+ "output_dir": "$@bundle_root + '/output'",
11
+ "data_list_file_path": "$@bundle_root + '/datasets/C4KC-KiTS_subset.json'",
12
+ "dataset_dir": "$@bundle_root + '/datasets/C4KC-KiTS_subset'",
13
+ "trained_diffusion_path": "$@ckpt_dir + '/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt'",
14
+ "trained_controlnet_path": "$@ckpt_dir + '/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt'",
15
+ "use_tensorboard": true,
16
+ "fold": 0,
17
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
18
+ "epochs": 100,
19
+ "batch_size": 1,
20
+ "val_at_start": false,
21
+ "learning_rate": 0.0001,
22
+ "weighted_loss_label": [
23
+ 129
24
+ ],
25
+ "weighted_loss": 100,
26
+ "amp": true,
27
+ "train_datalist": "$scripts.utils.maisi_datafold_read(json_list=@data_list_file_path, data_base_dir=@dataset_dir, fold=@fold)[0]",
28
+ "spatial_dims": 3,
29
+ "image_channels": 1,
30
+ "latent_channels": 4,
31
+ "diffusion_unet_def": {
32
+ "_target_": "monai.apps.generation.maisi.networks.diffusion_model_unet_maisi.DiffusionModelUNetMaisi",
33
+ "spatial_dims": "@spatial_dims",
34
+ "in_channels": "@latent_channels",
35
+ "out_channels": "@latent_channels",
36
+ "num_channels": [
37
+ 64,
38
+ 128,
39
+ 256,
40
+ 512
41
+ ],
42
+ "attention_levels": [
43
+ false,
44
+ false,
45
+ true,
46
+ true
47
+ ],
48
+ "num_head_channels": [
49
+ 0,
50
+ 0,
51
+ 32,
52
+ 32
53
+ ],
54
+ "num_res_blocks": 2,
55
+ "use_flash_attention": true,
56
+ "include_top_region_index_input": true,
57
+ "include_bottom_region_index_input": true,
58
+ "include_spacing_input": true
59
+ },
60
+ "controlnet_def": {
61
+ "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
62
+ "spatial_dims": "@spatial_dims",
63
+ "in_channels": "@latent_channels",
64
+ "num_channels": [
65
+ 64,
66
+ 128,
67
+ 256,
68
+ 512
69
+ ],
70
+ "attention_levels": [
71
+ false,
72
+ false,
73
+ true,
74
+ true
75
+ ],
76
+ "num_head_channels": [
77
+ 0,
78
+ 0,
79
+ 32,
80
+ 32
81
+ ],
82
+ "num_res_blocks": 2,
83
+ "use_flash_attention": true,
84
+ "conditioning_embedding_in_channels": 8,
85
+ "conditioning_embedding_num_channels": [
86
+ 8,
87
+ 32,
88
+ 64
89
+ ]
90
+ },
91
+ "noise_scheduler": {
92
+ "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
93
+ "num_train_timesteps": 1000,
94
+ "beta_start": 0.0015,
95
+ "beta_end": 0.0195,
96
+ "schedule": "scaled_linear_beta",
97
+ "clip_sample": false
98
+ },
99
+ "unzip_dataset": "$scripts.utils.unzip_dataset(@dataset_dir)",
100
+ "diffusion_unet": "$@diffusion_unet_def.to(@device)",
101
+ "checkpoint_diffusion_unet": "$torch.load(@trained_diffusion_path, weights_only=False)",
102
+ "load_diffusion": "$@diffusion_unet.load_state_dict(@checkpoint_diffusion_unet['unet_state_dict'])",
103
+ "controlnet": "$@controlnet_def.to(@device)",
104
+ "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @diffusion_unet.state_dict())",
105
+ "checkpoint_controlnet": "$torch.load(@trained_controlnet_path, weights_only=False)",
106
+ "load_controlnet": "[email protected]_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)",
107
+ "scale_factor": "$@checkpoint_diffusion_unet['scale_factor'].to(@device)",
108
+ "loss": {
109
+ "_target_": "torch.nn.L1Loss",
110
+ "reduction": "none"
111
+ },
112
+ "optimizer": {
113
+ "_target_": "torch.optim.AdamW",
114
+ "params": "[email protected]()",
115
+ "lr": "@learning_rate",
116
+ "weight_decay": 1e-05
117
+ },
118
+ "lr_schedule": {
119
+ "activate": true,
120
+ "lr_scheduler": {
121
+ "_target_": "torch.optim.lr_scheduler.PolynomialLR",
122
+ "optimizer": "@optimizer",
123
+ "total_iters": "$(@epochs * len(@train#dataloader.dataset)) / @batch_size",
124
+ "power": 2.0
125
+ }
126
+ },
127
+ "train": {
128
+ "deterministic_transforms": [
129
+ {
130
+ "_target_": "LoadImaged",
131
+ "keys": [
132
+ "image",
133
+ "label"
134
+ ],
135
+ "image_only": true,
136
+ "ensure_channel_first": true
137
+ },
138
+ {
139
+ "_target_": "Orientationd",
140
+ "keys": [
141
+ "label"
142
+ ],
143
+ "axcodes": "RAS"
144
+ },
145
+ {
146
+ "_target_": "EnsureTyped",
147
+ "keys": [
148
+ "label"
149
+ ],
150
+ "dtype": "$torch.uint8",
151
+ "track_meta": true
152
+ },
153
+ {
154
+ "_target_": "Lambdad",
155
+ "keys": "top_region_index",
156
+ "func": "$lambda x: torch.FloatTensor(x)"
157
+ },
158
+ {
159
+ "_target_": "Lambdad",
160
+ "keys": "bottom_region_index",
161
+ "func": "$lambda x: torch.FloatTensor(x)"
162
+ },
163
+ {
164
+ "_target_": "Lambdad",
165
+ "keys": "spacing",
166
+ "func": "$lambda x: torch.FloatTensor(x)"
167
+ },
168
+ {
169
+ "_target_": "Lambdad",
170
+ "keys": "top_region_index",
171
+ "func": "$lambda x: x * 1e2"
172
+ },
173
+ {
174
+ "_target_": "Lambdad",
175
+ "keys": "bottom_region_index",
176
+ "func": "$lambda x: x * 1e2"
177
+ },
178
+ {
179
+ "_target_": "Lambdad",
180
+ "keys": "spacing",
181
+ "func": "$lambda x: x * 1e2"
182
+ }
183
+ ],
184
+ "inferer": {
185
+ "_target_": "SimpleInferer"
186
+ },
187
+ "preprocessing": {
188
+ "_target_": "Compose",
189
+ "transforms": "$@train#deterministic_transforms"
190
+ },
191
+ "dataset": {
192
+ "_target_": "Dataset",
193
+ "data": "@train_datalist",
194
+ "transform": "@train#preprocessing"
195
+ },
196
+ "dataloader": {
197
+ "_target_": "DataLoader",
198
+ "dataset": "@train#dataset",
199
+ "batch_size": "@batch_size",
200
+ "shuffle": true,
201
+ "num_workers": 4,
202
+ "pin_memory": true,
203
+ "persistent_workers": true
204
+ },
205
+ "handlers": [
206
+ {
207
+ "_target_": "LrScheduleHandler",
208
+ "_disabled_": "$not @lr_schedule#activate",
209
+ "lr_scheduler": "@lr_schedule#lr_scheduler",
210
+ "epoch_level": false,
211
+ "print_lr": true
212
+ },
213
+ {
214
+ "_target_": "CheckpointSaver",
215
+ "save_dir": "@ckpt_dir",
216
+ "save_dict": {
217
+ "controlnet_state_dict": "@controlnet",
218
+ "optimizer": "@optimizer"
219
+ },
220
+ "save_interval": 1,
221
+ "n_saved": 5
222
+ },
223
+ {
224
+ "_target_": "TensorBoardStatsHandler",
225
+ "_disabled_": "$not @use_tensorboard",
226
+ "log_dir": "@output_dir",
227
+ "tag_name": "train_loss",
228
+ "output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
229
+ },
230
+ {
231
+ "_target_": "StatsHandler",
232
+ "tag_name": "train_loss",
233
+ "name": "StatsHandler",
234
+ "output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
235
+ }
236
+ ],
237
+ "trainer": {
238
+ "_target_": "scripts.trainer.MAISIControlNetTrainer",
239
+ "_requires_": [
240
+ "@load_diffusion",
241
+ "@copy_controlnet_state",
242
+ "@load_controlnet",
243
+ "@unzip_dataset"
244
+ ],
245
+ "max_epochs": "@epochs",
246
+ "device": "@device",
247
+ "train_data_loader": "@train#dataloader",
248
+ "diffusion_unet": "@diffusion_unet",
249
+ "controlnet": "@controlnet",
250
+ "noise_scheduler": "@noise_scheduler",
251
+ "loss_function": "@loss",
252
+ "optimizer": "@optimizer",
253
+ "inferer": "@train#inferer",
254
+ "key_train_metric": null,
255
+ "train_handlers": "@train#handlers",
256
+ "amp": "@amp",
257
+ "hyper_kwargs": {
258
+ "weighted_loss": "@weighted_loss",
259
+ "weighted_loss_label": "@weighted_loss_label",
260
+ "scale_factor": "@scale_factor"
261
+ }
262
+ }
263
+ },
264
+ "initialize": [
265
+ "$monai.utils.set_determinism(seed=0)"
266
+ ],
267
+ "run": [
268
+ "$@train#trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, ignite.handlers.TerminateOnNan())",
269
+ "$@train#trainer.run()"
270
+ ]
271
+ }
datasets/C4KC-KiTS_subset.json ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "training": [
3
+ {
4
+ "image": "KiTS-00186/2_arterial_emb_zs99.nii.gz",
5
+ "label": "KiTS-00186/mask_combined_label_zs99_wbdm.nii.gz",
6
+ "fold": 0,
7
+ "dim": [
8
+ 512,
9
+ 512,
10
+ 512
11
+ ],
12
+ "spacing": [
13
+ 1.0,
14
+ 1.0,
15
+ 1.0
16
+ ],
17
+ "top_region_index": [
18
+ 0,
19
+ 1,
20
+ 0,
21
+ 0
22
+ ],
23
+ "bottom_region_index": [
24
+ 0,
25
+ 0,
26
+ 0,
27
+ 1
28
+ ]
29
+ },
30
+ {
31
+ "image": "KiTS-00066/3_arterial_emb_zs11.nii.gz",
32
+ "label": "KiTS-00066/mask_combined_label_zs11_wbdm.nii.gz",
33
+ "fold": 0,
34
+ "dim": [
35
+ 512,
36
+ 512,
37
+ 512
38
+ ],
39
+ "spacing": [
40
+ 1.0,
41
+ 1.0,
42
+ 1.0
43
+ ],
44
+ "top_region_index": [
45
+ 0,
46
+ 1,
47
+ 0,
48
+ 0
49
+ ],
50
+ "bottom_region_index": [
51
+ 0,
52
+ 0,
53
+ 0,
54
+ 1
55
+ ]
56
+ },
57
+ {
58
+ "image": "KiTS-00012/2_arterial_emb_zs0.nii.gz",
59
+ "label": "KiTS-00012/mask_combined_label_zs0_wbdm.nii.gz",
60
+ "fold": 0,
61
+ "dim": [
62
+ 512,
63
+ 512,
64
+ 512
65
+ ],
66
+ "spacing": [
67
+ 1.0,
68
+ 1.0,
69
+ 1.0
70
+ ],
71
+ "top_region_index": [
72
+ 0,
73
+ 1,
74
+ 0,
75
+ 0
76
+ ],
77
+ "bottom_region_index": [
78
+ 0,
79
+ 0,
80
+ 0,
81
+ 1
82
+ ]
83
+ },
84
+ {
85
+ "image": "KiTS-00055/2_arterial_emb_zs0.nii.gz",
86
+ "label": "KiTS-00055/mask_combined_label_zs0_wbdm.nii.gz",
87
+ "fold": 0,
88
+ "dim": [
89
+ 512,
90
+ 512,
91
+ 512
92
+ ],
93
+ "spacing": [
94
+ 1.0,
95
+ 1.0,
96
+ 1.0
97
+ ],
98
+ "top_region_index": [
99
+ 0,
100
+ 1,
101
+ 0,
102
+ 0
103
+ ],
104
+ "bottom_region_index": [
105
+ 0,
106
+ 0,
107
+ 0,
108
+ 1
109
+ ]
110
+ },
111
+ {
112
+ "image": "KiTS-00193/100_arterial_emb_zs0.nii.gz",
113
+ "label": "KiTS-00193/mask_combined_label_zs0_wbdm.nii.gz",
114
+ "fold": 0,
115
+ "dim": [
116
+ 512,
117
+ 512,
118
+ 512
119
+ ],
120
+ "spacing": [
121
+ 1.0,
122
+ 1.0,
123
+ 1.0
124
+ ],
125
+ "top_region_index": [
126
+ 0,
127
+ 1,
128
+ 0,
129
+ 0
130
+ ],
131
+ "bottom_region_index": [
132
+ 0,
133
+ 0,
134
+ 1,
135
+ 0
136
+ ]
137
+ },
138
+ {
139
+ "image": "KiTS-00142/7_arterial_emb_zs0.nii.gz",
140
+ "label": "KiTS-00142/mask_combined_label_zs0_wbdm.nii.gz",
141
+ "fold": 0,
142
+ "dim": [
143
+ 512,
144
+ 512,
145
+ 512
146
+ ],
147
+ "spacing": [
148
+ 1.0,
149
+ 1.0,
150
+ 1.0
151
+ ],
152
+ "top_region_index": [
153
+ 0,
154
+ 1,
155
+ 0,
156
+ 0
157
+ ],
158
+ "bottom_region_index": [
159
+ 0,
160
+ 0,
161
+ 1,
162
+ 0
163
+ ]
164
+ },
165
+ {
166
+ "image": "KiTS-00069/3_arterial_emb_zs0.nii.gz",
167
+ "label": "KiTS-00069/mask_combined_label_zs0_wbdm.nii.gz",
168
+ "fold": 0,
169
+ "dim": [
170
+ 512,
171
+ 512,
172
+ 512
173
+ ],
174
+ "spacing": [
175
+ 1.0,
176
+ 1.0,
177
+ 1.0
178
+ ],
179
+ "top_region_index": [
180
+ 0,
181
+ 1,
182
+ 0,
183
+ 0
184
+ ],
185
+ "bottom_region_index": [
186
+ 0,
187
+ 0,
188
+ 0,
189
+ 1
190
+ ]
191
+ },
192
+ {
193
+ "image": "KiTS-00124/2_arterial_emb_zs0.nii.gz",
194
+ "label": "KiTS-00124/mask_combined_label_zs0_wbdm.nii.gz",
195
+ "fold": 0,
196
+ "dim": [
197
+ 512,
198
+ 512,
199
+ 512
200
+ ],
201
+ "spacing": [
202
+ 1.0,
203
+ 1.0,
204
+ 1.0
205
+ ],
206
+ "top_region_index": [
207
+ 0,
208
+ 1,
209
+ 0,
210
+ 0
211
+ ],
212
+ "bottom_region_index": [
213
+ 0,
214
+ 0,
215
+ 0,
216
+ 1
217
+ ]
218
+ },
219
+ {
220
+ "image": "KiTS-00208/2_arterial_emb_zs0.nii.gz",
221
+ "label": "KiTS-00208/mask_combined_label_zs0_wbdm.nii.gz",
222
+ "fold": 0,
223
+ "dim": [
224
+ 512,
225
+ 512,
226
+ 512
227
+ ],
228
+ "spacing": [
229
+ 1.0,
230
+ 1.0,
231
+ 1.0
232
+ ],
233
+ "top_region_index": [
234
+ 0,
235
+ 1,
236
+ 0,
237
+ 0
238
+ ],
239
+ "bottom_region_index": [
240
+ 0,
241
+ 0,
242
+ 0,
243
+ 1
244
+ ]
245
+ },
246
+ {
247
+ "image": "KiTS-00116/9_arterial_emb_zs0.nii.gz",
248
+ "label": "KiTS-00116/mask_combined_label_zs0_wbdm.nii.gz",
249
+ "fold": 0,
250
+ "dim": [
251
+ 512,
252
+ 512,
253
+ 512
254
+ ],
255
+ "spacing": [
256
+ 1.0,
257
+ 1.0,
258
+ 1.0
259
+ ],
260
+ "top_region_index": [
261
+ 0,
262
+ 1,
263
+ 0,
264
+ 0
265
+ ],
266
+ "bottom_region_index": [
267
+ 0,
268
+ 0,
269
+ 1,
270
+ 0
271
+ ]
272
+ },
273
+ {
274
+ "image": "KiTS-00061/4_arterial_emb_zs0.nii.gz",
275
+ "label": "KiTS-00061/mask_combined_label_zs0_wbdm.nii.gz",
276
+ "fold": 1,
277
+ "dim": [
278
+ 512,
279
+ 512,
280
+ 512
281
+ ],
282
+ "spacing": [
283
+ 1.0,
284
+ 1.0,
285
+ 1.0
286
+ ],
287
+ "top_region_index": [
288
+ 0,
289
+ 0,
290
+ 1,
291
+ 0
292
+ ],
293
+ "bottom_region_index": [
294
+ 0,
295
+ 0,
296
+ 1,
297
+ 0
298
+ ]
299
+ },
300
+ {
301
+ "image": "KiTS-00040/3_arterial_emb_zs0.nii.gz",
302
+ "label": "KiTS-00040/mask_combined_label_zs0_wbdm.nii.gz",
303
+ "fold": 1,
304
+ "dim": [
305
+ 512,
306
+ 512,
307
+ 512
308
+ ],
309
+ "spacing": [
310
+ 1.0,
311
+ 1.0,
312
+ 1.0
313
+ ],
314
+ "top_region_index": [
315
+ 0,
316
+ 1,
317
+ 0,
318
+ 0
319
+ ],
320
+ "bottom_region_index": [
321
+ 0,
322
+ 0,
323
+ 0,
324
+ 1
325
+ ]
326
+ },
327
+ {
328
+ "image": "KiTS-00068/7_arterial_emb_zs0.nii.gz",
329
+ "label": "KiTS-00068/mask_combined_label_zs0_wbdm.nii.gz",
330
+ "fold": 1,
331
+ "dim": [
332
+ 512,
333
+ 512,
334
+ 512
335
+ ],
336
+ "spacing": [
337
+ 1.0,
338
+ 1.0,
339
+ 1.0
340
+ ],
341
+ "top_region_index": [
342
+ 0,
343
+ 1,
344
+ 0,
345
+ 0
346
+ ],
347
+ "bottom_region_index": [
348
+ 0,
349
+ 0,
350
+ 1,
351
+ 0
352
+ ]
353
+ },
354
+ {
355
+ "image": "KiTS-00036/2_arterial_emb_zs0.nii.gz",
356
+ "label": "KiTS-00036/mask_combined_label_zs0_wbdm.nii.gz",
357
+ "fold": 1,
358
+ "dim": [
359
+ 512,
360
+ 512,
361
+ 512
362
+ ],
363
+ "spacing": [
364
+ 1.0,
365
+ 1.0,
366
+ 1.0
367
+ ],
368
+ "top_region_index": [
369
+ 0,
370
+ 1,
371
+ 0,
372
+ 0
373
+ ],
374
+ "bottom_region_index": [
375
+ 0,
376
+ 0,
377
+ 0,
378
+ 1
379
+ ]
380
+ },
381
+ {
382
+ "image": "KiTS-00153/8_arterial_emb_zs0.nii.gz",
383
+ "label": "KiTS-00153/mask_combined_label_zs0_wbdm.nii.gz",
384
+ "fold": 1,
385
+ "dim": [
386
+ 512,
387
+ 512,
388
+ 512
389
+ ],
390
+ "spacing": [
391
+ 1.0,
392
+ 1.0,
393
+ 1.0
394
+ ],
395
+ "top_region_index": [
396
+ 0,
397
+ 1,
398
+ 0,
399
+ 0
400
+ ],
401
+ "bottom_region_index": [
402
+ 0,
403
+ 0,
404
+ 1,
405
+ 0
406
+ ]
407
+ },
408
+ {
409
+ "image": "KiTS-00189/2_arterial_emb_zs107.nii.gz",
410
+ "label": "KiTS-00189/mask_combined_label_zs107_wbdm.nii.gz",
411
+ "fold": 1,
412
+ "dim": [
413
+ 512,
414
+ 512,
415
+ 512
416
+ ],
417
+ "spacing": [
418
+ 1.0,
419
+ 1.0,
420
+ 1.0
421
+ ],
422
+ "top_region_index": [
423
+ 0,
424
+ 1,
425
+ 0,
426
+ 0
427
+ ],
428
+ "bottom_region_index": [
429
+ 0,
430
+ 0,
431
+ 0,
432
+ 1
433
+ ]
434
+ },
435
+ {
436
+ "image": "KiTS-00091/7_arterial_emb_zs0.nii.gz",
437
+ "label": "KiTS-00091/mask_combined_label_zs0_wbdm.nii.gz",
438
+ "fold": 1,
439
+ "dim": [
440
+ 512,
441
+ 512,
442
+ 512
443
+ ],
444
+ "spacing": [
445
+ 1.0,
446
+ 1.0,
447
+ 1.0
448
+ ],
449
+ "top_region_index": [
450
+ 0,
451
+ 1,
452
+ 0,
453
+ 0
454
+ ],
455
+ "bottom_region_index": [
456
+ 0,
457
+ 0,
458
+ 1,
459
+ 0
460
+ ]
461
+ },
462
+ {
463
+ "image": "KiTS-00110/3_arterial_emb_zs0.nii.gz",
464
+ "label": "KiTS-00110/mask_combined_label_zs0_wbdm.nii.gz",
465
+ "fold": 1,
466
+ "dim": [
467
+ 512,
468
+ 512,
469
+ 512
470
+ ],
471
+ "spacing": [
472
+ 1.0,
473
+ 1.0,
474
+ 1.0
475
+ ],
476
+ "top_region_index": [
477
+ 0,
478
+ 1,
479
+ 0,
480
+ 0
481
+ ],
482
+ "bottom_region_index": [
483
+ 0,
484
+ 0,
485
+ 1,
486
+ 0
487
+ ]
488
+ },
489
+ {
490
+ "image": "KiTS-00046/2_arterial_emb_zs0.nii.gz",
491
+ "label": "KiTS-00046/mask_combined_label_zs0_wbdm.nii.gz",
492
+ "fold": 1,
493
+ "dim": [
494
+ 512,
495
+ 512,
496
+ 512
497
+ ],
498
+ "spacing": [
499
+ 1.0,
500
+ 1.0,
501
+ 1.0
502
+ ],
503
+ "top_region_index": [
504
+ 0,
505
+ 1,
506
+ 0,
507
+ 0
508
+ ],
509
+ "bottom_region_index": [
510
+ 0,
511
+ 0,
512
+ 0,
513
+ 1
514
+ ]
515
+ },
516
+ {
517
+ "image": "KiTS-00178/3_arterial_emb_zs0.nii.gz",
518
+ "label": "KiTS-00178/mask_combined_label_zs0_wbdm.nii.gz",
519
+ "fold": 1,
520
+ "dim": [
521
+ 512,
522
+ 512,
523
+ 512
524
+ ],
525
+ "spacing": [
526
+ 1.0,
527
+ 1.0,
528
+ 1.0
529
+ ],
530
+ "top_region_index": [
531
+ 0,
532
+ 1,
533
+ 0,
534
+ 0
535
+ ],
536
+ "bottom_region_index": [
537
+ 0,
538
+ 0,
539
+ 0,
540
+ 1
541
+ ]
542
+ },
543
+ {
544
+ "image": "KiTS-00075/2_arterial_emb_zs0.nii.gz",
545
+ "label": "KiTS-00075/mask_combined_label_zs0_wbdm.nii.gz",
546
+ "fold": 1,
547
+ "dim": [
548
+ 512,
549
+ 512,
550
+ 512
551
+ ],
552
+ "spacing": [
553
+ 1.0,
554
+ 1.0,
555
+ 1.0
556
+ ],
557
+ "top_region_index": [
558
+ 0,
559
+ 1,
560
+ 0,
561
+ 0
562
+ ],
563
+ "bottom_region_index": [
564
+ 0,
565
+ 0,
566
+ 0,
567
+ 1
568
+ ]
569
+ },
570
+ {
571
+ "image": "KiTS-00037/6_arterial_emb_zs0.nii.gz",
572
+ "label": "KiTS-00037/mask_combined_label_zs0_wbdm.nii.gz",
573
+ "fold": 1,
574
+ "dim": [
575
+ 512,
576
+ 512,
577
+ 512
578
+ ],
579
+ "spacing": [
580
+ 1.0,
581
+ 1.0,
582
+ 1.0
583
+ ],
584
+ "top_region_index": [
585
+ 0,
586
+ 1,
587
+ 0,
588
+ 0
589
+ ],
590
+ "bottom_region_index": [
591
+ 0,
592
+ 0,
593
+ 0,
594
+ 1
595
+ ]
596
+ },
597
+ {
598
+ "image": "KiTS-00130/9_arterial_emb_zs0.nii.gz",
599
+ "label": "KiTS-00130/mask_combined_label_zs0_wbdm.nii.gz",
600
+ "fold": 1,
601
+ "dim": [
602
+ 512,
603
+ 512,
604
+ 512
605
+ ],
606
+ "spacing": [
607
+ 1.0,
608
+ 1.0,
609
+ 1.0
610
+ ],
611
+ "top_region_index": [
612
+ 0,
613
+ 1,
614
+ 0,
615
+ 0
616
+ ],
617
+ "bottom_region_index": [
618
+ 0,
619
+ 0,
620
+ 1,
621
+ 0
622
+ ]
623
+ },
624
+ {
625
+ "image": "KiTS-00063/6_arterial_emb_zs0.nii.gz",
626
+ "label": "KiTS-00063/mask_combined_label_zs0_wbdm.nii.gz",
627
+ "fold": 1,
628
+ "dim": [
629
+ 512,
630
+ 512,
631
+ 512
632
+ ],
633
+ "spacing": [
634
+ 1.0,
635
+ 1.0,
636
+ 1.0
637
+ ],
638
+ "top_region_index": [
639
+ 0,
640
+ 1,
641
+ 0,
642
+ 0
643
+ ],
644
+ "bottom_region_index": [
645
+ 0,
646
+ 0,
647
+ 1,
648
+ 0
649
+ ]
650
+ },
651
+ {
652
+ "image": "KiTS-00205/4_arterial_emb_zs0.nii.gz",
653
+ "label": "KiTS-00205/mask_combined_label_zs0_wbdm.nii.gz",
654
+ "fold": 1,
655
+ "dim": [
656
+ 512,
657
+ 512,
658
+ 512
659
+ ],
660
+ "spacing": [
661
+ 1.0,
662
+ 1.0,
663
+ 1.0
664
+ ],
665
+ "top_region_index": [
666
+ 0,
667
+ 1,
668
+ 0,
669
+ 0
670
+ ],
671
+ "bottom_region_index": [
672
+ 0,
673
+ 0,
674
+ 0,
675
+ 1
676
+ ]
677
+ },
678
+ {
679
+ "image": "KiTS-00167/2_arterial_emb_zs0.nii.gz",
680
+ "label": "KiTS-00167/mask_combined_label_zs0_wbdm.nii.gz",
681
+ "fold": 1,
682
+ "dim": [
683
+ 512,
684
+ 512,
685
+ 512
686
+ ],
687
+ "spacing": [
688
+ 1.0,
689
+ 1.0,
690
+ 1.0
691
+ ],
692
+ "top_region_index": [
693
+ 0,
694
+ 1,
695
+ 0,
696
+ 0
697
+ ],
698
+ "bottom_region_index": [
699
+ 0,
700
+ 0,
701
+ 0,
702
+ 1
703
+ ]
704
+ },
705
+ {
706
+ "image": "KiTS-00059/8_arterial_emb_zs0.nii.gz",
707
+ "label": "KiTS-00059/mask_combined_label_zs0_wbdm.nii.gz",
708
+ "fold": 1,
709
+ "dim": [
710
+ 512,
711
+ 512,
712
+ 512
713
+ ],
714
+ "spacing": [
715
+ 1.0,
716
+ 1.0,
717
+ 1.0
718
+ ],
719
+ "top_region_index": [
720
+ 0,
721
+ 1,
722
+ 0,
723
+ 0
724
+ ],
725
+ "bottom_region_index": [
726
+ 0,
727
+ 0,
728
+ 1,
729
+ 0
730
+ ]
731
+ },
732
+ {
733
+ "image": "KiTS-00172/3_arterial_emb_zs0.nii.gz",
734
+ "label": "KiTS-00172/mask_combined_label_zs0_wbdm.nii.gz",
735
+ "fold": 1,
736
+ "dim": [
737
+ 512,
738
+ 512,
739
+ 512
740
+ ],
741
+ "spacing": [
742
+ 1.0,
743
+ 1.0,
744
+ 1.0
745
+ ],
746
+ "top_region_index": [
747
+ 0,
748
+ 1,
749
+ 0,
750
+ 0
751
+ ],
752
+ "bottom_region_index": [
753
+ 0,
754
+ 0,
755
+ 0,
756
+ 1
757
+ ]
758
+ },
759
+ {
760
+ "image": "KiTS-00093/7_arterial_emb_zs0.nii.gz",
761
+ "label": "KiTS-00093/mask_combined_label_zs0_wbdm.nii.gz",
762
+ "fold": 1,
763
+ "dim": [
764
+ 512,
765
+ 512,
766
+ 512
767
+ ],
768
+ "spacing": [
769
+ 1.0,
770
+ 1.0,
771
+ 1.0
772
+ ],
773
+ "top_region_index": [
774
+ 0,
775
+ 1,
776
+ 0,
777
+ 0
778
+ ],
779
+ "bottom_region_index": [
780
+ 0,
781
+ 0,
782
+ 0,
783
+ 1
784
+ ]
785
+ },
786
+ {
787
+ "image": "KiTS-00197/2_arterial_emb_zs0.nii.gz",
788
+ "label": "KiTS-00197/mask_combined_label_zs0_wbdm.nii.gz",
789
+ "fold": 1,
790
+ "dim": [
791
+ 512,
792
+ 512,
793
+ 512
794
+ ],
795
+ "spacing": [
796
+ 1.0,
797
+ 1.0,
798
+ 1.0
799
+ ],
800
+ "top_region_index": [
801
+ 0,
802
+ 1,
803
+ 0,
804
+ 0
805
+ ],
806
+ "bottom_region_index": [
807
+ 0,
808
+ 0,
809
+ 0,
810
+ 1
811
+ ]
812
+ }
813
+ ]
814
+ }
datasets/C4KC-KiTS_subset.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bb65d194571c8db8e26ac911b04898cd54376f3c76a0303be70c4f883102088
3
+ size 3155140827
datasets/IntegrationTest-AbdomenCT.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50b4a84769a31aeadd5f6d1a2bece82ba138bfb0eabe94ab13894fc8eb5dac90
3
+ size 7493659
datasets/all_masks_flexible_size_and_spacing_3000.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d89ebad0762448eca5b01e7b2e3199439111af50519fd4b8f124bc9e62968eb
3
+ size 9028952285
docs/README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Overview
2
+ This bundle is for Nvidia MAISI (Medical AI for Synthetic Imaging), a 3D Latent Diffusion Model that can generate large CT images with paired segmentation masks, variable volume size and voxel size, as well as controllable organ/tumor size.
3
+
4
+ The inference workflow of MAISI is depicted in the figure below. It first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Then it decodes the denoised latent features into images using the trained autoencoder.
5
+
6
+ <p align="center">
7
+ <img src="https://developer.download.nvidia.com/assets/Clara/Images/maisi_workflow_1.0.1.png" alt="MAISI inference scheme">
8
+ </p>
9
+
10
+ MAISI is based on the following papers:
11
+
12
+ [**Latent Diffusion:** Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf)
13
+
14
+ [**ControlNet:** Lvmin Zhang, Anyi Rao, Maneesh Agrawala; “Adding Conditional Control to Text-to-Image Diffusion Models.” ICCV 2023.](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhang_Adding_Conditional_Control_to_Text-to-Image_Diffusion_Models_ICCV_2023_paper.pdf)
15
+
16
+ [**Rectified Flow:** Liu, Xingchao, and Chengyue Gong. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." ICLR 2023.](https://arxiv.org/pdf/2209.03003)
17
+
18
+ #### Example synthetic image
19
+ An example result from inference is shown below:
20
+ ![Example synthetic image](https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_example_synthetic_data.png)
21
+
22
+ ### Inference configuration
23
+ The inference requires:
24
+ - GPU: at least 58GB GPU memory for 512 x 512 x 512
25
+ - Disk Memory: at least 21GB disk memory
26
+
27
+ #### Inference parameters:
28
+ The information for the inference input, like body region and anatomy to generate, is stored in [./configs/inference.json](../configs/inference.json). Please feel free to play with it. Here are the details of the parameters.
29
+
30
+ - `"num_output_samples"`: int, the number of output image/mask pairs it will generate.
31
+ - `"spacing"`: voxel size of generated images. E.g., if set to `[1.5, 1.5, 2.0]`, it will generate images with a resolution of 1.5&times;1.5&times;2.0 mm. The spacing for x and y axes has to be between 0.5 and 3.0 mm and the spacing for the z axis has to be between 0.5 and 5.0 mm.
32
+ - `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512&times;512&times;256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers. Note that `"spacing"` and `"output_size"` together decide the output field of view (FOV). For eample, if set them to `[1.5, 1.5, 2.0]`mm and `[512, 512, 256]`, the FOV is 768&times;768&times;512 mm. We recommend output_size is the FOV in x and y axis are same and to be at least 256mm for head, at least 384mm for other body regions like abdomen, and no larger than 640mm. The output size for the x and y axes can be selected from [256, 384, 512], while for the z axis, it can be chosen from [128, 256, 384, 512, 640, 768].
33
+ - `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. In addition, if the size scale is set to -1, it indicates that the organ does not exist or should be removed. The output will contain paired image and segmentation mask for the controllable anatomy.
34
+ The following organs support generation with a controllable size: ``["liver", "gallbladder", "stomach", "pancreas", "colon", "lung tumor", "bone lesion", "hepatic tumor", "colon cancer primaries", "pancreatic tumor"]``.
35
+ The raw output of the current mask generation model has a fixed size of $256^3$ voxels with a spacing of $1.5^3$ mm. If the "output_size" differs from this default, the generated masks will be resampled to the desired `"output_size"` and `"spacing"`. Note that resampling may degrade the quality of the generated masks and could trigger multiple inference attempts if the images fail to pass the [image quality check](../scripts/quality_check.py).
36
+ - `"body_region"`: Deprecated, please leave it as empty `"[]"`.
37
+ - `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json".
38
+ - `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
39
+ - `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time.
40
+
41
+ To generate images with substantial dimensions, such as 512 &times; 512 &times; 512 or larger, using GPUs with 80GB of memory, it is advisable to configure the `"num_splits"` parameter in [the auto-encoder configuration](./configs/config_maisi.json#L11-L37) to 16. This adjustment is crucial to avoid out-of-memory issues during inference.
42
+
43
+ #### Recommended spacing for different output sizes:
44
+
45
+ |`"output_size"`| Recommended `"spacing"`|
46
+ |:-----:|:-----:|
47
+ [256, 256, 256] | [1.5, 1.5, 1.5] |
48
+ [512, 512, 128] | [0.8, 0.8, 2.5] |
49
+ [512, 512, 512] | [1.0, 1.0, 1.0] |
50
+
51
+ ### Execute inference
52
+ The following code generates a synthetic image from a random sampled noise.
53
+ ```
54
+ python -m monai.bundle run --config_file configs/inference.json
55
+ ```
56
+
57
+ ## Execute Finetuning
58
+
59
+ ### Training configuration
60
+ The training was performed with the following:
61
+ - GPU: at least 60GB GPU memory for 512 x 512 x 512 volume
62
+ - Actual Model Input (the size of image embedding in latent space): 128 x 128 x 128
63
+ - AMP: True
64
+
65
+ ### Run finetuning:
66
+ This config executes finetuning for pretrained ControlNet with with a new class (i.e., Kidney Tumor). When finetuning with new class names, please update `configs/train.json`'s `weighted_loss_label` and `configs/label_dict.json` accordingly. There are 8 dummy labels as placeholders in default `configs/label_dict.json` that can be used for finetuning.
67
+ ```
68
+ python -m monai.bundle run --config_file configs/train.json
69
+ ```
70
+
71
+ ### Override the `train` config to execute multi-GPU training:
72
+
73
+ ```
74
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run --config_file "['configs/train.json','configs/multi_gpu_train.json']"
75
+ ```
76
+
77
+ ### Data:
78
+ The preprocessed subset of [C4KC-KiTS](https://www.cancerimagingarchive.net/collection/c4kc-kits/) dataset used in this finetuning config is provided in `./dataset/C4KC-KiTS_subset`.
79
+ ```
80
+ |-*arterial*.nii.gz # original image
81
+ |-*arterial_emb*.nii.gz # encoded image embedding
82
+ KiTS-000* --|-mask*.nii.gz # original labels
83
+ |-mask_pseudo_label*.nii.gz # pseudo labels
84
+ |-mask_combined_label*.nii.gz # combined mask of original and pseudo labels
85
+
86
+ ```
87
+ An example combined mask of original and pseudo labels is shown below:
88
+ ![example_combined_mask](https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_example_combined_mask.png)
89
+
90
+ Please note that the label of Kidney Tumor is mapped to index `129` in this preprocessed dataset. The encoded image embedding is generated by provided `Autoencoder` in `./models/autoencoder_epoch273.pt` during preprocessing to save memeory usage for training. The pseudo labels are generated by [VISTA 3D](https://github.com/Project-MONAI/VISTA). In addition, the dimension of each volume and corresponding pseudo label is resampled to the closest multiple of 128 (e.g., 128, 256, 384, 512, ...).
91
+
92
+ The training workflow requires one JSON file to specify the image embedding and segmentation pairs. The example file is located in the `./dataset/C4KC-KiTS_subset.json`.
93
+
94
+ The JSON file has the following structure:
95
+ ```python
96
+ {
97
+ "training": [
98
+ {
99
+ "image": "*/*arterial_emb*.nii.gz", # relative path to the image embedding file
100
+ "label": "*/mask_combined_label*.nii.gz", # relative path to the combined label file
101
+ "dim": [512, 512, 512], # the dimension of image
102
+ "spacing": [1.0, 1.0, 1.0], # the spacing of image
103
+ "top_region_index": [0, 1, 0, 0], # the top region index of the image
104
+ "bottom_region_index": [0, 0, 0, 1], # the bottom region index of the image
105
+ "fold": 0 # fold index for cross validation, fold 0 is used for training
106
+ },
107
+
108
+ ...
109
+ ]
110
+ }
111
+ ```
112
+
113
+ # References
114
+ [1] Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022. https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf
115
+
116
+ # License
117
+
118
+ ## Code License
119
+
120
+ This project includes code licensed under the Apache License 2.0.
121
+ You may obtain a copy of the License at
122
+
123
+ http://www.apache.org/licenses/LICENSE-2.0
124
+
125
+ ## Model Weights License
126
+
127
+ The model weights included in this project are licensed under the NCLS v1 License.
128
+
129
+ Both licenses' full texts have been combined into a single `LICENSE` file. Please refer to this `LICENSE` file for more details about the terms and conditions of both licenses.
docs/data_license.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third Party Licenses
2
+ -----------------------------------------------------------------------
3
+
4
+ /*********************************************************************/
5
+ i. Multimodal Brain Tumor Segmentation Challenge 2018
6
+ https://www.med.upenn.edu/sbia/brats2018/data.html
7
+ /*********************************************************************/
8
+
9
+ Data Usage Agreement / Citations
10
+
11
+ You are free to use and/or refer to the BraTS datasets in your own
12
+ research, provided that you always cite the following two manuscripts:
13
+
14
+ [1] Menze BH, Jakab A, Bauer S, Kalpathy-Cramer J, Farahani K, Kirby
15
+ [J, Burren Y, Porz N, Slotboom J, Wiest R, Lanczi L, Gerstner E, Weber
16
+ [MA, Arbel T, Avants BB, Ayache N, Buendia P, Collins DL, Cordier N,
17
+ [Corso JJ, Criminisi A, Das T, Delingette H, Demiralp Γ, Durst CR,
18
+ [Dojat M, Doyle S, Festa J, Forbes F, Geremia E, Glocker B, Golland P,
19
+ [Guo X, Hamamci A, Iftekharuddin KM, Jena R, John NM, Konukoglu E,
20
+ [Lashkari D, Mariz JA, Meier R, Pereira S, Precup D, Price SJ, Raviv
21
+ [TR, Reza SM, Ryan M, Sarikaya D, Schwartz L, Shin HC, Shotton J,
22
+ [Silva CA, Sousa N, Subbanna NK, Szekely G, Taylor TJ, Thomas OM,
23
+ [Tustison NJ, Unal G, Vasseur F, Wintermark M, Ye DH, Zhao L, Zhao B,
24
+ [Zikic D, Prastawa M, Reyes M, Van Leemput K. "The Multimodal Brain
25
+ [Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on
26
+ [Medical Imaging 34(10), 1993-2024 (2015) DOI:
27
+ [10.1109/TMI.2014.2377694
28
+
29
+ [2] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby JS,
30
+ [Freymann JB, Farahani K, Davatzikos C. "Advancing The Cancer Genome
31
+ [Atlas glioma MRI collections with expert segmentation labels and
32
+ [radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:
33
+ [10.1038/sdata.2017.117
34
+
35
+ In addition, if there are no restrictions imposed from the
36
+ journal/conference you submit your paper about citing "Data
37
+ Citations", please be specific and also cite the following:
38
+
39
+ [3] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby J,
40
+ [Freymann J, Farahani K, Davatzikos C. "Segmentation Labels and
41
+ [Radiomic Features for the Pre-operative Scans of the TCGA-GBM
42
+ [collection", The Cancer Imaging Archive, 2017. DOI:
43
+ [10.7937/K9/TCIA.2017.KLXWJJ1Q
44
+
45
+ [4] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby J,
46
+ [Freymann J, Farahani K, Davatzikos C. "Segmentation Labels and
47
+ [Radiomic Features for the Pre-operative Scans of the TCGA-LGG
48
+ [collection", The Cancer Imaging Archive, 2017. DOI:
49
+ [10.7937/K9/TCIA.2017.GJQ7R0EF
models/autoencoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f8a7a056d0ebc00486edc43c26768bf1c12eaa6df9dd172e34598003be95eb3
3
+ size 83831868
models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05fda7bfffde524d56cc2dc5b990f901216bc46c4b5e261404aebc409d27b78b
3
+ size 278366962
models/controlnet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:468c7c1d358530b9ebbdd643e4c1b1c1e4037df154e0bc15d21fc49e56a57f75
3
+ size 288255799
models/diffusion_unet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfc1ec59782f5ac7c0d22dd501654732109a971b93cbaa6607c4106a7f29066f
3
+ size 2166600232
models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b32be13118f9d6a077d42dd250c09c5e41673b48dbf2f35f2c587a7a9ebe5686
3
+ size 685298858
models/mask_generation_autoencoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:539175f6ede3cb1e6f01bfd6347cd446d601cf4a508fc632a1e36362b1428a5d
3
+ size 21072774
models/mask_generation_diffusion_unet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7d668b1356e9b94b8269decadf8f8116dc4ee2d365580d05349b4ddf6739155
3
+ size 788941780
scripts/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from . import sample, utils
scripts/augmentation.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from monai.transforms import Rand3DElastic, RandAffine, RandZoom
16
+ from monai.utils import ensure_tuple_rep
17
+
18
+
19
+ def erode3d(input_tensor, erosion=3):
20
+ # Define the structuring element
21
+ erosion = ensure_tuple_rep(erosion, 3)
22
+ structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device)
23
+
24
+ # Pad the input tensor to handle border pixels
25
+ input_padded = F.pad(
26
+ input_tensor.float().unsqueeze(0).unsqueeze(0),
27
+ (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2),
28
+ mode="constant",
29
+ value=1.0,
30
+ )
31
+
32
+ # Apply erosion operation
33
+ output = F.conv3d(input_padded, structuring_element, padding=0)
34
+
35
+ # Set output values based on the minimum value within the structuring element
36
+ output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0)
37
+
38
+ return output.squeeze(0).squeeze(0)
39
+
40
+
41
+ def dilate3d(input_tensor, erosion=3):
42
+ # Define the structuring element
43
+ erosion = ensure_tuple_rep(erosion, 3)
44
+ structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device)
45
+
46
+ # Pad the input tensor to handle border pixels
47
+ input_padded = F.pad(
48
+ input_tensor.float().unsqueeze(0).unsqueeze(0),
49
+ (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2),
50
+ mode="constant",
51
+ value=1.0,
52
+ )
53
+
54
+ # Apply erosion operation
55
+ output = F.conv3d(input_padded, structuring_element, padding=0)
56
+
57
+ # Set output values based on the minimum value within the structuring element
58
+ output = torch.where(output > 0, 1.0, 0.0)
59
+
60
+ return output.squeeze(0).squeeze(0)
61
+
62
+
63
+ def augmentation_tumor_bone(pt_nda, output_size, random_seed):
64
+ volume = pt_nda.squeeze(0)
65
+ real_l_volume_ = torch.zeros_like(volume)
66
+ real_l_volume_[volume == 128] = 1
67
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
68
+
69
+ elastic = RandAffine(
70
+ mode="nearest",
71
+ prob=1.0,
72
+ translate_range=(5, 5, 0),
73
+ rotate_range=(0, 0, 0.1),
74
+ scale_range=(0.15, 0.15, 0),
75
+ padding_mode="zeros",
76
+ )
77
+ elastic.set_random_state(seed=random_seed)
78
+
79
+ tumor_szie = torch.sum((real_l_volume_ > 0).float())
80
+ ###########################
81
+ # remove pred in pseudo_label in real lesion region
82
+ volume[real_l_volume_ > 0] = 200
83
+ ###########################
84
+ if tumor_szie > 0:
85
+ # get organ mask
86
+ organ_mask = (
87
+ torch.logical_and(33 <= volume, volume <= 56).float()
88
+ + torch.logical_and(63 <= volume, volume <= 97).float()
89
+ + (volume == 127).float()
90
+ + (volume == 114).float()
91
+ + real_l_volume_
92
+ )
93
+ organ_mask = (organ_mask > 0).float()
94
+ cnt = 0
95
+ while True:
96
+ threshold = 0.8 if cnt < 40 else 0.75
97
+ real_l_volume = real_l_volume_
98
+ # random distor mask
99
+ distored_mask = elastic((real_l_volume > 0).cuda(), spatial_size=tuple(output_size)).as_tensor()
100
+ real_l_volume = distored_mask * organ_mask
101
+ cnt += 1
102
+ print(torch.sum(real_l_volume), "|", tumor_szie * threshold)
103
+ if torch.sum(real_l_volume) >= tumor_szie * threshold:
104
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
105
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8)
106
+ break
107
+ else:
108
+ real_l_volume = real_l_volume_
109
+
110
+ volume[real_l_volume == 1] = 128
111
+
112
+ pt_nda = volume.unsqueeze(0)
113
+ return pt_nda
114
+
115
+
116
+ def augmentation_tumor_liver(pt_nda, output_size, random_seed):
117
+ volume = pt_nda.squeeze(0)
118
+ real_l_volume_ = torch.zeros_like(volume)
119
+ real_l_volume_[volume == 1] = 1
120
+ real_l_volume_[volume == 26] = 2
121
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
122
+
123
+ elastic = Rand3DElastic(
124
+ mode="nearest",
125
+ prob=1.0,
126
+ sigma_range=(5, 8),
127
+ magnitude_range=(100, 200),
128
+ translate_range=(10, 10, 10),
129
+ rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36),
130
+ scale_range=(0.2, 0.2, 0.2),
131
+ padding_mode="zeros",
132
+ )
133
+ elastic.set_random_state(seed=random_seed)
134
+
135
+ tumor_szie = torch.sum(real_l_volume_ == 2)
136
+ ###########################
137
+ # remove pred organ labels
138
+ volume[volume == 1] = 0
139
+ volume[volume == 26] = 0
140
+ # before move tumor maks, full the original location by organ labels
141
+ volume[real_l_volume_ == 1] = 1
142
+ volume[real_l_volume_ == 2] = 1
143
+ ###########################
144
+ while True:
145
+ real_l_volume = real_l_volume_
146
+ # random distor mask
147
+ real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor()
148
+ # get organ mask
149
+ organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float()
150
+
151
+ organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5)
152
+ organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0)
153
+ real_l_volume = real_l_volume * organ_mask
154
+ print(torch.sum(real_l_volume), "|", tumor_szie * 0.80)
155
+ if torch.sum(real_l_volume) >= tumor_szie * 0.80:
156
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
157
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0)
158
+ break
159
+
160
+ volume[real_l_volume == 1] = 26
161
+
162
+ pt_nda = volume.unsqueeze(0)
163
+ return pt_nda
164
+
165
+
166
+ def augmentation_tumor_lung(pt_nda, output_size, random_seed):
167
+ volume = pt_nda.squeeze(0)
168
+ real_l_volume_ = torch.zeros_like(volume)
169
+ real_l_volume_[volume == 23] = 1
170
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
171
+
172
+ elastic = Rand3DElastic(
173
+ mode="nearest",
174
+ prob=1.0,
175
+ sigma_range=(5, 8),
176
+ magnitude_range=(100, 200),
177
+ translate_range=(20, 20, 20),
178
+ rotate_range=(np.pi / 36, np.pi / 36, np.pi),
179
+ scale_range=(0.15, 0.15, 0.15),
180
+ padding_mode="zeros",
181
+ )
182
+ elastic.set_random_state(seed=random_seed)
183
+
184
+ tumor_szie = torch.sum(real_l_volume_)
185
+ # before move lung tumor maks, full the original location by lung labels
186
+ new_real_l_volume_ = dilate3d(real_l_volume_.squeeze(0), erosion=3)
187
+ new_real_l_volume_ = new_real_l_volume_.unsqueeze(0)
188
+ new_real_l_volume_[real_l_volume_ > 0] = 0
189
+ new_real_l_volume_[volume < 28] = 0
190
+ new_real_l_volume_[volume > 32] = 0
191
+ tmp = volume[(volume * new_real_l_volume_).nonzero(as_tuple=True)].view(-1)
192
+
193
+ mode = torch.mode(tmp, 0)[0].item()
194
+ print(mode)
195
+ assert 28 <= mode <= 32
196
+ volume[real_l_volume_.bool()] = mode
197
+ ###########################
198
+ if tumor_szie > 0:
199
+ # aug
200
+ while True:
201
+ real_l_volume = real_l_volume_
202
+ # random distor mask
203
+ real_l_volume = elastic(real_l_volume, spatial_size=tuple(output_size)).as_tensor()
204
+ # get lung mask v2 (133 order)
205
+ lung_mask = (
206
+ (volume == 28).float()
207
+ + (volume == 29).float()
208
+ + (volume == 30).float()
209
+ + (volume == 31).float()
210
+ + (volume == 32).float()
211
+ )
212
+
213
+ lung_mask = dilate3d(lung_mask.squeeze(0), erosion=5)
214
+ lung_mask = erode3d(lung_mask, erosion=5).unsqueeze(0)
215
+ real_l_volume = real_l_volume * lung_mask
216
+ print(torch.sum(real_l_volume), "|", tumor_szie * 0.85)
217
+ if torch.sum(real_l_volume) >= tumor_szie * 0.85:
218
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
219
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8)
220
+ break
221
+ else:
222
+ real_l_volume = real_l_volume_
223
+
224
+ volume[real_l_volume == 1] = 23
225
+
226
+ pt_nda = volume.unsqueeze(0)
227
+ return pt_nda
228
+
229
+
230
+ def augmentation_tumor_pancreas(pt_nda, output_size, random_seed):
231
+ volume = pt_nda.squeeze(0)
232
+ real_l_volume_ = torch.zeros_like(volume)
233
+ real_l_volume_[volume == 4] = 1
234
+ real_l_volume_[volume == 24] = 2
235
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
236
+
237
+ elastic = Rand3DElastic(
238
+ mode="nearest",
239
+ prob=1.0,
240
+ sigma_range=(5, 8),
241
+ magnitude_range=(100, 200),
242
+ translate_range=(15, 15, 15),
243
+ rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36),
244
+ scale_range=(0.1, 0.1, 0.1),
245
+ padding_mode="zeros",
246
+ )
247
+ elastic.set_random_state(seed=random_seed)
248
+
249
+ tumor_szie = torch.sum(real_l_volume_ == 2)
250
+ ###########################
251
+ # remove pred organ labels
252
+ volume[volume == 24] = 0
253
+ volume[volume == 4] = 0
254
+ # before move tumor maks, full the original location by organ labels
255
+ volume[real_l_volume_ == 1] = 4
256
+ volume[real_l_volume_ == 2] = 4
257
+ ###########################
258
+ while True:
259
+ real_l_volume = real_l_volume_
260
+ # random distor mask
261
+ real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor()
262
+ # get organ mask
263
+ organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float()
264
+
265
+ organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5)
266
+ organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0)
267
+ real_l_volume = real_l_volume * organ_mask
268
+ print(torch.sum(real_l_volume), "|", tumor_szie * 0.80)
269
+ if torch.sum(real_l_volume) >= tumor_szie * 0.80:
270
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
271
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0)
272
+ break
273
+
274
+ volume[real_l_volume == 1] = 24
275
+
276
+ pt_nda = volume.unsqueeze(0)
277
+ return pt_nda
278
+
279
+
280
+ def augmentation_tumor_colon(pt_nda, output_size, random_seed):
281
+ volume = pt_nda.squeeze(0)
282
+ real_l_volume_ = torch.zeros_like(volume)
283
+ real_l_volume_[volume == 27] = 1
284
+ real_l_volume_ = real_l_volume_.to(torch.uint8)
285
+
286
+ elastic = Rand3DElastic(
287
+ mode="nearest",
288
+ prob=1.0,
289
+ sigma_range=(5, 8),
290
+ magnitude_range=(100, 200),
291
+ translate_range=(5, 5, 5),
292
+ rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36),
293
+ scale_range=(0.1, 0.1, 0.1),
294
+ padding_mode="zeros",
295
+ )
296
+ elastic.set_random_state(seed=random_seed)
297
+
298
+ tumor_szie = torch.sum(real_l_volume_)
299
+ ###########################
300
+ # before move tumor maks, full the original location by organ labels
301
+ volume[real_l_volume_.bool()] = 62
302
+ ###########################
303
+ if tumor_szie > 0:
304
+ # get organ mask
305
+ organ_mask = (volume == 62).float()
306
+ organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5)
307
+ organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0)
308
+ # cnt = 0
309
+ cnt = 0
310
+ while True:
311
+ threshold = 0.8
312
+ real_l_volume = real_l_volume_
313
+ if cnt < 20:
314
+ # random distor mask
315
+ distored_mask = elastic((real_l_volume == 1).cuda(), spatial_size=tuple(output_size)).as_tensor()
316
+ real_l_volume = distored_mask * organ_mask
317
+ elif 20 <= cnt < 40:
318
+ threshold = 0.75
319
+ else:
320
+ break
321
+
322
+ real_l_volume = real_l_volume * organ_mask
323
+ print(torch.sum(real_l_volume), "|", tumor_szie * threshold)
324
+ cnt += 1
325
+ if torch.sum(real_l_volume) >= tumor_szie * threshold:
326
+ real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5)
327
+ real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8)
328
+ break
329
+ else:
330
+ real_l_volume = real_l_volume_
331
+ # break
332
+ volume[real_l_volume == 1] = 27
333
+
334
+ pt_nda = volume.unsqueeze(0)
335
+ return pt_nda
336
+
337
+
338
+ def augmentation_body(pt_nda, random_seed):
339
+ volume = pt_nda.squeeze(0)
340
+
341
+ zoom = RandZoom(min_zoom=0.99, max_zoom=1.01, mode="nearest", align_corners=None, prob=1.0)
342
+ zoom.set_random_state(seed=random_seed)
343
+
344
+ volume = zoom(volume)
345
+
346
+ pt_nda = volume.unsqueeze(0)
347
+ return pt_nda
348
+
349
+
350
+ def augmentation(pt_nda, output_size, random_seed):
351
+ label_list = torch.unique(pt_nda)
352
+ label_list = list(label_list.cpu().numpy())
353
+
354
+ if 128 in label_list:
355
+ print("augmenting bone lesion/tumor")
356
+ pt_nda = augmentation_tumor_bone(pt_nda, output_size, random_seed)
357
+ elif 26 in label_list:
358
+ print("augmenting liver tumor")
359
+ pt_nda = augmentation_tumor_liver(pt_nda, output_size, random_seed)
360
+ elif 23 in label_list:
361
+ print("augmenting lung tumor")
362
+ pt_nda = augmentation_tumor_lung(pt_nda, output_size, random_seed)
363
+ elif 24 in label_list:
364
+ print("augmenting pancreas tumor")
365
+ pt_nda = augmentation_tumor_pancreas(pt_nda, output_size, random_seed)
366
+ elif 27 in label_list:
367
+ print("augmenting colon tumor")
368
+ pt_nda = augmentation_tumor_colon(pt_nda, output_size, random_seed)
369
+ else:
370
+ print("augmenting body")
371
+ pt_nda = augmentation_body(pt_nda, random_seed)
372
+
373
+ return pt_nda
scripts/find_masks.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+
13
+ import json
14
+ import os
15
+ from typing import Sequence
16
+
17
+ from monai.apps.utils import extractall
18
+ from monai.utils import ensure_tuple_rep
19
+
20
+
21
+ def convert_body_region(body_region: str | Sequence[str]) -> Sequence[int]:
22
+ """
23
+ Convert body region string to body region index.
24
+ Args:
25
+ body_region: list of input body region string. If single str, will be converted to list of str.
26
+ Return:
27
+ body_region_indices, list of input body region index.
28
+ """
29
+ if type(body_region) is str:
30
+ body_region = [body_region]
31
+
32
+ # body region mapping for maisi
33
+ region_mapping_maisi = {
34
+ "head": 0,
35
+ "chest": 1,
36
+ "thorax": 1,
37
+ "chest/thorax": 1,
38
+ "abdomen": 2,
39
+ "pelvis": 3,
40
+ "lower": 3,
41
+ "pelvis/lower": 3,
42
+ }
43
+
44
+ # perform mapping
45
+ body_region_indices = []
46
+ for region in body_region:
47
+ normalized_region = region.lower() # norm str to lower case
48
+ if normalized_region not in region_mapping_maisi:
49
+ raise ValueError(f"Invalid region: {normalized_region}")
50
+ body_region_indices.append(region_mapping_maisi[normalized_region])
51
+
52
+ return body_region_indices
53
+
54
+
55
+ def find_masks(
56
+ anatomy_list: int | Sequence[int],
57
+ spacing: Sequence[float] | float = 1.0,
58
+ output_size: Sequence[int] = (512, 512, 512),
59
+ check_spacing_and_output_size: bool = False,
60
+ database_filepath: str = "./configs/database.json",
61
+ mask_foldername: str = "./datasets/masks/",
62
+ ):
63
+ """
64
+ Find candidate masks that fullfills all the requirements.
65
+ They shoud contain all the anatomies in `anatomy_list`.
66
+ If there is no tumor specified in `anatomy_list`, we also expect the candidate masks to be tumor free.
67
+ If check_spacing_and_output_size is True, the candidate masks need to have the expected `spacing` and `output_size`.
68
+ Args:
69
+ anatomy_list: list of input anatomy. The found candidate mask will include these anatomies.
70
+ spacing: list of three floats, voxel spacing. If providing a single number, will use it for all the three dimensions.
71
+ output_size: list of three int, expected candidate mask spatial size.
72
+ check_spacing_and_output_size: whether we expect candidate mask to have spatial size of `output_size`
73
+ and voxel size of `spacing`.
74
+ database_filepath: path for the json file that stores the information of all the candidate masks.
75
+ mask_foldername: directory that saves all the candidate masks.
76
+ Return:
77
+ candidate_masks, list of dict, each dict contains information of one candidate mask that fullfills all the requirements.
78
+ """
79
+ # check and preprocess input
80
+ if isinstance(anatomy_list, int):
81
+ anatomy_list = [anatomy_list]
82
+
83
+ spacing = ensure_tuple_rep(spacing, 3)
84
+
85
+ if not os.path.exists(mask_foldername):
86
+ zip_file_path = mask_foldername + ".zip"
87
+
88
+ if not os.path.isfile(zip_file_path):
89
+ raise ValueError(f"Please download {zip_file_path} following the instruction in ./datasets/README.md.")
90
+
91
+ print(f"Extracting {zip_file_path} to {os.path.dirname(zip_file_path)}")
92
+ extractall(filepath=zip_file_path, output_dir=os.path.dirname(zip_file_path), file_type="zip")
93
+ print(f"Unzipped {zip_file_path} to {mask_foldername}.")
94
+
95
+ if not os.path.isfile(database_filepath):
96
+ raise ValueError(f"Please download {database_filepath} following the instruction in ./datasets/README.md.")
97
+ with open(database_filepath, "r") as f:
98
+ db = json.load(f)
99
+
100
+ # select candidate_masks
101
+ candidate_masks = []
102
+ for _item in db:
103
+ if not set(anatomy_list).issubset(_item["label_list"]):
104
+ continue
105
+
106
+ # whether to keep this mask, default to be True.
107
+ keep_mask = True
108
+
109
+ for tumor_label in [23, 24, 26, 27, 128]:
110
+ # we skip those mask with tumors if users do not provide tumor label in anatomy_list
111
+ if tumor_label not in anatomy_list and tumor_label in _item["label_list"]:
112
+ keep_mask = False
113
+
114
+ if check_spacing_and_output_size:
115
+ # if the output_size and spacing are different with user's input, skip it
116
+ for axis in range(3):
117
+ if _item["dim"][axis] != output_size[axis] or _item["spacing"][axis] != spacing[axis]:
118
+ keep_mask = False
119
+
120
+ if keep_mask:
121
+ # if decide to keep this mask, we pack the information of this mask and add to final output.
122
+ candidate = {
123
+ "pseudo_label": os.path.join(mask_foldername, _item["pseudo_label_filename"]),
124
+ "spacing": _item["spacing"],
125
+ "dim": _item["dim"],
126
+ }
127
+
128
+ # Conditionally add the label to the candidate dictionary
129
+ if "label_filename" in _item:
130
+ candidate["label"] = os.path.join(mask_foldername, _item["label_filename"])
131
+
132
+ candidate_masks.append(candidate)
133
+
134
+ if len(candidate_masks) == 0 and not check_spacing_and_output_size:
135
+ raise ValueError("Cannot find body region with given anatomy list.")
136
+
137
+ return candidate_masks
scripts/quality_check.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import numpy as np
13
+
14
+
15
+ def get_masked_data(label_data, image_data, labels):
16
+ """
17
+ Extracts and returns the image data corresponding to specified labels within a 3D volume.
18
+
19
+ This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array.
20
+ The function handles cases with both a large and small number of labels, optimizing performance accordingly.
21
+
22
+ Args:
23
+ label_data (np.ndarray): A NumPy array containing label data, representing different anatomical
24
+ regions or classes in a 3D medical image.
25
+ image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions
26
+ will be extracted.
27
+ labels (list of int): A list of integers representing the label values to be used for masking.
28
+
29
+ Returns:
30
+ np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified
31
+ labels in `label_data`. If no labels are provided, an empty array is returned.
32
+
33
+ Raises:
34
+ ValueError: If `image_data` and `label_data` do not have the same shape.
35
+
36
+ Example:
37
+ label_int_dict = {"liver": [1], "kidney": [5, 14]}
38
+ masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"])
39
+ """
40
+
41
+ # Check if the shapes of image_data and label_data match
42
+ if image_data.shape != label_data.shape:
43
+ raise ValueError(
44
+ f"Shape mismatch: image_data has shape {image_data.shape}, "
45
+ f"but label_data has shape {label_data.shape}. They must be the same."
46
+ )
47
+
48
+ if not labels:
49
+ return np.array([]) # Return an empty array if no labels are provided
50
+
51
+ labels = list(set(labels)) # remove duplicate items
52
+
53
+ # Optimize performance based on the number of labels
54
+ num_label_acceleration_thresh = 3
55
+ if len(labels) >= num_label_acceleration_thresh:
56
+ # if many labels, np.isin is faster
57
+ mask = np.isin(label_data, labels)
58
+ else:
59
+ # Use logical OR to combine masks if the number of labels is small
60
+ mask = np.zeros_like(label_data, dtype=bool)
61
+ for label in labels:
62
+ mask = np.logical_or(mask, label_data == label)
63
+
64
+ # Retrieve the masked data
65
+ masked_data = image_data[mask.astype(bool)]
66
+
67
+ return masked_data
68
+
69
+
70
+ def is_outlier(statistics, image_data, label_data, label_int_dict):
71
+ """
72
+ Perform a quality check on the generated image by comparing its statistics with precomputed thresholds.
73
+
74
+ Args:
75
+ statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges.
76
+ image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array.
77
+ label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest.
78
+ label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists.
79
+ e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]}
80
+
81
+ Returns:
82
+ dict: A dictionary with labels as keys, each containing the quality check result,
83
+ including whether it's an outlier, the median value, and the thresholds used.
84
+ If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`.
85
+
86
+ Example:
87
+ # Example input data
88
+ statistics = {
89
+ "liver": {
90
+ "sigma_6_low": -21.596463547885904,
91
+ "sigma_6_high": 156.27881534763367
92
+ },
93
+ "kidney": {
94
+ "sigma_6_low": -15.0,
95
+ "sigma_6_high": 120.0
96
+ }
97
+ }
98
+ label_int_dict = {
99
+ "liver": [1],
100
+ "kidney": [5, 14]
101
+ }
102
+ image_data = np.random.rand(100, 100, 100) # Replace with actual image data
103
+ label_data = np.zeros((100, 100, 100)) # Replace with actual label data
104
+ label_data[40:60, 40:60, 40:60] = 1 # Example region for liver
105
+ label_data[70:90, 70:90, 70:90] = 5 # Example region for kidney
106
+ result = is_outlier(statistics, image_data, label_data, label_int_dict)
107
+ """
108
+ outlier_results = {}
109
+
110
+ for label_name, stats in statistics.items():
111
+ # Get the thresholds from the statistics
112
+ low_thresh = min(stats["sigma_6_low"], stats["percentile_0_5"]) # or "sigma_12_low" depending on your needs
113
+ high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs
114
+
115
+ if label_name == "bone":
116
+ high_thresh = 1000.0
117
+
118
+ # Retrieve the corresponding label integers
119
+ labels = label_int_dict.get(label_name, [])
120
+ masked_data = get_masked_data(label_data, image_data, labels)
121
+ masked_data = masked_data[~np.isnan(masked_data)]
122
+
123
+ if len(masked_data) == 0 or masked_data.size == 0:
124
+ outlier_results[label_name] = {
125
+ "is_outlier": False,
126
+ "median_value": None,
127
+ "low_thresh": low_thresh,
128
+ "high_thresh": high_thresh,
129
+ }
130
+ continue
131
+
132
+ # Compute the median of the masked region
133
+ median_value = np.nanmedian(masked_data)
134
+
135
+ if np.isnan(median_value):
136
+ median_value = None
137
+ is_outlier = False
138
+ else:
139
+ # Determine if the median value is an outlier
140
+ is_outlier = median_value < low_thresh or median_value > high_thresh
141
+
142
+ outlier_results[label_name] = {
143
+ "is_outlier": is_outlier,
144
+ "median_value": median_value,
145
+ "low_thresh": low_thresh,
146
+ "high_thresh": high_thresh,
147
+ }
148
+
149
+ return outlier_results
scripts/rectified_flow.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import torch
5
+ from monai.networks.schedulers import Scheduler
6
+ from torch.distributions import LogisticNormal
7
+
8
+ # code modified from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py
9
+
10
+
11
+ def timestep_transform(
12
+ t, input_img_size, base_img_size=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3
13
+ ):
14
+ t = t / num_train_timesteps
15
+ ratio_space = (input_img_size / base_img_size).pow(1.0 / spatial_dim)
16
+
17
+ ratio = ratio_space * scale
18
+ new_t = ratio * t / (1 + (ratio - 1) * t)
19
+
20
+ new_t = new_t * num_train_timesteps
21
+ return new_t
22
+
23
+
24
+ class RFlowScheduler(Scheduler):
25
+ def __init__(
26
+ self,
27
+ num_train_timesteps=1000,
28
+ num_inference_steps=10,
29
+ use_discrete_timesteps=False,
30
+ sample_method="uniform",
31
+ loc=0.0,
32
+ scale=1.0,
33
+ use_timestep_transform=False,
34
+ transform_scale=1.0,
35
+ steps_offset: int = 0,
36
+ ):
37
+ self.num_train_timesteps = num_train_timesteps
38
+ self.num_inference_steps = num_inference_steps
39
+ self.use_discrete_timesteps = use_discrete_timesteps
40
+
41
+ # sample method
42
+ assert sample_method in ["uniform", "logit-normal"]
43
+ # assert (
44
+ # sample_method == "uniform" or not use_discrete_timesteps
45
+ # ), "Only uniform sampling is supported for discrete timesteps"
46
+ self.sample_method = sample_method
47
+ if sample_method == "logit-normal":
48
+ self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
49
+ self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
50
+
51
+ # timestep transform
52
+ self.use_timestep_transform = use_timestep_transform
53
+ self.transform_scale = transform_scale
54
+ self.steps_offset = steps_offset
55
+
56
+ def add_noise(
57
+ self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
58
+ ) -> torch.FloatTensor:
59
+ """
60
+ compatible with diffusers add_noise()
61
+ """
62
+ timepoints = timesteps.float() / self.num_train_timesteps
63
+ timepoints = 1 - timepoints # [1,1/1000]
64
+
65
+ # timepoint (bsz) noise: (bsz, 4, frame, w ,h)
66
+ # expand timepoint to noise shape
67
+ timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
68
+ timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
69
+
70
+ return timepoints * original_samples + (1 - timepoints) * noise
71
+
72
+ def set_timesteps(
73
+ self,
74
+ num_inference_steps: int,
75
+ device: str | torch.device | None = None,
76
+ input_img_size: int | None = None,
77
+ base_img_size: int = 32 * 32 * 32,
78
+ ) -> None:
79
+ """
80
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
81
+
82
+ Args:
83
+ num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
84
+ device: target device to put the data.
85
+ input_img_size: int, H*W*D of the image, used with self.use_timestep_transform is True.
86
+ base_img_size: int, reference H*W*D size, used with self.use_timestep_transform is True.
87
+ """
88
+ if num_inference_steps > self.num_train_timesteps:
89
+ raise ValueError(
90
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
91
+ f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
92
+ f" maximal {self.num_train_timesteps} timesteps."
93
+ )
94
+
95
+ self.num_inference_steps = num_inference_steps
96
+ # prepare timesteps
97
+ timesteps = [
98
+ (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)
99
+ ]
100
+ if self.use_discrete_timesteps:
101
+ timesteps = [int(round(t)) for t in timesteps]
102
+ if self.use_timestep_transform:
103
+ timesteps = [
104
+ timestep_transform(
105
+ t,
106
+ input_img_size=input_img_size,
107
+ base_img_size=base_img_size,
108
+ num_train_timesteps=self.num_train_timesteps,
109
+ )
110
+ for t in timesteps
111
+ ]
112
+ timesteps = np.array(timesteps).astype(np.float16)
113
+ if self.use_discrete_timesteps:
114
+ timesteps = timesteps.astype(np.int64)
115
+ self.timesteps = torch.from_numpy(timesteps).to(device)
116
+ self.timesteps += self.steps_offset
117
+ print(self.timesteps)
118
+
119
+ def sample_timesteps(self, x_start):
120
+ if self.sample_method == "uniform":
121
+ t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps
122
+ elif self.sample_method == "logit-normal":
123
+ t = self.sample_t(x_start) * self.num_train_timesteps
124
+
125
+ if self.use_discrete_timesteps:
126
+ t = t.long()
127
+
128
+ if self.use_timestep_transform:
129
+ input_img_size = torch.prod(torch.tensor(x_start.shape[-3:]))
130
+ base_img_size = 32 * 32 * 32
131
+ t = timestep_transform(
132
+ t,
133
+ input_img_size=input_img_size,
134
+ base_img_size=base_img_size,
135
+ num_train_timesteps=self.num_train_timesteps,
136
+ )
137
+
138
+ return t
139
+
140
+ def step(
141
+ self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep=None
142
+ ) -> tuple[torch.Tensor, Any]:
143
+ """
144
+ Predict the sample at the previous timestep. Core function to propagate the diffusion
145
+ process from the learned model outputs.
146
+
147
+ Args:
148
+ model_output: direct output from learned diffusion model.
149
+ timestep: current discrete timestep in the diffusion chain.
150
+ sample: current instance of sample being created by diffusion process.
151
+ Returns:
152
+ pred_prev_sample: Predicted previous sample
153
+ None
154
+ """
155
+ v_pred = model_output
156
+ if next_timestep is None:
157
+ dt = 1.0 / self.num_inference_steps
158
+ else:
159
+ dt = timestep - next_timestep
160
+ dt = dt / self.num_train_timesteps
161
+ z = sample + v_pred * dt
162
+
163
+ return z, None
scripts/sample.py ADDED
@@ -0,0 +1,1036 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import json
13
+ import logging
14
+ import os
15
+ import random
16
+ import time
17
+ from datetime import datetime
18
+
19
+ import monai
20
+ import torch
21
+ from monai.data import MetaTensor
22
+ from monai.inferers.inferer import DiffusionInferer, SlidingWindowInferer
23
+ from monai.transforms import Compose, SaveImage
24
+ from monai.utils import set_determinism
25
+ from tqdm import tqdm
26
+
27
+ from .augmentation import augmentation
28
+ from .find_masks import find_masks
29
+ from .quality_check import is_outlier
30
+ from .utils import binarize_labels, dynamic_infer, general_mask_generation_post_process, remap_labels
31
+
32
+ modality_mapping = {
33
+ "unknown": 0,
34
+ "ct": 1,
35
+ "ct_wo_contrast": 2,
36
+ "ct_contrast": 3,
37
+ "mri": 8,
38
+ "mri_t1": 9,
39
+ "mri_t2": 10,
40
+ "mri_flair": 11,
41
+ "mri_pd": 12,
42
+ "mri_dwi": 13,
43
+ "mri_adc": 14,
44
+ "mri_ssfp": 15,
45
+ "mri_mra": 16,
46
+ } # current version only support "ct"
47
+
48
+
49
+ class ReconModel(torch.nn.Module):
50
+ """
51
+ A PyTorch module for reconstructing images from latent representations.
52
+
53
+ Attributes:
54
+ autoencoder: The autoencoder model used for decoding.
55
+ scale_factor: Scaling factor applied to the input before decoding.
56
+ """
57
+
58
+ def __init__(self, autoencoder, scale_factor):
59
+ super().__init__()
60
+ self.autoencoder = autoencoder
61
+ self.scale_factor = scale_factor
62
+
63
+ def forward(self, z):
64
+ """
65
+ Decode the input latent representation to an image.
66
+
67
+ Args:
68
+ z (torch.Tensor): The input latent representation.
69
+
70
+ Returns:
71
+ torch.Tensor: The reconstructed image.
72
+ """
73
+ recon_pt_nda = self.autoencoder.decode_stage_2_outputs(z / self.scale_factor)
74
+ return recon_pt_nda
75
+
76
+
77
+ def initialize_noise_latents(latent_shape, device):
78
+ """
79
+ Initialize random noise latents for image generation with float16.
80
+
81
+ Args:
82
+ latent_shape (tuple): The shape of the latent space.
83
+ device (torch.device): The device to create the tensor on.
84
+
85
+ Returns:
86
+ torch.Tensor: Initialized noise latents.
87
+ """
88
+ return torch.randn([1] + list(latent_shape)).half().to(device)
89
+
90
+
91
+ def ldm_conditional_sample_one_mask(
92
+ autoencoder,
93
+ diffusion_unet,
94
+ noise_scheduler,
95
+ scale_factor,
96
+ anatomy_size,
97
+ device,
98
+ latent_shape,
99
+ label_dict_remap_json,
100
+ num_inference_steps=1000,
101
+ autoencoder_sliding_window_infer_size=(96, 96, 96),
102
+ autoencoder_sliding_window_infer_overlap=0.6667,
103
+ ):
104
+ """
105
+ Generate a single synthetic mask using a latent diffusion model.
106
+
107
+ Args:
108
+ autoencoder (nn.Module): The autoencoder model.
109
+ diffusion_unet (nn.Module): The diffusion U-Net model.
110
+ noise_scheduler: The noise scheduler for the diffusion process.
111
+ scale_factor (float): Scaling factor for the latent space.
112
+ anatomy_size (torch.Tensor): Tensor specifying the desired anatomy sizes.
113
+ device (torch.device): The device to run the computation on.
114
+ latent_shape (tuple): The shape of the latent space.
115
+ label_dict_remap_json (str): Path to the JSON file for label remapping.
116
+ num_inference_steps (int): Number of inference steps for the diffusion process.
117
+ autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96].
118
+ autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667.
119
+
120
+ Returns:
121
+ torch.Tensor: The generated synthetic mask.
122
+ """
123
+ recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device)
124
+
125
+ with torch.no_grad(), torch.amp.autocast("cuda"):
126
+ # Generate random noise
127
+ latents = initialize_noise_latents(latent_shape, device)
128
+ anatomy_size = torch.FloatTensor(anatomy_size).unsqueeze(0).unsqueeze(0).half().to(device)
129
+ # synthesize latents
130
+ noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps)
131
+ inferer_ddpm = DiffusionInferer(noise_scheduler)
132
+ latents = inferer_ddpm.sample(
133
+ input_noise=latents,
134
+ diffusion_model=diffusion_unet,
135
+ scheduler=noise_scheduler,
136
+ verbose=True,
137
+ conditioning=anatomy_size.to(device),
138
+ )
139
+ # decode latents to synthesized masks
140
+ inferer = SlidingWindowInferer(
141
+ roi_size=autoencoder_sliding_window_infer_size,
142
+ sw_batch_size=1,
143
+ progress=True,
144
+ mode="gaussian",
145
+ overlap=autoencoder_sliding_window_infer_overlap,
146
+ device=torch.device("cpu"),
147
+ sw_device=device,
148
+ )
149
+ synthetic_mask = dynamic_infer(inferer, recon_model, latents)
150
+ synthetic_mask = torch.softmax(synthetic_mask, dim=1)
151
+ synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True)
152
+ # mapping raw index to 132 labels
153
+ synthetic_mask = remap_labels(synthetic_mask, label_dict_remap_json)
154
+
155
+ # post process
156
+ data = synthetic_mask.squeeze().cpu().detach().numpy()
157
+
158
+ labels = [23, 24, 26, 27, 128]
159
+ target_tumor_label = None
160
+ for index, size in enumerate(anatomy_size[0, 0, 5:10]):
161
+ if size.item() != -1.0:
162
+ target_tumor_label = labels[index]
163
+
164
+ logging.info(f"target_tumor_label for postprocess:{target_tumor_label}")
165
+ data = general_mask_generation_post_process(data, target_tumor_label=target_tumor_label, device=device)
166
+ synthetic_mask = torch.from_numpy(data).unsqueeze(0).unsqueeze(0).to(device)
167
+
168
+ return synthetic_mask
169
+
170
+
171
+ def ldm_conditional_sample_one_image(
172
+ autoencoder,
173
+ diffusion_unet,
174
+ controlnet,
175
+ noise_scheduler,
176
+ scale_factor,
177
+ device,
178
+ combine_label_or,
179
+ modality_tensor,
180
+ spacing_tensor,
181
+ latent_shape,
182
+ output_size,
183
+ noise_factor,
184
+ num_inference_steps=1000,
185
+ autoencoder_sliding_window_infer_size=(96, 96, 96),
186
+ autoencoder_sliding_window_infer_overlap=0.6667,
187
+ ):
188
+ """
189
+ Generate a single synthetic image using a latent diffusion model with controlnet.
190
+
191
+ Args:
192
+ autoencoder (nn.Module): The autoencoder model.
193
+ diffusion_unet (nn.Module): The diffusion U-Net model.
194
+ controlnet (nn.Module): The controlnet model.
195
+ noise_scheduler: The noise scheduler for the diffusion process.
196
+ scale_factor (float): Scaling factor for the latent space.
197
+ device (torch.device): The device to run the computation on.
198
+ combine_label_or (torch.Tensor): The combined label tensor.
199
+ spacing_tensor (torch.Tensor): Tensor specifying the spacing.
200
+ latent_shape (tuple): The shape of the latent space.
201
+ output_size (tuple): The desired output size of the image.
202
+ noise_factor (float): Factor to scale the initial noise.
203
+ num_inference_steps (int): Number of inference steps for the diffusion process.
204
+ autoencoder_sliding_window_infer_size (list, optional): Size of the sliding window for inference. Defaults to [96, 96, 96].
205
+ autoencoder_sliding_window_infer_overlap (float, optional): Overlap ratio for sliding window inference. Defaults to 0.6667.
206
+
207
+ Returns:
208
+ tuple: A tuple containing the synthetic image and its corresponding label.
209
+ """
210
+ # CT image intensity range
211
+ a_min = -1000
212
+ a_max = 1000
213
+ # autoencoder output intensity range
214
+ b_min = 0.0
215
+ b_max = 1
216
+
217
+ recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device)
218
+
219
+ with torch.no_grad(), torch.amp.autocast("cuda", enabled=True):
220
+ logging.info("---- Start generating latent features... ----")
221
+ start_time = time.time()
222
+ # generate segmentation mask
223
+ combine_label = combine_label_or.to(device)
224
+ if (
225
+ output_size[0] != combine_label.shape[2]
226
+ or output_size[1] != combine_label.shape[3]
227
+ or output_size[2] != combine_label.shape[4]
228
+ ):
229
+ logging.info(
230
+ "output_size is not a desired value. Need to interpolate the mask to match "
231
+ "with output_size. The result image will be very low quality."
232
+ )
233
+ combine_label = torch.nn.functional.interpolate(combine_label, size=output_size, mode="nearest")
234
+
235
+ controlnet_cond_vis = binarize_labels(combine_label.as_tensor().long()).half()
236
+
237
+ # Generate random noise
238
+ latents = initialize_noise_latents(latent_shape, device) * noise_factor
239
+
240
+ # synthesize latents
241
+ noise_scheduler.set_timesteps(
242
+ num_inference_steps=num_inference_steps, input_img_size=torch.prod(torch.tensor(latent_shape[-3:]))
243
+ )
244
+ # synthesize latents
245
+ guidance_scale = 0 # API for classifier-free guidence, not used in this version
246
+ all_next_timesteps = torch.cat(
247
+ (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))
248
+ )
249
+ for t, next_t in tqdm(
250
+ zip(noise_scheduler.timesteps, all_next_timesteps),
251
+ total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)),
252
+ ):
253
+ timesteps = torch.Tensor((t,)).to(device)
254
+ if guidance_scale == 0:
255
+ down_block_res_samples, mid_block_res_sample = controlnet(
256
+ x=latents, timesteps=timesteps, controlnet_cond=controlnet_cond_vis, class_labels=modality_tensor
257
+ )
258
+ predicted_velocity = diffusion_unet(
259
+ x=latents,
260
+ timesteps=timesteps,
261
+ spacing_tensor=spacing_tensor,
262
+ class_labels=modality_tensor,
263
+ down_block_additional_residuals=down_block_res_samples,
264
+ mid_block_additional_residual=mid_block_res_sample,
265
+ )
266
+ else:
267
+ down_block_res_samples, mid_block_res_sample = controlnet(
268
+ x=torch.cat([latents] * 2),
269
+ timesteps=torch.cat([timesteps] * 2),
270
+ controlnet_cond=torch.cat([controlnet_cond_vis] * 2),
271
+ class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]),
272
+ )
273
+ model_t, model_uncond = diffusion_unet(
274
+ x=torch.cat([latents] * 2),
275
+ timesteps=timesteps,
276
+ spacing_tensor=torch.cat([timesteps] * 2),
277
+ class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]),
278
+ down_block_additional_residuals=down_block_res_samples,
279
+ mid_block_additional_residual=mid_block_res_sample,
280
+ ).chunk(2)
281
+ predicted_velocity = model_uncond + guidance_scale * (model_t - model_uncond)
282
+ latents, _ = noise_scheduler.step(predicted_velocity, t, latents, next_timestep=next_t)
283
+ end_time = time.time()
284
+ logging.info(f"---- Latent features generation time: {end_time - start_time} seconds ----")
285
+ del predicted_velocity
286
+ torch.cuda.empty_cache()
287
+
288
+ # decode latents to synthesized images
289
+ logging.info("---- Start decoding latent features into images... ----")
290
+ inferer = SlidingWindowInferer(
291
+ roi_size=autoencoder_sliding_window_infer_size,
292
+ sw_batch_size=1,
293
+ progress=True,
294
+ mode="gaussian",
295
+ overlap=autoencoder_sliding_window_infer_overlap,
296
+ device=torch.device("cpu"),
297
+ sw_device=device,
298
+ )
299
+ start_time = time.time()
300
+ synthetic_images = dynamic_infer(inferer, recon_model, latents)
301
+ synthetic_images = torch.clip(synthetic_images, b_min, b_max).cpu()
302
+ end_time = time.time()
303
+ logging.info(f"---- Image decoding time: {end_time - start_time} seconds ----")
304
+
305
+ # post processing:
306
+ # project output to [0, 1]
307
+ synthetic_images = (synthetic_images - b_min) / (b_max - b_min)
308
+ # project output to [-1000, 1000]
309
+ synthetic_images = synthetic_images * (a_max - a_min) + a_min
310
+ # regularize background intensities
311
+ synthetic_images = crop_img_body_mask(synthetic_images, combine_label)
312
+ torch.cuda.empty_cache()
313
+
314
+ return synthetic_images, combine_label
315
+
316
+
317
+ def filter_mask_with_organs(combine_label, anatomy_list):
318
+ """
319
+ Filter a mask to only include specified organs.
320
+
321
+ Args:
322
+ combine_label (torch.Tensor): The input mask.
323
+ anatomy_list (list): List of organ labels to keep.
324
+
325
+ Returns:
326
+ torch.Tensor: The filtered mask.
327
+ """
328
+ # final output mask file has shape of output_size, contains labels in anatomy_list
329
+ # it is already interpolated to target size
330
+ combine_label = combine_label.long()
331
+ # filter out the organs that are not in anatomy_list
332
+ for i in range(len(anatomy_list)):
333
+ organ = anatomy_list[i]
334
+ # replace it with a negative value so it will get mixed
335
+ combine_label[combine_label == organ] = -(i + 1)
336
+ # zero-out voxels with value not in anatomy_list
337
+ combine_label[combine_label > 0] = 0
338
+ # output positive values
339
+ combine_label = -combine_label
340
+ return combine_label
341
+
342
+
343
+ def crop_img_body_mask(synthetic_images, combine_label):
344
+ """
345
+ Crop the synthetic image using a body mask.
346
+
347
+ Args:
348
+ synthetic_images (torch.Tensor): The synthetic images.
349
+ combine_label (torch.Tensor): The body mask.
350
+
351
+ Returns:
352
+ torch.Tensor: The cropped synthetic images.
353
+ """
354
+ synthetic_images[combine_label == 0] = -1000
355
+ return synthetic_images
356
+
357
+
358
+ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing, controllable_anatomy_size):
359
+ """
360
+ Validate input parameters for image generation.
361
+
362
+ Args:
363
+ body_region (list): List of body regions.
364
+ anatomy_list (list): List of anatomical structures.
365
+ label_dict_json (str): Path to the label dictionary JSON file.
366
+ output_size (tuple): Desired output size of the image.
367
+ spacing (tuple): Desired voxel spacing.
368
+ controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes.
369
+
370
+ Raises:
371
+ ValueError: If any input parameter is invalid.
372
+ """
373
+ # check output_size and spacing format
374
+ if output_size[0] != output_size[1]:
375
+ raise ValueError(f"The first two components of output_size need to be equal, yet got {output_size}.")
376
+ if (output_size[0] not in [256, 384, 512]) or (output_size[2] not in [128, 256, 384, 512, 640, 768]):
377
+ raise ValueError(
378
+ (
379
+ "The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] "
380
+ f"have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}."
381
+ )
382
+ )
383
+
384
+ if spacing[0] != spacing[1]:
385
+ raise ValueError(f"The first two components of spacing need to be equal, yet got {spacing}.")
386
+ if spacing[0] < 0.5 or spacing[0] > 3.0 or spacing[2] < 0.5 or spacing[2] > 5.0:
387
+ raise ValueError(
388
+ f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}."
389
+ )
390
+
391
+ if (
392
+ output_size[0] * spacing[0] < 256
393
+ or output_size[2] * spacing[2] < 128
394
+ or output_size[0] * spacing[0] > 640
395
+ or output_size[2] * spacing[2] > 2000
396
+ ):
397
+ fov = [output_size[axis] * spacing[axis] for axis in range(3)]
398
+ raise ValueError(
399
+ (
400
+ f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). "
401
+ f"The FOV will be {fov}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least "
402
+ "384mm for other body regions like abdomen, and less than 640mm. "
403
+ "For z-axis, we require it to be at least 128mm and less than 2000mm."
404
+ )
405
+ )
406
+
407
+ # check controllable_anatomy_size format
408
+ if len(controllable_anatomy_size) > 10:
409
+ raise ValueError(
410
+ (
411
+ "The output_size[0] have to be chosen from [256, 384, 512], and output_size[2] "
412
+ f"have to be chosen from [128, 256, 384, 512, 640, 768], yet got {output_size}."
413
+ )
414
+ )
415
+ available_controllable_organ = ["liver", "gallbladder", "stomach", "pancreas", "colon"]
416
+ available_controllable_tumor = [
417
+ "hepatic tumor",
418
+ "bone lesion",
419
+ "lung tumor",
420
+ "colon cancer primaries",
421
+ "pancreatic tumor",
422
+ ]
423
+ available_controllable_anatomy = available_controllable_organ + available_controllable_tumor
424
+ controllable_tumor = []
425
+ controllable_organ = []
426
+ for controllable_anatomy_size_pair in controllable_anatomy_size:
427
+ if controllable_anatomy_size_pair[0] not in available_controllable_anatomy:
428
+ raise ValueError(
429
+ (
430
+ f"The controllable_anatomy have to be chosen from {available_controllable_anatomy}, "
431
+ f"yet got {controllable_anatomy_size_pair[0]}."
432
+ )
433
+ )
434
+ if controllable_anatomy_size_pair[0] in available_controllable_tumor:
435
+ controllable_tumor += [controllable_anatomy_size_pair[0]]
436
+ if controllable_anatomy_size_pair[0] in available_controllable_organ:
437
+ controllable_organ += [controllable_anatomy_size_pair[0]]
438
+ if controllable_anatomy_size_pair[1] == -1:
439
+ continue
440
+ if controllable_anatomy_size_pair[1] < 0 or controllable_anatomy_size_pair[1] > 1.0:
441
+ raise ValueError(
442
+ (
443
+ "The controllable size scale have to be between 0 and 1,0, or equal to -1, "
444
+ f"yet got {controllable_anatomy_size_pair[1]}."
445
+ )
446
+ )
447
+ if len(controllable_tumor + controllable_organ) != len(list(set(controllable_tumor + controllable_organ))):
448
+ raise ValueError(f"Please do not repeat controllable_anatomy. Got {controllable_tumor + controllable_organ}.")
449
+ if len(controllable_tumor) > 1:
450
+ raise ValueError(f"Only one controllable tumor is supported. Yet got {controllable_tumor}.")
451
+
452
+ if len(controllable_anatomy_size) > 0:
453
+ logging.info(
454
+ (
455
+ "`controllable_anatomy_size` is not empty.\nWe will ignore `body_region` and `anatomy_list` "
456
+ f"and synthesize based on `controllable_anatomy_size`: ({controllable_anatomy_size})."
457
+ )
458
+ )
459
+ else:
460
+ logging.info(
461
+ (f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `anatomy_list`: ({anatomy_list}).")
462
+ )
463
+ # check body_region format
464
+ available_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"]
465
+ for region in body_region:
466
+ if region not in available_body_region:
467
+ raise ValueError(
468
+ f"The components in body_region have to be chosen from {available_body_region}, yet got {region}."
469
+ )
470
+
471
+ # check anatomy_list format
472
+ with open(label_dict_json) as f:
473
+ label_dict = json.load(f)
474
+ for anatomy in anatomy_list:
475
+ if anatomy not in label_dict.keys():
476
+ raise ValueError(
477
+ f"The components in anatomy_list have to be chosen from {label_dict.keys()}, yet got {anatomy}."
478
+ )
479
+ logging.info(f"The generate results will have voxel size to be {spacing} mm, volume size to be {output_size}.")
480
+
481
+ return
482
+
483
+
484
+ class LDMSampler:
485
+ """
486
+ A sampler class for generating synthetic medical images and masks using latent diffusion models.
487
+
488
+ Attributes:
489
+ Various attributes related to model configuration, input parameters, and generation settings.
490
+ """
491
+
492
+ def __init__(
493
+ self,
494
+ body_region,
495
+ anatomy_list,
496
+ modality,
497
+ all_mask_files_json,
498
+ all_anatomy_size_condtions_json,
499
+ all_mask_files_base_dir,
500
+ label_dict_json,
501
+ label_dict_remap_json,
502
+ autoencoder,
503
+ diffusion_unet,
504
+ controlnet,
505
+ noise_scheduler,
506
+ scale_factor,
507
+ mask_generation_autoencoder,
508
+ mask_generation_diffusion_unet,
509
+ mask_generation_scale_factor,
510
+ mask_generation_noise_scheduler,
511
+ device,
512
+ latent_shape,
513
+ mask_generation_latent_shape,
514
+ output_size,
515
+ output_dir,
516
+ controllable_anatomy_size,
517
+ image_output_ext=".nii.gz",
518
+ label_output_ext=".nii.gz",
519
+ real_img_median_statistics="./configs/image_median_statistics.json",
520
+ spacing=(1, 1, 1),
521
+ num_inference_steps=None,
522
+ mask_generation_num_inference_steps=None,
523
+ random_seed=None,
524
+ autoencoder_sliding_window_infer_size=(96, 96, 96),
525
+ autoencoder_sliding_window_infer_overlap=0.6667,
526
+ ) -> None:
527
+ """
528
+ Initialize the LDMSampler with various parameters and models.
529
+
530
+ Args:
531
+ Various parameters related to model configuration, input settings, and output specifications.
532
+ """
533
+ self.random_seed = random_seed
534
+ if random_seed is not None:
535
+ set_determinism(seed=random_seed)
536
+
537
+ with open(label_dict_json, "r") as f:
538
+ label_dict = json.load(f)
539
+ self.all_anatomy_size_condtions_json = all_anatomy_size_condtions_json
540
+
541
+ # intialize variables
542
+ self.body_region = body_region
543
+ self.anatomy_list = [label_dict[organ] for organ in anatomy_list]
544
+ self.modality_int = modality_mapping[modality]
545
+ self.all_mask_files_json = all_mask_files_json
546
+ self.data_root = all_mask_files_base_dir
547
+ self.label_dict_remap_json = label_dict_remap_json
548
+ self.autoencoder = autoencoder
549
+ self.diffusion_unet = diffusion_unet
550
+ self.controlnet = controlnet
551
+ self.noise_scheduler = noise_scheduler
552
+ self.scale_factor = scale_factor
553
+ self.mask_generation_autoencoder = mask_generation_autoencoder
554
+ self.mask_generation_diffusion_unet = mask_generation_diffusion_unet
555
+ self.mask_generation_scale_factor = mask_generation_scale_factor
556
+ self.mask_generation_noise_scheduler = mask_generation_noise_scheduler
557
+ self.device = device
558
+ self.latent_shape = latent_shape
559
+ self.mask_generation_latent_shape = mask_generation_latent_shape
560
+ self.output_size = output_size
561
+ self.output_dir = output_dir
562
+ self.noise_factor = 1.0
563
+ self.controllable_anatomy_size = controllable_anatomy_size
564
+ if len(self.controllable_anatomy_size):
565
+ logging.info("controllable_anatomy_size is given, mask generation is triggered!")
566
+ # overwrite the anatomy_list by given organs in self.controllable_anatomy_size
567
+ self.anatomy_list = [label_dict[organ_and_size[0]] for organ_and_size in self.controllable_anatomy_size]
568
+ self.image_output_ext = image_output_ext
569
+ self.label_output_ext = label_output_ext
570
+ # Set the default value for number of inference steps to 1000
571
+ self.num_inference_steps = num_inference_steps if num_inference_steps is not None else 1000
572
+ self.mask_generation_num_inference_steps = (
573
+ mask_generation_num_inference_steps if mask_generation_num_inference_steps is not None else 1000
574
+ )
575
+
576
+ if any(size % 16 != 0 for size in autoencoder_sliding_window_infer_size):
577
+ raise ValueError(
578
+ f"autoencoder_sliding_window_infer_size must be divisible by 16.\n Got {autoencoder_sliding_window_infer_size}"
579
+ )
580
+ if not (0 <= autoencoder_sliding_window_infer_overlap <= 1):
581
+ raise ValueError(
582
+ (
583
+ "Value of autoencoder_sliding_window_infer_overlap must be between 0 "
584
+ f"and 1.\n Got {autoencoder_sliding_window_infer_overlap}"
585
+ )
586
+ )
587
+ self.autoencoder_sliding_window_infer_size = autoencoder_sliding_window_infer_size
588
+ self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap
589
+
590
+ # quality check args
591
+ self.max_try_time = 3 # if not pass quality check, will try self.max_try_time times
592
+ with open(real_img_median_statistics, "r") as json_file:
593
+ self.median_statistics = json.load(json_file)
594
+ self.label_int_dict = {
595
+ "liver": [1],
596
+ "spleen": [3],
597
+ "pancreas": [4],
598
+ "kidney": [5, 14],
599
+ "lung": [28, 29, 30, 31, 31],
600
+ "brain": [22],
601
+ "hepatic tumor": [26],
602
+ "bone lesion": [128],
603
+ "lung tumor": [23],
604
+ "colon cancer primaries": [27],
605
+ "pancreatic tumor": [24],
606
+ "bone": list(range(33, 57)) + list(range(63, 98)) + [120, 122, 127],
607
+ }
608
+
609
+ # networks
610
+ self.autoencoder.eval()
611
+ self.diffusion_unet.eval()
612
+ self.controlnet.eval()
613
+ self.mask_generation_autoencoder.eval()
614
+ self.mask_generation_diffusion_unet.eval()
615
+
616
+ self.spacing = spacing
617
+
618
+ self.val_transforms = Compose(
619
+ [
620
+ monai.transforms.LoadImaged(keys=["pseudo_label"]),
621
+ monai.transforms.EnsureChannelFirstd(keys=["pseudo_label"]),
622
+ monai.transforms.Orientationd(keys=["pseudo_label"], axcodes="RAS"),
623
+ monai.transforms.EnsureTyped(keys=["pseudo_label"], dtype=torch.uint8),
624
+ monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)),
625
+ monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
626
+ ]
627
+ )
628
+ logging.info("LDM sampler initialized.")
629
+
630
+ def sample_multiple_images(self, num_img):
631
+ """
632
+ Generate multiple synthetic images and masks.
633
+
634
+ Args:
635
+ num_img (int): Number of images to generate.
636
+ """
637
+ output_filenames = []
638
+ if len(self.controllable_anatomy_size) > 0:
639
+ # we will use mask generation instead of finding candidate masks
640
+ # create a dummy selected_mask_files for placeholder
641
+ selected_mask_files = list(range(num_img))
642
+ # prerpare organ size conditions
643
+ anatomy_size_condtion = self.prepare_anatomy_size_condtion(self.controllable_anatomy_size)
644
+ else:
645
+ need_resample = False
646
+ # find candidate mask and save to candidate_mask_files
647
+ candidate_mask_files = find_masks(
648
+ self.anatomy_list, self.spacing, self.output_size, True, self.all_mask_files_json, self.data_root
649
+ )
650
+ if len(candidate_mask_files) < num_img:
651
+ # if we cannot find enough masks based on the exact match of anatomy list, spacing, and output size,
652
+ # then we will try to find the closest mask in terms of spacing, and output size.
653
+ logging.info("Resample mask file to get desired output size and spacing")
654
+ candidate_mask_files = self.find_closest_masks(num_img)
655
+ need_resample = True
656
+
657
+ selected_mask_files = self.select_mask(candidate_mask_files, num_img)
658
+ if len(selected_mask_files) < num_img:
659
+ raise ValueError(
660
+ (
661
+ f"len(selected_mask_files) ({len(selected_mask_files)}) < num_img ({num_img}). "
662
+ "This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)."
663
+ )
664
+ )
665
+ num_generated_img = 0
666
+ for index_s in range(len(selected_mask_files)):
667
+ item = selected_mask_files[index_s]
668
+ if num_generated_img >= num_img:
669
+ break
670
+ logging.info("---- Start preparing masks... ----")
671
+ start_time = time.time()
672
+ logging.info(f"Image will be generated based on {item}.")
673
+ if len(self.controllable_anatomy_size) > 0:
674
+ # generate a synthetic mask
675
+ (combine_label_or, spacing_tensor) = self.prepare_one_mask_and_meta_info(anatomy_size_condtion)
676
+ else:
677
+ # read in mask file
678
+ mask_file = item["mask_file"]
679
+ if_aug = item["if_aug"]
680
+ (combine_label_or, spacing_tensor) = self.read_mask_information(mask_file)
681
+ if need_resample:
682
+ combine_label_or = self.ensure_output_size_and_spacing(combine_label_or)
683
+ # mask augmentation
684
+ if if_aug:
685
+ combine_label_or = augmentation(combine_label_or, self.output_size, random_seed=self.random_seed)
686
+ end_time = time.time()
687
+ logging.info(f"---- Mask preparation time: {end_time - start_time} seconds ----")
688
+ torch.cuda.empty_cache()
689
+ # generate image/label pairs
690
+ modality_tensor = torch.ones_like(spacing_tensor[:, 0]).long() * self.modality_int
691
+ # start generation
692
+ synthetic_images, synthetic_labels = self.sample_one_pair(combine_label_or, modality_tensor, spacing_tensor)
693
+ # synthetic image quality check
694
+ pass_quality_check = self.quality_check(
695
+ synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy()
696
+ )
697
+ if pass_quality_check or (num_img - num_generated_img) >= (len(selected_mask_files) - index_s):
698
+ if not pass_quality_check:
699
+ logging.info(
700
+ "Generated image/label pair did not pass quality check, but will still save them. "
701
+ "Please consider changing spacing and output_size to facilitate a more realistic setting."
702
+ )
703
+ num_generated_img = num_generated_img + 1
704
+ # save image/label pairs
705
+ output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
706
+ synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz"
707
+ synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta)
708
+ img_saver = SaveImage(
709
+ output_dir=self.output_dir,
710
+ output_postfix=output_postfix + "_image",
711
+ output_ext=self.image_output_ext,
712
+ separate_folder=False,
713
+ )
714
+ img_saver(synthetic_images[0])
715
+ synthetic_images_filename = os.path.join(
716
+ self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext
717
+ )
718
+ # filter out the organs that are not in anatomy_list
719
+ synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list)
720
+ label_saver = SaveImage(
721
+ output_dir=self.output_dir,
722
+ output_postfix=output_postfix + "_label",
723
+ output_ext=self.label_output_ext,
724
+ separate_folder=False,
725
+ )
726
+ label_saver(synthetic_labels[0])
727
+ synthetic_labels_filename = os.path.join(
728
+ self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext
729
+ )
730
+ output_filenames.append([synthetic_images_filename, synthetic_labels_filename])
731
+ else:
732
+ logging.info("Generated image/label pair did not pass quality check, will re-generate another pair.")
733
+ return output_filenames
734
+
735
+ def select_mask(self, candidate_mask_files, num_img):
736
+ """
737
+ Select mask files for image generation.
738
+
739
+ Args:
740
+ candidate_mask_files (list): List of candidate mask files.
741
+ num_img (int): Number of images to generate.
742
+
743
+ Returns:
744
+ list: Selected mask files with augmentation flags.
745
+ """
746
+ selected_mask_files = []
747
+ random.shuffle(candidate_mask_files)
748
+
749
+ for n in range(num_img * self.max_try_time):
750
+ mask_file = candidate_mask_files[n % len(candidate_mask_files)]
751
+ selected_mask_files.append({"mask_file": mask_file, "if_aug": True})
752
+ return selected_mask_files
753
+
754
+ def sample_one_pair(self, combine_label_or_aug, modality_tensor, spacing_tensor):
755
+ """
756
+ Generate a single pair of synthetic image and mask.
757
+
758
+ Args:
759
+ combine_label_or_aug (torch.Tensor): Combined label tensor or augmented label.
760
+ modality_tensor (torch.Tensor): Tensor specifying the image modality.
761
+ spacing_tensor (torch.Tensor): Tensor specifying the spacing.
762
+
763
+ Returns:
764
+ tuple: A tuple containing the synthetic image and its corresponding label.
765
+ """
766
+ # generate image/label pairs
767
+ synthetic_images, synthetic_labels = ldm_conditional_sample_one_image(
768
+ autoencoder=self.autoencoder,
769
+ diffusion_unet=self.diffusion_unet,
770
+ controlnet=self.controlnet,
771
+ noise_scheduler=self.noise_scheduler,
772
+ scale_factor=self.scale_factor,
773
+ device=self.device,
774
+ combine_label_or=combine_label_or_aug,
775
+ modality_tensor=modality_tensor,
776
+ spacing_tensor=spacing_tensor,
777
+ latent_shape=self.latent_shape,
778
+ output_size=self.output_size,
779
+ noise_factor=self.noise_factor,
780
+ num_inference_steps=self.num_inference_steps,
781
+ autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size,
782
+ autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap,
783
+ )
784
+ return synthetic_images, synthetic_labels
785
+
786
+ def prepare_anatomy_size_condtion(self, controllable_anatomy_size):
787
+ """
788
+ Prepare anatomy size conditions for mask generation.
789
+
790
+ Args:
791
+ controllable_anatomy_size (list): List of tuples specifying controllable anatomy sizes.
792
+
793
+ Returns:
794
+ list: Prepared anatomy size conditions.
795
+ """
796
+ anatomy_size_idx = {
797
+ "gallbladder": 0,
798
+ "liver": 1,
799
+ "stomach": 2,
800
+ "pancreas": 3,
801
+ "colon": 4,
802
+ "lung tumor": 5,
803
+ "pancreatic tumor": 6,
804
+ "hepatic tumor": 7,
805
+ "colon cancer primaries": 8,
806
+ "bone lesion": 9,
807
+ }
808
+ provide_anatomy_size = [None for _ in range(10)]
809
+ logging.info(f"controllable_anatomy_size: {controllable_anatomy_size}")
810
+ for element in controllable_anatomy_size:
811
+ anatomy_name, anatomy_size = element
812
+ provide_anatomy_size[anatomy_size_idx[anatomy_name]] = anatomy_size
813
+
814
+ with open(self.all_anatomy_size_condtions_json, "r") as f:
815
+ all_anatomy_size_condtions = json.load(f)
816
+
817
+ # loop through the database and find closest combinations
818
+ candidate_list = []
819
+ for anatomy_size in all_anatomy_size_condtions:
820
+ size = anatomy_size["organ_size"]
821
+ diff = 0
822
+ for db_size, provide_size in zip(size, provide_anatomy_size):
823
+ if provide_size is None:
824
+ continue
825
+ diff += abs(provide_size - db_size)
826
+ candidate_list.append((size, diff))
827
+ candidate_condition = sorted(candidate_list, key=lambda x: x[1])[0][0]
828
+
829
+ # overwrite the anatomy size provided by users
830
+ for element in controllable_anatomy_size:
831
+ anatomy_name, anatomy_size = element
832
+ candidate_condition[anatomy_size_idx[anatomy_name]] = anatomy_size
833
+
834
+ return candidate_condition
835
+
836
+ def prepare_one_mask_and_meta_info(self, anatomy_size_condtion):
837
+ """
838
+ Prepare a single mask and its associated meta information.
839
+
840
+ Args:
841
+ anatomy_size_condtion (list): Anatomy size conditions.
842
+
843
+ Returns:
844
+ tuple: A tuple containing the prepared mask and associated tensors.
845
+ """
846
+ combine_label_or = self.sample_one_mask(anatomy_size=anatomy_size_condtion)
847
+ # TODO: current mask generation model only can generate 256^3 volumes with 1.5 mm spacing.
848
+ affine = torch.zeros((4, 4))
849
+ affine[0, 0] = 1.5
850
+ affine[1, 1] = 1.5
851
+ affine[2, 2] = 1.5
852
+ affine[3, 3] = 1.0 # dummy
853
+ combine_label_or = MetaTensor(combine_label_or, affine=affine)
854
+ combine_label_or = self.ensure_output_size_and_spacing(combine_label_or)
855
+
856
+ spacing_tensor = torch.FloatTensor(self.spacing).unsqueeze(0).half().to(self.device) * 1e2
857
+
858
+ return combine_label_or, spacing_tensor
859
+
860
+ def sample_one_mask(self, anatomy_size):
861
+ """
862
+ Generate a single synthetic mask.
863
+
864
+ Args:
865
+ anatomy_size (list): Anatomy size specifications.
866
+
867
+ Returns:
868
+ torch.Tensor: The generated synthetic mask.
869
+ """
870
+ # generate one synthetic mask
871
+ synthetic_mask = ldm_conditional_sample_one_mask(
872
+ self.mask_generation_autoencoder,
873
+ self.mask_generation_diffusion_unet,
874
+ self.mask_generation_noise_scheduler,
875
+ self.mask_generation_scale_factor,
876
+ anatomy_size,
877
+ self.device,
878
+ self.mask_generation_latent_shape,
879
+ label_dict_remap_json=self.label_dict_remap_json,
880
+ num_inference_steps=self.mask_generation_num_inference_steps,
881
+ autoencoder_sliding_window_infer_size=self.autoencoder_sliding_window_infer_size,
882
+ autoencoder_sliding_window_infer_overlap=self.autoencoder_sliding_window_infer_overlap,
883
+ )
884
+ return synthetic_mask
885
+
886
+ def ensure_output_size_and_spacing(self, labels, check_contains_target_labels=True):
887
+ """
888
+ Ensure the output mask has the correct size and spacing.
889
+
890
+ Args:
891
+ labels (torch.Tensor): Input label tensor.
892
+ check_contains_target_labels (bool): Whether to check if the resampled mask contains target labels.
893
+
894
+ Returns:
895
+ torch.Tensor: Resampled label tensor.
896
+
897
+ Raises:
898
+ ValueError: If the resampled mask doesn't contain required class labels.
899
+ """
900
+ current_spacing = [labels.affine[0, 0], labels.affine[1, 1], labels.affine[2, 2]]
901
+ current_shape = list(labels.squeeze().shape)
902
+
903
+ need_resample = False
904
+ # check spacing
905
+ for i, j in zip(current_spacing, self.spacing):
906
+ if i != j:
907
+ need_resample = True
908
+ # check output size
909
+ for i, j in zip(current_shape, self.output_size):
910
+ if i != j:
911
+ need_resample = True
912
+ # resample to target size and spacing
913
+ if need_resample:
914
+ logging.info("Resampling mask to target shape and spacing")
915
+ logging.info(f"Resize Spacing: {current_spacing} -> {self.spacing}")
916
+ logging.info(f"Output size: {current_shape} -> {self.output_size}")
917
+ spacing = monai.transforms.Spacing(pixdim=tuple(self.spacing), mode="nearest")
918
+ pad_crop = monai.transforms.ResizeWithPadOrCrop(spatial_size=tuple(self.output_size))
919
+ labels = pad_crop(spacing(labels.squeeze(0))).unsqueeze(0).to(labels.dtype)
920
+
921
+ contained_labels = torch.unique(labels)
922
+ if check_contains_target_labels:
923
+ # check if the resampled mask still contains those target labels
924
+ for anatomy_label in self.anatomy_list:
925
+ if anatomy_label not in contained_labels:
926
+ raise ValueError(
927
+ (
928
+ f"Resampled mask does not contain required class labels {anatomy_label}. "
929
+ "Please consider increasing the output spacing or specifying a larger output size."
930
+ )
931
+ )
932
+ return labels
933
+
934
+ def read_mask_information(self, mask_file):
935
+ """
936
+ Read mask information from a file.
937
+
938
+ Args:
939
+ mask_file (str): Path to the mask file.
940
+
941
+ Returns:
942
+ tuple: A tuple containing the mask tensor and associated information.
943
+ """
944
+ val_data = self.val_transforms(mask_file)
945
+
946
+ for key in ["pseudo_label", "spacing"]:
947
+ val_data[key] = val_data[key].unsqueeze(0).to(self.device)
948
+
949
+ return (val_data["pseudo_label"], val_data["spacing"])
950
+
951
+ def find_closest_masks(self, num_img):
952
+ """
953
+ Find the closest matching masks from the database.
954
+
955
+ Args:
956
+ num_img (int): Number of images to generate.
957
+
958
+ Returns:
959
+ list: List of closest matching mask candidates.
960
+
961
+ Raises:
962
+ ValueError: If suitable candidates cannot be found.
963
+ """
964
+ # first check the database based on anatomy list
965
+ candidates = find_masks(
966
+ self.anatomy_list, self.spacing, self.output_size, False, self.all_mask_files_json, self.data_root
967
+ )
968
+
969
+ if len(candidates) < num_img:
970
+ raise ValueError(f"candidate masks are less than {num_img}).")
971
+
972
+ # loop through the database and find closest combinations
973
+ new_candidates = []
974
+ for c in candidates:
975
+ diff = 0
976
+ include_c = True
977
+ for axis in range(3):
978
+ if abs(c["dim"][axis]) < self.output_size[axis] - 64:
979
+ # we cannot upsample the mask too much
980
+ include_c = False
981
+ break
982
+ # check diff in FOV, major metric
983
+ diff += abs(
984
+ (abs(c["dim"][axis] * c["spacing"][axis]) - self.output_size[axis] * self.spacing[axis]) / 10
985
+ )
986
+ # check diff in dim
987
+ diff += abs((abs(c["dim"][axis]) - self.output_size[axis]) / 100)
988
+ # check diff in spacing
989
+ diff += abs(abs(c["spacing"][axis]) - self.spacing[axis])
990
+ if include_c:
991
+ new_candidates.append((c, diff))
992
+
993
+ # choose top-2*num_img candidates (at least 5)
994
+ new_candidates = sorted(new_candidates, key=lambda x: x[1])[: max(2 * num_img, 5)]
995
+ final_candidates = []
996
+
997
+ # check top-2*num_img candidates and update spacing after resampling
998
+ image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True)
999
+ for c, _ in new_candidates:
1000
+ label = image_loader(c["pseudo_label"])
1001
+ try:
1002
+ label = self.ensure_output_size_and_spacing(label.unsqueeze(0))
1003
+ except ValueError as e:
1004
+ if "Resampled mask does not contain required class labels" in str(e):
1005
+ continue
1006
+ else:
1007
+ raise e
1008
+ # get region_index after resample
1009
+ c["spacing"] = self.spacing
1010
+ c["dim"] = self.output_size
1011
+
1012
+ final_candidates.append(c)
1013
+ if len(final_candidates) == 0:
1014
+ raise ValueError("Cannot find body region with given anatomy list.")
1015
+ return final_candidates
1016
+
1017
+ def quality_check(self, image_data, label_data):
1018
+ """
1019
+ Perform a quality check on the generated image.
1020
+ Args:
1021
+ image_data (np.ndarray): The generated image.
1022
+ label_data (np.ndarray): The corresponding whole body mask.
1023
+ Returns:
1024
+ bool: True if the image passes the quality check, False otherwise.
1025
+ """
1026
+ outlier_results = is_outlier(self.median_statistics, image_data, label_data, self.label_int_dict)
1027
+ for label, result in outlier_results.items():
1028
+ if result.get("is_outlier", False):
1029
+ logging.info(
1030
+ (
1031
+ f"Generated image quality check for label '{label}' failed: median value {result['median_value']} "
1032
+ f"is outside the acceptable range ({result['low_thresh']} - {result['high_thresh']})."
1033
+ )
1034
+ )
1035
+ return False
1036
+ return True
scripts/trainer.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from monai.engines.trainer import Trainer
19
+ from monai.engines.utils import IterationEvents, PrepareBatchExtraInput, default_metric_cmp_fn
20
+ from monai.inferers import Inferer
21
+ from monai.networks.schedulers import Scheduler
22
+ from monai.transforms import Transform
23
+ from monai.utils import IgniteInfo, RankFilter, min_version, optional_import
24
+ from monai.utils.enums import CommonKeys as Keys
25
+ from torch.optim.optimizer import Optimizer
26
+ from torch.utils.data import DataLoader
27
+
28
+ from .utils import binarize_labels
29
+
30
+ if TYPE_CHECKING:
31
+ from ignite.engine import Engine, EventEnum
32
+ from ignite.metrics import Metric
33
+ else:
34
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
35
+ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
36
+ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
37
+
38
+ __all__ = ["MAISIControlNetTrainer"]
39
+
40
+ # Module-level variable for prepare_batch default value
41
+ DEFAULT_PREPARE_BATCH = PrepareBatchExtraInput(extra_keys=("dim", "spacing", "top_region_index", "bottom_region_index"))
42
+
43
+
44
+ class MAISIControlNetTrainer(Trainer):
45
+ """
46
+ Supervised training method with image and label, inherits from ``Trainer`` and ``Workflow``.
47
+ Args:
48
+ device: an object representing the device on which to run.
49
+ max_epochs: the total epoch number for trainer to run.
50
+ train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
51
+ controlnet: controlnet to train in the trainer, should be regular PyTorch `torch.nn.Module`.
52
+ diffusion_unet: diffusion_unet used in the trainer, should be regular PyTorch `torch.nn.Module`.
53
+ optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim`
54
+ or its subclass.
55
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
56
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
57
+ with respect to the host. For other cases, this argument has no effect.
58
+ prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
59
+ from `engine.state.batch` for every iteration, for more details please refer to:
60
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
61
+ iteration_update: the callable function for every iteration, expect to accept `engine`
62
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
63
+ if not provided, use `self._iteration()` instead. for more details please refer to:
64
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
65
+ inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
66
+ postprocessing: execute additional transformation for the model output data.
67
+ Typically, several Tensor based transforms composed by `Compose`.
68
+ key_train_metric: compute metric when every iteration completed, and save average value to
69
+ engine.state.metrics when epoch completlabel_set = np.arange(output_classes).tolist()d.
70
+ key_train_metric is the main metric to compare and save the checkpoint into files.
71
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
72
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
73
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
74
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
75
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
76
+ CheckpointHandler, StatsHandler, etc.
77
+ amp: whether to enable auto-mixed-precision training, default is False.
78
+ event_names: additional custom ignite events that will register to the engine.
79
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
80
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
81
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
82
+ #ignite.engine.engine.Engine.register_events.
83
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
84
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
85
+ default to `True`.
86
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
87
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
88
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
89
+ `device`, `non_blocking`.
90
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
91
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ device: torch.device,
97
+ max_epochs: int,
98
+ train_data_loader: Iterable | DataLoader,
99
+ controlnet: torch.nn.Module,
100
+ diffusion_unet: torch.nn.Module,
101
+ optimizer: Optimizer,
102
+ loss_function: Callable,
103
+ inferer: Inferer,
104
+ noise_scheduler: Scheduler,
105
+ epoch_length: int | None = None,
106
+ non_blocking: bool = False,
107
+ prepare_batch: Callable = DEFAULT_PREPARE_BATCH,
108
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
109
+ postprocessing: Transform | None = None,
110
+ key_train_metric: dict[str, Metric] | None = None,
111
+ additional_metrics: dict[str, Metric] | None = None,
112
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
113
+ train_handlers: Sequence | None = None,
114
+ amp: bool = False,
115
+ event_names: list[str | EventEnum] | None = None,
116
+ event_to_attr: dict | None = None,
117
+ decollate: bool = True,
118
+ optim_set_to_none: bool = False,
119
+ to_kwargs: dict | None = None,
120
+ amp_kwargs: dict | None = None,
121
+ hyper_kwargs: dict | None = None,
122
+ ) -> None:
123
+ super().__init__(
124
+ device=device,
125
+ max_epochs=max_epochs,
126
+ data_loader=train_data_loader,
127
+ epoch_length=epoch_length,
128
+ non_blocking=non_blocking,
129
+ prepare_batch=prepare_batch,
130
+ iteration_update=iteration_update,
131
+ postprocessing=postprocessing,
132
+ key_metric=key_train_metric,
133
+ additional_metrics=additional_metrics,
134
+ metric_cmp_fn=metric_cmp_fn,
135
+ handlers=train_handlers,
136
+ amp=amp,
137
+ event_names=event_names,
138
+ event_to_attr=event_to_attr,
139
+ decollate=decollate,
140
+ to_kwargs=to_kwargs,
141
+ amp_kwargs=amp_kwargs,
142
+ )
143
+
144
+ self.controlnet = controlnet
145
+ self.diffusion_unet = diffusion_unet
146
+ self.optimizer = optimizer
147
+ self.loss_function = loss_function
148
+ self.inferer = inferer
149
+ self.optim_set_to_none = optim_set_to_none
150
+ self.hyper_kwargs = hyper_kwargs
151
+ self.noise_scheduler = noise_scheduler
152
+ self.logger.addFilter(RankFilter())
153
+ for p in self.diffusion_unet.parameters():
154
+ p.requires_grad = False
155
+ print("freeze the parameters of the diffusion unet model.")
156
+
157
+ def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
158
+ """
159
+ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
160
+ Return below items in a dictionary:
161
+ - IMAGE: image Tensor data for model input, already moved to device.
162
+ Args:
163
+ engine: `Vista3DTrainer` to execute operation for an iteration.
164
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
165
+ Raises:
166
+ ValueError: When ``batchdata`` is None.
167
+ """
168
+
169
+ if batchdata is None:
170
+ raise ValueError("Must provide batch data for current iteration.")
171
+
172
+ inputs, labels, (dim, spacing, top_region_index, bottom_region_index), _ = engine.prepare_batch(
173
+ batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs
174
+ )
175
+ engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels}
176
+ weighted_loss_label = engine.hyper_kwargs["weighted_loss_label"]
177
+ weighted_loss = engine.hyper_kwargs["weighted_loss"]
178
+ scale_factor = engine.hyper_kwargs["scale_factor"]
179
+ # scale image embedding by the provided scale_factor
180
+ inputs = inputs * scale_factor
181
+
182
+ def _compute_pred_loss():
183
+ # generate random noise
184
+ noise_shape = list(inputs.shape)
185
+ noise = torch.randn(noise_shape, dtype=inputs.dtype).to(inputs.device)
186
+
187
+ # use binary encoding to encode segmentation mask
188
+ controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float()
189
+
190
+ # create timesteps
191
+ timesteps = torch.randint(
192
+ 0, engine.noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=inputs.device
193
+ ).long()
194
+
195
+ # Create noisy latent
196
+ noisy_latent = engine.noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
197
+
198
+ # Get controlnet output
199
+ down_block_res_samples, mid_block_res_sample = engine.controlnet(
200
+ x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond
201
+ )
202
+ noise_pred = engine.diffusion_unet(
203
+ x=noisy_latent,
204
+ timesteps=timesteps,
205
+ top_region_index_tensor=top_region_index,
206
+ bottom_region_index_tensor=bottom_region_index,
207
+ spacing_tensor=spacing,
208
+ down_block_additional_residuals=down_block_res_samples,
209
+ mid_block_additional_residual=mid_block_res_sample,
210
+ )
211
+
212
+ engine.state.output[Keys.PRED] = noise_pred
213
+ engine.fire_event(IterationEvents.FORWARD_COMPLETED)
214
+
215
+ if weighted_loss > 1.0:
216
+ weights = torch.ones_like(inputs).to(inputs.device)
217
+ roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device)
218
+ interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest")
219
+ # assign larger weights for ROI (tumor)
220
+ for label in weighted_loss_label:
221
+ roi[interpolate_label == label] = 1
222
+ weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = weighted_loss
223
+ loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean()
224
+ else:
225
+ loss = F.l1_loss(noise_pred.float(), noise.float())
226
+
227
+ engine.state.output[Keys.LOSS] = loss
228
+ engine.fire_event(IterationEvents.LOSS_COMPLETED)
229
+
230
+ engine.controlnet.train()
231
+ engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
232
+
233
+ if engine.amp and engine.scaler is not None:
234
+ with torch.amp.autocast("cuda", **engine.amp_kwargs):
235
+ _compute_pred_loss()
236
+ engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
237
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
238
+ engine.scaler.step(engine.optimizer)
239
+ engine.scaler.update()
240
+ else:
241
+ _compute_pred_loss()
242
+ engine.state.output[Keys.LOSS].backward()
243
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
244
+ engine.optimizer.step()
245
+ engine.fire_event(IterationEvents.MODEL_COMPLETED)
246
+ return engine.state.output
scripts/utils.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ import copy
11
+ import json
12
+ import math
13
+ import os
14
+ import zipfile
15
+ from argparse import Namespace
16
+ from datetime import timedelta
17
+ from typing import Any, Sequence
18
+
19
+ import numpy as np
20
+ import skimage
21
+ import torch
22
+ import torch.distributed as dist
23
+ from monai.bundle import ConfigParser
24
+ from monai.config import DtypeLike, NdarrayOrTensor
25
+ from monai.data import CacheDataset, DataLoader, partition_dataset
26
+ from monai.transforms import Compose, EnsureTyped, Lambdad, LoadImaged, Orientationd
27
+ from monai.transforms.utils_morphological_ops import dilate, erode
28
+ from monai.utils import TransformBackends, convert_data_type, convert_to_dst_type, get_equivalent_dtype
29
+ from scipy import stats
30
+ from torch import Tensor
31
+
32
+
33
+ def unzip_dataset(dataset_dir):
34
+ if dist.is_available() and dist.is_initialized():
35
+ rank = dist.get_rank()
36
+ else:
37
+ rank = 0
38
+
39
+ if rank == 0:
40
+ if not os.path.exists(dataset_dir):
41
+ zip_file_path = dataset_dir + ".zip"
42
+ if not os.path.isfile(zip_file_path):
43
+ raise ValueError(f"Please download {zip_file_path}.")
44
+ with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
45
+ zip_ref.extractall(path=os.path.dirname(dataset_dir))
46
+ print(f"Unzipped {zip_file_path} to {dataset_dir}.")
47
+
48
+ if dist.is_available() and dist.is_initialized():
49
+ dist.barrier() # Synchronize all processes
50
+
51
+ return
52
+
53
+
54
+ def add_data_dir2path(list_files: list, data_dir: str, fold: int = None) -> tuple[list, list]:
55
+ """
56
+ Read a list of data dictionary.
57
+
58
+ Args:
59
+ list_files (list): input data to load and transform to generate dataset for model.
60
+ data_dir (str): directory of files.
61
+ fold (int, optional): fold index for cross validation. Defaults to None.
62
+
63
+ Returns:
64
+ tuple[list, list]: A tuple of two arrays (training, validation).
65
+ """
66
+ new_list_files = copy.deepcopy(list_files)
67
+ if fold is not None:
68
+ new_list_files_train = []
69
+ new_list_files_val = []
70
+ for d in new_list_files:
71
+ d["image"] = os.path.join(data_dir, d["image"])
72
+
73
+ if "label" in d:
74
+ d["label"] = os.path.join(data_dir, d["label"])
75
+
76
+ if fold is not None:
77
+ if d["fold"] == fold:
78
+ new_list_files_val.append(copy.deepcopy(d))
79
+ else:
80
+ new_list_files_train.append(copy.deepcopy(d))
81
+
82
+ if fold is not None:
83
+ return new_list_files_train, new_list_files_val
84
+ else:
85
+ return new_list_files, []
86
+
87
+
88
+ def maisi_datafold_read(json_list, data_base_dir, fold=None):
89
+ with open(json_list, "r") as f:
90
+ filenames_train = json.load(f)["training"]
91
+ # training data
92
+ train_files, val_files = add_data_dir2path(filenames_train, data_base_dir, fold=fold)
93
+ print(f"dataset: {data_base_dir}, num_training_files: {len(train_files)}, num_val_files: {len(val_files)}")
94
+ return train_files, val_files
95
+
96
+
97
+ def remap_labels(mask, label_dict_remap_json):
98
+ """
99
+ Remap labels in the mask according to the provided label dictionary.
100
+
101
+ This function reads a JSON file containing label mapping information and applies
102
+ the mapping to the input mask.
103
+
104
+ Args:
105
+ mask (Tensor): The input mask tensor to be remapped.
106
+ label_dict_remap_json (str): Path to the JSON file containing the label mapping dictionary.
107
+
108
+ Returns:
109
+ Tensor: The remapped mask tensor.
110
+ """
111
+ with open(label_dict_remap_json, "r") as f:
112
+ mapping_dict = json.load(f)
113
+ mapper = MapLabelValue(
114
+ orig_labels=[pair[0] for pair in mapping_dict.values()],
115
+ target_labels=[pair[1] for pair in mapping_dict.values()],
116
+ dtype=torch.uint8,
117
+ )
118
+ return mapper(mask[0, ...])[None, ...].to(mask.device)
119
+
120
+
121
+ def get_index_arr(img):
122
+ """
123
+ Generate an index array for the given image.
124
+
125
+ This function creates a 3D array of indices corresponding to the dimensions of the input image.
126
+
127
+ Args:
128
+ img (ndarray): The input image array.
129
+
130
+ Returns:
131
+ ndarray: A 3D array containing the indices for each dimension of the input image.
132
+ """
133
+ return np.moveaxis(
134
+ np.moveaxis(
135
+ np.stack(np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), np.arange(img.shape[2]))), 0, 3
136
+ ),
137
+ 0,
138
+ 1,
139
+ )
140
+
141
+
142
+ def supress_non_largest_components(img, target_label, default_val=0):
143
+ """
144
+ Suppress all components except the largest one(s) for specified target labels.
145
+
146
+ This function identifies the largest component(s) for each target label and
147
+ suppresses all other smaller components.
148
+
149
+ Args:
150
+ img (ndarray): The input image array.
151
+ target_label (list): List of label values to process.
152
+ default_val (int, optional): Value to assign to suppressed voxels. Defaults to 0.
153
+
154
+ Returns:
155
+ tuple: A tuple containing:
156
+ - ndarray: Modified image with non-largest components suppressed.
157
+ - int: Number of voxels that were changed.
158
+ """
159
+ index_arr = get_index_arr(img)
160
+ img_mod = copy.deepcopy(img)
161
+ new_background = np.zeros(img.shape, dtype=np.bool_)
162
+ for label in target_label:
163
+ label_cc = skimage.measure.label(img == label, connectivity=3)
164
+ uv, uc = np.unique(label_cc, return_counts=True)
165
+ dominant_vals = uv[np.argsort(uc)[::-1][:2]]
166
+ if len(dominant_vals) >= 2: # Case: no predictions
167
+ new_background = np.logical_or(
168
+ new_background,
169
+ np.logical_not(np.logical_or(label_cc == dominant_vals[0], label_cc == dominant_vals[1])),
170
+ )
171
+
172
+ for voxel in index_arr[new_background]:
173
+ img_mod[tuple(voxel)] = default_val
174
+ diff = np.sum((img - img_mod) > 0)
175
+
176
+ return img_mod, diff
177
+
178
+
179
+ def erode_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor:
180
+ """
181
+ Erode 2D/3D binary mask with data type as torch tensor.
182
+
183
+ Args:
184
+ mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor.
185
+ filter_size: erosion filter size, has to be odd numbers, default to be 3.
186
+ pad_value: the filled value for padding. We need to pad the input before filtering
187
+ to keep the output with the same size as input. Usually use default value
188
+ and not changed.
189
+
190
+ Return:
191
+ Tensor: eroded mask, same shape as input.
192
+ """
193
+ return erode(mask_t.float().unsqueeze(0).unsqueeze(0), filter_size, pad_value=pad_value).squeeze(0).squeeze(0)
194
+
195
+
196
+ def dilate_one_img(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor:
197
+ """
198
+ Dilate 2D/3D binary mask with data type as torch tensor.
199
+
200
+ Args:
201
+ mask_t: input 2D/3D binary mask, [M,N] or [M,N,P] torch tensor.
202
+ filter_size: dilation filter size, has to be odd numbers, default to be 3.
203
+ pad_value: the filled value for padding. We need to pad the input before filtering
204
+ to keep the output with the same size as input. Usually use default value
205
+ and not changed.
206
+
207
+ Return:
208
+ Tensor: dilated mask, same shape as input.
209
+ """
210
+ return dilate(mask_t.float().unsqueeze(0).unsqueeze(0), filter_size, pad_value=pad_value).squeeze(0).squeeze(0)
211
+
212
+
213
+ def binarize_labels(x: Tensor, bits: int = 8) -> Tensor:
214
+ """
215
+ Convert input tensor to binary representation.
216
+
217
+ This function takes an input tensor and converts it to a binary representation
218
+ using the specified number of bits.
219
+
220
+ Args:
221
+ x (Tensor): Input tensor with shape (B, 1, H, W, D).
222
+ bits (int, optional): Number of bits to use for binary representation. Defaults to 8.
223
+
224
+ Returns:
225
+ Tensor: Binary representation of the input tensor with shape (B, bits, H, W, D).
226
+ """
227
+ mask = 2 ** torch.arange(bits).to(x.device, x.dtype)
228
+ return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte().squeeze(1).permute(0, 4, 1, 2, 3)
229
+
230
+
231
+ def setup_ddp(rank: int, world_size: int) -> torch.device:
232
+ """
233
+ Initialize the distributed process group.
234
+
235
+ Args:
236
+ rank (int): rank of the current process.
237
+ world_size (int): number of processes participating in the job.
238
+
239
+ Returns:
240
+ torch.device: device of the current process.
241
+ """
242
+ dist.init_process_group(
243
+ backend="nccl", init_method="env://", timeout=timedelta(seconds=36000), rank=rank, world_size=world_size
244
+ )
245
+ dist.barrier()
246
+ device = torch.device(f"cuda:{rank}")
247
+ return device
248
+
249
+
250
+ def define_instance(args: Namespace, instance_def_key: str) -> Any:
251
+ """
252
+ Define and instantiate an object based on the provided arguments and instance definition key.
253
+
254
+ This function uses a ConfigParser to parse the arguments and instantiate an object
255
+ defined by the instance_def_key.
256
+
257
+ Args:
258
+ args: An object containing the arguments to be parsed.
259
+ instance_def_key (str): The key used to retrieve the instance definition from the parsed content.
260
+
261
+ Returns:
262
+ The instantiated object as defined by the instance_def_key in the parsed configuration.
263
+ """
264
+ parser = ConfigParser(vars(args))
265
+ parser.parse(True)
266
+ return parser.get_parsed_content(instance_def_key, instantiate=True)
267
+
268
+
269
+ def prepare_maisi_controlnet_json_dataloader(
270
+ json_data_list: list | str,
271
+ data_base_dir: list | str,
272
+ batch_size: int = 1,
273
+ fold: int = 0,
274
+ cache_rate: float = 0.0,
275
+ rank: int = 0,
276
+ world_size: int = 1,
277
+ ) -> tuple[DataLoader, DataLoader]:
278
+ """
279
+ Prepare dataloaders for training and validation.
280
+
281
+ Args:
282
+ json_data_list (list | str): the name of JSON files listing the data.
283
+ data_base_dir (list | str): directory of files.
284
+ batch_size (int, optional): how many samples per batch to load . Defaults to 1.
285
+ fold (int, optional): fold index for cross validation. Defaults to 0.
286
+ cache_rate (float, optional): percentage of cached data in total. Defaults to 0.0.
287
+ rank (int, optional): rank of the current process. Defaults to 0.
288
+ world_size (int, optional): number of processes participating in the job. Defaults to 1.
289
+
290
+ Returns:
291
+ tuple[DataLoader, DataLoader]: A tuple of two dataloaders (training, validation).
292
+ """
293
+ use_ddp = world_size > 1
294
+ if isinstance(json_data_list, list):
295
+ assert isinstance(data_base_dir, list)
296
+ list_train = []
297
+ list_valid = []
298
+ for data_list, data_root in zip(json_data_list, data_base_dir):
299
+ with open(data_list, "r") as f:
300
+ json_data = json.load(f)["training"]
301
+ train, val = add_data_dir2path(json_data, data_root, fold)
302
+ list_train += train
303
+ list_valid += val
304
+ else:
305
+ with open(json_data_list, "r") as f:
306
+ json_data = json.load(f)["training"]
307
+ list_train, list_valid = add_data_dir2path(json_data, data_base_dir, fold)
308
+
309
+ common_transform = [
310
+ LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True),
311
+ Orientationd(keys=["label"], axcodes="RAS"),
312
+ EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True),
313
+ Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)),
314
+ Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)),
315
+ Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)),
316
+ Lambdad(keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2),
317
+ ]
318
+ train_transforms, val_transforms = Compose(common_transform), Compose(common_transform)
319
+
320
+ train_loader = None
321
+
322
+ if use_ddp:
323
+ list_train = partition_dataset(data=list_train, shuffle=True, num_partitions=world_size, even_divisible=True)[
324
+ rank
325
+ ]
326
+ train_ds = CacheDataset(data=list_train, transform=train_transforms, cache_rate=cache_rate, num_workers=8)
327
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
328
+ if use_ddp:
329
+ list_valid = partition_dataset(data=list_valid, shuffle=True, num_partitions=world_size, even_divisible=False)[
330
+ rank
331
+ ]
332
+ val_ds = CacheDataset(data=list_valid, transform=val_transforms, cache_rate=cache_rate, num_workers=8)
333
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False)
334
+ return train_loader, val_loader
335
+
336
+
337
+ def organ_fill_by_closing(data, target_label, device, close_times=2, filter_size=3, pad_value=0.0):
338
+ """
339
+ Fill holes in an organ mask using morphological closing operations.
340
+
341
+ This function performs a series of dilation and erosion operations to fill holes
342
+ in the organ mask identified by the target label.
343
+
344
+ Args:
345
+ data (ndarray): The input data containing organ labels.
346
+ target_label (int): The label of the organ to be processed.
347
+ device (str): The device to perform the operations on (e.g., 'cuda:0').
348
+ close_times (int, optional): Number of times to perform the closing operation. Defaults to 2.
349
+ filter_size (int, optional): Size of the filter for dilation and erosion. Defaults to 3.
350
+ pad_value (float, optional): Value used for padding in dilation and erosion. Defaults to 0.0.
351
+
352
+ Returns:
353
+ ndarray: Boolean mask of the filled organ.
354
+ """
355
+ mask = (data == target_label).astype(np.uint8)
356
+ mask = torch.from_numpy(mask).to(device)
357
+ for _ in range(close_times):
358
+ mask = dilate_one_img(mask, filter_size=filter_size, pad_value=pad_value)
359
+ mask = erode_one_img(mask, filter_size=filter_size, pad_value=pad_value)
360
+ return mask.cpu().numpy().astype(np.bool_)
361
+
362
+
363
+ def organ_fill_by_removed_mask(data, target_label, remove_mask, device):
364
+ """
365
+ Fill an organ mask in regions where it was previously removed.
366
+
367
+ Args:
368
+ data (ndarray): The input data containing organ labels.
369
+ target_label (int): The label of the organ to be processed.
370
+ remove_mask (ndarray): Boolean mask indicating regions where the organ was removed.
371
+ device (str): The device to perform the operations on (e.g., 'cuda:0').
372
+
373
+ Returns:
374
+ ndarray: Boolean mask of the filled organ in previously removed regions.
375
+ """
376
+ mask = (data == target_label).astype(np.uint8)
377
+ mask = dilate_one_img(torch.from_numpy(mask).to(device), filter_size=3, pad_value=0.0)
378
+ mask = dilate_one_img(mask, filter_size=3, pad_value=0.0)
379
+ roi_oragn_mask = dilate_one_img(mask, filter_size=3, pad_value=0.0).cpu().numpy()
380
+ return (roi_oragn_mask * remove_mask).astype(np.bool_)
381
+
382
+
383
+ def get_body_region_index_from_mask(input_mask):
384
+ """
385
+ Determine the top and bottom body region indices from an input mask.
386
+
387
+ Args:
388
+ input_mask (Tensor): Input mask tensor containing body region labels.
389
+
390
+ Returns:
391
+ tuple: Two lists representing the top and bottom region indices.
392
+ """
393
+ region_indices = {}
394
+ # head and neck
395
+ region_indices["region_0"] = [22, 120]
396
+ # thorax
397
+ region_indices["region_1"] = [28, 29, 30, 31, 32]
398
+ # abdomen
399
+ region_indices["region_2"] = [1, 2, 3, 4, 5, 14]
400
+ # pelvis and lower
401
+ region_indices["region_3"] = [93, 94]
402
+
403
+ nda = input_mask.cpu().numpy().squeeze()
404
+ unique_elements = np.lib.arraysetops.unique(nda)
405
+ unique_elements = list(unique_elements)
406
+ # print(f"nda: {nda.shape} {unique_elements}.")
407
+ overlap_array = np.zeros(len(region_indices), dtype=np.uint8)
408
+ for _j in range(len(region_indices)):
409
+ overlap = any(element in region_indices[f"region_{_j}"] for element in unique_elements)
410
+ overlap_array[_j] = np.uint8(overlap)
411
+ overlap_array_indices = np.nonzero(overlap_array)[0]
412
+ top_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amin(overlap_array_indices), ...]
413
+ top_region_index = list(top_region_index)
414
+ top_region_index = [int(_k) for _k in top_region_index]
415
+ bottom_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amax(overlap_array_indices), ...]
416
+ bottom_region_index = list(bottom_region_index)
417
+ bottom_region_index = [int(_k) for _k in bottom_region_index]
418
+ # print(f"{top_region_index} {bottom_region_index}")
419
+ return top_region_index, bottom_region_index
420
+
421
+
422
+ def general_mask_generation_post_process(volume_t, target_tumor_label=None, device="cuda:0"):
423
+ """
424
+ Perform post-processing on a generated mask volume.
425
+
426
+ This function applies various refinement steps to improve the quality of the generated mask,
427
+ including body mask refinement, tumor prediction refinement, and organ-specific processing.
428
+
429
+ Args:
430
+ volume_t (ndarray): Input volume containing organ and tumor labels.
431
+ target_tumor_label (int, optional): Label of the target tumor. Defaults to None.
432
+ device (str, optional): Device to perform operations on. Defaults to "cuda:0".
433
+
434
+ Returns:
435
+ ndarray: Post-processed volume with refined organ and tumor labels.
436
+ """
437
+ # assume volume_t is np array with shape (H,W,D)
438
+ hepatic_vessel = volume_t == 25
439
+ airway = volume_t == 132
440
+
441
+ # ------------ refine body mask pred
442
+ body_region_mask = (
443
+ erode_one_img(torch.from_numpy((volume_t > 0)).to(device), filter_size=3, pad_value=0.0).cpu().numpy()
444
+ )
445
+ body_region_mask, _ = supress_non_largest_components(body_region_mask, [1])
446
+ body_region_mask = (
447
+ dilate_one_img(torch.from_numpy(body_region_mask).to(device), filter_size=3, pad_value=0.0)
448
+ .cpu()
449
+ .numpy()
450
+ .astype(np.uint8)
451
+ )
452
+ volume_t = volume_t * body_region_mask
453
+
454
+ # ------------ refine tumor pred
455
+ tumor_organ_dict = {23: 28, 24: 4, 26: 1, 27: 62, 128: 200}
456
+ for t in [23, 24, 26, 27, 128]:
457
+ if t != target_tumor_label:
458
+ volume_t[volume_t == t] = tumor_organ_dict[t]
459
+ else:
460
+ volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t
461
+ volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t
462
+ # we only keep the largest connected componet for tumors except hepatic tumor and bone lesion
463
+ if target_tumor_label != 26 and target_tumor_label != 128:
464
+ volume_t, _ = supress_non_largest_components(volume_t, [target_tumor_label], default_val=200)
465
+ target_tumor = volume_t == target_tumor_label
466
+
467
+ # ------------ remove undesired organ pred
468
+ # general post-process non-largest components suppression
469
+ # process 4 ROI organs + spleen + 2 kidney + 5 lung lobes + duodenum + inferior vena cava
470
+ oran_list = [1, 4, 10, 12, 3, 28, 29, 30, 31, 32, 5, 14, 13, 6, 7, 8, 9, 10]
471
+ if target_tumor_label != 128:
472
+ oran_list += list(range(33, 60)) # + list(range(63,87))
473
+ data, _ = supress_non_largest_components(volume_t, oran_list, default_val=200) # 200 is body region
474
+ organ_remove_mask = (volume_t - data).astype(np.bool_)
475
+ # process intestinal system (stomach 12, duodenum 13, small bowel 19, colon 62)
476
+ intestinal_mask_ = (
477
+ (data == 12).astype(np.uint8)
478
+ + (data == 13).astype(np.uint8)
479
+ + (data == 19).astype(np.uint8)
480
+ + (data == 62).astype(np.uint8)
481
+ )
482
+ intestinal_mask, _ = supress_non_largest_components(intestinal_mask_, [1], default_val=0)
483
+ # process small bowel 19
484
+ small_bowel_remove_mask = (data == 19).astype(np.uint8) - (data == 19).astype(np.uint8) * intestinal_mask
485
+ # process colon 62
486
+ colon_remove_mask = (data == 62).astype(np.uint8) - (data == 62).astype(np.uint8) * intestinal_mask
487
+ intestinal_remove_mask = (small_bowel_remove_mask + colon_remove_mask).astype(np.bool_)
488
+ data[intestinal_remove_mask] = 200
489
+
490
+ # ------------ full correponding organ in removed regions
491
+ for organ_label in oran_list:
492
+ data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label
493
+
494
+ if target_tumor_label == 23 and np.sum(target_tumor) > 0:
495
+ # speical process for cases with lung tumor
496
+ dia_lung_tumor_mask = (
497
+ dilate_one_img(torch.from_numpy((data == 23)).to(device), filter_size=3, pad_value=0.0).cpu().numpy()
498
+ )
499
+ tmp = (
500
+ (data * (dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8))).astype(np.float32).flatten()
501
+ )
502
+ tmp[tmp == 0] = float("nan")
503
+ mode = int(stats.mode(tmp.flatten(), nan_policy="omit")[0])
504
+ if mode in [28, 29, 30, 31, 32]:
505
+ dia_lung_tumor_mask = (
506
+ dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0)
507
+ .cpu()
508
+ .numpy()
509
+ )
510
+ lung_remove_mask = dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8).astype(np.uint8)
511
+ data[organ_fill_by_removed_mask(data, target_label=mode, remove_mask=lung_remove_mask, device=device)] = (
512
+ mode
513
+ )
514
+ dia_lung_tumor_mask = (
515
+ dilate_one_img(torch.from_numpy(dia_lung_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy()
516
+ )
517
+ data[
518
+ organ_fill_by_removed_mask(
519
+ data, target_label=23, remove_mask=dia_lung_tumor_mask * organ_remove_mask, device=device
520
+ )
521
+ ] = 23
522
+ for organ_label in [28, 29, 30, 31, 32]:
523
+ data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label
524
+ data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label
525
+ data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label
526
+
527
+ if target_tumor_label == 26 and np.sum(target_tumor) > 0:
528
+ # speical process for cases with hepatic tumor
529
+ # process liver 1
530
+ data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1
531
+ data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1
532
+ # process spleen 2
533
+ data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3
534
+ data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3
535
+ dia_tumor_mask = (
536
+ dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0)
537
+ .cpu()
538
+ .numpy()
539
+ )
540
+ dia_tumor_mask = (
541
+ dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy()
542
+ )
543
+ data[
544
+ organ_fill_by_removed_mask(
545
+ data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device
546
+ )
547
+ ] = target_tumor_label
548
+ # refine hepatic tumor
549
+ hepatic_tumor_vessel_liver_mask_ = (
550
+ (data == 26).astype(np.uint8) + (data == 25).astype(np.uint8) + (data == 1).astype(np.uint8)
551
+ )
552
+ hepatic_tumor_vessel_liver_mask_ = (hepatic_tumor_vessel_liver_mask_ > 1).astype(np.uint8)
553
+ hepatic_tumor_vessel_liver_mask, _ = supress_non_largest_components(
554
+ hepatic_tumor_vessel_liver_mask_, [1], default_val=0
555
+ )
556
+ removed_region = (hepatic_tumor_vessel_liver_mask_ - hepatic_tumor_vessel_liver_mask).astype(np.bool_)
557
+ data[removed_region] = 200
558
+ target_tumor = (target_tumor * hepatic_tumor_vessel_liver_mask).astype(np.bool_)
559
+ # refine liver
560
+ data[organ_fill_by_closing(data, target_label=1, device=device)] = 1
561
+ data[organ_fill_by_closing(data, target_label=1, device=device)] = 1
562
+ data[organ_fill_by_closing(data, target_label=1, device=device)] = 1
563
+
564
+ if target_tumor_label == 27 and np.sum(target_tumor) > 0:
565
+ # speical process for cases with colon tumor
566
+ dia_tumor_mask = (
567
+ dilate_one_img(torch.from_numpy((data == target_tumor_label)).to(device), filter_size=3, pad_value=0.0)
568
+ .cpu()
569
+ .numpy()
570
+ )
571
+ dia_tumor_mask = (
572
+ dilate_one_img(torch.from_numpy(dia_tumor_mask).to(device), filter_size=3, pad_value=0.0).cpu().numpy()
573
+ )
574
+ data[
575
+ organ_fill_by_removed_mask(
576
+ data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device
577
+ )
578
+ ] = target_tumor_label
579
+
580
+ if target_tumor_label == 129 and np.sum(target_tumor) > 0:
581
+ # speical process for cases with kidney tumor
582
+ for organ_label in [5, 14]:
583
+ data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label
584
+ data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label
585
+ data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label
586
+ # TODO: current model does not support hepatic vessel by size control.
587
+ # we treat it as liver for better visiaulization
588
+ print(
589
+ "Current model does not support hepatic vessel by size control, "
590
+ "so we treat generated hepatic vessel as part of liver for better visiaulization."
591
+ )
592
+ data[hepatic_vessel] = 1
593
+ data[airway] = 132
594
+ if target_tumor_label is not None:
595
+ data[target_tumor] = target_tumor_label
596
+
597
+ return data
598
+
599
+
600
+ class MapLabelValue:
601
+ """
602
+ Utility to map label values to another set of values.
603
+ For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], ["label3", "label2", "label1"] -> [0, 1, 2],
604
+ [3.5, 2.5, 1.5] -> ["label0", "label1", "label2"], etc.
605
+ The label data must be numpy array or array-like data and the output data will be numpy array.
606
+
607
+ """
608
+
609
+ backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
610
+
611
+ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None:
612
+ """
613
+ Args:
614
+ orig_labels: original labels that map to others.
615
+ target_labels: expected label values, 1: 1 map to the `orig_labels`.
616
+ dtype: convert the output data to dtype, default to float32.
617
+ if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend.
618
+
619
+ """
620
+ if len(orig_labels) != len(target_labels):
621
+ raise ValueError("orig_labels and target_labels must have the same length.")
622
+
623
+ self.orig_labels = orig_labels
624
+ self.target_labels = target_labels
625
+ self.pair = tuple((o, t) for o, t in zip(self.orig_labels, self.target_labels) if o != t)
626
+ type_dtype = type(dtype)
627
+ if getattr(type_dtype, "__module__", "") == "torch":
628
+ self.use_numpy = False
629
+ self.dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor)
630
+ else:
631
+ self.use_numpy = True
632
+ self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray)
633
+
634
+ def __call__(self, img: NdarrayOrTensor):
635
+ """
636
+ Apply the label mapping to the input image.
637
+
638
+ Args:
639
+ img (NdarrayOrTensor): Input image to be remapped.
640
+
641
+ Returns:
642
+ NdarrayOrTensor: Remapped image.
643
+ """
644
+ if self.use_numpy:
645
+ img_np, *_ = convert_data_type(img, np.ndarray)
646
+ _out_shape = img_np.shape
647
+ img_flat = img_np.flatten()
648
+ try:
649
+ out_flat = img_flat.astype(self.dtype)
650
+ except ValueError:
651
+ # can't copy unchanged labels as the expected dtype is not supported, must map all the label values
652
+ out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype)
653
+ for o, t in self.pair:
654
+ out_flat[img_flat == o] = t
655
+ out_t = out_flat.reshape(_out_shape)
656
+ else:
657
+ img_t, *_ = convert_data_type(img, torch.Tensor)
658
+ out_t = img_t.detach().clone().to(self.dtype) # type: ignore
659
+ for o, t in self.pair:
660
+ out_t[img_t == o] = t
661
+ out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype)
662
+ return out
663
+
664
+
665
+ def dynamic_infer(inferer, model, images):
666
+ """
667
+ Perform dynamic inference using a model and an inferer, typically a monai SlidingWindowInferer.
668
+
669
+ This function determines whether to use the model directly or to use the provided inferer
670
+ (such as a sliding window inferer) based on the size of the input images.
671
+
672
+ Args:
673
+ inferer: An inference object, typically a monai SlidingWindowInferer, which handles patch-based inference.
674
+ model (torch.nn.Module): The model used for inference.
675
+ images (torch.Tensor): The input images for inference, shape [N,C,H,W,D] or [N,C,H,W].
676
+
677
+ Returns:
678
+ torch.Tensor: The output from the model or the inferer, depending on the input size.
679
+ """
680
+ if torch.numel(images[0:1, 0:1, ...]) < math.prod(inferer.roi_size):
681
+ return model(images)
682
+ else:
683
+ # Extract the spatial dimensions from the images tensor (H, W, D)
684
+ spatial_dims = images.shape[2:]
685
+ orig_roi = inferer.roi_size
686
+
687
+ # Check that roi has the same number of dimensions as spatial_dims
688
+ if len(orig_roi) != len(spatial_dims):
689
+ raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).")
690
+
691
+ # Iterate and adjust each ROI dimension
692
+ adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)]
693
+ inferer.roi_size = adjusted_roi
694
+ output = inferer(network=model, inputs=images)
695
+ inferer.roi_size = orig_roi
696
+ return output