Upload lung_nodule_ct_detection version 0.6.9
Browse files- .gitattributes +1 -0
- LICENSE +201 -0
- configs/evaluate.json +50 -0
- configs/inference.json +221 -0
- configs/inference_trt.json +13 -0
- configs/logging.conf +21 -0
- configs/metadata.json +107 -0
- configs/train.json +452 -0
- docs/README.md +163 -0
- docs/data_license.txt +11 -0
- models/model.pt +3 -0
- models/model.ts +3 -0
- scripts/__init__.py +14 -0
- scripts/cocometric_ignite.py +111 -0
- scripts/detection_inferer.py +66 -0
- scripts/detection_saver.py +126 -0
- scripts/evaluator.py +173 -0
- scripts/trainer.py +228 -0
- scripts/utils.py +26 -0
- scripts/warmup_scheduler.py +89 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
models/model.ts filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
configs/evaluate.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@dataset_dir)",
|
3 |
+
"validate#dataset": {
|
4 |
+
"_target_": "Dataset",
|
5 |
+
"data": "$@test_datalist",
|
6 |
+
"transform": "@validate#preprocessing"
|
7 |
+
},
|
8 |
+
"validate#key_metric": {
|
9 |
+
"val_coco": {
|
10 |
+
"_target_": "scripts.cocometric_ignite.IgniteCocoMetric",
|
11 |
+
"coco_metric_monai": "$monai.apps.detection.metrics.coco.COCOMetric(classes=['nodule'], iou_list=[0.1], max_detection=[100])",
|
12 |
+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])",
|
13 |
+
"box_key": "box",
|
14 |
+
"label_key": "label",
|
15 |
+
"pred_score_key": "label_scores",
|
16 |
+
"reduce_scalar": false
|
17 |
+
}
|
18 |
+
},
|
19 |
+
"validate#handlers": [
|
20 |
+
{
|
21 |
+
"_target_": "CheckpointLoader",
|
22 |
+
"load_path": "$@ckpt_dir + '/model.pt'",
|
23 |
+
"load_dict": {
|
24 |
+
"model": "@network"
|
25 |
+
}
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"_target_": "StatsHandler",
|
29 |
+
"iteration_log": false
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"_target_": "MetricsSaver",
|
33 |
+
"save_dir": "@output_dir",
|
34 |
+
"metrics": [
|
35 |
+
"val_coco"
|
36 |
+
],
|
37 |
+
"metric_details": [
|
38 |
+
"val_coco"
|
39 |
+
],
|
40 |
+
"batch_transform": "$lambda x: [xx['image'].meta for xx in x]",
|
41 |
+
"summary_ops": "*"
|
42 |
+
}
|
43 |
+
],
|
44 |
+
"initialize": [
|
45 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)"
|
46 |
+
],
|
47 |
+
"run": [
|
48 |
+
"$@validate#evaluator.run()"
|
49 |
+
]
|
50 |
+
}
|
configs/inference.json
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"whether_raw_luna16": false,
|
3 |
+
"whether_resampled_luna16": "$(not @whether_raw_luna16)",
|
4 |
+
"imports": [
|
5 |
+
"$import glob",
|
6 |
+
"$import numpy",
|
7 |
+
"$import os"
|
8 |
+
],
|
9 |
+
"bundle_root": ".",
|
10 |
+
"image_key": "image",
|
11 |
+
"ckpt_dir": "$@bundle_root + '/models'",
|
12 |
+
"output_dir": "$@bundle_root + '/eval'",
|
13 |
+
"output_filename": "result_luna16_fold0.json",
|
14 |
+
"data_list_file_path": "$@bundle_root + '/LUNA16_datasplit/dataset_fold0.json'",
|
15 |
+
"dataset_dir": "/datasets/LUNA16_Images_resample",
|
16 |
+
"test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@dataset_dir)",
|
17 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
18 |
+
"amp": true,
|
19 |
+
"load_pretrain": true,
|
20 |
+
"spatial_dims": 3,
|
21 |
+
"num_classes": 1,
|
22 |
+
"force_sliding_window": false,
|
23 |
+
"size_divisible": [
|
24 |
+
16,
|
25 |
+
16,
|
26 |
+
8
|
27 |
+
],
|
28 |
+
"infer_patch_size": [
|
29 |
+
512,
|
30 |
+
512,
|
31 |
+
192
|
32 |
+
],
|
33 |
+
"anchor_generator": {
|
34 |
+
"_target_": "monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape",
|
35 |
+
"feature_map_scales": [
|
36 |
+
1,
|
37 |
+
2,
|
38 |
+
4
|
39 |
+
],
|
40 |
+
"base_anchor_shapes": [
|
41 |
+
[
|
42 |
+
6,
|
43 |
+
8,
|
44 |
+
4
|
45 |
+
],
|
46 |
+
[
|
47 |
+
8,
|
48 |
+
6,
|
49 |
+
5
|
50 |
+
],
|
51 |
+
[
|
52 |
+
10,
|
53 |
+
10,
|
54 |
+
6
|
55 |
+
]
|
56 |
+
]
|
57 |
+
},
|
58 |
+
"backbone": "$monai.networks.nets.resnet.resnet50(spatial_dims=3,n_input_channels=1,conv1_t_stride=[2,2,1],conv1_t_size=[7,7,7])",
|
59 |
+
"feature_extractor": "$monai.apps.detection.networks.retinanet_network.resnet_fpn_feature_extractor(@backbone,3,False,[1,2],None)",
|
60 |
+
"network_def": {
|
61 |
+
"_target_": "RetinaNet",
|
62 |
+
"spatial_dims": "@spatial_dims",
|
63 |
+
"num_classes": "@num_classes",
|
64 |
+
"num_anchors": 3,
|
65 |
+
"feature_extractor": "@feature_extractor",
|
66 |
+
"size_divisible": "@size_divisible",
|
67 |
+
"use_list_output": false
|
68 |
+
},
|
69 |
+
"network": "$@network_def.to(@device)",
|
70 |
+
"detector": {
|
71 |
+
"_target_": "RetinaNetDetector",
|
72 |
+
"network": "@network",
|
73 |
+
"anchor_generator": "@anchor_generator",
|
74 |
+
"debug": false,
|
75 |
+
"spatial_dims": "@spatial_dims",
|
76 |
+
"num_classes": "@num_classes",
|
77 |
+
"size_divisible": "@size_divisible"
|
78 |
+
},
|
79 |
+
"detector_ops": [
|
80 |
+
"[email protected]_target_keys(box_key='box', label_key='label')",
|
81 |
+
"[email protected]_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
|
82 |
+
"[email protected]_sliding_window_inferer(roi_size=@infer_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
|
83 |
+
],
|
84 |
+
"preprocessing": {
|
85 |
+
"_target_": "Compose",
|
86 |
+
"transforms": [
|
87 |
+
{
|
88 |
+
"_target_": "LoadImaged",
|
89 |
+
"keys": "@image_key",
|
90 |
+
"_disabled_": "@whether_raw_luna16"
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"_target_": "LoadImaged",
|
94 |
+
"keys": "@image_key",
|
95 |
+
"reader": "itkreader",
|
96 |
+
"affine_lps_to_ras": true,
|
97 |
+
"_disabled_": "@whether_resampled_luna16"
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"_target_": "EnsureChannelFirstd",
|
101 |
+
"keys": "@image_key"
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"_target_": "Orientationd",
|
105 |
+
"keys": "@image_key",
|
106 |
+
"axcodes": "RAS"
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"_target_": "Spacingd",
|
110 |
+
"keys": "@image_key",
|
111 |
+
"pixdim": [
|
112 |
+
0.703125,
|
113 |
+
0.703125,
|
114 |
+
1.25
|
115 |
+
],
|
116 |
+
"_disabled_": "@whether_resampled_luna16"
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"_target_": "ScaleIntensityRanged",
|
120 |
+
"keys": "@image_key",
|
121 |
+
"a_min": -1024.0,
|
122 |
+
"a_max": 300.0,
|
123 |
+
"b_min": 0.0,
|
124 |
+
"b_max": 1.0,
|
125 |
+
"clip": true
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"_target_": "EnsureTyped",
|
129 |
+
"keys": "@image_key"
|
130 |
+
}
|
131 |
+
]
|
132 |
+
},
|
133 |
+
"dataset": {
|
134 |
+
"_target_": "Dataset",
|
135 |
+
"data": "$@test_datalist",
|
136 |
+
"transform": "@preprocessing"
|
137 |
+
},
|
138 |
+
"dataloader": {
|
139 |
+
"_target_": "DataLoader",
|
140 |
+
"dataset": "@dataset",
|
141 |
+
"batch_size": 1,
|
142 |
+
"shuffle": false,
|
143 |
+
"num_workers": 4,
|
144 |
+
"collate_fn": "$monai.data.utils.no_collation"
|
145 |
+
},
|
146 |
+
"inferer": {
|
147 |
+
"_target_": "scripts.detection_inferer.RetinaNetInferer",
|
148 |
+
"detector": "@detector",
|
149 |
+
"force_sliding_window": "@force_sliding_window"
|
150 |
+
},
|
151 |
+
"postprocessing": {
|
152 |
+
"_target_": "Compose",
|
153 |
+
"transforms": [
|
154 |
+
{
|
155 |
+
"_target_": "ClipBoxToImaged",
|
156 |
+
"box_keys": "box",
|
157 |
+
"label_keys": "label",
|
158 |
+
"box_ref_image_keys": "@image_key",
|
159 |
+
"remove_empty": true
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"_target_": "AffineBoxToWorldCoordinated",
|
163 |
+
"box_keys": "box",
|
164 |
+
"box_ref_image_keys": "@image_key",
|
165 |
+
"affine_lps_to_ras": true
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"_target_": "ConvertBoxModed",
|
169 |
+
"box_keys": "box",
|
170 |
+
"src_mode": "xyzxyz",
|
171 |
+
"dst_mode": "cccwhd"
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"_target_": "DeleteItemsd",
|
175 |
+
"keys": [
|
176 |
+
"@image_key"
|
177 |
+
]
|
178 |
+
}
|
179 |
+
]
|
180 |
+
},
|
181 |
+
"handlers": [
|
182 |
+
{
|
183 |
+
"_target_": "StatsHandler",
|
184 |
+
"iteration_log": false
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"_target_": "scripts.detection_saver.DetectionSaver",
|
188 |
+
"output_dir": "@output_dir",
|
189 |
+
"filename": "@output_filename",
|
190 |
+
"batch_transform": "$lambda x: [xx['image'].meta for xx in x]",
|
191 |
+
"output_transform": "$lambda x: [@postprocessing({**xx['pred'],'image':xx['image']}) for xx in x]",
|
192 |
+
"pred_box_key": "box",
|
193 |
+
"pred_label_key": "label",
|
194 |
+
"pred_score_key": "label_scores"
|
195 |
+
}
|
196 |
+
],
|
197 |
+
"evaluator": {
|
198 |
+
"_target_": "scripts.evaluator.DetectionEvaluator",
|
199 |
+
"_requires_": "@detector_ops",
|
200 |
+
"device": "@device",
|
201 |
+
"val_data_loader": "@dataloader",
|
202 |
+
"network": "@network",
|
203 |
+
"inferer": "@inferer",
|
204 |
+
"val_handlers": "@handlers",
|
205 |
+
"amp": "@amp"
|
206 |
+
},
|
207 |
+
"checkpointloader": {
|
208 |
+
"_target_": "CheckpointLoader",
|
209 |
+
"load_path": "$@bundle_root + '/models/model.pt'",
|
210 |
+
"load_dict": {
|
211 |
+
"model": "@network"
|
212 |
+
}
|
213 |
+
},
|
214 |
+
"initialize": [
|
215 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)",
|
216 |
+
"$@checkpointloader(@evaluator) if @load_pretrain else None"
|
217 |
+
],
|
218 |
+
"run": [
|
219 |
+
"[email protected]()"
|
220 |
+
]
|
221 |
+
}
|
configs/inference_trt.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"imports": [
|
3 |
+
"$import glob",
|
4 |
+
"$import os",
|
5 |
+
"$import torch_tensorrt"
|
6 |
+
],
|
7 |
+
"force_sliding_window": true,
|
8 |
+
"network_def": "$torch.jit.load(@bundle_root + '/models/model_trt.ts')",
|
9 |
+
"evaluator#amp": false,
|
10 |
+
"initialize": [
|
11 |
+
"$setattr(torch.backends.cudnn, 'benchmark', True)"
|
12 |
+
]
|
13 |
+
}
|
configs/logging.conf
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[loggers]
|
2 |
+
keys=root
|
3 |
+
|
4 |
+
[handlers]
|
5 |
+
keys=consoleHandler
|
6 |
+
|
7 |
+
[formatters]
|
8 |
+
keys=fullFormatter
|
9 |
+
|
10 |
+
[logger_root]
|
11 |
+
level=INFO
|
12 |
+
handlers=consoleHandler
|
13 |
+
|
14 |
+
[handler_consoleHandler]
|
15 |
+
class=StreamHandler
|
16 |
+
level=INFO
|
17 |
+
formatter=fullFormatter
|
18 |
+
args=(sys.stdout,)
|
19 |
+
|
20 |
+
[formatter_fullFormatter]
|
21 |
+
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
configs/metadata.json
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
|
3 |
+
"version": "0.6.9",
|
4 |
+
"changelog": {
|
5 |
+
"0.6.9": "update to huggingface hosting and fix missing dependencies",
|
6 |
+
"0.6.8": "update issue for IgniteInfo",
|
7 |
+
"0.6.7": "use monai 1.4 and update large files",
|
8 |
+
"0.6.6": "update to use monai 1.3.1",
|
9 |
+
"0.6.5": "remove notes for trt_export in readme",
|
10 |
+
"0.6.4": "add notes for trt_export in readme",
|
11 |
+
"0.6.3": "add load_pretrain flag for infer",
|
12 |
+
"0.6.2": "add checkpoint loader for infer",
|
13 |
+
"0.6.1": "fix format error",
|
14 |
+
"0.6.0": "remove meta_dict usage",
|
15 |
+
"0.5.9": "use monai 1.2.0",
|
16 |
+
"0.5.8": "update TRT memory requirement in readme",
|
17 |
+
"0.5.7": "add dataset dir example",
|
18 |
+
"0.5.6": "add the ONNX-TensorRT way of model conversion",
|
19 |
+
"0.5.5": "update retrained validation results and training curve",
|
20 |
+
"0.5.4": "add non-deterministic note",
|
21 |
+
"0.5.3": "adapt to BundleWorkflow interface",
|
22 |
+
"0.5.2": "black autofix format and add name tag",
|
23 |
+
"0.5.1": "modify dataset key name",
|
24 |
+
"0.5.0": "use detection inferer",
|
25 |
+
"0.4.5": "fixed some small changes with formatting in readme",
|
26 |
+
"0.4.4": "add data resource to readme",
|
27 |
+
"0.4.3": "update val patch size to avoid warning in monai 1.0.1",
|
28 |
+
"0.4.2": "update to use monai 1.0.1",
|
29 |
+
"0.4.1": "fix license Copyright error",
|
30 |
+
"0.4.0": "add support for raw images",
|
31 |
+
"0.3.0": "update license files",
|
32 |
+
"0.2.0": "unify naming",
|
33 |
+
"0.1.1": "add reference for LIDC dataset",
|
34 |
+
"0.1.0": "complete the model package"
|
35 |
+
},
|
36 |
+
"monai_version": "1.4.0",
|
37 |
+
"pytorch_version": "2.4.0",
|
38 |
+
"numpy_version": "1.24.4",
|
39 |
+
"required_packages_version": {
|
40 |
+
"nibabel": "5.2.1",
|
41 |
+
"pytorch-ignite": "0.4.11",
|
42 |
+
"torchvision": "0.19.0",
|
43 |
+
"tensorboard": "2.17.0"
|
44 |
+
},
|
45 |
+
"supported_apps": {},
|
46 |
+
"name": "Lung nodule CT detection",
|
47 |
+
"task": "CT lung nodule detection",
|
48 |
+
"description": "A pre-trained model for volumetric (3D) detection of the lung lesion from CT image on LUNA16 dataset",
|
49 |
+
"authors": "MONAI team",
|
50 |
+
"copyright": "Copyright (c) MONAI Consortium",
|
51 |
+
"data_source": "https://luna16.grand-challenge.org/Home/",
|
52 |
+
"data_type": "nibabel",
|
53 |
+
"image_classes": "1 channel data, CT at 0.703125 x 0.703125 x 1.25 mm",
|
54 |
+
"label_classes": "dict data, containing Nx6 box and Nx1 classification labels.",
|
55 |
+
"pred_classes": "dict data, containing Nx6 box, Nx1 classification labels, Nx1 classification scores.",
|
56 |
+
"eval_metrics": {
|
57 |
+
"mAP_IoU_0.10_0.50_0.05_MaxDet_100": 0.852,
|
58 |
+
"AP_IoU_0.10_MaxDet_100": 0.858,
|
59 |
+
"mAR_IoU_0.10_0.50_0.05_MaxDet_100": 0.998,
|
60 |
+
"AR_IoU_0.10_MaxDet_100": 1.0
|
61 |
+
},
|
62 |
+
"intended_use": "This is an example, not to be used for diagnostic purposes",
|
63 |
+
"references": [
|
64 |
+
"Lin, Tsung-Yi, et al. 'Focal loss for dense object detection. ICCV 2017"
|
65 |
+
],
|
66 |
+
"network_data_format": {
|
67 |
+
"inputs": {
|
68 |
+
"image": {
|
69 |
+
"type": "image",
|
70 |
+
"format": "magnitude",
|
71 |
+
"modality": "CT",
|
72 |
+
"num_channels": 1,
|
73 |
+
"spatial_shape": [
|
74 |
+
"16*n",
|
75 |
+
"16*n",
|
76 |
+
"8*n"
|
77 |
+
],
|
78 |
+
"dtype": "float16",
|
79 |
+
"value_range": [
|
80 |
+
0,
|
81 |
+
1
|
82 |
+
],
|
83 |
+
"is_patch_data": true,
|
84 |
+
"channel_def": {
|
85 |
+
"0": "image"
|
86 |
+
}
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"outputs": {
|
90 |
+
"pred": {
|
91 |
+
"type": "object",
|
92 |
+
"format": "dict",
|
93 |
+
"dtype": "float16",
|
94 |
+
"num_channels": 1,
|
95 |
+
"spatial_shape": [
|
96 |
+
"n",
|
97 |
+
"n",
|
98 |
+
"n"
|
99 |
+
],
|
100 |
+
"value_range": [
|
101 |
+
-10000,
|
102 |
+
10000
|
103 |
+
]
|
104 |
+
}
|
105 |
+
}
|
106 |
+
}
|
107 |
+
}
|
configs/train.json
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"imports": [
|
3 |
+
"$import glob",
|
4 |
+
"$import os"
|
5 |
+
],
|
6 |
+
"bundle_root": ".",
|
7 |
+
"ckpt_dir": "$@bundle_root + '/models'",
|
8 |
+
"output_dir": "$@bundle_root + '/eval'",
|
9 |
+
"data_list_file_path": "$@bundle_root + '/LUNA16_datasplit/dataset_fold0.json'",
|
10 |
+
"dataset_dir": "/datasets/LUNA16_Images_resample",
|
11 |
+
"train_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='training', base_dir=@dataset_dir)",
|
12 |
+
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
13 |
+
"epochs": 300,
|
14 |
+
"val_interval": 5,
|
15 |
+
"learning_rate": 0.01,
|
16 |
+
"amp": true,
|
17 |
+
"batch_size": 4,
|
18 |
+
"patch_size": [
|
19 |
+
192,
|
20 |
+
192,
|
21 |
+
80
|
22 |
+
],
|
23 |
+
"val_patch_size": [
|
24 |
+
512,
|
25 |
+
512,
|
26 |
+
192
|
27 |
+
],
|
28 |
+
"anchor_generator": {
|
29 |
+
"_target_": "monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape",
|
30 |
+
"feature_map_scales": [
|
31 |
+
1,
|
32 |
+
2,
|
33 |
+
4
|
34 |
+
],
|
35 |
+
"base_anchor_shapes": [
|
36 |
+
[
|
37 |
+
6,
|
38 |
+
8,
|
39 |
+
4
|
40 |
+
],
|
41 |
+
[
|
42 |
+
8,
|
43 |
+
6,
|
44 |
+
5
|
45 |
+
],
|
46 |
+
[
|
47 |
+
10,
|
48 |
+
10,
|
49 |
+
6
|
50 |
+
]
|
51 |
+
]
|
52 |
+
},
|
53 |
+
"backbone": "$monai.networks.nets.resnet.resnet50(spatial_dims=3,n_input_channels=1,conv1_t_stride=[2,2,1],conv1_t_size=[7,7,7])",
|
54 |
+
"feature_extractor": "$monai.apps.detection.networks.retinanet_network.resnet_fpn_feature_extractor(@backbone,3,False,[1,2],None)",
|
55 |
+
"network_def": {
|
56 |
+
"_target_": "RetinaNet",
|
57 |
+
"spatial_dims": 3,
|
58 |
+
"num_classes": 1,
|
59 |
+
"num_anchors": 3,
|
60 |
+
"feature_extractor": "@feature_extractor",
|
61 |
+
"size_divisible": [
|
62 |
+
16,
|
63 |
+
16,
|
64 |
+
8
|
65 |
+
]
|
66 |
+
},
|
67 |
+
"network": "$@network_def.to(@device)",
|
68 |
+
"detector": {
|
69 |
+
"_target_": "RetinaNetDetector",
|
70 |
+
"network": "@network",
|
71 |
+
"anchor_generator": "@anchor_generator",
|
72 |
+
"debug": false
|
73 |
+
},
|
74 |
+
"detector_ops": [
|
75 |
+
"[email protected]_atss_matcher(num_candidates=4, center_in_gt=False)",
|
76 |
+
"[email protected]_hard_negative_sampler(batch_size_per_image=64,positive_fraction=0.3,pool_size=20,min_neg=16)",
|
77 |
+
"[email protected]_target_keys(box_key='box', label_key='label')",
|
78 |
+
"[email protected]_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
|
79 |
+
"[email protected]_sliding_window_inferer(roi_size=@val_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
|
80 |
+
],
|
81 |
+
"optimizer": {
|
82 |
+
"_target_": "torch.optim.SGD",
|
83 |
+
"params": "[email protected]()",
|
84 |
+
"lr": "@learning_rate",
|
85 |
+
"momentum": 0.9,
|
86 |
+
"weight_decay": 3e-05,
|
87 |
+
"nesterov": true
|
88 |
+
},
|
89 |
+
"after_scheduler": {
|
90 |
+
"_target_": "torch.optim.lr_scheduler.StepLR",
|
91 |
+
"optimizer": "@optimizer",
|
92 |
+
"step_size": 160,
|
93 |
+
"gamma": 0.1
|
94 |
+
},
|
95 |
+
"lr_scheduler": {
|
96 |
+
"_target_": "scripts.warmup_scheduler.GradualWarmupScheduler",
|
97 |
+
"optimizer": "@optimizer",
|
98 |
+
"multiplier": 1,
|
99 |
+
"total_epoch": 10,
|
100 |
+
"after_scheduler": "@after_scheduler"
|
101 |
+
},
|
102 |
+
"train": {
|
103 |
+
"preprocessing_transforms": [
|
104 |
+
{
|
105 |
+
"_target_": "LoadImaged",
|
106 |
+
"keys": "image"
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"_target_": "EnsureChannelFirstd",
|
110 |
+
"keys": "image"
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"_target_": "EnsureTyped",
|
114 |
+
"keys": [
|
115 |
+
"image",
|
116 |
+
"box"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"_target_": "EnsureTyped",
|
121 |
+
"keys": "label",
|
122 |
+
"dtype": "$torch.long"
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"_target_": "Orientationd",
|
126 |
+
"keys": "image",
|
127 |
+
"axcodes": "RAS"
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"_target_": "ScaleIntensityRanged",
|
131 |
+
"keys": "image",
|
132 |
+
"a_min": -1024.0,
|
133 |
+
"a_max": 300.0,
|
134 |
+
"b_min": 0.0,
|
135 |
+
"b_max": 1.0,
|
136 |
+
"clip": true
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"_target_": "ConvertBoxToStandardModed",
|
140 |
+
"box_keys": "box",
|
141 |
+
"mode": "cccwhd"
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"_target_": "AffineBoxToImageCoordinated",
|
145 |
+
"box_keys": "box",
|
146 |
+
"box_ref_image_keys": "image",
|
147 |
+
"affine_lps_to_ras": true
|
148 |
+
}
|
149 |
+
],
|
150 |
+
"random_transforms": [
|
151 |
+
{
|
152 |
+
"_target_": "RandCropBoxByPosNegLabeld",
|
153 |
+
"image_keys": "image",
|
154 |
+
"box_keys": "box",
|
155 |
+
"label_keys": "label",
|
156 |
+
"spatial_size": "@patch_size",
|
157 |
+
"whole_box": true,
|
158 |
+
"num_samples": "@batch_size",
|
159 |
+
"pos": 1,
|
160 |
+
"neg": 1
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"_target_": "RandZoomBoxd",
|
164 |
+
"image_keys": "image",
|
165 |
+
"box_keys": "box",
|
166 |
+
"label_keys": "label",
|
167 |
+
"box_ref_image_keys": "image",
|
168 |
+
"prob": 0.2,
|
169 |
+
"min_zoom": 0.7,
|
170 |
+
"max_zoom": 1.4,
|
171 |
+
"padding_mode": "constant",
|
172 |
+
"keep_size": true
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"_target_": "ClipBoxToImaged",
|
176 |
+
"box_keys": "box",
|
177 |
+
"label_keys": "label",
|
178 |
+
"box_ref_image_keys": "image",
|
179 |
+
"remove_empty": true
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"_target_": "RandFlipBoxd",
|
183 |
+
"image_keys": "image",
|
184 |
+
"box_keys": "box",
|
185 |
+
"box_ref_image_keys": "image",
|
186 |
+
"prob": 0.5,
|
187 |
+
"spatial_axis": 0
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"_target_": "RandFlipBoxd",
|
191 |
+
"image_keys": "image",
|
192 |
+
"box_keys": "box",
|
193 |
+
"box_ref_image_keys": "image",
|
194 |
+
"prob": 0.5,
|
195 |
+
"spatial_axis": 1
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"_target_": "RandFlipBoxd",
|
199 |
+
"image_keys": "image",
|
200 |
+
"box_keys": "box",
|
201 |
+
"box_ref_image_keys": "image",
|
202 |
+
"prob": 0.5,
|
203 |
+
"spatial_axis": 2
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"_target_": "RandRotateBox90d",
|
207 |
+
"image_keys": "image",
|
208 |
+
"box_keys": "box",
|
209 |
+
"box_ref_image_keys": "image",
|
210 |
+
"prob": 0.75,
|
211 |
+
"max_k": 3,
|
212 |
+
"spatial_axes": [
|
213 |
+
0,
|
214 |
+
1
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"_target_": "BoxToMaskd",
|
219 |
+
"box_keys": "box",
|
220 |
+
"label_keys": "label",
|
221 |
+
"box_mask_keys": "box_mask",
|
222 |
+
"box_ref_image_keys": "image",
|
223 |
+
"min_fg_label": 0,
|
224 |
+
"ellipse_mask": true
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"_target_": "RandRotated",
|
228 |
+
"keys": [
|
229 |
+
"image",
|
230 |
+
"box_mask"
|
231 |
+
],
|
232 |
+
"mode": [
|
233 |
+
"nearest",
|
234 |
+
"nearest"
|
235 |
+
],
|
236 |
+
"prob": 0.2,
|
237 |
+
"range_x": 0.5236,
|
238 |
+
"range_y": 0.5236,
|
239 |
+
"range_z": 0.5236,
|
240 |
+
"keep_size": true,
|
241 |
+
"padding_mode": "zeros"
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"_target_": "MaskToBoxd",
|
245 |
+
"box_keys": [
|
246 |
+
"box"
|
247 |
+
],
|
248 |
+
"label_keys": [
|
249 |
+
"label"
|
250 |
+
],
|
251 |
+
"box_mask_keys": [
|
252 |
+
"box_mask"
|
253 |
+
],
|
254 |
+
"min_fg_label": 0
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"_target_": "DeleteItemsd",
|
258 |
+
"keys": "box_mask"
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"_target_": "RandGaussianNoised",
|
262 |
+
"keys": "image",
|
263 |
+
"prob": 0.1,
|
264 |
+
"mean": 0.0,
|
265 |
+
"std": 0.1
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"_target_": "RandGaussianSmoothd",
|
269 |
+
"keys": "image",
|
270 |
+
"prob": 0.1,
|
271 |
+
"sigma_x": [
|
272 |
+
0.5,
|
273 |
+
1.0
|
274 |
+
],
|
275 |
+
"sigma_y": [
|
276 |
+
0.5,
|
277 |
+
1.0
|
278 |
+
],
|
279 |
+
"sigma_z": [
|
280 |
+
0.5,
|
281 |
+
1.0
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"_target_": "RandScaleIntensityd",
|
286 |
+
"keys": "image",
|
287 |
+
"factors": 0.25,
|
288 |
+
"prob": 0.15
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"_target_": "RandShiftIntensityd",
|
292 |
+
"keys": "image",
|
293 |
+
"offsets": 0.1,
|
294 |
+
"prob": 0.15
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"_target_": "RandAdjustContrastd",
|
298 |
+
"keys": "image",
|
299 |
+
"prob": 0.3,
|
300 |
+
"gamma": [
|
301 |
+
0.7,
|
302 |
+
1.5
|
303 |
+
]
|
304 |
+
}
|
305 |
+
],
|
306 |
+
"final_transforms": [
|
307 |
+
{
|
308 |
+
"_target_": "EnsureTyped",
|
309 |
+
"keys": [
|
310 |
+
"image",
|
311 |
+
"box"
|
312 |
+
]
|
313 |
+
},
|
314 |
+
{
|
315 |
+
"_target_": "EnsureTyped",
|
316 |
+
"keys": "label",
|
317 |
+
"dtype": "$torch.long"
|
318 |
+
},
|
319 |
+
{
|
320 |
+
"_target_": "ToTensord",
|
321 |
+
"keys": [
|
322 |
+
"image",
|
323 |
+
"box",
|
324 |
+
"label"
|
325 |
+
]
|
326 |
+
}
|
327 |
+
],
|
328 |
+
"preprocessing": {
|
329 |
+
"_target_": "Compose",
|
330 |
+
"transforms": "$@train#preprocessing_transforms + @train#random_transforms + @train#final_transforms"
|
331 |
+
},
|
332 |
+
"dataset": {
|
333 |
+
"_target_": "Dataset",
|
334 |
+
"data": "$@train_datalist[: int(0.95 * len(@train_datalist))]",
|
335 |
+
"transform": "@train#preprocessing"
|
336 |
+
},
|
337 |
+
"dataloader": {
|
338 |
+
"_target_": "DataLoader",
|
339 |
+
"dataset": "@train#dataset",
|
340 |
+
"batch_size": 1,
|
341 |
+
"shuffle": true,
|
342 |
+
"num_workers": 4,
|
343 |
+
"collate_fn": "$monai.data.utils.no_collation"
|
344 |
+
},
|
345 |
+
"handlers": [
|
346 |
+
{
|
347 |
+
"_target_": "LrScheduleHandler",
|
348 |
+
"lr_scheduler": "@lr_scheduler",
|
349 |
+
"print_lr": true
|
350 |
+
},
|
351 |
+
{
|
352 |
+
"_target_": "ValidationHandler",
|
353 |
+
"validator": "@validate#evaluator",
|
354 |
+
"epoch_level": true,
|
355 |
+
"interval": "@val_interval"
|
356 |
+
},
|
357 |
+
{
|
358 |
+
"_target_": "StatsHandler",
|
359 |
+
"tag_name": "train_loss",
|
360 |
+
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)[0]"
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"_target_": "TensorBoardStatsHandler",
|
364 |
+
"log_dir": "@output_dir",
|
365 |
+
"tag_name": "train_loss",
|
366 |
+
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)[0]"
|
367 |
+
}
|
368 |
+
],
|
369 |
+
"trainer": {
|
370 |
+
"_target_": "scripts.trainer.DetectionTrainer",
|
371 |
+
"_requires_": "@detector_ops",
|
372 |
+
"max_epochs": "@epochs",
|
373 |
+
"device": "@device",
|
374 |
+
"train_data_loader": "@train#dataloader",
|
375 |
+
"detector": "@detector",
|
376 |
+
"optimizer": "@optimizer",
|
377 |
+
"train_handlers": "@train#handlers",
|
378 |
+
"amp": "@amp"
|
379 |
+
}
|
380 |
+
},
|
381 |
+
"validate": {
|
382 |
+
"preprocessing": {
|
383 |
+
"_target_": "Compose",
|
384 |
+
"transforms": "$@train#preprocessing_transforms + @train#final_transforms"
|
385 |
+
},
|
386 |
+
"dataset": {
|
387 |
+
"_target_": "Dataset",
|
388 |
+
"data": "$@train_datalist[int(0.95 * len(@train_datalist)): ]",
|
389 |
+
"transform": "@validate#preprocessing"
|
390 |
+
},
|
391 |
+
"dataloader": {
|
392 |
+
"_target_": "DataLoader",
|
393 |
+
"dataset": "@validate#dataset",
|
394 |
+
"batch_size": 1,
|
395 |
+
"shuffle": false,
|
396 |
+
"num_workers": 2,
|
397 |
+
"collate_fn": "$monai.data.utils.no_collation"
|
398 |
+
},
|
399 |
+
"inferer": {
|
400 |
+
"_target_": "scripts.detection_inferer.RetinaNetInferer",
|
401 |
+
"detector": "@detector"
|
402 |
+
},
|
403 |
+
"handlers": [
|
404 |
+
{
|
405 |
+
"_target_": "StatsHandler",
|
406 |
+
"iteration_log": false
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"_target_": "TensorBoardStatsHandler",
|
410 |
+
"log_dir": "@output_dir",
|
411 |
+
"iteration_log": false
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"_target_": "CheckpointSaver",
|
415 |
+
"save_dir": "@ckpt_dir",
|
416 |
+
"save_dict": {
|
417 |
+
"model": "@network"
|
418 |
+
},
|
419 |
+
"save_key_metric": true,
|
420 |
+
"key_metric_filename": "model.pt"
|
421 |
+
}
|
422 |
+
],
|
423 |
+
"key_metric": {
|
424 |
+
"val_coco": {
|
425 |
+
"_target_": "scripts.cocometric_ignite.IgniteCocoMetric",
|
426 |
+
"coco_metric_monai": "$monai.apps.detection.metrics.coco.COCOMetric(classes=['nodule'], iou_list=[0.1], max_detection=[100])",
|
427 |
+
"output_transform": "$monai.handlers.from_engine(['pred', 'label'])",
|
428 |
+
"box_key": "box",
|
429 |
+
"label_key": "label",
|
430 |
+
"pred_score_key": "label_scores",
|
431 |
+
"reduce_scalar": true
|
432 |
+
}
|
433 |
+
},
|
434 |
+
"evaluator": {
|
435 |
+
"_target_": "scripts.evaluator.DetectionEvaluator",
|
436 |
+
"_requires_": "@detector_ops",
|
437 |
+
"device": "@device",
|
438 |
+
"val_data_loader": "@validate#dataloader",
|
439 |
+
"network": "@network",
|
440 |
+
"inferer": "@validate#inferer",
|
441 |
+
"key_val_metric": "@validate#key_metric",
|
442 |
+
"val_handlers": "@validate#handlers",
|
443 |
+
"amp": "@amp"
|
444 |
+
}
|
445 |
+
},
|
446 |
+
"initialize": [
|
447 |
+
"$monai.utils.set_determinism(seed=0)"
|
448 |
+
],
|
449 |
+
"run": [
|
450 |
+
"$@train#trainer.run()"
|
451 |
+
]
|
452 |
+
}
|
docs/README.md
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Overview
|
2 |
+
A pre-trained model for volumetric (3D) detection of the lung nodule from CT image.
|
3 |
+
|
4 |
+
This model is trained on LUNA16 dataset (https://luna16.grand-challenge.org/Home/), using the RetinaNet (Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002).
|
5 |
+
|
6 |
+

|
7 |
+
|
8 |
+
## Data
|
9 |
+
The dataset we are experimenting in this example is LUNA16 (https://luna16.grand-challenge.org/Home/), which is based on [LIDC-IDRI database](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) [3,4,5].
|
10 |
+
|
11 |
+
LUNA16 is a public dataset of CT lung nodule detection. Using raw CT scans, the goal is to identify locations of possible nodules, and to assign a probability for being a nodule to each location.
|
12 |
+
|
13 |
+
Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset! We acknowledge the National Cancer Institute and the Foundation for the National Institutes of Health, and their critical role in the creation of the free publicly available LIDC/IDRI Database used in this study.
|
14 |
+
|
15 |
+
### 10-fold data splitting
|
16 |
+
We follow the official 10-fold data splitting from LUNA16 challenge and generate data split json files using the script from [nnDetection](https://github.com/MIC-DKFZ/nnDetection/blob/main/projects/Task016_Luna/scripts/prepare.py).
|
17 |
+
|
18 |
+
Please download the resulted json files from https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/LUNA16_datasplit-20220615T233840Z-001.zip.
|
19 |
+
|
20 |
+
In these files, the values of "box" are the ground truth boxes in world coordinate.
|
21 |
+
|
22 |
+
### Data resampling
|
23 |
+
The raw CT images in LUNA16 have various of voxel sizes. The first step is to resample them to the same voxel size.
|
24 |
+
In this model, we resampled them into 0.703125 x 0.703125 x 1.25 mm.
|
25 |
+
|
26 |
+
Please following the instruction in Section 3.1 of https://github.com/Project-MONAI/tutorials/tree/main/detection to do the resampling.
|
27 |
+
|
28 |
+
### Data download
|
29 |
+
The mhd/raw original data can be downloaded from [LUNA16](https://luna16.grand-challenge.org/Home/). The DICOM original data can be downloaded from [LIDC-IDRI database](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) [3,4,5]. You will need to resample the original data to start training.
|
30 |
+
|
31 |
+
Alternatively, we provide [resampled nifti images](https://drive.google.com/drive/folders/1JozrufA1VIZWJIc5A1EMV3J4CNCYovKK?usp=share_link) and a copy of [original mhd/raw images](https://drive.google.com/drive/folders/1-enN4eNEnKmjltevKg3W2V-Aj0nriQWE?usp=share_link) from [LUNA16](https://luna16.grand-challenge.org/Home/) for users to download.
|
32 |
+
|
33 |
+
## Training configuration
|
34 |
+
The training was performed with the following:
|
35 |
+
|
36 |
+
- GPU: at least 16GB GPU memory, requires 32G when exporting TRT model
|
37 |
+
- Actual Model Input: 192 x 192 x 80
|
38 |
+
- AMP: True
|
39 |
+
- Optimizer: Adam
|
40 |
+
- Learning Rate: 1e-2
|
41 |
+
- Loss: BCE loss and L1 loss
|
42 |
+
|
43 |
+
### Input
|
44 |
+
1 channel
|
45 |
+
- List of 3D CT patches
|
46 |
+
|
47 |
+
### Output
|
48 |
+
In Training Mode: A dictionary of classification and box regression loss.
|
49 |
+
|
50 |
+
In Evaluation Mode: A list of dictionaries of predicted box, classification label, and classification score.
|
51 |
+
|
52 |
+
## Performance
|
53 |
+
Coco metric is used for evaluating the performance of the model. The pre-trained model was trained and validated on data fold 0. This model achieves a mAP=0.852, mAR=0.998, AP(IoU=0.1)=0.858, AR(IoU=0.1)=1.0.
|
54 |
+
|
55 |
+
Please note that this bundle is non-deterministic because of the max pooling layer used in the network. Therefore, reproducing the training process may not get exactly the same performance.
|
56 |
+
Please refer to https://pytorch.org/docs/stable/notes/randomness.html#reproducibility for more details about reproducibility.
|
57 |
+
|
58 |
+
#### Training Loss
|
59 |
+

|
60 |
+
|
61 |
+
#### Validation Accuracy
|
62 |
+
The validation accuracy in this curve is the mean of mAP, mAR, AP(IoU=0.1), and AR(IoU=0.1) in Coco metric.
|
63 |
+
|
64 |
+

|
65 |
+
|
66 |
+
#### TensorRT speedup
|
67 |
+
The `lung_nodule_ct_detection` bundle supports acceleration with TensorRT through the ONNX-TensorRT method. The table below displays the speedup ratios observed on an A100 80G GPU. Please note that when using the TensorRT model for inference, the `force_sliding_window` parameter in the `inference.json` file must be set to `true`. This ensures that the bundle uses the `SlidingWindowInferer` during inference and maintains the input spatial size of the network. Otherwise, if given an input with spatial size less than the `infer_patch_size`, the input spatial size of the network would be changed.
|
68 |
+
|
69 |
+
| method | torch_fp32(ms) | torch_amp(ms) | trt_fp32(ms) | trt_fp16(ms) | speedup amp | speedup fp32 | speedup fp16 | amp vs fp16|
|
70 |
+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
71 |
+
| model computation | 7449.84 | 996.08 | 976.67 | 626.90 | 7.63 | 7.63 | 11.88 | 1.56 |
|
72 |
+
| end2end | 36458.26 | 7259.35 | 6420.60 | 4698.34 | 5.02 | 5.68 | 7.76 | 1.55 |
|
73 |
+
|
74 |
+
Where:
|
75 |
+
- `model computation` means the speedup ratio of model's inference with a random input without preprocessing and postprocessing
|
76 |
+
- `end2end` means run the bundle end-to-end with the TensorRT based model.
|
77 |
+
- `torch_fp32` and `torch_amp` are for the PyTorch models with or without `amp` mode.
|
78 |
+
- `trt_fp32` and `trt_fp16` are for the TensorRT based models converted in corresponding precision.
|
79 |
+
- `speedup amp`, `speedup fp32` and `speedup fp16` are the speedup ratios of corresponding models versus the PyTorch float32 model
|
80 |
+
- `amp vs fp16` is the speedup ratio between the PyTorch amp model and the TensorRT float16 based model.
|
81 |
+
|
82 |
+
Currently, the only available method to accelerate this model is through ONNX-TensorRT. However, the Torch-TensorRT method is under development and will be available in the near future.
|
83 |
+
|
84 |
+
This result is benchmarked under:
|
85 |
+
- TensorRT: 8.5.3+cuda11.8
|
86 |
+
- Torch-TensorRT Version: 1.4.0
|
87 |
+
- CPU Architecture: x86-64
|
88 |
+
- OS: ubuntu 20.04
|
89 |
+
- Python version:3.8.10
|
90 |
+
- CUDA version: 12.0
|
91 |
+
- GPU models and configuration: A100 80G
|
92 |
+
|
93 |
+
## MONAI Bundle Commands
|
94 |
+
In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. The CLI supports flexible use cases, such as overriding configs at runtime and predefining arguments in a file.
|
95 |
+
|
96 |
+
For more details usage instructions, visit the [MONAI Bundle Configuration Page](https://docs.monai.io/en/latest/config_syntax.html).
|
97 |
+
|
98 |
+
#### Execute training:
|
99 |
+
|
100 |
+
```
|
101 |
+
python -m monai.bundle run --config_file configs/train.json
|
102 |
+
```
|
103 |
+
|
104 |
+
Please note that if the default dataset path is not modified with the actual path in the bundle config files, you can also override it by using `--dataset_dir`:
|
105 |
+
|
106 |
+
```
|
107 |
+
python -m monai.bundle run --config_file configs/train.json --dataset_dir <actual dataset path>
|
108 |
+
```
|
109 |
+
|
110 |
+
#### Override the `train` config to execute evaluation with the trained model:
|
111 |
+
|
112 |
+
```
|
113 |
+
python -m monai.bundle run --config_file "['configs/train.json','configs/evaluate.json']"
|
114 |
+
```
|
115 |
+
|
116 |
+
#### Execute inference on resampled LUNA16 images by setting `"whether_raw_luna16": false` in `inference.json`:
|
117 |
+
|
118 |
+
```
|
119 |
+
python -m monai.bundle run --config_file configs/inference.json
|
120 |
+
```
|
121 |
+
With the same command, we can execute inference on original LUNA16 images by setting `"whether_raw_luna16": true` in `inference.json`. Remember to also set `"data_list_file_path": "$@bundle_root + '/LUNA16_datasplit/mhd_original/dataset_fold0.json'"` and change `"dataset_dir"`.
|
122 |
+
|
123 |
+
Note that in inference.json, the transform "LoadImaged" in "preprocessing" and "AffineBoxToWorldCoordinated" in "postprocessing" has `"affine_lps_to_ras": true`.
|
124 |
+
This depends on the input images. LUNA16 needs `"affine_lps_to_ras": true`.
|
125 |
+
It is possible that your inference dataset should set `"affine_lps_to_ras": false`.
|
126 |
+
|
127 |
+
#### Export checkpoint to TensorRT based models with fp32 or fp16 precision
|
128 |
+
|
129 |
+
```bash
|
130 |
+
python -m monai.bundle trt_export --net_id network_def --filepath models/model_trt.ts --ckpt_file models/model.pt --meta_file configs/metadata.json --config_file configs/inference.json --precision <fp32/fp16> --input_shape "[1, 1, 512, 512, 192]" --use_onnx "True" --use_trace "True" --onnx_output_names "['output_0', 'output_1', 'output_2', 'output_3', 'output_4', 'output_5']" --network_def#use_list_output "True"
|
131 |
+
```
|
132 |
+
|
133 |
+
#### Execute inference with the TensorRT model
|
134 |
+
|
135 |
+
```
|
136 |
+
python -m monai.bundle run --config_file "['configs/inference.json', 'configs/inference_trt.json']"
|
137 |
+
```
|
138 |
+
|
139 |
+
# References
|
140 |
+
[1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." ICCV 2017. https://arxiv.org/abs/1708.02002)
|
141 |
+
|
142 |
+
[2] Baumgartner and Jaeger et al. "nnDetection: A self-configuring method for medical object detection." MICCAI 2021. https://arxiv.org/pdf/2106.00817.pdf
|
143 |
+
|
144 |
+
[3] Armato III, S. G., McLennan, G., Bidaut, L., McNitt-Gray, M. F., Meyer, C. R., Reeves, A. P., Zhao, B., Aberle, D. R., Henschke, C. I., Hoffman, E. A., Kazerooni, E. A., MacMahon, H., Van Beek, E. J. R., Yankelevitz, D., Biancardi, A. M., Bland, P. H., Brown, M. S., Engelmann, R. M., Laderach, G. E., Max, D., Pais, R. C. , Qing, D. P. Y. , Roberts, R. Y., Smith, A. R., Starkey, A., Batra, P., Caligiuri, P., Farooqi, A., Gladish, G. W., Jude, C. M., Munden, R. F., Petkovska, I., Quint, L. E., Schwartz, L. H., Sundaram, B., Dodd, L. E., Fenimore, C., Gur, D., Petrick, N., Freymann, J., Kirby, J., Hughes, B., Casteele, A. V., Gupte, S., Sallam, M., Heath, M. D., Kuhn, M. H., Dharaiya, E., Burns, R., Fryd, D. S., Salganicoff, M., Anand, V., Shreter, U., Vastagh, S., Croft, B. Y., Clarke, L. P. (2015). Data From LIDC-IDRI [Data set]. The Cancer Imaging Archive. https://doi.org/10.7937/K9/TCIA.2015.LO9QL9SX
|
145 |
+
|
146 |
+
[4] Armato SG 3rd, McLennan G, Bidaut L, McNitt-Gray MF, Meyer CR, Reeves AP, Zhao B, Aberle DR, Henschke CI, Hoffman EA, Kazerooni EA, MacMahon H, Van Beeke EJ, Yankelevitz D, Biancardi AM, Bland PH, Brown MS, Engelmann RM, Laderach GE, Max D, Pais RC, Qing DP, Roberts RY, Smith AR, Starkey A, Batrah P, Caligiuri P, Farooqi A, Gladish GW, Jude CM, Munden RF, Petkovska I, Quint LE, Schwartz LH, Sundaram B, Dodd LE, Fenimore C, Gur D, Petrick N, Freymann J, Kirby J, Hughes B, Casteele AV, Gupte S, Sallamm M, Heath MD, Kuhn MH, Dharaiya E, Burns R, Fryd DS, Salganicoff M, Anand V, Shreter U, Vastagh S, Croft BY. The Lung Image Database Consortium (LIDC) and Image Database Resource Initiative (IDRI): A completed reference database of lung nodules on CT scans. Medical Physics, 38: 915--931, 2011. DOI: https://doi.org/10.1118/1.3528204
|
147 |
+
|
148 |
+
[5] Clark, K., Vendt, B., Smith, K., Freymann, J., Kirby, J., Koppel, P., Moore, S., Phillips, S., Maffitt, D., Pringle, M., Tarbox, L., & Prior, F. (2013). The Cancer Imaging Archive (TCIA): Maintaining and Operating a Public Information Repository. Journal of Digital Imaging, 26(6), 1045–1057. https://doi.org/10.1007/s10278-013-9622-7
|
149 |
+
|
150 |
+
# License
|
151 |
+
Copyright (c) MONAI Consortium
|
152 |
+
|
153 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
154 |
+
you may not use this file except in compliance with the License.
|
155 |
+
You may obtain a copy of the License at
|
156 |
+
|
157 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
158 |
+
|
159 |
+
Unless required by applicable law or agreed to in writing, software
|
160 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
161 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
162 |
+
See the License for the specific language governing permissions and
|
163 |
+
limitations under the License.
|
docs/data_license.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Third Party Licenses
|
2 |
+
-----------------------------------------------------------------------
|
3 |
+
|
4 |
+
/*********************************************************************/
|
5 |
+
i. LUng Nodule Analysis 2016
|
6 |
+
https://luna16.grand-challenge.org/Home/
|
7 |
+
https://creativecommons.org/licenses/by/4.0/
|
8 |
+
|
9 |
+
ii. Lung Image Database Consortium image collection (LIDC-IDRI)
|
10 |
+
https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI
|
11 |
+
https://creativecommons.org/licenses/by/3.0/
|
models/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5e79231466adae93a6fe8e8594029e9add142914e223b879aa0343bb2402d01
|
3 |
+
size 83709381
|
models/model.ts
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:68afd1ed4be8d01196d575d13931dab24cc50d46a74528a47d54496ba29e2583
|
3 |
+
size 83784539
|
scripts/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
# from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
|
13 |
+
# from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer
|
14 |
+
from .trainer import DetectionTrainer
|
scripts/cocometric_ignite.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Dict, Sequence, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
|
5 |
+
from monai.apps.detection.metrics.coco import COCOMetric
|
6 |
+
from monai.apps.detection.metrics.matching import matching_batch
|
7 |
+
from monai.data import box_utils
|
8 |
+
|
9 |
+
from .utils import detach_to_numpy
|
10 |
+
|
11 |
+
|
12 |
+
class IgniteCocoMetric(Metric):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
coco_metric_monai: Union[None, COCOMetric] = None,
|
16 |
+
box_key="box",
|
17 |
+
label_key="label",
|
18 |
+
pred_score_key="label_scores",
|
19 |
+
output_transform: Callable = lambda x: x,
|
20 |
+
device: Union[str, torch.device, None] = None,
|
21 |
+
reduce_scalar: bool = True,
|
22 |
+
):
|
23 |
+
r"""
|
24 |
+
Computes coco detection metric in Ignite.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
coco_metric_monai: the coco metric in monai.
|
28 |
+
If not given, will asume COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
|
29 |
+
box_key: box key in the ground truth target dict and prediction dict.
|
30 |
+
label_key: classification label key in the ground truth target dict and prediction dict.
|
31 |
+
pred_score_key: classification score key in the prediction dict.
|
32 |
+
output_transform: A callable that is used to transform the Engine’s
|
33 |
+
process_function’s output into the form expected by the metric.
|
34 |
+
device: specifies which device updates are accumulated on.
|
35 |
+
Setting the metric’s device to be the same as your update arguments ensures
|
36 |
+
the update method is non-blocking. By default, CPU.
|
37 |
+
reduce_scalar: if True, will return the average value of coc metric values;
|
38 |
+
if False, will return an dictionary of coc metric.
|
39 |
+
|
40 |
+
Examples:
|
41 |
+
To use with ``Engine`` and ``process_function``,
|
42 |
+
simply attach the metric instance to the engine.
|
43 |
+
The output of the engine's ``process_function`` needs to be in format of
|
44 |
+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.
|
45 |
+
For more information on how metric works with :class:`~ignite.engine.engine.Engine`,
|
46 |
+
visit :ref:`attach-engine`.
|
47 |
+
.. include:: defaults.rst
|
48 |
+
:start-after: :orphan:
|
49 |
+
.. testcode::
|
50 |
+
coco = IgniteCocoMetric()
|
51 |
+
coco.attach(default_evaluator, 'coco')
|
52 |
+
preds = [
|
53 |
+
{
|
54 |
+
'box': torch.Tensor([[1,1,1,2,2,2]]),
|
55 |
+
'label':torch.Tensor([0]),
|
56 |
+
'label_scores':torch.Tensor([0.8])
|
57 |
+
}
|
58 |
+
]
|
59 |
+
target = [{'box': torch.Tensor([[1,1,1,2,2,2]]), 'label':torch.Tensor([0])}]
|
60 |
+
state = default_evaluator.run([[preds, target]])
|
61 |
+
print(state.metrics['coco'])
|
62 |
+
.. testoutput::
|
63 |
+
1.0...
|
64 |
+
.. versionadded:: 0.4.3
|
65 |
+
"""
|
66 |
+
self.box_key = box_key
|
67 |
+
self.label_key = label_key
|
68 |
+
self.pred_score_key = pred_score_key
|
69 |
+
if coco_metric_monai is None:
|
70 |
+
self.coco_metric = COCOMetric(classes=[0], iou_list=[0.1], max_detection=[100])
|
71 |
+
else:
|
72 |
+
self.coco_metric = coco_metric_monai
|
73 |
+
self.reduce_scalar = reduce_scalar
|
74 |
+
|
75 |
+
if device is None:
|
76 |
+
device = torch.device("cpu")
|
77 |
+
super(IgniteCocoMetric, self).__init__(output_transform=output_transform, device=device)
|
78 |
+
|
79 |
+
@reinit__is_reduced
|
80 |
+
def reset(self) -> None:
|
81 |
+
self.val_targets_all = []
|
82 |
+
self.val_outputs_all = []
|
83 |
+
|
84 |
+
@reinit__is_reduced
|
85 |
+
def update(self, output: Sequence[Dict]) -> None:
|
86 |
+
y_pred, y = output[0], output[1]
|
87 |
+
self.val_outputs_all += y_pred
|
88 |
+
self.val_targets_all += y
|
89 |
+
|
90 |
+
@sync_all_reduce("val_targets_all", "val_outputs_all")
|
91 |
+
def compute(self) -> float:
|
92 |
+
self.val_outputs_all = detach_to_numpy(self.val_outputs_all)
|
93 |
+
self.val_targets_all = detach_to_numpy(self.val_targets_all)
|
94 |
+
|
95 |
+
results_metric = matching_batch(
|
96 |
+
iou_fn=box_utils.box_iou,
|
97 |
+
iou_thresholds=self.coco_metric.iou_thresholds,
|
98 |
+
pred_boxes=[val_data_i[self.box_key] for val_data_i in self.val_outputs_all],
|
99 |
+
pred_classes=[val_data_i[self.label_key] for val_data_i in self.val_outputs_all],
|
100 |
+
pred_scores=[val_data_i[self.pred_score_key] for val_data_i in self.val_outputs_all],
|
101 |
+
gt_boxes=[val_data_i[self.box_key] for val_data_i in self.val_targets_all],
|
102 |
+
gt_classes=[val_data_i[self.label_key] for val_data_i in self.val_targets_all],
|
103 |
+
)
|
104 |
+
val_epoch_metric_dict = self.coco_metric(results_metric)[0]
|
105 |
+
|
106 |
+
if self.reduce_scalar:
|
107 |
+
val_epoch_metric = val_epoch_metric_dict.values()
|
108 |
+
val_epoch_metric = sum(val_epoch_metric) / len(val_epoch_metric)
|
109 |
+
return val_epoch_metric
|
110 |
+
else:
|
111 |
+
return val_epoch_metric_dict
|
scripts/detection_inferer.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from typing import Any, List, Union
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
|
17 |
+
from monai.inferers.inferer import Inferer
|
18 |
+
from torch import Tensor
|
19 |
+
|
20 |
+
|
21 |
+
class RetinaNetInferer(Inferer):
|
22 |
+
"""
|
23 |
+
RetinaNet Inferer takes RetinaNet as input
|
24 |
+
|
25 |
+
Args:
|
26 |
+
detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP
|
27 |
+
map into boxes and classification scores.
|
28 |
+
force_sliding_window: whether to force using a SlidingWindowInferer to do the inference.
|
29 |
+
If False, will check the input spatial size to decide whether to simply
|
30 |
+
forward the network or using SlidingWindowInferer.
|
31 |
+
If True, will force using SlidingWindowInferer to do the inference.
|
32 |
+
args: other optional args to be passed to detector.
|
33 |
+
kwargs: other optional keyword args to be passed to detector.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, detector: RetinaNetDetector, force_sliding_window: bool = False) -> None:
|
37 |
+
Inferer.__init__(self)
|
38 |
+
self.detector = detector
|
39 |
+
self.sliding_window_size = None
|
40 |
+
self.force_sliding_window = force_sliding_window
|
41 |
+
if self.detector.inferer is not None:
|
42 |
+
if hasattr(self.detector.inferer, "roi_size"):
|
43 |
+
self.sliding_window_size = np.prod(self.detector.inferer.roi_size)
|
44 |
+
|
45 |
+
def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any):
|
46 |
+
"""Unified callable function API of Inferers.
|
47 |
+
Args:
|
48 |
+
inputs: model input data for inference.
|
49 |
+
network: target detection network to execute inference.
|
50 |
+
supports callable that fullfilles requirements of network in
|
51 |
+
monai.apps.detection.networks.retinanet_detector.RetinaNetDetector``
|
52 |
+
args: optional args to be passed to ``network``.
|
53 |
+
kwargs: optional keyword args to be passed to ``network``.
|
54 |
+
"""
|
55 |
+
self.detector.network = network
|
56 |
+
self.detector.training = self.detector.network.training
|
57 |
+
|
58 |
+
# if image smaller than sliding window roi size, no need to use sliding window inferer
|
59 |
+
# use sliding window inferer only when image is large
|
60 |
+
use_inferer = (
|
61 |
+
self.force_sliding_window
|
62 |
+
or self.sliding_window_size is not None
|
63 |
+
and not all([data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs])
|
64 |
+
)
|
65 |
+
|
66 |
+
return self.detector(inputs, *args, use_inferer=use_inferer, **kwargs)
|
scripts/detection_saver.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import warnings
|
15 |
+
from typing import TYPE_CHECKING, Callable, Optional
|
16 |
+
|
17 |
+
from monai.handlers.classification_saver import ClassificationSaver
|
18 |
+
from monai.utils import IgniteInfo, evenly_divisible_all_gather, min_version, optional_import, string_list_all_gather
|
19 |
+
|
20 |
+
from .utils import detach_to_numpy
|
21 |
+
|
22 |
+
idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
|
23 |
+
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
|
24 |
+
if TYPE_CHECKING:
|
25 |
+
from ignite.engine import Engine
|
26 |
+
else:
|
27 |
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
28 |
+
|
29 |
+
|
30 |
+
class DetectionSaver(ClassificationSaver):
|
31 |
+
"""
|
32 |
+
Event handler triggered on completing every iteration to save the classification predictions as json file.
|
33 |
+
If running in distributed data parallel, only saves json file in the specified rank.
|
34 |
+
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
output_dir: str = "./",
|
40 |
+
filename: str = "predictions.json",
|
41 |
+
overwrite: bool = True,
|
42 |
+
batch_transform: Callable = lambda x: x,
|
43 |
+
output_transform: Callable = lambda x: x,
|
44 |
+
name: Optional[str] = None,
|
45 |
+
save_rank: int = 0,
|
46 |
+
pred_box_key: str = "box",
|
47 |
+
pred_label_key: str = "label",
|
48 |
+
pred_score_key: str = "label_scores",
|
49 |
+
) -> None:
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
output_dir: if `saver=None`, output json file directory.
|
53 |
+
filename: if `saver=None`, name of the saved json file name.
|
54 |
+
overwrite: if `saver=None`, whether to overwriting existing file content, if True,
|
55 |
+
will clear the file before saving. otherwise, will append new content to the file.
|
56 |
+
batch_transform: a callable that is used to extract the `meta_data` dictionary of
|
57 |
+
the input images from `ignite.engine.state.batch`. the purpose is to get the input
|
58 |
+
filenames from the `meta_data` and store with classification results together.
|
59 |
+
`engine.state` and `batch_transform` inherit from the ignite concept:
|
60 |
+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
|
61 |
+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
|
62 |
+
output_transform: a callable that is used to extract the model prediction data from
|
63 |
+
`ignite.engine.state.output`. the first dimension of its output will be treated as
|
64 |
+
the batch dimension. each item in the batch will be saved individually.
|
65 |
+
`engine.state` and `output_transform` inherit from the ignite concept:
|
66 |
+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
|
67 |
+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
|
68 |
+
name: identifier of logging.logger to use, defaulting to `engine.logger`.
|
69 |
+
save_rank: only the handler on specified rank will save to json file in multi-gpus validation,
|
70 |
+
default to 0.
|
71 |
+
pred_box_key: box key in the prediction dict.
|
72 |
+
pred_label_key: classification label key in the prediction dict.
|
73 |
+
pred_score_key: classification score key in the prediction dict.
|
74 |
+
|
75 |
+
"""
|
76 |
+
super().__init__(
|
77 |
+
output_dir=output_dir,
|
78 |
+
filename=filename,
|
79 |
+
overwrite=overwrite,
|
80 |
+
batch_transform=batch_transform,
|
81 |
+
output_transform=output_transform,
|
82 |
+
name=name,
|
83 |
+
save_rank=save_rank,
|
84 |
+
saver=None,
|
85 |
+
)
|
86 |
+
self.pred_box_key = pred_box_key
|
87 |
+
self.pred_label_key = pred_label_key
|
88 |
+
self.pred_score_key = pred_score_key
|
89 |
+
|
90 |
+
def _finalize(self, _engine: Engine) -> None:
|
91 |
+
"""
|
92 |
+
All gather classification results from ranks and save to json file.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
_engine: Ignite Engine, unused argument.
|
96 |
+
"""
|
97 |
+
ws = idist.get_world_size()
|
98 |
+
if self.save_rank >= ws:
|
99 |
+
raise ValueError("target save rank is greater than the distributed group size.")
|
100 |
+
|
101 |
+
# self._outputs is supposed to be a list of dict
|
102 |
+
# self._outputs[i] should be have at least three keys: pred_box_key, pred_label_key, pred_score_key
|
103 |
+
# self._filenames is supposed to be a list of str
|
104 |
+
outputs = self._outputs
|
105 |
+
filenames = self._filenames
|
106 |
+
if ws > 1:
|
107 |
+
outputs = evenly_divisible_all_gather(outputs, concat=False)
|
108 |
+
filenames = string_list_all_gather(filenames)
|
109 |
+
|
110 |
+
if len(filenames) != len(outputs):
|
111 |
+
warnings.warn(f"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.")
|
112 |
+
|
113 |
+
# save to json file only in the expected rank
|
114 |
+
if idist.get_rank() == self.save_rank:
|
115 |
+
results = [
|
116 |
+
{
|
117 |
+
self.pred_box_key: detach_to_numpy(o[self.pred_box_key]).tolist(),
|
118 |
+
self.pred_label_key: detach_to_numpy(o[self.pred_label_key]).tolist(),
|
119 |
+
self.pred_score_key: detach_to_numpy(o[self.pred_score_key]).tolist(),
|
120 |
+
"image": f,
|
121 |
+
}
|
122 |
+
for o, f in zip(outputs, filenames)
|
123 |
+
]
|
124 |
+
|
125 |
+
with open(os.path.join(self.output_dir, self.filename), "w") as outfile:
|
126 |
+
json.dump(results, outfile, indent=4)
|
scripts/evaluator.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from __future__ import annotations
|
13 |
+
|
14 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
|
18 |
+
from monai.engines.evaluator import SupervisedEvaluator
|
19 |
+
from monai.engines.utils import IterationEvents, default_metric_cmp_fn
|
20 |
+
from monai.transforms import Transform
|
21 |
+
from monai.utils import ForwardMode, IgniteInfo, min_version, optional_import
|
22 |
+
from monai.utils.enums import CommonKeys as Keys
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from .detection_inferer import RetinaNetInferer
|
26 |
+
|
27 |
+
if TYPE_CHECKING:
|
28 |
+
from ignite.engine import Engine, EventEnum
|
29 |
+
from ignite.metrics import Metric
|
30 |
+
else:
|
31 |
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
32 |
+
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
|
33 |
+
EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
|
34 |
+
|
35 |
+
__all__ = ["DetectionEvaluator"]
|
36 |
+
|
37 |
+
|
38 |
+
def detection_prepare_val_batch(
|
39 |
+
batchdata: List[Dict[str, torch.Tensor]],
|
40 |
+
device: Optional[Union[str, torch.device]] = None,
|
41 |
+
non_blocking: bool = False,
|
42 |
+
**kwargs,
|
43 |
+
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
44 |
+
"""
|
45 |
+
Default function to prepare the data for current iteration.
|
46 |
+
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
|
47 |
+
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
|
48 |
+
`kwargs` supports other args for `Tensor.to()` API.
|
49 |
+
Returns:
|
50 |
+
image, label(optional).
|
51 |
+
"""
|
52 |
+
inputs = [
|
53 |
+
batch_data_i["image"].to(device=device, non_blocking=non_blocking, **kwargs) for batch_data_i in batchdata
|
54 |
+
]
|
55 |
+
|
56 |
+
if isinstance(batchdata[0].get(Keys.LABEL), torch.Tensor):
|
57 |
+
targets = [
|
58 |
+
dict(
|
59 |
+
label=batch_data_i["label"].to(device=device, non_blocking=non_blocking, **kwargs),
|
60 |
+
box=batch_data_i["box"].to(device=device, non_blocking=non_blocking, **kwargs),
|
61 |
+
)
|
62 |
+
for batch_data_i in batchdata
|
63 |
+
]
|
64 |
+
return (inputs, targets)
|
65 |
+
return inputs, None
|
66 |
+
|
67 |
+
|
68 |
+
class DetectionEvaluator(SupervisedEvaluator):
|
69 |
+
"""
|
70 |
+
Supervised detection evaluation method with image and label, inherits from ``SupervisedEvaluator`` and ``Workflow``.
|
71 |
+
Args:
|
72 |
+
device: an object representing the device on which to run.
|
73 |
+
val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
|
74 |
+
network: detector to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`.
|
75 |
+
epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
|
76 |
+
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
|
77 |
+
with respect to the host. For other cases, this argument has no effect.
|
78 |
+
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
|
79 |
+
from `engine.state.batch` for every iteration, for more details please refer to:
|
80 |
+
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
|
81 |
+
iteration_update: the callable function for every iteration, expect to accept `engine`
|
82 |
+
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
|
83 |
+
if not provided, use `self._iteration()` instead. for more details please refer to:
|
84 |
+
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
|
85 |
+
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
|
86 |
+
postprocessing: execute additional transformation for the model output data.
|
87 |
+
Typically, several Tensor based transforms composed by `Compose`.
|
88 |
+
key_val_metric: compute metric when every iteration completed, and save average value to
|
89 |
+
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
|
90 |
+
checkpoint into files.
|
91 |
+
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
|
92 |
+
metric_cmp_fn: function to compare current key metric with previous best key metric value,
|
93 |
+
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
|
94 |
+
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
|
95 |
+
val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
|
96 |
+
CheckpointHandler, StatsHandler, etc.
|
97 |
+
amp: whether to enable auto-mixed-precision evaluation, default is False.
|
98 |
+
mode: model forward mode during evaluation, should be 'eval' or 'train',
|
99 |
+
which maps to `model.eval()` or `model.train()`, default to 'eval'.
|
100 |
+
event_names: additional custom ignite events that will register to the engine.
|
101 |
+
new events can be a list of str or `ignite.engine.events.EventEnum`.
|
102 |
+
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
|
103 |
+
for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
|
104 |
+
#ignite.engine.engine.Engine.register_events.
|
105 |
+
decollate: whether to decollate the batch-first data to a list of data after model computation,
|
106 |
+
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
|
107 |
+
default to `True`.
|
108 |
+
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
109 |
+
`device`, `non_blocking`.
|
110 |
+
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
|
111 |
+
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
device: torch.device,
|
117 |
+
val_data_loader: Iterable | DataLoader,
|
118 |
+
network: RetinaNetDetector,
|
119 |
+
epoch_length: int | None = None,
|
120 |
+
non_blocking: bool = False,
|
121 |
+
prepare_batch: Callable = detection_prepare_val_batch,
|
122 |
+
iteration_update: Callable[[Engine, Any], Any] | None = None,
|
123 |
+
inferer: RetinaNetInferer | None = None,
|
124 |
+
postprocessing: Transform | None = None,
|
125 |
+
key_val_metric: dict[str, Metric] | None = None,
|
126 |
+
additional_metrics: dict[str, Metric] | None = None,
|
127 |
+
metric_cmp_fn: Callable = default_metric_cmp_fn,
|
128 |
+
val_handlers: Sequence | None = None,
|
129 |
+
amp: bool = False,
|
130 |
+
mode: ForwardMode | str = ForwardMode.EVAL,
|
131 |
+
event_names: list[str | EventEnum] | None = None,
|
132 |
+
event_to_attr: dict | None = None,
|
133 |
+
decollate: bool = True,
|
134 |
+
to_kwargs: dict | None = None,
|
135 |
+
amp_kwargs: dict | None = None,
|
136 |
+
) -> None:
|
137 |
+
super().__init__(
|
138 |
+
device=device,
|
139 |
+
val_data_loader=val_data_loader,
|
140 |
+
network=network,
|
141 |
+
epoch_length=epoch_length,
|
142 |
+
non_blocking=non_blocking,
|
143 |
+
prepare_batch=prepare_batch,
|
144 |
+
iteration_update=iteration_update,
|
145 |
+
inferer=inferer,
|
146 |
+
postprocessing=postprocessing,
|
147 |
+
key_val_metric=key_val_metric,
|
148 |
+
additional_metrics=additional_metrics,
|
149 |
+
metric_cmp_fn=metric_cmp_fn,
|
150 |
+
val_handlers=val_handlers,
|
151 |
+
amp=amp,
|
152 |
+
mode=mode,
|
153 |
+
event_names=event_names,
|
154 |
+
event_to_attr=event_to_attr,
|
155 |
+
decollate=decollate,
|
156 |
+
to_kwargs=to_kwargs,
|
157 |
+
amp_kwargs=amp_kwargs,
|
158 |
+
)
|
159 |
+
|
160 |
+
def _register_decollate(self):
|
161 |
+
"""
|
162 |
+
Register the decollate operation for batch data, will execute after model forward and loss forward.
|
163 |
+
"""
|
164 |
+
|
165 |
+
@self.on(IterationEvents.MODEL_COMPLETED)
|
166 |
+
def _decollate_data(engine: Engine) -> None:
|
167 |
+
output_list = []
|
168 |
+
for i in range(len(engine.state.output[Keys.IMAGE])):
|
169 |
+
output_list.append({})
|
170 |
+
for k in engine.state.output.keys():
|
171 |
+
if engine.state.output[k] is not None:
|
172 |
+
output_list[i][k] = engine.state.output[k][i]
|
173 |
+
engine.state.output = output_list
|
scripts/trainer.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from __future__ import annotations
|
13 |
+
|
14 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from monai.engines.trainer import Trainer
|
18 |
+
from monai.engines.utils import IterationEvents, default_metric_cmp_fn
|
19 |
+
from monai.inferers import Inferer
|
20 |
+
from monai.transforms import Transform
|
21 |
+
from monai.utils import IgniteInfo, min_version, optional_import
|
22 |
+
from monai.utils.enums import CommonKeys as Keys
|
23 |
+
from torch.optim.optimizer import Optimizer
|
24 |
+
from torch.utils.data import DataLoader
|
25 |
+
|
26 |
+
if TYPE_CHECKING:
|
27 |
+
from ignite.engine import Engine, EventEnum
|
28 |
+
from ignite.metrics import Metric
|
29 |
+
else:
|
30 |
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
31 |
+
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
|
32 |
+
EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
|
33 |
+
|
34 |
+
__all__ = ["DetectionTrainer"]
|
35 |
+
|
36 |
+
|
37 |
+
def detection_prepare_batch(
|
38 |
+
batchdata: List[Dict[str, torch.Tensor]],
|
39 |
+
device: Optional[Union[str, torch.device]] = None,
|
40 |
+
non_blocking: bool = False,
|
41 |
+
**kwargs,
|
42 |
+
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
|
43 |
+
"""
|
44 |
+
Default function to prepare the data for current iteration.
|
45 |
+
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
|
46 |
+
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
|
47 |
+
`kwargs` supports other args for `Tensor.to()` API.
|
48 |
+
Returns:
|
49 |
+
image, label(optional).
|
50 |
+
"""
|
51 |
+
inputs = [
|
52 |
+
batch_data_ii["image"].to(device=device, non_blocking=non_blocking, **kwargs)
|
53 |
+
for batch_data_i in batchdata
|
54 |
+
for batch_data_ii in batch_data_i
|
55 |
+
]
|
56 |
+
|
57 |
+
if isinstance(batchdata[0][0].get(Keys.LABEL), torch.Tensor):
|
58 |
+
targets = [
|
59 |
+
dict(
|
60 |
+
label=batch_data_ii["label"].to(device=device, non_blocking=non_blocking, **kwargs),
|
61 |
+
box=batch_data_ii["box"].to(device=device, non_blocking=non_blocking, **kwargs),
|
62 |
+
)
|
63 |
+
for batch_data_i in batchdata
|
64 |
+
for batch_data_ii in batch_data_i
|
65 |
+
]
|
66 |
+
return (inputs, targets)
|
67 |
+
return inputs, None
|
68 |
+
|
69 |
+
|
70 |
+
class DetectionTrainer(Trainer):
|
71 |
+
"""
|
72 |
+
Supervised detection training method with image and label, inherits from ``Trainer`` and ``Workflow``.
|
73 |
+
Args:
|
74 |
+
device: an object representing the device on which to run.
|
75 |
+
max_epochs: the total epoch number for trainer to run.
|
76 |
+
train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
|
77 |
+
detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`.
|
78 |
+
optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim`
|
79 |
+
or its subclass.
|
80 |
+
epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
|
81 |
+
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
|
82 |
+
with respect to the host. For other cases, this argument has no effect.
|
83 |
+
prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
|
84 |
+
from `engine.state.batch` for every iteration, for more details please refer to:
|
85 |
+
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
|
86 |
+
iteration_update: the callable function for every iteration, expect to accept `engine`
|
87 |
+
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
|
88 |
+
if not provided, use `self._iteration()` instead. for more details please refer to:
|
89 |
+
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
|
90 |
+
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
|
91 |
+
postprocessing: execute additional transformation for the model output data.
|
92 |
+
Typically, several Tensor based transforms composed by `Compose`.
|
93 |
+
key_train_metric: compute metric when every iteration completed, and save average value to
|
94 |
+
engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
|
95 |
+
checkpoint into files.
|
96 |
+
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
|
97 |
+
metric_cmp_fn: function to compare current key metric with previous best key metric value,
|
98 |
+
it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
|
99 |
+
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
|
100 |
+
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
|
101 |
+
CheckpointHandler, StatsHandler, etc.
|
102 |
+
amp: whether to enable auto-mixed-precision training, default is False.
|
103 |
+
event_names: additional custom ignite events that will register to the engine.
|
104 |
+
new events can be a list of str or `ignite.engine.events.EventEnum`.
|
105 |
+
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
|
106 |
+
for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
|
107 |
+
#ignite.engine.engine.Engine.register_events.
|
108 |
+
decollate: whether to decollate the batch-first data to a list of data after model computation,
|
109 |
+
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
|
110 |
+
default to `True`.
|
111 |
+
optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
|
112 |
+
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
|
113 |
+
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
|
114 |
+
`device`, `non_blocking`.
|
115 |
+
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
|
116 |
+
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
device: torch.device,
|
122 |
+
max_epochs: int,
|
123 |
+
train_data_loader: Iterable | DataLoader,
|
124 |
+
detector: torch.nn.Module,
|
125 |
+
optimizer: Optimizer,
|
126 |
+
epoch_length: int | None = None,
|
127 |
+
non_blocking: bool = False,
|
128 |
+
prepare_batch: Callable = detection_prepare_batch,
|
129 |
+
iteration_update: Callable[[Engine, Any], Any] | None = None,
|
130 |
+
inferer: Inferer | None = None,
|
131 |
+
postprocessing: Transform | None = None,
|
132 |
+
key_train_metric: dict[str, Metric] | None = None,
|
133 |
+
additional_metrics: dict[str, Metric] | None = None,
|
134 |
+
metric_cmp_fn: Callable = default_metric_cmp_fn,
|
135 |
+
train_handlers: Sequence | None = None,
|
136 |
+
amp: bool = False,
|
137 |
+
event_names: list[str | EventEnum] | None = None,
|
138 |
+
event_to_attr: dict | None = None,
|
139 |
+
decollate: bool = True,
|
140 |
+
optim_set_to_none: bool = False,
|
141 |
+
to_kwargs: dict | None = None,
|
142 |
+
amp_kwargs: dict | None = None,
|
143 |
+
) -> None:
|
144 |
+
super().__init__(
|
145 |
+
device=device,
|
146 |
+
max_epochs=max_epochs,
|
147 |
+
data_loader=train_data_loader,
|
148 |
+
epoch_length=epoch_length,
|
149 |
+
non_blocking=non_blocking,
|
150 |
+
prepare_batch=prepare_batch,
|
151 |
+
iteration_update=iteration_update,
|
152 |
+
postprocessing=postprocessing,
|
153 |
+
key_metric=key_train_metric,
|
154 |
+
additional_metrics=additional_metrics,
|
155 |
+
metric_cmp_fn=metric_cmp_fn,
|
156 |
+
handlers=train_handlers,
|
157 |
+
amp=amp,
|
158 |
+
event_names=event_names,
|
159 |
+
event_to_attr=event_to_attr,
|
160 |
+
decollate=decollate,
|
161 |
+
to_kwargs=to_kwargs,
|
162 |
+
amp_kwargs=amp_kwargs,
|
163 |
+
)
|
164 |
+
|
165 |
+
self.detector = detector
|
166 |
+
self.optimizer = optimizer
|
167 |
+
self.optim_set_to_none = optim_set_to_none
|
168 |
+
|
169 |
+
def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
|
170 |
+
"""
|
171 |
+
Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
|
172 |
+
Return below items in a dictionary:
|
173 |
+
- IMAGE: image Tensor data for model input, already moved to device.
|
174 |
+
- BOX: box regression loss corresponding to the image, already moved to device.
|
175 |
+
- LABEL: classification loss corresponding to the image, already moved to device.
|
176 |
+
- LOSS: weighted sum of loss values computed by loss function.
|
177 |
+
Args:
|
178 |
+
engine: `DetectionTrainer` to execute operation for an iteration.
|
179 |
+
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
|
180 |
+
Raises:
|
181 |
+
ValueError: When ``batchdata`` is None.
|
182 |
+
"""
|
183 |
+
|
184 |
+
if batchdata is None:
|
185 |
+
raise ValueError("Must provide batch data for current iteration.")
|
186 |
+
|
187 |
+
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
|
188 |
+
if len(batch) == 2:
|
189 |
+
inputs, targets = batch
|
190 |
+
args: tuple = ()
|
191 |
+
kwargs: dict = {}
|
192 |
+
else:
|
193 |
+
inputs, targets, args, kwargs = batch
|
194 |
+
# put iteration outputs into engine.state
|
195 |
+
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
|
196 |
+
|
197 |
+
def _compute_pred_loss(w_cls: float = 1.0, w_box_reg: float = 1.0):
|
198 |
+
"""
|
199 |
+
Args:
|
200 |
+
w_cls: weight of classification loss
|
201 |
+
w_box_reg: weight of box regression loss
|
202 |
+
"""
|
203 |
+
outputs = engine.detector(inputs, targets)
|
204 |
+
engine.state.output[engine.detector.cls_key] = outputs[engine.detector.cls_key]
|
205 |
+
engine.state.output[engine.detector.box_reg_key] = outputs[engine.detector.box_reg_key]
|
206 |
+
engine.state.output[Keys.LOSS] = (
|
207 |
+
w_cls * outputs[engine.detector.cls_key] + w_box_reg * outputs[engine.detector.box_reg_key]
|
208 |
+
)
|
209 |
+
engine.fire_event(IterationEvents.LOSS_COMPLETED)
|
210 |
+
|
211 |
+
engine.detector.train()
|
212 |
+
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
|
213 |
+
|
214 |
+
if engine.amp and engine.scaler is not None:
|
215 |
+
with torch.cuda.amp.autocast(**engine.amp_kwargs):
|
216 |
+
inputs = [img.to(torch.float16) for img in inputs]
|
217 |
+
_compute_pred_loss()
|
218 |
+
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
|
219 |
+
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
|
220 |
+
engine.scaler.step(engine.optimizer)
|
221 |
+
engine.scaler.update()
|
222 |
+
else:
|
223 |
+
_compute_pred_loss()
|
224 |
+
engine.state.output[Keys.LOSS].backward()
|
225 |
+
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
|
226 |
+
engine.optimizer.step()
|
227 |
+
|
228 |
+
return engine.state.output
|
scripts/utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def detach_to_numpy(data: Union[List, Dict, torch.Tensor]) -> Union[List, Dict, torch.Tensor]:
|
8 |
+
"""
|
9 |
+
Recursively detach elements in data
|
10 |
+
"""
|
11 |
+
if isinstance(data, torch.Tensor):
|
12 |
+
return data.cpu().detach().numpy() # pytype: disable=attribute-error
|
13 |
+
|
14 |
+
elif isinstance(data, np.ndarray):
|
15 |
+
return data
|
16 |
+
|
17 |
+
elif isinstance(data, list):
|
18 |
+
return [detach_to_numpy(d) for d in data]
|
19 |
+
|
20 |
+
elif isinstance(data, dict):
|
21 |
+
for k in data.keys():
|
22 |
+
data[k] = detach_to_numpy(data[k])
|
23 |
+
return data
|
24 |
+
|
25 |
+
else:
|
26 |
+
raise ValueError("data should be tensor, numpy array, dict, or list.")
|
scripts/warmup_scheduler.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
"""
|
13 |
+
This script is adapted from
|
14 |
+
https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py
|
15 |
+
"""
|
16 |
+
|
17 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
|
18 |
+
|
19 |
+
|
20 |
+
class GradualWarmupScheduler(_LRScheduler):
|
21 |
+
"""Gradually warm-up(increasing) learning rate in optimizer.
|
22 |
+
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
optimizer (Optimizer): Wrapped optimizer.
|
26 |
+
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0.
|
27 |
+
if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
|
28 |
+
total_epoch: target learning rate is reached at total_epoch, gradually
|
29 |
+
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
|
33 |
+
self.multiplier = multiplier
|
34 |
+
if self.multiplier < 1.0:
|
35 |
+
raise ValueError("multiplier should be greater thant or equal to 1.")
|
36 |
+
self.total_epoch = total_epoch
|
37 |
+
self.after_scheduler = after_scheduler
|
38 |
+
self.finished = False
|
39 |
+
super(GradualWarmupScheduler, self).__init__(optimizer)
|
40 |
+
|
41 |
+
def get_lr(self):
|
42 |
+
self.last_epoch = max(1, self.last_epoch) # to avoid epoch=0 thus lr=0
|
43 |
+
if self.last_epoch > self.total_epoch:
|
44 |
+
if self.after_scheduler:
|
45 |
+
if not self.finished:
|
46 |
+
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
|
47 |
+
self.finished = True
|
48 |
+
return self.after_scheduler.get_last_lr()
|
49 |
+
return [base_lr * self.multiplier for base_lr in self.base_lrs]
|
50 |
+
|
51 |
+
if self.multiplier == 1.0:
|
52 |
+
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
|
53 |
+
else:
|
54 |
+
return [
|
55 |
+
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
|
56 |
+
for base_lr in self.base_lrs
|
57 |
+
]
|
58 |
+
|
59 |
+
def step_reduce_lr_on_plateau(self, metrics, epoch=None):
|
60 |
+
if epoch is None:
|
61 |
+
epoch = self.last_epoch + 1
|
62 |
+
self.last_epoch = (
|
63 |
+
epoch if epoch != 0 else 1
|
64 |
+
) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
|
65 |
+
if self.last_epoch <= self.total_epoch:
|
66 |
+
warmup_lr = [
|
67 |
+
base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
|
68 |
+
for base_lr in self.base_lrs
|
69 |
+
]
|
70 |
+
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
|
71 |
+
param_group["lr"] = lr
|
72 |
+
else:
|
73 |
+
if epoch is None:
|
74 |
+
self.after_scheduler.step(metrics, None)
|
75 |
+
else:
|
76 |
+
self.after_scheduler.step(metrics, epoch - self.total_epoch)
|
77 |
+
|
78 |
+
def step(self, epoch=None, metrics=None):
|
79 |
+
if not isinstance(self.after_scheduler, ReduceLROnPlateau):
|
80 |
+
if self.finished and self.after_scheduler:
|
81 |
+
if epoch is None:
|
82 |
+
self.after_scheduler.step(None)
|
83 |
+
else:
|
84 |
+
self.after_scheduler.step(epoch - self.total_epoch)
|
85 |
+
self._last_lr = self.after_scheduler.get_last_lr()
|
86 |
+
else:
|
87 |
+
return super(GradualWarmupScheduler, self).step(epoch)
|
88 |
+
else:
|
89 |
+
self.step_reduce_lr_on_plateau(metrics, epoch)
|