project-monai commited on
Commit
e217a55
·
verified ·
1 Parent(s): ccbbae4

Upload brats_mri_generative_diffusion version 1.1.3

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
configs/inference.json ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import torch",
4
+ "$from datetime import datetime",
5
+ "$from pathlib import Path"
6
+ ],
7
+ "bundle_root": ".",
8
+ "model_dir": "$@bundle_root + '/models'",
9
+ "output_dir": "$@bundle_root + '/output'",
10
+ "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
11
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
12
+ "output_postfix": "$datetime.now().strftime('sample_%Y%m%d_%H%M%S')",
13
+ "spatial_dims": 3,
14
+ "image_channels": 1,
15
+ "latent_channels": 8,
16
+ "latent_shape": [
17
+ 8,
18
+ 36,
19
+ 44,
20
+ 28
21
+ ],
22
+ "autoencoder_def": {
23
+ "_target_": "monai.networks.nets.autoencoderkl.AutoencoderKL",
24
+ "spatial_dims": "@spatial_dims",
25
+ "in_channels": "@image_channels",
26
+ "out_channels": "@image_channels",
27
+ "latent_channels": "@latent_channels",
28
+ "channels": [
29
+ 64,
30
+ 128,
31
+ 256
32
+ ],
33
+ "num_res_blocks": 2,
34
+ "norm_num_groups": 32,
35
+ "norm_eps": 1e-06,
36
+ "attention_levels": [
37
+ false,
38
+ false,
39
+ false
40
+ ],
41
+ "with_encoder_nonlocal_attn": false,
42
+ "with_decoder_nonlocal_attn": false,
43
+ "include_fc": false
44
+ },
45
+ "network_def": {
46
+ "_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
47
+ "spatial_dims": "@spatial_dims",
48
+ "in_channels": "@latent_channels",
49
+ "out_channels": "@latent_channels",
50
+ "channels": [
51
+ 256,
52
+ 256,
53
+ 512
54
+ ],
55
+ "attention_levels": [
56
+ false,
57
+ true,
58
+ true
59
+ ],
60
+ "num_head_channels": [
61
+ 0,
62
+ 64,
63
+ 64
64
+ ],
65
+ "num_res_blocks": 2,
66
+ "include_fc": false,
67
+ "use_combined_linear": false
68
+ },
69
+ "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
70
+ "load_autoencoder": "$@autoencoder_def.load_old_state_dict(torch.load(@load_autoencoder_path))",
71
+ "autoencoder": "$@autoencoder_def.to(@device)",
72
+ "load_diffusion_path": "$@model_dir + '/model.pt'",
73
+ "load_diffusion": "$@network_def.load_old_state_dict(torch.load(@load_diffusion_path))",
74
+ "diffusion": "$@network_def.to(@device)",
75
+ "noise_scheduler": {
76
+ "_target_": "monai.networks.schedulers.ddim.DDIMScheduler",
77
+ "_requires_": [
78
+ "@load_diffusion",
79
+ "@load_autoencoder"
80
+ ],
81
+ "num_train_timesteps": 1000,
82
+ "beta_start": 0.0015,
83
+ "beta_end": 0.0195,
84
+ "schedule": "scaled_linear_beta",
85
+ "clip_sample": false
86
+ },
87
+ "noise": "$torch.randn([1]+@latent_shape).to(@device)",
88
+ "set_timesteps": "$@noise_scheduler.set_timesteps(num_inference_steps=50)",
89
+ "inferer": {
90
+ "_target_": "scripts.ldm_sampler.LDMSampler",
91
+ "_requires_": "@set_timesteps"
92
+ },
93
+ "saver": {
94
+ "_target_": "SaveImage",
95
+ "_requires_": "@create_output_dir",
96
+ "output_dir": "@output_dir",
97
+ "output_postfix": "@output_postfix"
98
+ },
99
+ "run": [
100
+ "[email protected](@noise, @autoencoder, @diffusion, @noise_scheduler, @saver)"
101
+ ]
102
+ }
configs/inference_autoencoder.json ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import torch",
4
+ "$from datetime import datetime",
5
+ "$from pathlib import Path"
6
+ ],
7
+ "bundle_root": ".",
8
+ "model_dir": "$@bundle_root + '/models'",
9
+ "dataset_dir": "/workspace/data/medical",
10
+ "output_dir": "$@bundle_root + '/output'",
11
+ "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
12
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
13
+ "output_orig_postfix": "recon",
14
+ "output_recon_postfix": "orig",
15
+ "channel": 0,
16
+ "spacing": [
17
+ 1.1,
18
+ 1.1,
19
+ 1.1
20
+ ],
21
+ "spatial_dims": 3,
22
+ "image_channels": 1,
23
+ "latent_channels": 8,
24
+ "infer_patch_size": [
25
+ 144,
26
+ 176,
27
+ 112
28
+ ],
29
+ "autoencoder_def": {
30
+ "_target_": "monai.networks.nets.autoencoderkl.AutoencoderKL",
31
+ "spatial_dims": "@spatial_dims",
32
+ "in_channels": "@image_channels",
33
+ "out_channels": "@image_channels",
34
+ "latent_channels": "@latent_channels",
35
+ "channels": [
36
+ 64,
37
+ 128,
38
+ 256
39
+ ],
40
+ "num_res_blocks": 2,
41
+ "norm_num_groups": 32,
42
+ "norm_eps": 1e-06,
43
+ "attention_levels": [
44
+ false,
45
+ false,
46
+ false
47
+ ],
48
+ "with_encoder_nonlocal_attn": false,
49
+ "with_decoder_nonlocal_attn": false,
50
+ "include_fc": false
51
+ },
52
+ "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
53
+ "load_autoencoder": "$@autoencoder_def.load_old_state_dict(torch.load(@load_autoencoder_path))",
54
+ "autoencoder": "$@autoencoder_def.to(@device)",
55
+ "preprocessing_transforms": [
56
+ {
57
+ "_target_": "LoadImaged",
58
+ "keys": "image"
59
+ },
60
+ {
61
+ "_target_": "EnsureChannelFirstd",
62
+ "keys": "image"
63
+ },
64
+ {
65
+ "_target_": "Lambdad",
66
+ "keys": "image",
67
+ "func": "$lambda x: x[@channel, :, :, :]"
68
+ },
69
+ {
70
+ "_target_": "EnsureChannelFirstd",
71
+ "keys": "image",
72
+ "channel_dim": "no_channel"
73
+ },
74
+ {
75
+ "_target_": "EnsureTyped",
76
+ "keys": "image"
77
+ },
78
+ {
79
+ "_target_": "Orientationd",
80
+ "keys": "image",
81
+ "axcodes": "RAS"
82
+ },
83
+ {
84
+ "_target_": "Spacingd",
85
+ "keys": "image",
86
+ "pixdim": "@spacing",
87
+ "mode": "bilinear"
88
+ }
89
+ ],
90
+ "crop_transforms": [
91
+ {
92
+ "_target_": "CenterSpatialCropd",
93
+ "keys": "image",
94
+ "roi_size": "@infer_patch_size"
95
+ }
96
+ ],
97
+ "final_transforms": [
98
+ {
99
+ "_target_": "ScaleIntensityRangePercentilesd",
100
+ "keys": "image",
101
+ "lower": 0,
102
+ "upper": 99.5,
103
+ "b_min": 0,
104
+ "b_max": 1
105
+ }
106
+ ],
107
+ "preprocessing": {
108
+ "_target_": "Compose",
109
+ "transforms": "$@preprocessing_transforms + @crop_transforms + @final_transforms"
110
+ },
111
+ "dataset": {
112
+ "_target_": "monai.apps.DecathlonDataset",
113
+ "root_dir": "@dataset_dir",
114
+ "task": "Task01_BrainTumour",
115
+ "section": "validation",
116
+ "cache_rate": 0.0,
117
+ "num_workers": 8,
118
+ "download": false,
119
+ "transform": "@preprocessing"
120
+ },
121
+ "dataloader": {
122
+ "_target_": "DataLoader",
123
+ "dataset": "@dataset",
124
+ "batch_size": 1,
125
+ "shuffle": true,
126
+ "num_workers": 0
127
+ },
128
+ "saver_orig": {
129
+ "_target_": "SaveImage",
130
+ "_requires_": "@create_output_dir",
131
+ "output_dir": "@output_dir",
132
+ "output_postfix": "@output_orig_postfix",
133
+ "resample": false,
134
+ "padding_mode": "zeros"
135
+ },
136
+ "saver_recon": {
137
+ "_target_": "SaveImage",
138
+ "_requires_": "@create_output_dir",
139
+ "output_dir": "@output_dir",
140
+ "output_postfix": "@output_recon_postfix",
141
+ "resample": false,
142
+ "padding_mode": "zeros"
143
+ },
144
+ "input_img": "$monai.utils.first(@dataloader)['image'].to(@device)",
145
+ "recon_img": "$@autoencoder(@input_img)[0][0]",
146
+ "run": [
147
+ "$@load_autoencoder",
148
+ "$@saver_orig(@input_img[0][0])",
149
+ "$@saver_recon(@recon_img)"
150
+ ]
151
+ }
configs/inference_trt.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "+imports": [
3
+ "$from monai.networks import trt_compile"
4
+ ],
5
+ "diffusion": "$trt_compile(@network_def.to(@device), @load_diffusion_path)",
6
+ "autoencoder": "$trt_compile(@autoencoder_def.to(@device), @load_autoencoder_path)"
7
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
3
+ "version": "1.1.3",
4
+ "changelog": {
5
+ "1.1.3": "update to huggingface hosting and fix missing dependencies",
6
+ "1.1.2": "update issue for IgniteInfo",
7
+ "1.1.1": "enable tensorrt",
8
+ "1.1.0": "update to use monai 1.4, model ckpt not changed, rm GenerativeAI repo",
9
+ "1.0.9": "update to use monai 1.3.1",
10
+ "1.0.8": "update run section",
11
+ "1.0.7": "update with EnsureChannelFirstd",
12
+ "1.0.6": "update with new lr scheduler api in inference",
13
+ "1.0.5": "fix the wrong GPU index issue of multi-node",
14
+ "1.0.4": "update with new lr scheduler api",
15
+ "1.0.3": "update required packages",
16
+ "1.0.2": "unify dataset dir in different configs",
17
+ "1.0.1": "update dependency, update trained model weights",
18
+ "1.0.0": "Initial release"
19
+ },
20
+ "monai_version": "1.4.0",
21
+ "pytorch_version": "2.4.0",
22
+ "numpy_version": "1.24.4",
23
+ "required_packages_version": {
24
+ "nibabel": "5.2.1",
25
+ "lpips": "0.1.4",
26
+ "einops": "0.7.0",
27
+ "pytorch-ignite": "0.4.11",
28
+ "tensorboard": "2.17.0"
29
+ },
30
+ "supported_apps": {},
31
+ "name": "BraTS MRI image latent diffusion generation",
32
+ "task": "BraTS MRI image synthesis",
33
+ "description": "A generative model for creating 3D brain MRI from Gaussian noise based on BraTS dataset",
34
+ "authors": "MONAI team",
35
+ "copyright": "Copyright (c) MONAI Consortium",
36
+ "data_source": "http://medicaldecathlon.com/",
37
+ "data_type": "nibabel",
38
+ "image_classes": "Flair brain MRI with 1.1x1.1x1.1 mm voxel size",
39
+ "eval_metrics": {},
40
+ "intended_use": "This is a research tool/prototype and not to be used clinically",
41
+ "references": [],
42
+ "autoencoder_data_format": {
43
+ "inputs": {
44
+ "image": {
45
+ "type": "image",
46
+ "format": "image",
47
+ "num_channels": 1,
48
+ "spatial_shape": [
49
+ 112,
50
+ 128,
51
+ 80
52
+ ],
53
+ "dtype": "float32",
54
+ "value_range": [
55
+ 0,
56
+ 1
57
+ ],
58
+ "is_patch_data": true
59
+ }
60
+ },
61
+ "outputs": {
62
+ "pred": {
63
+ "type": "image",
64
+ "format": "image",
65
+ "num_channels": 1,
66
+ "spatial_shape": [
67
+ 112,
68
+ 128,
69
+ 80
70
+ ],
71
+ "dtype": "float32",
72
+ "value_range": [
73
+ 0,
74
+ 1
75
+ ],
76
+ "is_patch_data": true,
77
+ "channel_def": {
78
+ "0": "image"
79
+ }
80
+ }
81
+ }
82
+ },
83
+ "network_data_format": {
84
+ "inputs": {
85
+ "latent": {
86
+ "type": "noise",
87
+ "format": "image",
88
+ "num_channels": 8,
89
+ "spatial_shape": [
90
+ 36,
91
+ 44,
92
+ 28
93
+ ],
94
+ "dtype": "float32",
95
+ "value_range": [
96
+ 0,
97
+ 1
98
+ ],
99
+ "is_patch_data": true
100
+ },
101
+ "condition": {
102
+ "type": "timesteps",
103
+ "format": "timesteps",
104
+ "num_channels": 1,
105
+ "spatial_shape": [],
106
+ "dtype": "long",
107
+ "value_range": [
108
+ 0,
109
+ 1000
110
+ ],
111
+ "is_patch_data": false
112
+ }
113
+ },
114
+ "outputs": {
115
+ "pred": {
116
+ "type": "feature",
117
+ "format": "image",
118
+ "num_channels": 8,
119
+ "spatial_shape": [
120
+ 36,
121
+ 44,
122
+ 28
123
+ ],
124
+ "dtype": "float32",
125
+ "value_range": [
126
+ 0,
127
+ 1
128
+ ],
129
+ "is_patch_data": true,
130
+ "channel_def": {
131
+ "0": "image"
132
+ }
133
+ }
134
+ }
135
+ }
136
+ }
configs/multi_gpu_train_autoencoder.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])",
3
+ "gnetwork": {
4
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
5
+ "module": "$@autoencoder_def.to(@device)",
6
+ "device_ids": [
7
+ "@device"
8
+ ]
9
+ },
10
+ "dnetwork": {
11
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
12
+ "module": "$@discriminator_def.to(@device)",
13
+ "device_ids": [
14
+ "@device"
15
+ ]
16
+ },
17
+ "train#sampler": {
18
+ "_target_": "DistributedSampler",
19
+ "dataset": "@train#dataset",
20
+ "even_divisible": true,
21
+ "shuffle": true
22
+ },
23
+ "train#dataloader#sampler": "@train#sampler",
24
+ "train#dataloader#shuffle": false,
25
+ "train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]",
26
+ "initialize": [
27
+ "$import torch.distributed as dist",
28
+ "$import os",
29
+ "$dist.is_initialized() or dist.init_process_group(backend='nccl')",
30
+ "$torch.cuda.set_device(@device)",
31
+ "$monai.utils.set_determinism(seed=123)",
32
+ "$import logging",
33
+ "$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)"
34
+ ],
35
+ "run": [
36
+ "$@train#trainer.run()"
37
+ ],
38
+ "finalize": [
39
+ "$dist.is_initialized() and dist.destroy_process_group()"
40
+ ]
41
+ }
configs/multi_gpu_train_diffusion.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "diffusion": {
3
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
4
+ "module": "$@network_def.to(@device)",
5
+ "device_ids": [
6
+ "@device"
7
+ ],
8
+ "find_unused_parameters": true
9
+ },
10
+ "run": [
11
+ "@load_autoencoder",
12
13
+ "$print('scale factor:',@scale_factor)",
14
+ "$@train#trainer.run()"
15
+ ]
16
+ }
configs/train_autoencoder.json ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import functools",
4
+ "$import glob",
5
+ "$import scripts"
6
+ ],
7
+ "bundle_root": ".",
8
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
9
+ "ckpt_dir": "$@bundle_root + '/models'",
10
+ "tf_dir": "$@bundle_root + '/eval'",
11
+ "dataset_dir": "/workspace/data/medical",
12
+ "pretrained": false,
13
+ "perceptual_loss_model_weights_path": null,
14
+ "train_batch_size": 2,
15
+ "lr": 1e-05,
16
+ "train_patch_size": [
17
+ 112,
18
+ 128,
19
+ 80
20
+ ],
21
+ "channel": 0,
22
+ "spacing": [
23
+ 1.1,
24
+ 1.1,
25
+ 1.1
26
+ ],
27
+ "spatial_dims": 3,
28
+ "image_channels": 1,
29
+ "latent_channels": 8,
30
+ "discriminator_def": {
31
+ "_target_": "monai.networks.nets.patchgan_discriminator.PatchDiscriminator",
32
+ "spatial_dims": "@spatial_dims",
33
+ "num_layers_d": 3,
34
+ "channels": 32,
35
+ "in_channels": 1,
36
+ "out_channels": 1,
37
+ "norm": "INSTANCE"
38
+ },
39
+ "autoencoder_def": {
40
+ "_target_": "monai.networks.nets.autoencoderkl.AutoencoderKL",
41
+ "spatial_dims": "@spatial_dims",
42
+ "in_channels": "@image_channels",
43
+ "out_channels": "@image_channels",
44
+ "latent_channels": "@latent_channels",
45
+ "channels": [
46
+ 64,
47
+ 128,
48
+ 256
49
+ ],
50
+ "num_res_blocks": 2,
51
+ "norm_num_groups": 32,
52
+ "norm_eps": 1e-06,
53
+ "attention_levels": [
54
+ false,
55
+ false,
56
+ false
57
+ ],
58
+ "with_encoder_nonlocal_attn": false,
59
+ "with_decoder_nonlocal_attn": false,
60
+ "include_fc": false
61
+ },
62
+ "perceptual_loss_def": {
63
+ "_target_": "monai.losses.perceptual.PerceptualLoss",
64
+ "spatial_dims": "@spatial_dims",
65
+ "network_type": "resnet50",
66
+ "is_fake_3d": true,
67
+ "fake_3d_ratio": 0.2,
68
+ "pretrained": "@pretrained",
69
+ "pretrained_path": "@perceptual_loss_model_weights_path",
70
+ "pretrained_state_dict_key": "state_dict"
71
+ },
72
+ "dnetwork": "$@discriminator_def.to(@device)",
73
+ "gnetwork": "$@autoencoder_def.to(@device)",
74
+ "loss_perceptual": "$@perceptual_loss_def.to(@device)",
75
+ "doptimizer": {
76
+ "_target_": "torch.optim.Adam",
77
+ "params": "[email protected]()",
78
+ "lr": "@lr"
79
+ },
80
+ "goptimizer": {
81
+ "_target_": "torch.optim.Adam",
82
+ "params": "[email protected]()",
83
+ "lr": "@lr"
84
+ },
85
+ "preprocessing_transforms": [
86
+ {
87
+ "_target_": "LoadImaged",
88
+ "keys": "image"
89
+ },
90
+ {
91
+ "_target_": "EnsureChannelFirstd",
92
+ "keys": "image"
93
+ },
94
+ {
95
+ "_target_": "Lambdad",
96
+ "keys": "image",
97
+ "func": "$lambda x: x[@channel, :, :, :]"
98
+ },
99
+ {
100
+ "_target_": "EnsureChannelFirstd",
101
+ "keys": "image",
102
+ "channel_dim": "no_channel"
103
+ },
104
+ {
105
+ "_target_": "EnsureTyped",
106
+ "keys": "image"
107
+ },
108
+ {
109
+ "_target_": "Orientationd",
110
+ "keys": "image",
111
+ "axcodes": "RAS"
112
+ },
113
+ {
114
+ "_target_": "Spacingd",
115
+ "keys": "image",
116
+ "pixdim": "@spacing",
117
+ "mode": "bilinear"
118
+ }
119
+ ],
120
+ "final_transforms": [
121
+ {
122
+ "_target_": "ScaleIntensityRangePercentilesd",
123
+ "keys": "image",
124
+ "lower": 0,
125
+ "upper": 99.5,
126
+ "b_min": 0,
127
+ "b_max": 1
128
+ }
129
+ ],
130
+ "train": {
131
+ "crop_transforms": [
132
+ {
133
+ "_target_": "RandSpatialCropd",
134
+ "keys": "image",
135
+ "roi_size": "@train_patch_size",
136
+ "random_size": false
137
+ }
138
+ ],
139
+ "preprocessing": {
140
+ "_target_": "Compose",
141
+ "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms"
142
+ },
143
+ "dataset": {
144
+ "_target_": "monai.apps.DecathlonDataset",
145
+ "root_dir": "@dataset_dir",
146
+ "task": "Task01_BrainTumour",
147
+ "section": "training",
148
+ "cache_rate": 1.0,
149
+ "num_workers": 8,
150
+ "download": false,
151
+ "transform": "@train#preprocessing"
152
+ },
153
+ "dataloader": {
154
+ "_target_": "DataLoader",
155
+ "dataset": "@train#dataset",
156
+ "batch_size": "@train_batch_size",
157
+ "shuffle": true,
158
+ "num_workers": 0
159
+ },
160
+ "handlers": [
161
+ {
162
+ "_target_": "CheckpointSaver",
163
+ "save_dir": "@ckpt_dir",
164
+ "save_dict": {
165
+ "model": "@gnetwork"
166
+ },
167
+ "save_interval": 0,
168
+ "save_final": true,
169
+ "epoch_level": true,
170
+ "final_filename": "model_autoencoder.pt"
171
+ },
172
+ {
173
+ "_target_": "StatsHandler",
174
+ "tag_name": "train_loss",
175
+ "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]"
176
+ },
177
+ {
178
+ "_target_": "TensorBoardStatsHandler",
179
+ "log_dir": "@tf_dir",
180
+ "tag_name": "train_loss",
181
+ "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]"
182
+ }
183
+ ],
184
+ "trainer": {
185
+ "_target_": "scripts.ldm_trainer.VaeGanTrainer",
186
+ "device": "@device",
187
+ "max_epochs": 1500,
188
+ "train_data_loader": "@train#dataloader",
189
+ "g_network": "@gnetwork",
190
+ "g_optimizer": "@goptimizer",
191
+ "g_loss_function": "$functools.partial(scripts.losses.generator_loss, disc_net=@dnetwork, loss_perceptual=@loss_perceptual)",
192
+ "d_network": "@dnetwork",
193
+ "d_optimizer": "@doptimizer",
194
+ "d_loss_function": "$functools.partial(scripts.losses.discriminator_loss, disc_net=@dnetwork)",
195
+ "d_train_steps": 5,
196
+ "g_update_latents": true,
197
+ "latent_shape": "@latent_channels",
198
+ "key_train_metric": "$None",
199
+ "train_handlers": "@train#handlers"
200
+ }
201
+ },
202
+ "initialize": [
203
+ "$monai.utils.set_determinism(seed=0)"
204
+ ],
205
+ "run": [
206
+ "$@train#trainer.run()"
207
+ ]
208
+ }
configs/train_diffusion.json ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ckpt_dir": "$@bundle_root + '/models'",
3
+ "train_batch_size": 4,
4
+ "lr": 1e-05,
5
+ "train_patch_size": [
6
+ 144,
7
+ 176,
8
+ 112
9
+ ],
10
+ "latent_shape": [
11
+ "@latent_channels",
12
+ 36,
13
+ 44,
14
+ 28
15
+ ],
16
+ "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
17
+ "load_autoencoder": "$@autoencoder_def.load_old_state_dict(torch.load(@load_autoencoder_path))",
18
+ "autoencoder": "$@autoencoder_def.to(@device)",
19
+ "network_def": {
20
+ "_target_": "monai.networks.nets.diffusion_model_unet.DiffusionModelUNet",
21
+ "spatial_dims": "@spatial_dims",
22
+ "in_channels": "@latent_channels",
23
+ "out_channels": "@latent_channels",
24
+ "channels": [
25
+ 256,
26
+ 256,
27
+ 512
28
+ ],
29
+ "attention_levels": [
30
+ false,
31
+ true,
32
+ true
33
+ ],
34
+ "num_head_channels": [
35
+ 0,
36
+ 64,
37
+ 64
38
+ ],
39
+ "num_res_blocks": 2,
40
+ "include_fc": false,
41
+ "use_combined_linear": false
42
+ },
43
+ "diffusion": "$@network_def.to(@device)",
44
+ "optimizer": {
45
+ "_target_": "torch.optim.Adam",
46
+ "params": "[email protected]()",
47
+ "lr": "@lr"
48
+ },
49
+ "lr_scheduler": {
50
+ "_target_": "torch.optim.lr_scheduler.MultiStepLR",
51
+ "optimizer": "@optimizer",
52
+ "milestones": [
53
+ 100,
54
+ 1000
55
+ ],
56
+ "gamma": 0.1
57
+ },
58
+ "scale_factor": "$scripts.utils.compute_scale_factor(@autoencoder,@train#dataloader,@device)",
59
+ "noise_scheduler": {
60
+ "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
61
+ "_requires_": [
62
+ "@load_autoencoder"
63
+ ],
64
+ "schedule": "scaled_linear_beta",
65
+ "num_train_timesteps": 1000,
66
+ "beta_start": 0.0015,
67
+ "beta_end": 0.0195
68
+ },
69
+ "loss": {
70
+ "_target_": "torch.nn.MSELoss"
71
+ },
72
+ "train": {
73
+ "inferer": {
74
+ "_target_": "monai.inferers.LatentDiffusionInferer",
75
+ "scheduler": "@noise_scheduler",
76
+ "scale_factor": "@scale_factor"
77
+ },
78
+ "crop_transforms": [
79
+ {
80
+ "_target_": "CenterSpatialCropd",
81
+ "keys": "image",
82
+ "roi_size": "@train_patch_size"
83
+ }
84
+ ],
85
+ "preprocessing": {
86
+ "_target_": "Compose",
87
+ "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms"
88
+ },
89
+ "dataset": {
90
+ "_target_": "monai.apps.DecathlonDataset",
91
+ "root_dir": "@dataset_dir",
92
+ "task": "Task01_BrainTumour",
93
+ "section": "training",
94
+ "cache_rate": 1.0,
95
+ "num_workers": 8,
96
+ "download": false,
97
+ "transform": "@train#preprocessing"
98
+ },
99
+ "dataloader": {
100
+ "_target_": "DataLoader",
101
+ "dataset": "@train#dataset",
102
+ "batch_size": "@train_batch_size",
103
+ "shuffle": true,
104
+ "num_workers": 0
105
+ },
106
+ "handlers": [
107
+ {
108
+ "_target_": "LrScheduleHandler",
109
+ "lr_scheduler": "@lr_scheduler",
110
+ "print_lr": true
111
+ },
112
+ {
113
+ "_target_": "CheckpointSaver",
114
+ "save_dir": "@ckpt_dir",
115
+ "save_dict": {
116
+ "model": "@diffusion"
117
+ },
118
+ "save_interval": 0,
119
+ "save_final": true,
120
+ "epoch_level": true,
121
+ "final_filename": "model.pt"
122
+ },
123
+ {
124
+ "_target_": "StatsHandler",
125
+ "tag_name": "train_diffusion_loss",
126
+ "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)"
127
+ },
128
+ {
129
+ "_target_": "TensorBoardStatsHandler",
130
+ "log_dir": "@tf_dir",
131
+ "tag_name": "train_diffusion_loss",
132
+ "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)"
133
+ }
134
+ ],
135
+ "trainer": {
136
+ "_target_": "scripts.ldm_trainer.LDMTrainer",
137
+ "device": "@device",
138
+ "max_epochs": 5000,
139
+ "train_data_loader": "@train#dataloader",
140
+ "network": "@diffusion",
141
+ "autoencoder_model": "@autoencoder",
142
+ "optimizer": "@optimizer",
143
+ "loss_function": "@loss",
144
+ "latent_shape": "@latent_shape",
145
+ "inferer": "@train#inferer",
146
+ "key_train_metric": "$None",
147
+ "train_handlers": "@train#handlers"
148
+ }
149
+ },
150
+ "initialize": [
151
+ "$monai.utils.set_determinism(seed=0)"
152
+ ],
153
+ "run": [
154
+ "@load_autoencoder",
155
156
+ "$print('scale factor:',@scale_factor)",
157
+ "$@train#trainer.run()"
158
+ ]
159
+ }
docs/README.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Overview
2
+ A pre-trained model for volumetric (3D) Brats MRI 3D Latent Diffusion Generative Model.
3
+
4
+ This model is trained on BraTS 2016 and 2017 data from [Medical Decathlon](http://medicaldecathlon.com/), using the Latent diffusion model [1].
5
+
6
+ ![model workflow](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_network.png)
7
+
8
+ This model is a generator for creating images like the Flair MRIs based on BraTS 2016 and 2017 data. It was trained as a 3d latent diffusion model and accepts Gaussian random noise as inputs to produce an image output. The `train_autoencoder.json` file describes the training process of the variational autoencoder with GAN loss. The `train_diffusion.json` file describes the training process of the 3D latent diffusion model.
9
+
10
+ In this bundle, the autoencoder uses perceptual loss, which is based on ResNet50 with pre-trained weights (the network is frozen and will not be trained in the bundle). In default, the `pretrained` parameter is specified as `False` in `train_autoencoder.json`. To ensure correct training, changing the default settings is necessary. There are two ways to utilize pretrained weights:
11
+ 1. if set `pretrained` to `True`, ImageNet pretrained weights from [torchvision](https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html#ResNet50_Weights) will be used. However, the weights are for non-commercial use only.
12
+ 2. if set `pretrained` to `True` and specifies the `perceptual_loss_model_weights_path` parameter, users are able to load weights from a local path. This is the way this bundle used to train, and the pre-trained weights are from some internal data.
13
+
14
+ Please note that each user is responsible for checking the data source of the pre-trained models, the applicable licenses, and determining if suitable for the intended use.
15
+
16
+ #### Example synthetic image
17
+ An example result from inference is shown below:
18
+ ![Example synthetic image](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_example_generation_v2.png)
19
+
20
+ **This is a demonstration network meant to just show the training process for this sort of network with MONAI. To achieve better performance, users need to use larger dataset like [Brats 2021](https://www.synapse.org/#!Synapse:syn25829067/wiki/610865) and have GPU with memory larger than 32G to enable larger networks and attention layers.**
21
+
22
+ ## Data
23
+ The training data is BraTS 2016 and 2017 from the Medical Segmentation Decathalon. Users can find more details on the dataset (`Task01_BrainTumour`) at http://medicaldecathlon.com/.
24
+
25
+ - Target: Image Generation
26
+ - Task: Synthesis
27
+ - Modality: MRI
28
+ - Size: 388 3D volumes (1 channel used)
29
+
30
+ ## Training Configuration
31
+ If you have a GPU with less than 32G of memory, you may need to decrease the batch size when training. To do so, modify the `train_batch_size` parameter in the [configs/train_autoencoder.json](../configs/train_autoencoder.json) and [configs/train_diffusion.json](../configs/train_diffusion.json) configuration files.
32
+
33
+ ### Training Configuration of Autoencoder
34
+ The autoencoder was trained using the following configuration:
35
+
36
+ - GPU: at least 32GB GPU memory
37
+ - Actual Model Input: 112 x 128 x 80
38
+ - AMP: False
39
+ - Optimizer: Adam
40
+ - Learning Rate: 1e-5
41
+ - Loss: L1 loss, perceptual loss, KL divergence loss, adversarial loss, GAN BCE loss
42
+
43
+ #### Input
44
+ 1 channel 3D MRI Flair patches
45
+
46
+ #### Output
47
+ - 1 channel 3D MRI reconstructed patches
48
+ - 8 channel mean of latent features
49
+ - 8 channel standard deviation of latent features
50
+
51
+ ### Training Configuration of Diffusion Model
52
+ The latent diffusion model was trained using the following configuration:
53
+
54
+ - GPU: at least 32GB GPU memory
55
+ - Actual Model Input: 36 x 44 x 28
56
+ - AMP: False
57
+ - Optimizer: Adam
58
+ - Learning Rate: 1e-5
59
+ - Loss: MSE loss
60
+
61
+ #### Training Input
62
+ - 8 channel noisy latent features
63
+ - a long int that indicates the time step
64
+
65
+ #### Training Output
66
+ 8 channel predicted added noise
67
+
68
+ #### Inference Input
69
+ 8 channel noise
70
+
71
+ #### Inference Output
72
+ 8 channel denoised latent features
73
+
74
+ ### Memory Consumption Warning
75
+
76
+ If you face memory issues with data loading, you can lower the caching rate `cache_rate` in the configurations within range [0, 1] to minimize the System RAM requirements.
77
+
78
+ ## Performance
79
+
80
+ #### Training Loss
81
+ ![A graph showing the autoencoder training curve](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_train_autoencoder_loss_v2.png)
82
+
83
+ ![A graph showing the latent diffusion training curve](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_train_diffusion_loss_v2.png)
84
+
85
+ #### TensorRT speedup
86
+ This bundle supports acceleration with TensorRT. The table below displays the speedup ratios observed on an A100 80G GPU. Please note that 32-bit precision models are benchmarked with tf32 weight format.
87
+
88
+ | method | torch_tf32(ms) | torch_amp(ms) | trt_tf32(ms) | trt_fp16(ms) | speedup amp | speedup tf32 | speedup fp16 | amp vs fp16|
89
+ | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
90
+ | model computation (diffusion) | 44.57 | 44.59 | 40.89 | 18.79 | 1.00 | 1.09 | 2.37 | 2.37 |
91
+ | model computation (autoencoder) | 96.29 | 97.01 | 78.51 | 44.03 | 0.99 | 1.23 | 2.19 | 2.20 |
92
+ | end2end | 2826 | 2538 | 2759 | 1472 | 1.11 | 1.02 | 1.92 | 1.72 |
93
+
94
+ Where:
95
+ - `model computation` means the speedup ratio of model's inference with a random input without preprocessing and postprocessing
96
+ - `end2end` means run the bundle end-to-end with the TensorRT based model.
97
+ - `torch_tf32` and `torch_amp` are for the PyTorch models with or without `amp` mode.
98
+ - `trt_tf32` and `trt_fp16` are for the TensorRT based models converted in corresponding precision.
99
+ - `speedup amp`, `speedup tf32` and `speedup fp16` are the speedup ratios of corresponding models versus the PyTorch float32 model
100
+ - `amp vs fp16` is the speedup ratio between the PyTorch amp model and the TensorRT float16 based model.
101
+
102
+ This result is benchmarked under:
103
+ - TensorRT: 10.3.0+cuda12.6
104
+ - Torch-TensorRT Version: 2.4.0
105
+ - CPU Architecture: x86-64
106
+ - OS: ubuntu 20.04
107
+ - Python version:3.10.12
108
+ - CUDA version: 12.6
109
+ - GPU models and configuration: A100 80G
110
+
111
+ ## MONAI Bundle Commands
112
+
113
+ In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. The CLI supports flexible use cases, such as overriding configs at runtime and predefining arguments in a file.
114
+
115
+ For more details usage instructions, visit the [MONAI Bundle Configuration Page](https://docs.monai.io/en/latest/config_syntax.html).
116
+
117
+ ### Execute Autoencoder Training
118
+
119
+ #### Execute Autoencoder Training on single GPU
120
+
121
+ ```
122
+ python -m monai.bundle run --config_file configs/train_autoencoder.json
123
+ ```
124
+
125
+ Please note that if the default dataset path is not modified with the actual path (it should be the path that contains `Task01_BrainTumour`) in the bundle config files, you can also override it by using `--dataset_dir`:
126
+
127
+ ```
128
+ python -m monai.bundle run --config_file configs/train_autoencoder.json --dataset_dir <actual dataset path>
129
+ ```
130
+
131
+ #### Override the `train` config to execute multi-GPU training for Autoencoder
132
+ To train with multiple GPUs, use the following command, which requires scaling up the learning rate according to the number of GPUs.
133
+
134
+ ```
135
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/multi_gpu_train_autoencoder.json']" --lr 8e-5
136
+ ```
137
+
138
+ #### Check the Autoencoder Training result
139
+ The following code generates a reconstructed image from a random input image.
140
+ We can visualize it to see if the autoencoder is trained correctly.
141
+ ```
142
+ python -m monai.bundle run --config_file configs/inference_autoencoder.json
143
+ ```
144
+
145
+ An example of reconstructed image from inference is shown below. If the autoencoder is trained correctly, the reconstructed image should look similar to original image.
146
+
147
+ ![Example reconstructed image](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_recon_example.jpg)
148
+
149
+
150
+ ### Execute Latent Diffusion Training
151
+
152
+ #### Execute Latent Diffusion Model Training on single GPU
153
+ After training the autoencoder, run the following command to train the latent diffusion model. This command will print out the scale factor of the latent feature space. If your autoencoder is well trained, this value should be close to 1.0.
154
+
155
+ ```
156
+ python -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/train_diffusion.json']"
157
+ ```
158
+
159
+ #### Override the `train` config to execute multi-GPU training for Latent Diffusion Model
160
+ To train with multiple GPUs, use the following command, which requires scaling up the learning rate according to the number of GPUs.
161
+
162
+ ```
163
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/train_diffusion.json','configs/multi_gpu_train_autoencoder.json','configs/multi_gpu_train_diffusion.json']" --lr 8e-5
164
+ ```
165
+
166
+ #### Execute inference
167
+ The following code generates a synthetic image from a random sampled noise.
168
+ ```
169
+ python -m monai.bundle run --config_file configs/inference.json
170
+ ```
171
+
172
+ #### Execute inference with the TensorRT model:
173
+
174
+ ```
175
+ python -m monai.bundle run --config_file "['configs/inference.json', 'configs/inference_trt.json']"
176
+ ```
177
+
178
+
179
+ # References
180
+ [1] Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022. https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf
181
+
182
+ # License
183
+ Copyright (c) MONAI Consortium
184
+
185
+ Licensed under the Apache License, Version 2.0 (the "License");
186
+ you may not use this file except in compliance with the License.
187
+ You may obtain a copy of the License at
188
+
189
+ http://www.apache.org/licenses/LICENSE-2.0
190
+
191
+ Unless required by applicable law or agreed to in writing, software
192
+ distributed under the License is distributed on an "AS IS" BASIS,
193
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
194
+ See the License for the specific language governing permissions and
195
+ limitations under the License.
docs/data_license.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third Party Licenses
2
+ -----------------------------------------------------------------------
3
+
4
+ /*********************************************************************/
5
+ i. Multimodal Brain Tumor Segmentation Challenge 2018
6
+ https://www.med.upenn.edu/sbia/brats2018/data.html
7
+ /*********************************************************************/
8
+
9
+ Data Usage Agreement / Citations
10
+
11
+ You are free to use and/or refer to the BraTS datasets in your own
12
+ research, provided that you always cite the following two manuscripts:
13
+
14
+ [1] Menze BH, Jakab A, Bauer S, Kalpathy-Cramer J, Farahani K, Kirby
15
+ [J, Burren Y, Porz N, Slotboom J, Wiest R, Lanczi L, Gerstner E, Weber
16
+ [MA, Arbel T, Avants BB, Ayache N, Buendia P, Collins DL, Cordier N,
17
+ [Corso JJ, Criminisi A, Das T, Delingette H, Demiralp Γ, Durst CR,
18
+ [Dojat M, Doyle S, Festa J, Forbes F, Geremia E, Glocker B, Golland P,
19
+ [Guo X, Hamamci A, Iftekharuddin KM, Jena R, John NM, Konukoglu E,
20
+ [Lashkari D, Mariz JA, Meier R, Pereira S, Precup D, Price SJ, Raviv
21
+ [TR, Reza SM, Ryan M, Sarikaya D, Schwartz L, Shin HC, Shotton J,
22
+ [Silva CA, Sousa N, Subbanna NK, Szekely G, Taylor TJ, Thomas OM,
23
+ [Tustison NJ, Unal G, Vasseur F, Wintermark M, Ye DH, Zhao L, Zhao B,
24
+ [Zikic D, Prastawa M, Reyes M, Van Leemput K. "The Multimodal Brain
25
+ [Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on
26
+ [Medical Imaging 34(10), 1993-2024 (2015) DOI:
27
+ [10.1109/TMI.2014.2377694
28
+
29
+ [2] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby JS,
30
+ [Freymann JB, Farahani K, Davatzikos C. "Advancing The Cancer Genome
31
+ [Atlas glioma MRI collections with expert segmentation labels and
32
+ [radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:
33
+ [10.1038/sdata.2017.117
34
+
35
+ In addition, if there are no restrictions imposed from the
36
+ journal/conference you submit your paper about citing "Data
37
+ Citations", please be specific and also cite the following:
38
+
39
+ [3] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby J,
40
+ [Freymann J, Farahani K, Davatzikos C. "Segmentation Labels and
41
+ [Radiomic Features for the Pre-operative Scans of the TCGA-GBM
42
+ [collection", The Cancer Imaging Archive, 2017. DOI:
43
+ [10.7937/K9/TCIA.2017.KLXWJJ1Q
44
+
45
+ [4] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby J,
46
+ [Freymann J, Farahani K, Davatzikos C. "Segmentation Labels and
47
+ [Radiomic Features for the Pre-operative Scans of the TCGA-LGG
48
+ [collection", The Cancer Imaging Archive, 2017. DOI:
49
+ [10.7937/K9/TCIA.2017.GJQ7R0EF
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4af4a27015291c0a6390b45b0d39e4d54924c2250cadd6d5c1bb9717d76a26fd
3
+ size 765020741
models/model_autoencoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85ec0986beb902b93beb5e27ea5c39429e8ae02c3bde8ca581bef0cac83014bc
3
+ size 84050405
scripts/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from . import ldm_sampler, ldm_trainer, losses, utils
scripts/ldm_sampler.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from monai.transforms import Transform
17
+ from monai.utils import optional_import
18
+ from torch.cuda.amp import autocast
19
+
20
+ tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
21
+
22
+
23
+ class LDMSampler:
24
+ def __init__(self) -> None:
25
+ super().__init__()
26
+
27
+ @torch.no_grad()
28
+ def sampling_fn(
29
+ self,
30
+ input_noise: torch.Tensor,
31
+ autoencoder_model: nn.Module,
32
+ diffusion_model: nn.Module,
33
+ scheduler: nn.Module,
34
+ conditioning: torch.Tensor | None = None,
35
+ ) -> torch.Tensor:
36
+ if has_tqdm:
37
+ progress_bar = tqdm(scheduler.timesteps)
38
+ else:
39
+ progress_bar = iter(scheduler.timesteps)
40
+
41
+ image = input_noise
42
+ if conditioning is not None:
43
+ cond_concat = conditioning.squeeze(1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
44
+ cond_concat = cond_concat.expand(list(cond_concat.shape[0:2]) + list(input_noise.shape[2:]))
45
+
46
+ for t in progress_bar:
47
+ with torch.no_grad():
48
+ if conditioning is not None:
49
+ input_t = torch.cat((image, cond_concat), dim=1)
50
+ else:
51
+ input_t = image
52
+ model_output = diffusion_model(
53
+ input_t, timesteps=torch.Tensor((t,)).to(input_noise.device).long(), context=conditioning
54
+ )
55
+ image, _ = scheduler.step(model_output, t, image)
56
+
57
+ with torch.no_grad():
58
+ with autocast():
59
+ sample = autoencoder_model.decode_stage_2_outputs(image)
60
+
61
+ return sample
62
+
63
+ def run(
64
+ self,
65
+ input_noise: torch.Tensor,
66
+ autoencoder_model: nn.Module,
67
+ diffusion_model: nn.Module,
68
+ scheduler: nn.Module,
69
+ saver: Transform,
70
+ conditioning: torch.Tensor | None = None,
71
+ ) -> torch.Tensor:
72
+ sample = self.sampling_fn(input_noise, autoencoder_model, diffusion_model, scheduler, conditioning)
73
+ saver(sample[0])
scripts/ldm_trainer.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
15
+
16
+ import torch
17
+ from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
18
+ from monai.inferers import Inferer, SimpleInferer
19
+ from monai.transforms import Transform
20
+ from monai.utils import IgniteInfo, min_version, optional_import
21
+ from monai.utils.enums import CommonKeys, GanKeys
22
+ from torch.optim.optimizer import Optimizer
23
+ from torch.utils.data import DataLoader
24
+
25
+ if TYPE_CHECKING:
26
+ from ignite.engine import Engine, EventEnum
27
+ from ignite.metrics import Metric
28
+ else:
29
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
30
+ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
31
+ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
32
+ from monai.engines.trainer import SupervisedTrainer, Trainer
33
+
34
+
35
+ class VaeGanTrainer(Trainer):
36
+ """
37
+ Generative adversarial network training based on Goodfellow et al. 2014 https://arxiv.org/abs/1406.266,
38
+ inherits from ``Trainer`` and ``Workflow``.
39
+ Training Loop: for each batch of data size `m`
40
+ 1. Generate `m` fakes from random latent codes.
41
+ 2. Update discriminator with these fakes and current batch reals, repeated d_train_steps times.
42
+ 3. If g_update_latents, generate `m` fakes from new random latent codes.
43
+ 4. Update generator with these fakes using discriminator feedback.
44
+ Args:
45
+ device: an object representing the device on which to run.
46
+ max_epochs: the total epoch number for engine to run.
47
+ train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.
48
+ g_network: generator (G) network architecture.
49
+ g_optimizer: G optimizer function.
50
+ g_loss_function: G loss function for optimizer.
51
+ d_network: discriminator (D) network architecture.
52
+ d_optimizer: D optimizer function.
53
+ d_loss_function: D loss function for optimizer.
54
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
55
+ g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.
56
+ d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.
57
+ d_train_steps: number of times to update D with real data minibatch. Defaults to ``1``.
58
+ latent_shape: size of G input latent code. Defaults to ``64``.
59
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
60
+ with respect to the host. For other cases, this argument has no effect.
61
+ d_prepare_batch: callback function to prepare batchdata for D inferer.
62
+ Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to:
63
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
64
+ g_prepare_batch: callback function to create batch of latent input for G inferer.
65
+ Defaults to return random latents. for more details please refer to:
66
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
67
+ g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``.
68
+ iteration_update: the callable function for every iteration, expect to accept `engine`
69
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
70
+ if not provided, use `self._iteration()` instead. for more details please refer to:
71
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
72
+ postprocessing: execute additional transformation for the model output data.
73
+ Typically, several Tensor based transforms composed by `Compose`.
74
+ key_train_metric: compute metric when every iteration completed, and save average value to
75
+ engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
76
+ checkpoint into files.
77
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
78
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
79
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
80
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
81
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
82
+ CheckpointHandler, StatsHandler, etc.
83
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
84
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
85
+ default to `True`.
86
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
87
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
88
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
89
+ `device`, `non_blocking`.
90
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
91
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ device: str | torch.device,
97
+ max_epochs: int,
98
+ train_data_loader: DataLoader,
99
+ g_network: torch.nn.Module,
100
+ g_optimizer: Optimizer,
101
+ g_loss_function: Callable,
102
+ d_network: torch.nn.Module,
103
+ d_optimizer: Optimizer,
104
+ d_loss_function: Callable,
105
+ epoch_length: int | None = None,
106
+ g_inferer: Inferer | None = None,
107
+ d_inferer: Inferer | None = None,
108
+ d_train_steps: int = 1,
109
+ latent_shape: int = 64,
110
+ non_blocking: bool = False,
111
+ d_prepare_batch: Callable = default_prepare_batch,
112
+ g_prepare_batch: Callable = default_prepare_batch,
113
+ g_update_latents: bool = True,
114
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
115
+ postprocessing: Transform | None = None,
116
+ key_train_metric: dict[str, Metric] | None = None,
117
+ additional_metrics: dict[str, Metric] | None = None,
118
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
119
+ train_handlers: Sequence | None = None,
120
+ decollate: bool = True,
121
+ optim_set_to_none: bool = False,
122
+ to_kwargs: dict | None = None,
123
+ amp_kwargs: dict | None = None,
124
+ ):
125
+ if not isinstance(train_data_loader, DataLoader):
126
+ raise ValueError("train_data_loader must be PyTorch DataLoader.")
127
+
128
+ # set up Ignite engine and environments
129
+ super().__init__(
130
+ device=device,
131
+ max_epochs=max_epochs,
132
+ data_loader=train_data_loader,
133
+ epoch_length=epoch_length,
134
+ non_blocking=non_blocking,
135
+ prepare_batch=d_prepare_batch,
136
+ iteration_update=iteration_update,
137
+ key_metric=key_train_metric,
138
+ additional_metrics=additional_metrics,
139
+ metric_cmp_fn=metric_cmp_fn,
140
+ handlers=train_handlers,
141
+ postprocessing=postprocessing,
142
+ decollate=decollate,
143
+ to_kwargs=to_kwargs,
144
+ amp_kwargs=amp_kwargs,
145
+ )
146
+ self.g_network = g_network
147
+ self.g_optimizer = g_optimizer
148
+ self.g_loss_function = g_loss_function
149
+ self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer
150
+ self.d_network = d_network
151
+ self.d_optimizer = d_optimizer
152
+ self.d_loss_function = d_loss_function
153
+ self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer
154
+ self.d_train_steps = d_train_steps
155
+ self.latent_shape = latent_shape
156
+ self.g_prepare_batch = g_prepare_batch
157
+ self.g_update_latents = g_update_latents
158
+ self.optim_set_to_none = optim_set_to_none
159
+
160
+ def _iteration(
161
+ self, engine: VaeGanTrainer, batchdata: dict | Sequence
162
+ ) -> dict[str, torch.Tensor | int | float | bool]:
163
+ """
164
+ Callback function for Adversarial Training processing logic of 1 iteration in Ignite Engine.
165
+ Args:
166
+ engine: `VaeGanTrainer` to execute operation for an iteration.
167
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
168
+ Raises:
169
+ ValueError: must provide batch data for current iteration.
170
+ """
171
+ if batchdata is None:
172
+ raise ValueError("must provide batch data for current iteration.")
173
+
174
+ d_input = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)[0]
175
+ g_input = d_input
176
+ g_output, z_mu, z_sigma = engine.g_inferer(g_input, engine.g_network)
177
+
178
+ # Train Discriminator
179
+ d_total_loss = torch.zeros(1)
180
+ for _ in range(engine.d_train_steps):
181
+ engine.d_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
182
+ dloss = engine.d_loss_function(g_output, d_input)
183
+ dloss.backward()
184
+ engine.d_optimizer.step()
185
+ d_total_loss += dloss.item()
186
+
187
+ # Train Generator
188
+ engine.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
189
+ g_loss = engine.g_loss_function(g_output, g_input, z_mu, z_sigma)
190
+ g_loss.backward()
191
+ engine.g_optimizer.step()
192
+
193
+ return {
194
+ GanKeys.REALS: d_input,
195
+ GanKeys.FAKES: g_output,
196
+ GanKeys.LATENTS: g_input,
197
+ GanKeys.GLOSS: g_loss.item(),
198
+ GanKeys.DLOSS: d_total_loss.item(),
199
+ }
200
+
201
+
202
+ class LDMTrainer(SupervisedTrainer):
203
+ """
204
+ Standard supervised training method with image and label, inherits from ``Trainer`` and ``Workflow``.
205
+ Args:
206
+ device: an object representing the device on which to run.
207
+ max_epochs: the total epoch number for trainer to run.
208
+ train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
209
+ network: network to train in the trainer, should be regular PyTorch `torch.nn.Module`.
210
+ optimizer: the optimizer associated to the network, should be regular PyTorch optimizer from `torch.optim`
211
+ or its subclass.
212
+ loss_function: the loss function associated to the optimizer, should be regular PyTorch loss,
213
+ which inherit from `torch.nn.modules.loss`.
214
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
215
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
216
+ with respect to the host. For other cases, this argument has no effect.
217
+ prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
218
+ from `engine.state.batch` for every iteration, for more details please refer to:
219
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
220
+ iteration_update: the callable function for every iteration, expect to accept `engine`
221
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
222
+ if not provided, use `self._iteration()` instead. for more details please refer to:
223
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
224
+ inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
225
+ postprocessing: execute additional transformation for the model output data.
226
+ Typically, several Tensor based transforms composed by `Compose`.
227
+ key_train_metric: compute metric when every iteration completed, and save average value to
228
+ engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
229
+ checkpoint into files.
230
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
231
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
232
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
233
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
234
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
235
+ CheckpointHandler, StatsHandler, etc.
236
+ amp: whether to enable auto-mixed-precision training, default is False.
237
+ event_names: additional custom ignite events that will register to the engine.
238
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
239
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
240
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
241
+ #ignite.engine.engine.Engine.register_events.
242
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
243
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
244
+ default to `True`.
245
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
246
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
247
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
248
+ `device`, `non_blocking`.
249
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
250
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ device: str | torch.device,
256
+ max_epochs: int,
257
+ train_data_loader: Iterable | DataLoader,
258
+ network: torch.nn.Module,
259
+ autoencoder_model: torch.nn.Module,
260
+ optimizer: Optimizer,
261
+ loss_function: Callable,
262
+ latent_shape: Sequence,
263
+ inferer: Inferer,
264
+ epoch_length: int | None = None,
265
+ non_blocking: bool = False,
266
+ prepare_batch: Callable = default_prepare_batch,
267
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
268
+ postprocessing: Transform | None = None,
269
+ key_train_metric: dict[str, Metric] | None = None,
270
+ additional_metrics: dict[str, Metric] | None = None,
271
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
272
+ train_handlers: Sequence | None = None,
273
+ amp: bool = False,
274
+ event_names: list[str | EventEnum | type[EventEnum]] | None = None,
275
+ event_to_attr: dict | None = None,
276
+ decollate: bool = True,
277
+ optim_set_to_none: bool = False,
278
+ to_kwargs: dict | None = None,
279
+ amp_kwargs: dict | None = None,
280
+ ) -> None:
281
+ super().__init__(
282
+ device=device,
283
+ max_epochs=max_epochs,
284
+ train_data_loader=train_data_loader,
285
+ network=network,
286
+ optimizer=optimizer,
287
+ loss_function=loss_function,
288
+ inferer=inferer,
289
+ optim_set_to_none=optim_set_to_none,
290
+ epoch_length=epoch_length,
291
+ non_blocking=non_blocking,
292
+ prepare_batch=prepare_batch,
293
+ iteration_update=iteration_update,
294
+ postprocessing=postprocessing,
295
+ key_train_metric=key_train_metric,
296
+ additional_metrics=additional_metrics,
297
+ metric_cmp_fn=metric_cmp_fn,
298
+ train_handlers=train_handlers,
299
+ amp=amp,
300
+ event_names=event_names,
301
+ event_to_attr=event_to_attr,
302
+ decollate=decollate,
303
+ to_kwargs=to_kwargs,
304
+ amp_kwargs=amp_kwargs,
305
+ )
306
+
307
+ self.latent_shape = latent_shape
308
+ self.autoencoder_model = autoencoder_model
309
+
310
+ def _iteration(self, engine: LDMTrainer, batchdata: dict[str, torch.Tensor]) -> dict:
311
+ """
312
+ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
313
+ Return below items in a dictionary:
314
+ - IMAGE: image Tensor data for model input, already moved to device.
315
+ - LABEL: label Tensor data corresponding to the image, already moved to device.
316
+ - PRED: prediction result of model.
317
+ - LOSS: loss value computed by loss function.
318
+ Args:
319
+ engine: `SupervisedTrainer` to execute operation for an iteration.
320
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
321
+ Raises:
322
+ ValueError: When ``batchdata`` is None.
323
+ """
324
+ if batchdata is None:
325
+ raise ValueError("Must provide batch data for current iteration.")
326
+ batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
327
+ if len(batch) == 2:
328
+ images, labels = batch
329
+ args: tuple = ()
330
+ kwargs: dict = {}
331
+ else:
332
+ images, labels, args, kwargs = batch
333
+ # put iteration outputs into engine.state
334
+ engine.state.output = {CommonKeys.IMAGE: images}
335
+
336
+ # generate noise
337
+ noise_shape = [images.shape[0]] + list(self.latent_shape)
338
+ noise = torch.randn(noise_shape, dtype=images.dtype).to(images.device)
339
+ engine.state.output = {"noise": noise}
340
+
341
+ # Create timesteps
342
+ timesteps = torch.randint(
343
+ 0, engine.inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
344
+ ).long()
345
+
346
+ def _compute_pred_loss():
347
+ # predicted noise
348
+ engine.state.output[CommonKeys.PRED] = engine.inferer(
349
+ inputs=images,
350
+ autoencoder_model=self.autoencoder_model,
351
+ diffusion_model=engine.network,
352
+ noise=noise,
353
+ timesteps=timesteps,
354
+ )
355
+ engine.fire_event(IterationEvents.FORWARD_COMPLETED)
356
+ # compute loss
357
+ engine.state.output[CommonKeys.LOSS] = engine.loss_function(
358
+ engine.state.output[CommonKeys.PRED], noise
359
+ ).mean()
360
+ engine.fire_event(IterationEvents.LOSS_COMPLETED)
361
+
362
+ engine.network.train()
363
+ engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
364
+
365
+ if engine.amp and engine.scaler is not None:
366
+ with torch.cuda.amp.autocast(**engine.amp_kwargs):
367
+ _compute_pred_loss()
368
+ engine.scaler.scale(engine.state.output[CommonKeys.LOSS]).backward()
369
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
370
+ engine.scaler.step(engine.optimizer)
371
+ engine.scaler.update()
372
+ else:
373
+ _compute_pred_loss()
374
+ engine.state.output[CommonKeys.LOSS].backward()
375
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
376
+ engine.optimizer.step()
377
+ engine.fire_event(IterationEvents.MODEL_COMPLETED)
378
+
379
+ return engine.state.output
scripts/losses.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+
11
+ import torch
12
+ from monai.losses.adversarial_loss import PatchAdversarialLoss
13
+
14
+ intensity_loss = torch.nn.L1Loss()
15
+ adv_loss = PatchAdversarialLoss(criterion="least_squares")
16
+
17
+ adv_weight = 0.1
18
+ perceptual_weight = 0.1
19
+ # kl_weight: important hyper-parameter.
20
+ # If too large, decoder cannot recon good results from latent space.
21
+ # If too small, latent space will not be regularized enough for the diffusion model
22
+ kl_weight = 1e-7
23
+
24
+
25
+ def compute_kl_loss(z_mu, z_sigma):
26
+ kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])
27
+ return torch.sum(kl_loss) / kl_loss.shape[0]
28
+
29
+
30
+ def generator_loss(gen_images, real_images, z_mu, z_sigma, disc_net, loss_perceptual):
31
+ recons_loss = intensity_loss(gen_images, real_images)
32
+ kl_loss = compute_kl_loss(z_mu, z_sigma)
33
+ p_loss = loss_perceptual(gen_images.float(), real_images.float())
34
+ loss_g = recons_loss + kl_weight * kl_loss + perceptual_weight * p_loss
35
+
36
+ logits_fake = disc_net(gen_images)[-1]
37
+ generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
38
+ loss_g = loss_g + adv_weight * generator_loss
39
+
40
+ return loss_g
41
+
42
+
43
+ def discriminator_loss(gen_images, real_images, disc_net):
44
+ logits_fake = disc_net(gen_images.contiguous().detach())[-1]
45
+ loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
46
+ logits_real = disc_net(real_images.contiguous().detach())[-1]
47
+ loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
48
+ discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
49
+ loss_d = adv_weight * discriminator_loss
50
+ return loss_d
scripts/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+
11
+ import monai
12
+ import torch
13
+
14
+
15
+ def compute_scale_factor(autoencoder, train_loader, device):
16
+ with torch.no_grad():
17
+ check_data = monai.utils.first(train_loader)
18
+ z = autoencoder.encode_stage_2_inputs(check_data["image"].to(device))
19
+ scale_factor = 1 / torch.std(z)
20
+ return scale_factor.item()