project-monai commited on
Commit
7226a40
·
verified ·
1 Parent(s): 1414cb8

Upload lung_nodule_ct_detection version 0.6.9

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/model.ts filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
configs/evaluate.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@dataset_dir)",
3
+ "validate#dataset": {
4
+ "_target_": "Dataset",
5
+ "data": "$@test_datalist",
6
+ "transform": "@validate#preprocessing"
7
+ },
8
+ "validate#key_metric": {
9
+ "val_coco": {
10
+ "_target_": "scripts.cocometric_ignite.IgniteCocoMetric",
11
+ "coco_metric_monai": "$monai.apps.detection.metrics.coco.COCOMetric(classes=['nodule'], iou_list=[0.1], max_detection=[100])",
12
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])",
13
+ "box_key": "box",
14
+ "label_key": "label",
15
+ "pred_score_key": "label_scores",
16
+ "reduce_scalar": false
17
+ }
18
+ },
19
+ "validate#handlers": [
20
+ {
21
+ "_target_": "CheckpointLoader",
22
+ "load_path": "$@ckpt_dir + '/model.pt'",
23
+ "load_dict": {
24
+ "model": "@network"
25
+ }
26
+ },
27
+ {
28
+ "_target_": "StatsHandler",
29
+ "iteration_log": false
30
+ },
31
+ {
32
+ "_target_": "MetricsSaver",
33
+ "save_dir": "@output_dir",
34
+ "metrics": [
35
+ "val_coco"
36
+ ],
37
+ "metric_details": [
38
+ "val_coco"
39
+ ],
40
+ "batch_transform": "$lambda x: [xx['image'].meta for xx in x]",
41
+ "summary_ops": "*"
42
+ }
43
+ ],
44
+ "initialize": [
45
+ "$setattr(torch.backends.cudnn, 'benchmark', True)"
46
+ ],
47
+ "run": [
48
+ "$@validate#evaluator.run()"
49
+ ]
50
+ }
configs/inference.json ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "whether_raw_luna16": false,
3
+ "whether_resampled_luna16": "$(not @whether_raw_luna16)",
4
+ "imports": [
5
+ "$import glob",
6
+ "$import numpy",
7
+ "$import os"
8
+ ],
9
+ "bundle_root": ".",
10
+ "image_key": "image",
11
+ "ckpt_dir": "$@bundle_root + '/models'",
12
+ "output_dir": "$@bundle_root + '/eval'",
13
+ "output_filename": "result_luna16_fold0.json",
14
+ "data_list_file_path": "$@bundle_root + '/LUNA16_datasplit/dataset_fold0.json'",
15
+ "dataset_dir": "/datasets/LUNA16_Images_resample",
16
+ "test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@dataset_dir)",
17
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
18
+ "amp": true,
19
+ "load_pretrain": true,
20
+ "spatial_dims": 3,
21
+ "num_classes": 1,
22
+ "force_sliding_window": false,
23
+ "size_divisible": [
24
+ 16,
25
+ 16,
26
+ 8
27
+ ],
28
+ "infer_patch_size": [
29
+ 512,
30
+ 512,
31
+ 192
32
+ ],
33
+ "anchor_generator": {
34
+ "_target_": "monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape",
35
+ "feature_map_scales": [
36
+ 1,
37
+ 2,
38
+ 4
39
+ ],
40
+ "base_anchor_shapes": [
41
+ [
42
+ 6,
43
+ 8,
44
+ 4
45
+ ],
46
+ [
47
+ 8,
48
+ 6,
49
+ 5
50
+ ],
51
+ [
52
+ 10,
53
+ 10,
54
+ 6
55
+ ]
56
+ ]
57
+ },
58
+ "backbone": "$monai.networks.nets.resnet.resnet50(spatial_dims=3,n_input_channels=1,conv1_t_stride=[2,2,1],conv1_t_size=[7,7,7])",
59
+ "feature_extractor": "$monai.apps.detection.networks.retinanet_network.resnet_fpn_feature_extractor(@backbone,3,False,[1,2],None)",
60
+ "network_def": {
61
+ "_target_": "RetinaNet",
62
+ "spatial_dims": "@spatial_dims",
63
+ "num_classes": "@num_classes",
64
+ "num_anchors": 3,
65
+ "feature_extractor": "@feature_extractor",
66
+ "size_divisible": "@size_divisible",
67
+ "use_list_output": false
68
+ },
69
+ "network": "$@network_def.to(@device)",
70
+ "detector": {
71
+ "_target_": "RetinaNetDetector",
72
+ "network": "@network",
73
+ "anchor_generator": "@anchor_generator",
74
+ "debug": false,
75
+ "spatial_dims": "@spatial_dims",
76
+ "num_classes": "@num_classes",
77
+ "size_divisible": "@size_divisible"
78
+ },
79
+ "detector_ops": [
80
+ "[email protected]_target_keys(box_key='box', label_key='label')",
81
+ "[email protected]_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
82
+ "[email protected]_sliding_window_inferer(roi_size=@infer_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
83
+ ],
84
+ "preprocessing": {
85
+ "_target_": "Compose",
86
+ "transforms": [
87
+ {
88
+ "_target_": "LoadImaged",
89
+ "keys": "@image_key",
90
+ "_disabled_": "@whether_raw_luna16"
91
+ },
92
+ {
93
+ "_target_": "LoadImaged",
94
+ "keys": "@image_key",
95
+ "reader": "itkreader",
96
+ "affine_lps_to_ras": true,
97
+ "_disabled_": "@whether_resampled_luna16"
98
+ },
99
+ {
100
+ "_target_": "EnsureChannelFirstd",
101
+ "keys": "@image_key"
102
+ },
103
+ {
104
+ "_target_": "Orientationd",
105
+ "keys": "@image_key",
106
+ "axcodes": "RAS"
107
+ },
108
+ {
109
+ "_target_": "Spacingd",
110
+ "keys": "@image_key",
111
+ "pixdim": [
112
+ 0.703125,
113
+ 0.703125,
114
+ 1.25
115
+ ],
116
+ "_disabled_": "@whether_resampled_luna16"
117
+ },
118
+ {
119
+ "_target_": "ScaleIntensityRanged",
120
+ "keys": "@image_key",
121
+ "a_min": -1024.0,
122
+ "a_max": 300.0,
123
+ "b_min": 0.0,
124
+ "b_max": 1.0,
125
+ "clip": true
126
+ },
127
+ {
128
+ "_target_": "EnsureTyped",
129
+ "keys": "@image_key"
130
+ }
131
+ ]
132
+ },
133
+ "dataset": {
134
+ "_target_": "Dataset",
135
+ "data": "$@test_datalist",
136
+ "transform": "@preprocessing"
137
+ },
138
+ "dataloader": {
139
+ "_target_": "DataLoader",
140
+ "dataset": "@dataset",
141
+ "batch_size": 1,
142
+ "shuffle": false,
143
+ "num_workers": 4,
144
+ "collate_fn": "$monai.data.utils.no_collation"
145
+ },
146
+ "inferer": {
147
+ "_target_": "scripts.detection_inferer.RetinaNetInferer",
148
+ "detector": "@detector",
149
+ "force_sliding_window": "@force_sliding_window"
150
+ },
151
+ "postprocessing": {
152
+ "_target_": "Compose",
153
+ "transforms": [
154
+ {
155
+ "_target_": "ClipBoxToImaged",
156
+ "box_keys": "box",
157
+ "label_keys": "label",
158
+ "box_ref_image_keys": "@image_key",
159
+ "remove_empty": true
160
+ },
161
+ {
162
+ "_target_": "AffineBoxToWorldCoordinated",
163
+ "box_keys": "box",
164
+ "box_ref_image_keys": "@image_key",
165
+ "affine_lps_to_ras": true
166
+ },
167
+ {
168
+ "_target_": "ConvertBoxModed",
169
+ "box_keys": "box",
170
+ "src_mode": "xyzxyz",
171
+ "dst_mode": "cccwhd"
172
+ },
173
+ {
174
+ "_target_": "DeleteItemsd",
175
+ "keys": [
176
+ "@image_key"
177
+ ]
178
+ }
179
+ ]
180
+ },
181
+ "handlers": [
182
+ {
183
+ "_target_": "StatsHandler",
184
+ "iteration_log": false
185
+ },
186
+ {
187
+ "_target_": "scripts.detection_saver.DetectionSaver",
188
+ "output_dir": "@output_dir",
189
+ "filename": "@output_filename",
190
+ "batch_transform": "$lambda x: [xx['image'].meta for xx in x]",
191
+ "output_transform": "$lambda x: [@postprocessing({**xx['pred'],'image':xx['image']}) for xx in x]",
192
+ "pred_box_key": "box",
193
+ "pred_label_key": "label",
194
+ "pred_score_key": "label_scores"
195
+ }
196
+ ],
197
+ "evaluator": {
198
+ "_target_": "scripts.evaluator.DetectionEvaluator",
199
+ "_requires_": "@detector_ops",
200
+ "device": "@device",
201
+ "val_data_loader": "@dataloader",
202
+ "network": "@network",
203
+ "inferer": "@inferer",
204
+ "val_handlers": "@handlers",
205
+ "amp": "@amp"
206
+ },
207
+ "checkpointloader": {
208
+ "_target_": "CheckpointLoader",
209
+ "load_path": "$@bundle_root + '/models/model.pt'",
210
+ "load_dict": {
211
+ "model": "@network"
212
+ }
213
+ },
214
+ "initialize": [
215
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
216
+ "$@checkpointloader(@evaluator) if @load_pretrain else None"
217
+ ],
218
+ "run": [
219
220
+ ]
221
+ }
configs/inference_trt.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os",
5
+ "$import torch_tensorrt"
6
+ ],
7
+ "force_sliding_window": true,
8
+ "network_def": "$torch.jit.load(@bundle_root + '/models/model_trt.ts')",
9
+ "evaluator#amp": false,
10
+ "initialize": [
11
+ "$setattr(torch.backends.cudnn, 'benchmark', True)"
12
+ ]
13
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
3
+ "version": "0.6.9",
4
+ "changelog": {
5
+ "0.6.9": "update to huggingface hosting and fix missing dependencies",
6
+ "0.6.8": "update issue for IgniteInfo",
7
+ "0.6.7": "use monai 1.4 and update large files",
8
+ "0.6.6": "update to use monai 1.3.1",
9
+ "0.6.5": "remove notes for trt_export in readme",
10
+ "0.6.4": "add notes for trt_export in readme",
11
+ "0.6.3": "add load_pretrain flag for infer",
12
+ "0.6.2": "add checkpoint loader for infer",
13
+ "0.6.1": "fix format error",
14
+ "0.6.0": "remove meta_dict usage",
15
+ "0.5.9": "use monai 1.2.0",
16
+ "0.5.8": "update TRT memory requirement in readme",
17
+ "0.5.7": "add dataset dir example",
18
+ "0.5.6": "add the ONNX-TensorRT way of model conversion",
19
+ "0.5.5": "update retrained validation results and training curve",
20
+ "0.5.4": "add non-deterministic note",
21
+ "0.5.3": "adapt to BundleWorkflow interface",
22
+ "0.5.2": "black autofix format and add name tag",
23
+ "0.5.1": "modify dataset key name",
24
+ "0.5.0": "use detection inferer",
25
+ "0.4.5": "fixed some small changes with formatting in readme",
26
+ "0.4.4": "add data resource to readme",
27
+ "0.4.3": "update val patch size to avoid warning in monai 1.0.1",
28
+ "0.4.2": "update to use monai 1.0.1",
29
+ "0.4.1": "fix license Copyright error",
30
+ "0.4.0": "add support for raw images",
31
+ "0.3.0": "update license files",
32
+ "0.2.0": "unify naming",
33
+ "0.1.1": "add reference for LIDC dataset",
34
+ "0.1.0": "complete the model package"
35
+ },
36
+ "monai_version": "1.4.0",
37
+ "pytorch_version": "2.4.0",
38
+ "numpy_version": "1.24.4",
39
+ "required_packages_version": {
40
+ "nibabel": "5.2.1",
41
+ "pytorch-ignite": "0.4.11",
42
+ "torchvision": "0.19.0",
43
+ "tensorboard": "2.17.0"
44
+ },
45
+ "supported_apps": {},
46
+ "name": "Lung nodule CT detection",
47
+ "task": "CT lung nodule detection",
48
+ "description": "A pre-trained model for volumetric (3D) detection of the lung lesion from CT image on LUNA16 dataset",
49
+ "authors": "MONAI team",
50
+ "copyright": "Copyright (c) MONAI Consortium",
51
+ "data_source": "https://luna16.grand-challenge.org/Home/",
52
+ "data_type": "nibabel",
53
+ "image_classes": "1 channel data, CT at 0.703125 x 0.703125 x 1.25 mm",
54
+ "label_classes": "dict data, containing Nx6 box and Nx1 classification labels.",
55
+ "pred_classes": "dict data, containing Nx6 box, Nx1 classification labels, Nx1 classification scores.",
56
+ "eval_metrics": {
57
+ "mAP_IoU_0.10_0.50_0.05_MaxDet_100": 0.852,
58
+ "AP_IoU_0.10_MaxDet_100": 0.858,
59
+ "mAR_IoU_0.10_0.50_0.05_MaxDet_100": 0.998,
60
+ "AR_IoU_0.10_MaxDet_100": 1.0
61
+ },
62
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
63
+ "references": [
64
+ "Lin, Tsung-Yi, et al. 'Focal loss for dense object detection. ICCV 2017"
65
+ ],
66
+ "network_data_format": {
67
+ "inputs": {
68
+ "image": {
69
+ "type": "image",
70
+ "format": "magnitude",
71
+ "modality": "CT",
72
+ "num_channels": 1,
73
+ "spatial_shape": [
74
+ "16*n",
75
+ "16*n",
76
+ "8*n"
77
+ ],
78
+ "dtype": "float16",
79
+ "value_range": [
80
+ 0,
81
+ 1
82
+ ],
83
+ "is_patch_data": true,
84
+ "channel_def": {
85
+ "0": "image"
86
+ }
87
+ }
88
+ },
89
+ "outputs": {
90
+ "pred": {
91
+ "type": "object",
92
+ "format": "dict",
93
+ "dtype": "float16",
94
+ "num_channels": 1,
95
+ "spatial_shape": [
96
+ "n",
97
+ "n",
98
+ "n"
99
+ ],
100
+ "value_range": [
101
+ -10000,
102
+ 10000
103
+ ]
104
+ }
105
+ }
106
+ }
107
+ }
configs/train.json ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os"
5
+ ],
6
+ "bundle_root": ".",
7
+ "ckpt_dir": "$@bundle_root + '/models'",
8
+ "output_dir": "$@bundle_root + '/eval'",
9
+ "data_list_file_path": "$@bundle_root + '/LUNA16_datasplit/dataset_fold0.json'",
10
+ "dataset_dir": "/datasets/LUNA16_Images_resample",
11
+ "train_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='training', base_dir=@dataset_dir)",
12
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
13
+ "epochs": 300,
14
+ "val_interval": 5,
15
+ "learning_rate": 0.01,
16
+ "amp": true,
17
+ "batch_size": 4,
18
+ "patch_size": [
19
+ 192,
20
+ 192,
21
+ 80
22
+ ],
23
+ "val_patch_size": [
24
+ 512,
25
+ 512,
26
+ 192
27
+ ],
28
+ "anchor_generator": {
29
+ "_target_": "monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape",
30
+ "feature_map_scales": [
31
+ 1,
32
+ 2,
33
+ 4
34
+ ],
35
+ "base_anchor_shapes": [
36
+ [
37
+ 6,
38
+ 8,
39
+ 4
40
+ ],
41
+ [
42
+ 8,
43
+ 6,
44
+ 5
45
+ ],
46
+ [
47
+ 10,
48
+ 10,
49
+ 6
50
+ ]
51
+ ]
52
+ },
53
+ "backbone": "$monai.networks.nets.resnet.resnet50(spatial_dims=3,n_input_channels=1,conv1_t_stride=[2,2,1],conv1_t_size=[7,7,7])",
54
+ "feature_extractor": "$monai.apps.detection.networks.retinanet_network.resnet_fpn_feature_extractor(@backbone,3,False,[1,2],None)",
55
+ "network_def": {
56
+ "_target_": "RetinaNet",
57
+ "spatial_dims": 3,
58
+ "num_classes": 1,
59
+ "num_anchors": 3,
60
+ "feature_extractor": "@feature_extractor",
61
+ "size_divisible": [
62
+ 16,
63
+ 16,
64
+ 8
65
+ ]
66
+ },
67
+ "network": "$@network_def.to(@device)",
68
+ "detector": {
69
+ "_target_": "RetinaNetDetector",
70
+ "network": "@network",
71
+ "anchor_generator": "@anchor_generator",
72
+ "debug": false
73
+ },
74
+ "detector_ops": [
75
+ "[email protected]_atss_matcher(num_candidates=4, center_in_gt=False)",
76
+ "[email protected]_hard_negative_sampler(batch_size_per_image=64,positive_fraction=0.3,pool_size=20,min_neg=16)",
77
+ "[email protected]_target_keys(box_key='box', label_key='label')",
78
+ "[email protected]_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
79
+ "[email protected]_sliding_window_inferer(roi_size=@val_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
80
+ ],
81
+ "optimizer": {
82
+ "_target_": "torch.optim.SGD",
83
+ "params": "[email protected]()",
84
+ "lr": "@learning_rate",
85
+ "momentum": 0.9,
86
+ "weight_decay": 3e-05,
87
+ "nesterov": true
88
+ },
89
+ "after_scheduler": {
90
+ "_target_": "torch.optim.lr_scheduler.StepLR",
91
+ "optimizer": "@optimizer",
92
+ "step_size": 160,
93
+ "gamma": 0.1
94
+ },
95
+ "lr_scheduler": {
96
+ "_target_": "scripts.warmup_scheduler.GradualWarmupScheduler",
97
+ "optimizer": "@optimizer",
98
+ "multiplier": 1,
99
+ "total_epoch": 10,
100
+ "after_scheduler": "@after_scheduler"
101
+ },
102
+ "train": {
103
+ "preprocessing_transforms": [
104
+ {
105
+ "_target_": "LoadImaged",
106
+ "keys": "image"
107
+ },
108
+ {
109
+ "_target_": "EnsureChannelFirstd",
110
+ "keys": "image"
111
+ },
112
+ {
113
+ "_target_": "EnsureTyped",
114
+ "keys": [
115
+ "image",
116
+ "box"
117
+ ]
118
+ },
119
+ {
120
+ "_target_": "EnsureTyped",
121
+ "keys": "label",
122
+ "dtype": "$torch.long"
123
+ },
124
+ {
125
+ "_target_": "Orientationd",
126
+ "keys": "image",
127
+ "axcodes": "RAS"
128
+ },
129
+ {
130
+ "_target_": "ScaleIntensityRanged",
131
+ "keys": "image",
132
+ "a_min": -1024.0,
133
+ "a_max": 300.0,
134
+ "b_min": 0.0,
135
+ "b_max": 1.0,
136
+ "clip": true
137
+ },
138
+ {
139
+ "_target_": "ConvertBoxToStandardModed",
140
+ "box_keys": "box",
141
+ "mode": "cccwhd"
142
+ },
143
+ {
144
+ "_target_": "AffineBoxToImageCoordinated",
145
+ "box_keys": "box",
146
+ "box_ref_image_keys": "image",
147
+ "affine_lps_to_ras": true
148
+ }
149
+ ],
150
+ "random_transforms": [
151
+ {
152
+ "_target_": "RandCropBoxByPosNegLabeld",
153
+ "image_keys": "image",
154
+ "box_keys": "box",
155
+ "label_keys": "label",
156
+ "spatial_size": "@patch_size",
157
+ "whole_box": true,
158
+ "num_samples": "@batch_size",
159
+ "pos": 1,
160
+ "neg": 1
161
+ },
162
+ {
163
+ "_target_": "RandZoomBoxd",
164
+ "image_keys": "image",
165
+ "box_keys": "box",
166
+ "label_keys": "label",
167
+ "box_ref_image_keys": "image",
168
+ "prob": 0.2,
169
+ "min_zoom": 0.7,
170
+ "max_zoom": 1.4,
171
+ "padding_mode": "constant",
172
+ "keep_size": true
173
+ },
174
+ {
175
+ "_target_": "ClipBoxToImaged",
176
+ "box_keys": "box",
177
+ "label_keys": "label",
178
+ "box_ref_image_keys": "image",
179
+ "remove_empty": true
180
+ },
181
+ {
182
+ "_target_": "RandFlipBoxd",
183
+ "image_keys": "image",
184
+ "box_keys": "box",
185
+ "box_ref_image_keys": "image",
186
+ "prob": 0.5,
187
+ "spatial_axis": 0
188
+ },
189
+ {
190
+ "_target_": "RandFlipBoxd",
191
+ "image_keys": "image",
192
+ "box_keys": "box",
193
+ "box_ref_image_keys": "image",
194
+ "prob": 0.5,
195
+ "spatial_axis": 1
196
+ },
197
+ {
198
+ "_target_": "RandFlipBoxd",
199
+ "image_keys": "image",
200
+ "box_keys": "box",
201
+ "box_ref_image_keys": "image",
202
+ "prob": 0.5,
203
+ "spatial_axis": 2
204
+ },
205
+ {
206
+ "_target_": "RandRotateBox90d",
207
+ "image_keys": "image",
208
+ "box_keys": "box",
209
+ "box_ref_image_keys": "image",
210
+ "prob": 0.75,
211
+ "max_k": 3,
212
+ "spatial_axes": [
213
+ 0,
214
+ 1
215
+ ]
216
+ },
217
+ {
218
+ "_target_": "BoxToMaskd",
219
+ "box_keys": "box",
220
+ "label_keys": "label",
221
+ "box_mask_keys": "box_mask",
222
+ "box_ref_image_keys": "image",
223
+ "min_fg_label": 0,
224
+ "ellipse_mask": true
225
+ },
226
+ {
227
+ "_target_": "RandRotated",
228
+ "keys": [
229
+ "image",
230
+ "box_mask"
231
+ ],
232
+ "mode": [
233
+ "nearest",
234
+ "nearest"
235
+ ],
236
+ "prob": 0.2,
237
+ "range_x": 0.5236,
238
+ "range_y": 0.5236,
239
+ "range_z": 0.5236,
240
+ "keep_size": true,
241
+ "padding_mode": "zeros"
242
+ },
243
+ {
244
+ "_target_": "MaskToBoxd",
245
+ "box_keys": [
246
+ "box"
247
+ ],
248
+ "label_keys": [
249
+ "label"
250
+ ],
251
+ "box_mask_keys": [
252
+ "box_mask"
253
+ ],
254
+ "min_fg_label": 0
255
+ },
256
+ {
257
+ "_target_": "DeleteItemsd",
258
+ "keys": "box_mask"
259
+ },
260
+ {
261
+ "_target_": "RandGaussianNoised",
262
+ "keys": "image",
263
+ "prob": 0.1,
264
+ "mean": 0.0,
265
+ "std": 0.1
266
+ },
267
+ {
268
+ "_target_": "RandGaussianSmoothd",
269
+ "keys": "image",
270
+ "prob": 0.1,
271
+ "sigma_x": [
272
+ 0.5,
273
+ 1.0
274
+ ],
275
+ "sigma_y": [
276
+ 0.5,
277
+ 1.0
278
+ ],
279
+ "sigma_z": [
280
+ 0.5,
281
+ 1.0
282
+ ]
283
+ },
284
+ {
285
+ "_target_": "RandScaleIntensityd",
286
+ "keys": "image",
287
+ "factors": 0.25,
288
+ "prob": 0.15
289
+ },
290
+ {
291
+ "_target_": "RandShiftIntensityd",
292
+ "keys": "image",
293
+ "offsets": 0.1,
294
+ "prob": 0.15
295
+ },
296
+ {
297
+ "_target_": "RandAdjustContrastd",
298
+ "keys": "image",
299
+ "prob": 0.3,
300
+ "gamma": [
301
+ 0.7,
302
+ 1.5
303
+ ]
304
+ }
305
+ ],
306
+ "final_transforms": [
307
+ {
308
+ "_target_": "EnsureTyped",
309
+ "keys": [
310
+ "image",
311
+ "box"
312
+ ]
313
+ },
314
+ {
315
+ "_target_": "EnsureTyped",
316
+ "keys": "label",
317
+ "dtype": "$torch.long"
318
+ },
319
+ {
320
+ "_target_": "ToTensord",
321
+ "keys": [
322
+ "image",
323
+ "box",
324
+ "label"
325
+ ]
326
+ }
327
+ ],
328
+ "preprocessing": {
329
+ "_target_": "Compose",
330
+ "transforms": "$@train#preprocessing_transforms + @train#random_transforms + @train#final_transforms"
331
+ },
332
+ "dataset": {
333
+ "_target_": "Dataset",
334
+ "data": "$@train_datalist[: int(0.95 * len(@train_datalist))]",
335
+ "transform": "@train#preprocessing"
336
+ },
337
+ "dataloader": {
338
+ "_target_": "DataLoader",
339
+ "dataset": "@train#dataset",
340
+ "batch_size": 1,
341
+ "shuffle": true,
342
+ "num_workers": 4,
343
+ "collate_fn": "$monai.data.utils.no_collation"
344
+ },
345
+ "handlers": [
346
+ {
347
+ "_target_": "LrScheduleHandler",
348
+ "lr_scheduler": "@lr_scheduler",
349
+ "print_lr": true
350
+ },
351
+ {
352
+ "_target_": "ValidationHandler",
353
+ "validator": "@validate#evaluator",
354
+ "epoch_level": true,
355
+ "interval": "@val_interval"
356
+ },
357
+ {
358
+ "_target_": "StatsHandler",
359
+ "tag_name": "train_loss",
360
+ "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)[0]"
361
+ },
362
+ {
363
+ "_target_": "TensorBoardStatsHandler",
364
+ "log_dir": "@output_dir",
365
+ "tag_name": "train_loss",
366
+ "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)[0]"
367
+ }
368
+ ],
369
+ "trainer": {
370
+ "_target_": "scripts.trainer.DetectionTrainer",
371
+ "_requires_": "@detector_ops",
372
+ "max_epochs": "@epochs",
373
+ "device": "@device",
374
+ "train_data_loader": "@train#dataloader",
375
+ "detector": "@detector",
376
+ "optimizer": "@optimizer",
377
+ "train_handlers": "@train#handlers",
378
+ "amp": "@amp"
379
+ }
380
+ },
381
+ "validate": {
382
+ "preprocessing": {
383
+ "_target_": "Compose",
384
+ "transforms": "$@train#preprocessing_transforms + @train#final_transforms"
385
+ },
386
+ "dataset": {
387
+ "_target_": "Dataset",
388
+ "data": "$@train_datalist[int(0.95 * len(@train_datalist)): ]",
389
+ "transform": "@validate#preprocessing"
390
+ },
391
+ "dataloader": {
392
+ "_target_": "DataLoader",
393
+ "dataset": "@validate#dataset",
394
+ "batch_size": 1,
395
+ "shuffle": false,
396
+ "num_workers": 2,
397
+ "collate_fn": "$monai.data.utils.no_collation"
398
+ },
399
+ "inferer": {
400
+ "_target_": "scripts.detection_inferer.RetinaNetInferer",
401
+ "detector": "@detector"
402
+ },
403
+ "handlers": [
404
+ {
405
+ "_target_": "StatsHandler",
406
+ "iteration_log": false
407
+ },
408
+ {
409
+ "_target_": "TensorBoardStatsHandler",
410
+ "log_dir": "@output_dir",
411
+ "iteration_log": false
412
+ },
413
+ {
414
+ "_target_": "CheckpointSaver",
415
+ "save_dir": "@ckpt_dir",
416
+ "save_dict": {
417
+ "model": "@network"
418
+ },
419
+ "save_key_metric": true,
420
+ "key_metric_filename": "model.pt"
421
+ }
422
+ ],
423
+ "key_metric": {
424
+ "val_coco": {
425
+ "_target_": "scripts.cocometric_ignite.IgniteCocoMetric",
426
+ "coco_metric_monai": "$monai.apps.detection.metrics.coco.COCOMetric(classes=['nodule'], iou_list=[0.1], max_detection=[100])",
427
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])",
428
+ "box_key": "box",
429
+ "label_key": "label",
430
+ "pred_score_key": "label_scores",
431
+ "reduce_scalar": true
432
+ }
433
+ },
434
+ "evaluator": {
435
+ "_target_": "scripts.evaluator.DetectionEvaluator",
436
+ "_requires_": "@detector_ops",
437
+ "device": "@device",
438
+ "val_data_loader": "@validate#dataloader",
439
+ "network": "@network",
440
+ "inferer": "@validate#inferer",
441
+ "key_val_metric": "@validate#key_metric",
442
+ "val_handlers": "@validate#handlers",
443
+ "amp": "@amp"
444
+ }
445
+ },
446
+ "initialize": [
447
+ "$monai.utils.set_determinism(seed=0)"
448
+ ],
449
+ "run": [
450
+ "$@train#trainer.run()"
451
+ ]
452
+ }
docs/README.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Overview
2
+ A pre-trained model for volumetric (3D) detection of the lung nodule from CT image.
3
+
4
+ This model is trained on LUNA16 dataset (https://luna16.grand-challenge.org/Home/), using the RetinaNet (Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002).
5
+
6
+ ![model workflow](https://developer.download.nvidia.com/assets/Clara/Images/monai_retinanet_detection_workflow.png)
7
+
8
+ ## Data
9
+ The dataset we are experimenting in this example is LUNA16 (https://luna16.grand-challenge.org/Home/), which is based on [LIDC-IDRI database](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) [3,4,5].
10
+
11
+ LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
12
+
13
+ Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset! We acknowledge the National Cancer Institute and the Foundation for the National Institutes of Health, and their critical role in the creation of the free publicly available LIDC/IDRI Database used in this study.
14
+
15
+ ### 10-fold data splitting
16
+ We follow the official 10-fold data splitting from LUNA16 challenge and generate data split json files using the script from [nnDetection](https://github.com/MIC-DKFZ/nnDetection/blob/main/projects/Task016_Luna/scripts/prepare.py).
17
+
18
+ Please download the resulted json files from https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/LUNA16_datasplit-20220615T233840Z-001.zip.
19
+
20
+ In these files, the values of "box" are the ground truth boxes in world coordinate.
21
+
22
+ ### Data resampling
23
+ The raw CT images in LUNA16 have various of voxel sizes. The first step is to resample them to the same voxel size.
24
+ In this model, we resampled them into 0.703125 x 0.703125 x 1.25 mm.
25
+
26
+ Please following the instruction in Section 3.1 of https://github.com/Project-MONAI/tutorials/tree/main/detection to do the resampling.
27
+
28
+ ### Data download
29
+ The mhd/raw original data can be downloaded from [LUNA16](https://luna16.grand-challenge.org/Home/). The DICOM original data can be downloaded from [LIDC-IDRI database](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) [3,4,5]. You will need to resample the original data to start training.
30
+
31
+ Alternatively, we provide [resampled nifti images](https://drive.google.com/drive/folders/1JozrufA1VIZWJIc5A1EMV3J4CNCYovKK?usp=share_link) and a copy of [original mhd/raw images](https://drive.google.com/drive/folders/1-enN4eNEnKmjltevKg3W2V-Aj0nriQWE?usp=share_link) from [LUNA16](https://luna16.grand-challenge.org/Home/) for users to download.
32
+
33
+ ## Training configuration
34
+ The training was performed with the following:
35
+
36
+ - GPU: at least 16GB GPU memory, requires 32G when exporting TRT model
37
+ - Actual Model Input: 192 x 192 x 80
38
+ - AMP: True
39
+ - Optimizer: Adam
40
+ - Learning Rate: 1e-2
41
+ - Loss: BCE loss and L1 loss
42
+
43
+ ### Input
44
+ 1 channel
45
+ - List of 3D CT patches
46
+
47
+ ### Output
48
+ In Training Mode: A dictionary of classification and box regression loss.
49
+
50
+ In Evaluation Mode: A list of dictionaries of predicted box, classification label, and classification score.
51
+
52
+ ## Performance
53
+ Coco metric is used for evaluating the performance of the model. The pre-trained model was trained and validated on data fold 0. This model achieves a mAP=0.852, mAR=0.998, AP(IoU=0.1)=0.858, AR(IoU=0.1)=1.0.
54
+
55
+ Please note that this bundle is non-deterministic because of the max pooling layer used in the network. Therefore, reproducing the training process may not get exactly the same performance.
56
+ Please refer to https://pytorch.org/docs/stable/notes/randomness.html#reproducibility for more details about reproducibility.
57
+
58
+ #### Training Loss
59
+ ![A graph showing the detection train loss](https://developer.download.nvidia.com/assets/Clara/Images/monai_retinanet_detection_train_loss_v2.png)
60
+
61
+ #### Validation Accuracy
62
+ The validation accuracy in this curve is the mean of mAP, mAR, AP(IoU=0.1), and AR(IoU=0.1) in Coco metric.
63
+
64
+ ![A graph showing the detection val accuracy](https://developer.download.nvidia.com/assets/Clara/Images/monai_retinanet_detection_val_acc_v2.png)
65
+
66
+ #### TensorRT speedup
67
+ The `lung_nodule_ct_detection` bundle supports acceleration with TensorRT through the ONNX-TensorRT method. The table below displays the speedup ratios observed on an A100 80G GPU. Please note that when using the TensorRT model for inference, the `force_sliding_window` parameter in the `inference.json` file must be set to `true`. This ensures that the bundle uses the `SlidingWindowInferer` during inference and maintains the input spatial size of the network. Otherwise, if given an input with spatial size less than the `infer_patch_size`, the input spatial size of the network would be changed.
68
+
69
+ | method | torch_fp32(ms) | torch_amp(ms) | trt_fp32(ms) | trt_fp16(ms) | speedup amp | speedup fp32 | speedup fp16 | amp vs fp16|
70
+ | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
71
+ | model computation | 7449.84 | 996.08 | 976.67 | 626.90 | 7.63 | 7.63 | 11.88 | 1.56 |
72
+ | end2end | 36458.26 | 7259.35 | 6420.60 | 4698.34 | 5.02 | 5.68 | 7.76 | 1.55 |
73
+
74
+ Where:
75
+ - `model computation` means the speedup ratio of model's inference with a random input without preprocessing and postprocessing
76
+ - `end2end` means run the bundle end-to-end with the TensorRT based model.
77
+ - `torch_fp32` and `torch_amp` are for the PyTorch models with or without `amp` mode.
78
+ - `trt_fp32` and `trt_fp16` are for the TensorRT based models converted in corresponding precision.
79
+ - `speedup amp`, `speedup fp32` and `speedup fp16` are the speedup ratios of corresponding models versus the PyTorch float32 model
80
+ - `amp vs fp16` is the speedup ratio between the PyTorch amp model and the TensorRT float16 based model.
81
+
82
+ Currently, the only available method to accelerate this model is through ONNX-TensorRT. However, the Torch-TensorRT method is under development and will be available in the near future.
83
+
84
+ This result is benchmarked under:
85
+ - TensorRT: 8.5.3+cuda11.8
86
+ - Torch-TensorRT Version: 1.4.0
87
+ - CPU Architecture: x86-64
88
+ - OS: ubuntu 20.04
89
+ - Python version:3.8.10
90
+ - CUDA version: 12.0
91
+ - GPU models and configuration: A100 80G
92
+
93
+ ## MONAI Bundle Commands
94
+ In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. The CLI supports flexible use cases, such as overriding configs at runtime and predefining arguments in a file.
95
+
96
+ For more details usage instructions, visit the [MONAI Bundle Configuration Page](https://docs.monai.io/en/latest/config_syntax.html).
97
+
98
+ #### Execute training:
99
+
100
+ ```
101
+ python -m monai.bundle run --config_file configs/train.json
102
+ ```
103
+
104
+ Please note that if the default dataset path is not modified with the actual path in the bundle config files, you can also override it by using `--dataset_dir`:
105
+
106
+ ```
107
+ python -m monai.bundle run --config_file configs/train.json --dataset_dir <actual dataset path>
108
+ ```
109
+
110
+ #### Override the `train` config to execute evaluation with the trained model:
111
+
112
+ ```
113
+ python -m monai.bundle run --config_file "['configs/train.json','configs/evaluate.json']"
114
+ ```
115
+
116
+ #### Execute inference on resampled LUNA16 images by setting `"whether_raw_luna16": false` in `inference.json`:
117
+
118
+ ```
119
+ python -m monai.bundle run --config_file configs/inference.json
120
+ ```
121
+ With the same command, we can execute inference on original LUNA16 images by setting `"whether_raw_luna16": true` in `inference.json`. Remember to also set `"data_list_file_path": "$@bundle_root + '/LUNA16_datasplit/mhd_original/dataset_fold0.json'"` and change `"dataset_dir"`.
122
+
123
+ Note that in inference.json, the transform "LoadImaged" in "preprocessing" and "AffineBoxToWorldCoordinated" in "postprocessing" has `"affine_lps_to_ras": true`.
124
+ This depends on the input images. LUNA16 needs `"affine_lps_to_ras": true`.
125
+ It is possible that your inference dataset should set `"affine_lps_to_ras": false`.
126
+
127
+ #### Export checkpoint to TensorRT based models with fp32 or fp16 precision
128
+
129
+ ```bash
130
+ python -m monai.bundle trt_export --net_id network_def --filepath models/model_trt.ts --ckpt_file models/model.pt --meta_file configs/metadata.json --config_file configs/inference.json --precision <fp32/fp16> --input_shape "[1, 1, 512, 512, 192]" --use_onnx "True" --use_trace "True" --onnx_output_names "['output_0', 'output_1', 'output_2', 'output_3', 'output_4', 'output_5']" --network_def#use_list_output "True"
131
+ ```
132
+
133
+ #### Execute inference with the TensorRT model
134
+
135
+ ```
136
+ python -m monai.bundle run --config_file "['configs/inference.json', 'configs/inference_trt.json']"
137
+ ```
138
+
139
+ # References
140
+ [1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002)
141
+
142
+ [2] Baumgartner and Jaeger et al. "nnDetection: A self-configuring method for medical object detection." MICCAI 2021. https://arxiv.org/pdf/2106.00817.pdf
143
+
144
+ [3] Armato III, S. G., McLennan, G., Bidaut, L., McNitt-Gray, M. F., Meyer, C. R., Reeves, A. P., Zhao, B., Aberle, D. R., Henschke, C. I., Hoffman, E. A., Kazerooni, E. A., MacMahon, H., Van Beek, E. J. R., Yankelevitz, D., Biancardi, A. M., Bland, P. H., Brown, M. S., Engelmann, R. M., Laderach, G. E., Max, D., Pais, R. C. , Qing, D. P. Y. , Roberts, R. Y., Smith, A. R., Starkey, A., Batra, P., Caligiuri, P., Farooqi, A., Gladish, G. W., Jude, C. M., Munden, R. F., Petkovska, I., Quint, L. E., Schwartz, L. H., Sundaram, B., Dodd, L. E., Fenimore, C., Gur, D., Petrick, N., Freymann, J., Kirby, J., Hughes, B., Casteele, A. V., Gupte, S., Sallam, M., Heath, M. D., Kuhn, M. H., Dharaiya, E., Burns, R., Fryd, D. S., Salganicoff, M., Anand, V., Shreter, U., Vastagh, S., Croft, B. Y., Clarke, L. P. (2015). Data From LIDC-IDRI [Data set]. The Cancer Imaging Archive. https://doi.org/10.7937/K9/TCIA.2015.LO9QL9SX
145
+
146
+ [4] Armato SG 3rd, McLennan G, Bidaut L, McNitt-Gray MF, Meyer CR, Reeves AP, Zhao B, Aberle DR, Henschke CI, Hoffman EA, Kazerooni EA, MacMahon H, Van Beeke EJ, Yankelevitz D, Biancardi AM, Bland PH, Brown MS, Engelmann RM, Laderach GE, Max D, Pais RC, Qing DP, Roberts RY, Smith AR, Starkey A, Batrah P, Caligiuri P, Farooqi A, Gladish GW, Jude CM, Munden RF, Petkovska I, Quint LE, Schwartz LH, Sundaram B, Dodd LE, Fenimore C, Gur D, Petrick N, Freymann J, Kirby J, Hughes B, Casteele AV, Gupte S, Sallamm M, Heath MD, Kuhn MH, Dharaiya E, Burns R, Fryd DS, Salganicoff M, Anand V, Shreter U, Vastagh S, Croft BY. The Lung Image Database Consortium (LIDC) and Image Database Resource Initiative (IDRI): A completed reference database of lung nodules on CT scans. Medical Physics, 38: 915--931, 2011. DOI: https://doi.org/10.1118/1.3528204
147
+
148
+ [5] Clark, K., Vendt, B., Smith, K., Freymann, J., Kirby, J., Koppel, P., Moore, S., Phillips, S., Maffitt, D., Pringle, M., Tarbox, L., & Prior, F. (2013). The Cancer Imaging Archive (TCIA): Maintaining and Operating a Public Information Repository. Journal of Digital Imaging, 26(6), 1045–1057. https://doi.org/10.1007/s10278-013-9622-7
149
+
150
+ # License
151
+ Copyright (c) MONAI Consortium
152
+
153
+ Licensed under the Apache License, Version 2.0 (the "License");
154
+ you may not use this file except in compliance with the License.
155
+ You may obtain a copy of the License at
156
+
157
+ http://www.apache.org/licenses/LICENSE-2.0
158
+
159
+ Unless required by applicable law or agreed to in writing, software
160
+ distributed under the License is distributed on an "AS IS" BASIS,
161
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
162
+ See the License for the specific language governing permissions and
163
+ limitations under the License.
docs/data_license.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third Party Licenses
2
+ -----------------------------------------------------------------------
3
+
4
+ /*********************************************************************/
5
+ i. LUng Nodule Analysis 2016
6
+ https://luna16.grand-challenge.org/Home/
7
+ https://creativecommons.org/licenses/by/4.0/
8
+
9
+ ii. Lung Image Database Consortium image collection (LIDC-IDRI)
10
+ https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI
11
+ https://creativecommons.org/licenses/by/3.0/
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5e79231466adae93a6fe8e8594029e9add142914e223b879aa0343bb2402d01
3
+ size 83709381
models/model.ts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68afd1ed4be8d01196d575d13931dab24cc50d46a74528a47d54496ba29e2583
3
+ size 83784539
scripts/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
13
+ # from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer
14
+ from .trainer import DetectionTrainer
scripts/cocometric_ignite.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, Sequence, Union
2
+
3
+ import torch
4
+ from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
5
+ from monai.apps.detection.metrics.coco import COCOMetric
6
+ from monai.apps.detection.metrics.matching import matching_batch
7
+ from monai.data import box_utils
8
+
9
+ from .utils import detach_to_numpy
10
+
11
+
12
+ class IgniteCocoMetric(Metric):
13
+ def __init__(
14
+ self,
15
+ coco_metric_monai: Union[None, COCOMetric] = None,
16
+ box_key="box",
17
+ label_key="label",
18
+ pred_score_key="label_scores",
19
+ output_transform: Callable = lambda x: x,
20
+ device: Union[str, torch.device, None] = None,
21
+ reduce_scalar: bool = True,
22
+ ):
23
+ r"""
24
+ Computes coco detection metric in Ignite.
25
+
26
+ Args:
27
+ coco_metric_monai: the coco metric in monai.
28
+ If not given, will asume COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
29
+ box_key: box key in the ground truth target dict and prediction dict.
30
+ label_key: classification label key in the ground truth target dict and prediction dict.
31
+ pred_score_key: classification score key in the prediction dict.
32
+ output_transform: A callable that is used to transform the Engine’s
33
+ process_function’s output into the form expected by the metric.
34
+ device: specifies which device updates are accumulated on.
35
+ Setting the metric’s device to be the same as your update arguments ensures
36
+ the update method is non-blocking. By default, CPU.
37
+ reduce_scalar: if True, will return the average value of coc metric values;
38
+ if False, will return an dictionary of coc metric.
39
+
40
+ Examples:
41
+ To use with ``Engine`` and ``process_function``,
42
+ simply attach the metric instance to the engine.
43
+ The output of the engine's ``process_function`` needs to be in format of
44
+ ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.
45
+ For more information on how metric works with :class:`~ignite.engine.engine.Engine`,
46
+ visit :ref:`attach-engine`.
47
+ .. include:: defaults.rst
48
+ :start-after: :orphan:
49
+ .. testcode::
50
+ coco = IgniteCocoMetric()
51
+ coco.attach(default_evaluator, 'coco')
52
+ preds = [
53
+ {
54
+ 'box': torch.Tensor([[1,1,1,2,2,2]]),
55
+ 'label':torch.Tensor([0]),
56
+ 'label_scores':torch.Tensor([0.8])
57
+ }
58
+ ]
59
+ target = [{'box': torch.Tensor([[1,1,1,2,2,2]]), 'label':torch.Tensor([0])}]
60
+ state = default_evaluator.run([[preds, target]])
61
+ print(state.metrics['coco'])
62
+ .. testoutput::
63
+ 1.0...
64
+ .. versionadded:: 0.4.3
65
+ """
66
+ self.box_key = box_key
67
+ self.label_key = label_key
68
+ self.pred_score_key = pred_score_key
69
+ if coco_metric_monai is None:
70
+ self.coco_metric = COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
71
+ else:
72
+ self.coco_metric = coco_metric_monai
73
+ self.reduce_scalar = reduce_scalar
74
+
75
+ if device is None:
76
+ device = torch.device("cpu")
77
+ super(IgniteCocoMetric, self).__init__(output_transform=output_transform, device=device)
78
+
79
+ @reinit__is_reduced
80
+ def reset(self) -> None:
81
+ self.val_targets_all = []
82
+ self.val_outputs_all = []
83
+
84
+ @reinit__is_reduced
85
+ def update(self, output: Sequence[Dict]) -> None:
86
+ y_pred, y = output[0], output[1]
87
+ self.val_outputs_all += y_pred
88
+ self.val_targets_all += y
89
+
90
+ @sync_all_reduce("val_targets_all", "val_outputs_all")
91
+ def compute(self) -> float:
92
+ self.val_outputs_all = detach_to_numpy(self.val_outputs_all)
93
+ self.val_targets_all = detach_to_numpy(self.val_targets_all)
94
+
95
+ results_metric = matching_batch(
96
+ iou_fn=box_utils.box_iou,
97
+ iou_thresholds=self.coco_metric.iou_thresholds,
98
+ pred_boxes=[val_data_i[self.box_key] for val_data_i in self.val_outputs_all],
99
+ pred_classes=[val_data_i[self.label_key] for val_data_i in self.val_outputs_all],
100
+ pred_scores=[val_data_i[self.pred_score_key] for val_data_i in self.val_outputs_all],
101
+ gt_boxes=[val_data_i[self.box_key] for val_data_i in self.val_targets_all],
102
+ gt_classes=[val_data_i[self.label_key] for val_data_i in self.val_targets_all],
103
+ )
104
+ val_epoch_metric_dict = self.coco_metric(results_metric)[0]
105
+
106
+ if self.reduce_scalar:
107
+ val_epoch_metric = val_epoch_metric_dict.values()
108
+ val_epoch_metric = sum(val_epoch_metric) / len(val_epoch_metric)
109
+ return val_epoch_metric
110
+ else:
111
+ return val_epoch_metric_dict
scripts/detection_inferer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from typing import Any, List, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+ from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
17
+ from monai.inferers.inferer import Inferer
18
+ from torch import Tensor
19
+
20
+
21
+ class RetinaNetInferer(Inferer):
22
+ """
23
+ RetinaNet Inferer takes RetinaNet as input
24
+
25
+ Args:
26
+ detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP
27
+ map into boxes and classification scores.
28
+ force_sliding_window: whether to force using a SlidingWindowInferer to do the inference.
29
+ If False, will check the input spatial size to decide whether to simply
30
+ forward the network or using SlidingWindowInferer.
31
+ If True, will force using SlidingWindowInferer to do the inference.
32
+ args: other optional args to be passed to detector.
33
+ kwargs: other optional keyword args to be passed to detector.
34
+ """
35
+
36
+ def __init__(self, detector: RetinaNetDetector, force_sliding_window: bool = False) -> None:
37
+ Inferer.__init__(self)
38
+ self.detector = detector
39
+ self.sliding_window_size = None
40
+ self.force_sliding_window = force_sliding_window
41
+ if self.detector.inferer is not None:
42
+ if hasattr(self.detector.inferer, "roi_size"):
43
+ self.sliding_window_size = np.prod(self.detector.inferer.roi_size)
44
+
45
+ def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any):
46
+ """Unified callable function API of Inferers.
47
+ Args:
48
+ inputs: model input data for inference.
49
+ network: target detection network to execute inference.
50
+ supports callable that fullfilles requirements of network in
51
+ monai.apps.detection.networks.retinanet_detector.RetinaNetDetector``
52
+ args: optional args to be passed to ``network``.
53
+ kwargs: optional keyword args to be passed to ``network``.
54
+ """
55
+ self.detector.network = network
56
+ self.detector.training = self.detector.network.training
57
+
58
+ # if image smaller than sliding window roi size, no need to use sliding window inferer
59
+ # use sliding window inferer only when image is large
60
+ use_inferer = (
61
+ self.force_sliding_window
62
+ or self.sliding_window_size is not None
63
+ and not all([data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs])
64
+ )
65
+
66
+ return self.detector(inputs, *args, use_inferer=use_inferer, **kwargs)
scripts/detection_saver.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import json
13
+ import os
14
+ import warnings
15
+ from typing import TYPE_CHECKING, Callable, Optional
16
+
17
+ from monai.handlers.classification_saver import ClassificationSaver
18
+ from monai.utils import IgniteInfo, evenly_divisible_all_gather, min_version, optional_import, string_list_all_gather
19
+
20
+ from .utils import detach_to_numpy
21
+
22
+ idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
23
+ Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
24
+ if TYPE_CHECKING:
25
+ from ignite.engine import Engine
26
+ else:
27
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
28
+
29
+
30
+ class DetectionSaver(ClassificationSaver):
31
+ """
32
+ Event handler triggered on completing every iteration to save the classification predictions as json file.
33
+ If running in distributed data parallel, only saves json file in the specified rank.
34
+
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ output_dir: str = "./",
40
+ filename: str = "predictions.json",
41
+ overwrite: bool = True,
42
+ batch_transform: Callable = lambda x: x,
43
+ output_transform: Callable = lambda x: x,
44
+ name: Optional[str] = None,
45
+ save_rank: int = 0,
46
+ pred_box_key: str = "box",
47
+ pred_label_key: str = "label",
48
+ pred_score_key: str = "label_scores",
49
+ ) -> None:
50
+ """
51
+ Args:
52
+ output_dir: if `saver=None`, output json file directory.
53
+ filename: if `saver=None`, name of the saved json file name.
54
+ overwrite: if `saver=None`, whether to overwriting existing file content, if True,
55
+ will clear the file before saving. otherwise, will append new content to the file.
56
+ batch_transform: a callable that is used to extract the `meta_data` dictionary of
57
+ the input images from `ignite.engine.state.batch`. the purpose is to get the input
58
+ filenames from the `meta_data` and store with classification results together.
59
+ `engine.state` and `batch_transform` inherit from the ignite concept:
60
+ https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
61
+ https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
62
+ output_transform: a callable that is used to extract the model prediction data from
63
+ `ignite.engine.state.output`. the first dimension of its output will be treated as
64
+ the batch dimension. each item in the batch will be saved individually.
65
+ `engine.state` and `output_transform` inherit from the ignite concept:
66
+ https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
67
+ https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
68
+ name: identifier of logging.logger to use, defaulting to `engine.logger`.
69
+ save_rank: only the handler on specified rank will save to json file in multi-gpus validation,
70
+ default to 0.
71
+ pred_box_key: box key in the prediction dict.
72
+ pred_label_key: classification label key in the prediction dict.
73
+ pred_score_key: classification score key in the prediction dict.
74
+
75
+ """
76
+ super().__init__(
77
+ output_dir=output_dir,
78
+ filename=filename,
79
+ overwrite=overwrite,
80
+ batch_transform=batch_transform,
81
+ output_transform=output_transform,
82
+ name=name,
83
+ save_rank=save_rank,
84
+ saver=None,
85
+ )
86
+ self.pred_box_key = pred_box_key
87
+ self.pred_label_key = pred_label_key
88
+ self.pred_score_key = pred_score_key
89
+
90
+ def _finalize(self, _engine: Engine) -> None:
91
+ """
92
+ All gather classification results from ranks and save to json file.
93
+
94
+ Args:
95
+ _engine: Ignite Engine, unused argument.
96
+ """
97
+ ws = idist.get_world_size()
98
+ if self.save_rank >= ws:
99
+ raise ValueError("target save rank is greater than the distributed group size.")
100
+
101
+ # self._outputs is supposed to be a list of dict
102
+ # self._outputs[i] should be have at least three keys: pred_box_key, pred_label_key, pred_score_key
103
+ # self._filenames is supposed to be a list of str
104
+ outputs = self._outputs
105
+ filenames = self._filenames
106
+ if ws > 1:
107
+ outputs = evenly_divisible_all_gather(outputs, concat=False)
108
+ filenames = string_list_all_gather(filenames)
109
+
110
+ if len(filenames) != len(outputs):
111
+ warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.")
112
+
113
+ # save to json file only in the expected rank
114
+ if idist.get_rank() == self.save_rank:
115
+ results = [
116
+ {
117
+ self.pred_box_key: detach_to_numpy(o[self.pred_box_key]).tolist(),
118
+ self.pred_label_key: detach_to_numpy(o[self.pred_label_key]).tolist(),
119
+ self.pred_score_key: detach_to_numpy(o[self.pred_score_key]).tolist(),
120
+ "image": f,
121
+ }
122
+ for o, f in zip(outputs, filenames)
123
+ ]
124
+
125
+ with open(os.path.join(self.output_dir, self.filename), "w") as outfile:
126
+ json.dump(results, outfile, indent=4)
scripts/evaluator.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
15
+
16
+ import torch
17
+ from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
18
+ from monai.engines.evaluator import SupervisedEvaluator
19
+ from monai.engines.utils import IterationEvents, default_metric_cmp_fn
20
+ from monai.transforms import Transform
21
+ from monai.utils import ForwardMode, IgniteInfo, min_version, optional_import
22
+ from monai.utils.enums import CommonKeys as Keys
23
+ from torch.utils.data import DataLoader
24
+
25
+ from .detection_inferer import RetinaNetInferer
26
+
27
+ if TYPE_CHECKING:
28
+ from ignite.engine import Engine, EventEnum
29
+ from ignite.metrics import Metric
30
+ else:
31
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
32
+ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
33
+ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
34
+
35
+ __all__ = ["DetectionEvaluator"]
36
+
37
+
38
+ def detection_prepare_val_batch(
39
+ batchdata: List[Dict[str, torch.Tensor]],
40
+ device: Optional[Union[str, torch.device]] = None,
41
+ non_blocking: bool = False,
42
+ **kwargs,
43
+ ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
44
+ """
45
+ Default function to prepare the data for current iteration.
46
+ Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
47
+ https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
48
+ `kwargs` supports other args for `Tensor.to()` API.
49
+ Returns:
50
+ image, label(optional).
51
+ """
52
+ inputs = [
53
+ batch_data_i["image"].to(device=device, non_blocking=non_blocking, **kwargs) for batch_data_i in batchdata
54
+ ]
55
+
56
+ if isinstance(batchdata[0].get(Keys.LABEL), torch.Tensor):
57
+ targets = [
58
+ dict(
59
+ label=batch_data_i["label"].to(device=device, non_blocking=non_blocking, **kwargs),
60
+ box=batch_data_i["box"].to(device=device, non_blocking=non_blocking, **kwargs),
61
+ )
62
+ for batch_data_i in batchdata
63
+ ]
64
+ return (inputs, targets)
65
+ return inputs, None
66
+
67
+
68
+ class DetectionEvaluator(SupervisedEvaluator):
69
+ """
70
+ Supervised detection evaluation method with image and label, inherits from ``SupervisedEvaluator`` and ``Workflow``.
71
+ Args:
72
+ device: an object representing the device on which to run.
73
+ val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
74
+ network: detector to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`.
75
+ epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
76
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
77
+ with respect to the host. For other cases, this argument has no effect.
78
+ prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
79
+ from `engine.state.batch` for every iteration, for more details please refer to:
80
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
81
+ iteration_update: the callable function for every iteration, expect to accept `engine`
82
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
83
+ if not provided, use `self._iteration()` instead. for more details please refer to:
84
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
85
+ inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
86
+ postprocessing: execute additional transformation for the model output data.
87
+ Typically, several Tensor based transforms composed by `Compose`.
88
+ key_val_metric: compute metric when every iteration completed, and save average value to
89
+ engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
90
+ checkpoint into files.
91
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
92
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
93
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
94
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
95
+ val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
96
+ CheckpointHandler, StatsHandler, etc.
97
+ amp: whether to enable auto-mixed-precision evaluation, default is False.
98
+ mode: model forward mode during evaluation, should be 'eval' or 'train',
99
+ which maps to `model.eval()` or `model.train()`, default to 'eval'.
100
+ event_names: additional custom ignite events that will register to the engine.
101
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
102
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
103
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
104
+ #ignite.engine.engine.Engine.register_events.
105
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
106
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
107
+ default to `True`.
108
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
109
+ `device`, `non_blocking`.
110
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
111
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ device: torch.device,
117
+ val_data_loader: Iterable | DataLoader,
118
+ network: RetinaNetDetector,
119
+ epoch_length: int | None = None,
120
+ non_blocking: bool = False,
121
+ prepare_batch: Callable = detection_prepare_val_batch,
122
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
123
+ inferer: RetinaNetInferer | None = None,
124
+ postprocessing: Transform | None = None,
125
+ key_val_metric: dict[str, Metric] | None = None,
126
+ additional_metrics: dict[str, Metric] | None = None,
127
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
128
+ val_handlers: Sequence | None = None,
129
+ amp: bool = False,
130
+ mode: ForwardMode | str = ForwardMode.EVAL,
131
+ event_names: list[str | EventEnum] | None = None,
132
+ event_to_attr: dict | None = None,
133
+ decollate: bool = True,
134
+ to_kwargs: dict | None = None,
135
+ amp_kwargs: dict | None = None,
136
+ ) -> None:
137
+ super().__init__(
138
+ device=device,
139
+ val_data_loader=val_data_loader,
140
+ network=network,
141
+ epoch_length=epoch_length,
142
+ non_blocking=non_blocking,
143
+ prepare_batch=prepare_batch,
144
+ iteration_update=iteration_update,
145
+ inferer=inferer,
146
+ postprocessing=postprocessing,
147
+ key_val_metric=key_val_metric,
148
+ additional_metrics=additional_metrics,
149
+ metric_cmp_fn=metric_cmp_fn,
150
+ val_handlers=val_handlers,
151
+ amp=amp,
152
+ mode=mode,
153
+ event_names=event_names,
154
+ event_to_attr=event_to_attr,
155
+ decollate=decollate,
156
+ to_kwargs=to_kwargs,
157
+ amp_kwargs=amp_kwargs,
158
+ )
159
+
160
+ def _register_decollate(self):
161
+ """
162
+ Register the decollate operation for batch data, will execute after model forward and loss forward.
163
+ """
164
+
165
+ @self.on(IterationEvents.MODEL_COMPLETED)
166
+ def _decollate_data(engine: Engine) -> None:
167
+ output_list = []
168
+ for i in range(len(engine.state.output[Keys.IMAGE])):
169
+ output_list.append({})
170
+ for k in engine.state.output.keys():
171
+ if engine.state.output[k] is not None:
172
+ output_list[i][k] = engine.state.output[k][i]
173
+ engine.state.output = output_list
scripts/trainer.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
15
+
16
+ import torch
17
+ from monai.engines.trainer import Trainer
18
+ from monai.engines.utils import IterationEvents, default_metric_cmp_fn
19
+ from monai.inferers import Inferer
20
+ from monai.transforms import Transform
21
+ from monai.utils import IgniteInfo, min_version, optional_import
22
+ from monai.utils.enums import CommonKeys as Keys
23
+ from torch.optim.optimizer import Optimizer
24
+ from torch.utils.data import DataLoader
25
+
26
+ if TYPE_CHECKING:
27
+ from ignite.engine import Engine, EventEnum
28
+ from ignite.metrics import Metric
29
+ else:
30
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
31
+ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
32
+ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
33
+
34
+ __all__ = ["DetectionTrainer"]
35
+
36
+
37
+ def detection_prepare_batch(
38
+ batchdata: List[Dict[str, torch.Tensor]],
39
+ device: Optional[Union[str, torch.device]] = None,
40
+ non_blocking: bool = False,
41
+ **kwargs,
42
+ ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
43
+ """
44
+ Default function to prepare the data for current iteration.
45
+ Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
46
+ https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
47
+ `kwargs` supports other args for `Tensor.to()` API.
48
+ Returns:
49
+ image, label(optional).
50
+ """
51
+ inputs = [
52
+ batch_data_ii["image"].to(device=device, non_blocking=non_blocking, **kwargs)
53
+ for batch_data_i in batchdata
54
+ for batch_data_ii in batch_data_i
55
+ ]
56
+
57
+ if isinstance(batchdata[0][0].get(Keys.LABEL), torch.Tensor):
58
+ targets = [
59
+ dict(
60
+ label=batch_data_ii["label"].to(device=device, non_blocking=non_blocking, **kwargs),
61
+ box=batch_data_ii["box"].to(device=device, non_blocking=non_blocking, **kwargs),
62
+ )
63
+ for batch_data_i in batchdata
64
+ for batch_data_ii in batch_data_i
65
+ ]
66
+ return (inputs, targets)
67
+ return inputs, None
68
+
69
+
70
+ class DetectionTrainer(Trainer):
71
+ """
72
+ Supervised detection training method with image and label, inherits from ``Trainer`` and ``Workflow``.
73
+ Args:
74
+ device: an object representing the device on which to run.
75
+ max_epochs: the total epoch number for trainer to run.
76
+ train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
77
+ detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`.
78
+ optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim`
79
+ or its subclass.
80
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
81
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
82
+ with respect to the host. For other cases, this argument has no effect.
83
+ prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
84
+ from `engine.state.batch` for every iteration, for more details please refer to:
85
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
86
+ iteration_update: the callable function for every iteration, expect to accept `engine`
87
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
88
+ if not provided, use `self._iteration()` instead. for more details please refer to:
89
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
90
+ inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
91
+ postprocessing: execute additional transformation for the model output data.
92
+ Typically, several Tensor based transforms composed by `Compose`.
93
+ key_train_metric: compute metric when every iteration completed, and save average value to
94
+ engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
95
+ checkpoint into files.
96
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
97
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
98
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
99
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
100
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
101
+ CheckpointHandler, StatsHandler, etc.
102
+ amp: whether to enable auto-mixed-precision training, default is False.
103
+ event_names: additional custom ignite events that will register to the engine.
104
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
105
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
106
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
107
+ #ignite.engine.engine.Engine.register_events.
108
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
109
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
110
+ default to `True`.
111
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
112
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
113
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
114
+ `device`, `non_blocking`.
115
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
116
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ device: torch.device,
122
+ max_epochs: int,
123
+ train_data_loader: Iterable | DataLoader,
124
+ detector: torch.nn.Module,
125
+ optimizer: Optimizer,
126
+ epoch_length: int | None = None,
127
+ non_blocking: bool = False,
128
+ prepare_batch: Callable = detection_prepare_batch,
129
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
130
+ inferer: Inferer | None = None,
131
+ postprocessing: Transform | None = None,
132
+ key_train_metric: dict[str, Metric] | None = None,
133
+ additional_metrics: dict[str, Metric] | None = None,
134
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
135
+ train_handlers: Sequence | None = None,
136
+ amp: bool = False,
137
+ event_names: list[str | EventEnum] | None = None,
138
+ event_to_attr: dict | None = None,
139
+ decollate: bool = True,
140
+ optim_set_to_none: bool = False,
141
+ to_kwargs: dict | None = None,
142
+ amp_kwargs: dict | None = None,
143
+ ) -> None:
144
+ super().__init__(
145
+ device=device,
146
+ max_epochs=max_epochs,
147
+ data_loader=train_data_loader,
148
+ epoch_length=epoch_length,
149
+ non_blocking=non_blocking,
150
+ prepare_batch=prepare_batch,
151
+ iteration_update=iteration_update,
152
+ postprocessing=postprocessing,
153
+ key_metric=key_train_metric,
154
+ additional_metrics=additional_metrics,
155
+ metric_cmp_fn=metric_cmp_fn,
156
+ handlers=train_handlers,
157
+ amp=amp,
158
+ event_names=event_names,
159
+ event_to_attr=event_to_attr,
160
+ decollate=decollate,
161
+ to_kwargs=to_kwargs,
162
+ amp_kwargs=amp_kwargs,
163
+ )
164
+
165
+ self.detector = detector
166
+ self.optimizer = optimizer
167
+ self.optim_set_to_none = optim_set_to_none
168
+
169
+ def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
170
+ """
171
+ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
172
+ Return below items in a dictionary:
173
+ - IMAGE: image Tensor data for model input, already moved to device.
174
+ - BOX: box regression loss corresponding to the image, already moved to device.
175
+ - LABEL: classification loss corresponding to the image, already moved to device.
176
+ - LOSS: weighted sum of loss values computed by loss function.
177
+ Args:
178
+ engine: `DetectionTrainer` to execute operation for an iteration.
179
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
180
+ Raises:
181
+ ValueError: When ``batchdata`` is None.
182
+ """
183
+
184
+ if batchdata is None:
185
+ raise ValueError("Must provide batch data for current iteration.")
186
+
187
+ batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
188
+ if len(batch) == 2:
189
+ inputs, targets = batch
190
+ args: tuple = ()
191
+ kwargs: dict = {}
192
+ else:
193
+ inputs, targets, args, kwargs = batch
194
+ # put iteration outputs into engine.state
195
+ engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
196
+
197
+ def _compute_pred_loss(w_cls: float = 1.0, w_box_reg: float = 1.0):
198
+ """
199
+ Args:
200
+ w_cls: weight of classification loss
201
+ w_box_reg: weight of box regression loss
202
+ """
203
+ outputs = engine.detector(inputs, targets)
204
+ engine.state.output[engine.detector.cls_key] = outputs[engine.detector.cls_key]
205
+ engine.state.output[engine.detector.box_reg_key] = outputs[engine.detector.box_reg_key]
206
+ engine.state.output[Keys.LOSS] = (
207
+ w_cls * outputs[engine.detector.cls_key] + w_box_reg * outputs[engine.detector.box_reg_key]
208
+ )
209
+ engine.fire_event(IterationEvents.LOSS_COMPLETED)
210
+
211
+ engine.detector.train()
212
+ engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
213
+
214
+ if engine.amp and engine.scaler is not None:
215
+ with torch.cuda.amp.autocast(**engine.amp_kwargs):
216
+ inputs = [img.to(torch.float16) for img in inputs]
217
+ _compute_pred_loss()
218
+ engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
219
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
220
+ engine.scaler.step(engine.optimizer)
221
+ engine.scaler.update()
222
+ else:
223
+ _compute_pred_loss()
224
+ engine.state.output[Keys.LOSS].backward()
225
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
226
+ engine.optimizer.step()
227
+
228
+ return engine.state.output
scripts/utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def detach_to_numpy(data: Union[List, Dict, torch.Tensor]) -> Union[List, Dict, torch.Tensor]:
8
+ """
9
+ Recursively detach elements in data
10
+ """
11
+ if isinstance(data, torch.Tensor):
12
+ return data.cpu().detach().numpy() # pytype: disable=attribute-error
13
+
14
+ elif isinstance(data, np.ndarray):
15
+ return data
16
+
17
+ elif isinstance(data, list):
18
+ return [detach_to_numpy(d) for d in data]
19
+
20
+ elif isinstance(data, dict):
21
+ for k in data.keys():
22
+ data[k] = detach_to_numpy(data[k])
23
+ return data
24
+
25
+ else:
26
+ raise ValueError("data should be tensor, numpy array, dict, or list.")
scripts/warmup_scheduler.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ """
13
+ This script is adapted from
14
+ https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py
15
+ """
16
+
17
+ from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
18
+
19
+
20
+ class GradualWarmupScheduler(_LRScheduler):
21
+ """Gradually warm-up(increasing) learning rate in optimizer.
22
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
23
+
24
+ Args:
25
+ optimizer (Optimizer): Wrapped optimizer.
26
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0.
27
+ if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
28
+ total_epoch: target learning rate is reached at total_epoch, gradually
29
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
30
+ """
31
+
32
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
33
+ self.multiplier = multiplier
34
+ if self.multiplier < 1.0:
35
+ raise ValueError("multiplier should be greater thant or equal to 1.")
36
+ self.total_epoch = total_epoch
37
+ self.after_scheduler = after_scheduler
38
+ self.finished = False
39
+ super(GradualWarmupScheduler, self).__init__(optimizer)
40
+
41
+ def get_lr(self):
42
+ self.last_epoch = max(1, self.last_epoch) # to avoid epoch=0 thus lr=0
43
+ if self.last_epoch > self.total_epoch:
44
+ if self.after_scheduler:
45
+ if not self.finished:
46
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
47
+ self.finished = True
48
+ return self.after_scheduler.get_last_lr()
49
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
50
+
51
+ if self.multiplier == 1.0:
52
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
53
+ else:
54
+ return [
55
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
56
+ for base_lr in self.base_lrs
57
+ ]
58
+
59
+ def step_reduce_lr_on_plateau(self, metrics, epoch=None):
60
+ if epoch is None:
61
+ epoch = self.last_epoch + 1
62
+ self.last_epoch = (
63
+ epoch if epoch != 0 else 1
64
+ ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
65
+ if self.last_epoch <= self.total_epoch:
66
+ warmup_lr = [
67
+ base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
68
+ for base_lr in self.base_lrs
69
+ ]
70
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
71
+ param_group["lr"] = lr
72
+ else:
73
+ if epoch is None:
74
+ self.after_scheduler.step(metrics, None)
75
+ else:
76
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
77
+
78
+ def step(self, epoch=None, metrics=None):
79
+ if not isinstance(self.after_scheduler, ReduceLROnPlateau):
80
+ if self.finished and self.after_scheduler:
81
+ if epoch is None:
82
+ self.after_scheduler.step(None)
83
+ else:
84
+ self.after_scheduler.step(epoch - self.total_epoch)
85
+ self._last_lr = self.after_scheduler.get_last_lr()
86
+ else:
87
+ return super(GradualWarmupScheduler, self).step(epoch)
88
+ else:
89
+ self.step_reduce_lr_on_plateau(metrics, epoch)