project-monai commited on
Commit
38fd365
·
verified ·
1 Parent(s): a4a37b1

Upload classification_template version 0.0.3

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 MONAI Consortium
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
configs/evaluate.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implements the workflow for applying the network to a directory of images and measuring network performance with metrics.
2
+
3
+ # these transforms are used for inference to load and regularise inputs
4
+ transforms:
5
+ - _target_: AsDiscreted
6
+ keys: ['@pred', '@label']
7
+ argmax: [true, false]
8
+ to_onehot: '@num_classes'
9
+ - _target_: ToTensord
10
+ keys: ['@pred', '@label']
11
+ device: '@device'
12
+
13
+ postprocessing:
14
+ _target_: Compose
15
+ transforms: $@transforms
16
+
17
+ # inference handlers to load checkpoint, gather statistics
18
+ val_handlers:
19
+ - _target_: CheckpointLoader
20
+ _disabled_: $not os.path.exists(@ckpt_path)
21
+ load_path: '@ckpt_path'
22
+ load_dict:
23
+ model: '@network'
24
+ - _target_: StatsHandler
25
+ name: null # use engine.logger as the Logger object to log to
26
+ output_transform: '$lambda x: None'
27
+ - _target_: MetricsSaver
28
+ save_dir: '@output_dir'
29
+ metrics: ['val_accuracy']
30
+ metric_details: ['val_accuracy']
31
+ batch_transform: "$lambda x: [xx['image'].meta for xx in x]"
32
+ summary_ops: "*"
33
+
34
+ initialize:
35
+ - "$monai.utils.set_determinism(seed=123)"
36
+ - "$setattr(torch.backends.cudnn, 'benchmark', True)"
37
+ run:
38
configs/inference.yaml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implements the workflow for applying the network to a directory of images and measuring network performance with metrics.
2
+
3
+ imports:
4
+ - $import os
5
+ - $import json
6
+ - $import torch
7
+ - $import glob
8
+
9
+ # pull out some constants from MONAI
10
+ image: $monai.utils.CommonKeys.IMAGE
11
+ label: $monai.utils.CommonKeys.LABEL
12
+ pred: $monai.utils.CommonKeys.PRED
13
+
14
+ # hyperparameters for you to modify on the command line
15
+ batch_size: 1 # number of images per batch
16
+ num_workers: 0 # number of workers to generate batches with
17
+ num_classes: 4 # number of classes in training data which network should predict
18
+ device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
+
20
+ # define various paths
21
+ bundle_root: . # root directory of the bundle
22
+ ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting
23
+ dataset_dir: $@bundle_root + '/data/test_data' # where data is coming from
24
+
25
+ # network definition, this could be parameterised by pre-defined values or on the command line
26
+ network_def:
27
+ _target_: DenseNet121
28
+ spatial_dims: 2
29
+ in_channels: 1
30
+ out_channels: '@num_classes'
31
+ network: $@network_def.to(@device)
32
+
33
+ # list all niftis in the input directory
34
+ test_json: "$@bundle_root+'/data/test_samples.json'"
35
+ test_fp: "$open(@test_json,'r', encoding='utf8')"
36
+ # load json file
37
+ test_dict: "$json.load(@test_fp)"
38
+
39
+ # these transforms are used for inference to load and regularise inputs
40
+ transforms:
41
+ - _target_: LoadImaged
42
+ keys: '@image'
43
+ - _target_: EnsureChannelFirstd
44
+ keys: '@image'
45
+ - _target_: ScaleIntensityd
46
+ keys: '@image'
47
+
48
+ preprocessing:
49
+ _target_: Compose
50
+ transforms: $@transforms
51
+
52
+ dataset:
53
+ _target_: Dataset
54
+ data: '@test_dict'
55
+ transform: '@preprocessing'
56
+
57
+ dataloader:
58
+ _target_: ThreadDataLoader # generate data ansynchronously from inference
59
+ dataset: '@dataset'
60
+ batch_size: '@batch_size'
61
+ num_workers: '@num_workers'
62
+
63
+ # should be replaced with other inferer types if training process is different for your network
64
+ inferer:
65
+ _target_: SimpleInferer
66
+
67
+ # transform to apply to data from network to be suitable for validation
68
+ postprocessing:
69
+ _target_: Compose
70
+ transforms:
71
+ - _target_: Activationsd
72
+ keys: '@pred'
73
+ softmax: true
74
+ - _target_: AsDiscreted
75
+ keys: ['@pred', '@label']
76
+ argmax: [true, false]
77
+ to_onehot: '@num_classes'
78
+ - _target_: ToTensord
79
+ keys: ['@pred', '@label']
80
+ device: '@device'
81
+
82
+ # inference handlers to load checkpoint, gather statistics
83
+ val_handlers:
84
+ - _target_: CheckpointLoader
85
+ _disabled_: $not os.path.exists(@ckpt_path)
86
+ load_path: '@ckpt_path'
87
+ load_dict:
88
+ model: '@network'
89
+ - _target_: StatsHandler
90
+ name: null # use engine.logger as the Logger object to log to
91
+ output_transform: '$lambda x: None'
92
+
93
+ # engine for running inference, ties together objects defined above and has metric definitions
94
+ evaluator:
95
+ _target_: SupervisedEvaluator
96
+ device: '@device'
97
+ val_data_loader: '@dataloader'
98
+ network: '@network'
99
+ inferer: '@inferer'
100
+ postprocessing: '@postprocessing'
101
+ key_val_metric:
102
+ val_accuracy:
103
+ _target_: ignite.metrics.Accuracy
104
+ output_transform: $monai.handlers.from_engine([@pred, @label])
105
+ additional_metrics:
106
+ val_f1: # can have other metrics
107
+ _target_: ConfusionMatrix
108
+ metric_name: 'f1 score'
109
+ output_transform: $monai.handlers.from_engine([@pred, @label])
110
+ val_handlers: '@val_handlers'
111
+
112
+ initialize:
113
+ - "$setattr(torch.backends.cudnn, 'benchmark', True)"
114
+ run:
115
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
3
+ "version": "0.0.3",
4
+ "changelog": {
5
+ "0.0.3": "update to huggingface hosting",
6
+ "0.0.2": "update large file yml",
7
+ "0.0.1": "Initial version"
8
+ },
9
+ "monai_version": "1.4.0",
10
+ "pytorch_version": "2.4.0",
11
+ "numpy_version": "1.24.4",
12
+ "required_packages_version": {
13
+ "pytorch-ignite": "0.4.11",
14
+ "pyyaml": "6.0.2"
15
+ },
16
+ "supported_apps": {},
17
+ "name": "Classification Template",
18
+ "task": "Classification Template in 2D images",
19
+ "description": "This is a template bundle for classifying in 2D, take this as a basis for your own bundles.",
20
+ "authors": "Yun Liu",
21
+ "copyright": "Copyright (c) 2023 MONAI Consortium",
22
+ "network_data_format": {
23
+ "inputs": {
24
+ "image": {
25
+ "type": "image",
26
+ "format": "magnitude",
27
+ "modality": "none",
28
+ "num_channels": 1,
29
+ "spatial_shape": [
30
+ 128,
31
+ 128
32
+ ],
33
+ "dtype": "float32",
34
+ "value_range": [],
35
+ "is_patch_data": false,
36
+ "channel_def": {
37
+ "0": "image"
38
+ }
39
+ }
40
+ },
41
+ "outputs": {
42
+ "pred": {
43
+ "type": "probabilities",
44
+ "format": "classes",
45
+ "num_channels": 4,
46
+ "spatial_shape": [
47
+ 1,
48
+ 4
49
+ ],
50
+ "dtype": "float32",
51
+ "value_range": [
52
+ 0,
53
+ 1,
54
+ 2,
55
+ 3
56
+ ],
57
+ "is_patch_data": false,
58
+ "channel_def": {
59
+ "0": "background",
60
+ "1": "circle",
61
+ "2": "triangle",
62
+ "3": "rectangle"
63
+ }
64
+ }
65
+ }
66
+ }
67
+ }
configs/multi_gpu_train.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains the changes to implement DDP training with the train.yaml config.
2
+
3
+ device: "$torch.device('cuda:' + os.environ['LOCAL_RANK'])" # assumes GPU # matches rank #
4
+
5
+ # wrap the network in a DistributedDataParallel instance, moving it to the chosen device for this process
6
+ network:
7
+ _target_: torch.nn.parallel.DistributedDataParallel
8
+ module: $@network_def.to(@device)
9
+ device_ids: ['@device']
10
+ find_unused_parameters: true
11
+
12
+ train_sampler:
13
+ _target_: DistributedSampler
14
+ dataset: '@train_dataset'
15
+ even_divisible: true
16
+ shuffle: true
17
+
18
+ train_dataloader#sampler: '@train_sampler'
19
+ train_dataloader#shuffle: false
20
+
21
+ val_sampler:
22
+ _target_: DistributedSampler
23
+ dataset: '@val_dataset'
24
+ even_divisible: false
25
+ shuffle: false
26
+
27
+ val_dataloader#sampler: '@val_sampler'
28
+
29
+ initialize:
30
+ - $import torch.distributed as dist
31
+ - $dist.init_process_group(backend='nccl')
32
+ - $torch.cuda.set_device(@device)
33
+ - $monai.utils.set_determinism(seed=123) # may want to choose a different seed or not do this here
34
+ run:
35
36
+ finalize:
37
+ - '$dist.is_initialized() and dist.destroy_process_group()'
configs/train.yaml ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This config file implements the training workflow. It can be combined with multi_gpu_train.yaml to use DDP for
2
+ # multi-GPU runs.
3
+
4
+ imports:
5
+ - $import os
6
+ - $import json
7
+ - $import datetime
8
+ - $import torch
9
+ - $import glob
10
+
11
+ # pull out some constants from MONAI
12
+ image: $monai.utils.CommonKeys.IMAGE
13
+ label: $monai.utils.CommonKeys.LABEL
14
+ pred: $monai.utils.CommonKeys.PRED
15
+
16
+ # multi-gpu values, `rank` will be replaced in a separate script implementing multi-gpu changes
17
+ rank: 0 # without multi-gpu support consider the process as rank 0 anyway
18
+ is_not_rank0: '$@rank > 0' # true if not main process, used to disable handlers for other ranks
19
+
20
+ # hyperparameters for you to modify on the command line
21
+ val_interval: 1 # how often to perform validation after an epoch
22
+ ckpt_interval: 1 # how often to save a checkpoint after an epoch
23
+ rand_prob: 0.5 # probability a random transform is applied
24
+ batch_size: 5 # number of images per batch
25
+ num_epochs: 10 # number of epochs to train for
26
+ num_substeps: 1 # how many times to repeatly train with the same batch
27
+ num_workers: 4 # number of workers to generate batches with
28
+ learning_rate: 0.001 # initial learning rate
29
+ num_classes: 4 # number of classes in training data which network should predict
30
+ device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
31
+
32
+ # define various paths
33
+ bundle_root: . # root directory of the bundle
34
+ ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting
35
+ dataset_dir: $@bundle_root + '/data/train_data' # where data is coming from
36
+ results_dir: $@bundle_root + '/results' # where results are being stored to
37
+ # a new output directory is chosen using a timestamp for every invocation
38
+ output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')'
39
+
40
+ # network definition, this could be parameterised by pre-defined values or on the command line
41
+ network_def:
42
+ _target_: DenseNet121
43
+ spatial_dims: 2
44
+ in_channels: 1
45
+ out_channels: '@num_classes'
46
+ network: $@network_def.to(@device)
47
+
48
+ # dataset value, this assumes a JOSN file filled with img##.nii.gz file and label
49
+ data_json: $@bundle_root + '/data/train_samples.json' # where training data is located and label
50
+ data_fp: "$open(@data_json,'r', encoding='utf8')"
51
+ data_dict: "$json.load(@data_fp)"
52
+ partitions: '$monai.data.partition_dataset(@data_dict, (4, 1), shuffle=True, seed=0)'
53
+ train_sub: '$@partitions[0]' # train partition
54
+ val_sub: '$@partitions[1]' # validation partition
55
+
56
+ # these transforms are used for training and validation transform sequences
57
+ base_transforms:
58
+ - _target_: LoadImaged
59
+ keys: '@image'
60
+ - _target_: EnsureChannelFirstd
61
+ keys: '@image'
62
+
63
+ # these are the random and regularising transforms used only for training
64
+ train_transforms:
65
+ - _target_: RandAxisFlipd
66
+ keys: '@image'
67
+ prob: '@rand_prob'
68
+ - _target_: RandRotate90d
69
+ keys: '@image'
70
+ prob: '@rand_prob'
71
+ - _target_: RandGaussianNoised
72
+ keys: '@image'
73
+ prob: '@rand_prob'
74
+ std: 0.05
75
+ - _target_: ScaleIntensityd
76
+ keys: '@image'
77
+
78
+ # these are used for validation data so no randomness
79
+ val_transforms:
80
+ - _target_: ScaleIntensityd
81
+ keys: '@image'
82
+
83
+ # define the Compose objects for training and validation
84
+ preprocessing:
85
+ _target_: Compose
86
+ transforms: $@base_transforms + @train_transforms
87
+
88
+ val_preprocessing:
89
+ _target_: Compose
90
+ transforms: $@base_transforms + @val_transforms
91
+
92
+ # define the datasets for training and validation
93
+ train_dataset:
94
+ _target_: Dataset
95
+ data: '@train_sub'
96
+ transform: '@preprocessing'
97
+
98
+ val_dataset:
99
+ _target_: Dataset
100
+ data: '@val_sub'
101
+ transform: '@val_preprocessing'
102
+
103
+ # define the dataloaders for training and validation
104
+ train_dataloader:
105
+ _target_: ThreadDataLoader # generate data ansynchronously from training
106
+ dataset: '@train_dataset'
107
+ batch_size: '@batch_size'
108
+ repeats: '@num_substeps'
109
+ num_workers: '@num_workers'
110
+
111
+ val_dataloader:
112
+ _target_: DataLoader # faster transforms probably won't benefit from threading
113
+ dataset: '@val_dataset'
114
+ batch_size: '@batch_size'
115
+ num_workers: '@num_workers'
116
+
117
+ # Simple CrossEntropy loss configured for multi-class classification
118
+ lossfn:
119
+ _target_: torch.nn.CrossEntropyLoss
120
+ reduction: sum
121
+
122
+ # hyperparameters could be added for other arguments of this class
123
+ optimizer:
124
+ _target_: torch.optim.Adam
125
+ params: [email protected]()
126
+ lr: '@learning_rate'
127
+
128
+ # should be replaced with other inferer types if training process is different for your network
129
+ inferer:
130
+ _target_: SimpleInferer
131
+
132
+ # transform to apply to data from network to be suitable for validation
133
+ postprocessing:
134
+ _target_: Compose
135
+ transforms:
136
+ - _target_: Activationsd
137
+ keys: '@pred'
138
+ softmax: true
139
+ - _target_: AsDiscreted
140
+ keys: ['@pred', '@label']
141
+ argmax: [true, false]
142
+ to_onehot: '@num_classes'
143
+ - _target_: ToTensord
144
+ keys: ['@pred', '@label']
145
+ device: '@device'
146
+
147
+ # validation handlers to gather statistics, log these to a file, and save best checkpoint
148
+ val_handlers:
149
+ - _target_: StatsHandler
150
+ name: null # use engine.logger as the Logger object to log to
151
+ output_transform: '$lambda x: None'
152
+ - _target_: LogfileHandler # log outputs from the validation engine
153
+ output_dir: '@output_dir'
154
+ - _target_: CheckpointSaver
155
+ _disabled_: '@is_not_rank0' # only need rank 0 to save
156
+ save_dir: '@output_dir'
157
+ save_dict:
158
+ model: '@network'
159
+ save_interval: 0 # don't save iterations, just when the metric improves
160
+ save_final: false
161
+ epoch_level: false
162
+ save_key_metric: true
163
+ key_metric_name: val_accuracy # save the checkpoint when this value improves
164
+
165
+ # engine for running validation, ties together objects defined above and has metric definitions
166
+ evaluator:
167
+ _target_: SupervisedEvaluator
168
+ device: '@device'
169
+ val_data_loader: '@val_dataloader'
170
+ network: '@network'
171
+ postprocessing: '@postprocessing'
172
+ key_val_metric:
173
+ val_accuracy:
174
+ _target_: ignite.metrics.Accuracy
175
+ output_transform: $monai.handlers.from_engine([@pred, @label])
176
+ additional_metrics:
177
+ val_f1: # can have other metrics
178
+ _target_: ConfusionMatrix
179
+ metric_name: 'f1 score'
180
+ output_transform: $monai.handlers.from_engine([@pred, @label])
181
+ val_handlers: '@val_handlers'
182
+
183
+ # gathers the loss and validation values for each iteration, referred to by CheckpointSaver so defined separately
184
+ metriclogger:
185
+ _target_: MetricLogger
186
+ evaluator: '@evaluator'
187
+
188
+ handlers:
189
+ - '@metriclogger'
190
+ - _target_: CheckpointLoader
191
+ _disabled_: $not os.path.exists(@ckpt_path)
192
+ load_path: '@ckpt_path'
193
+ load_dict:
194
+ model: '@network'
195
+ - _target_: ValidationHandler # run validation at the set interval, bridge between trainer and evaluator objects
196
+ validator: '@evaluator'
197
+ epoch_level: true
198
+ interval: '@val_interval'
199
+ - _target_: CheckpointSaver
200
+ _disabled_: '@is_not_rank0' # only need rank 0 to save
201
+ save_dir: '@output_dir'
202
+ save_dict: # every epoch checkpoint saves the network and the metric logger in a dictionary
203
+ model: '@network'
204
+ logger: '@metriclogger'
205
+ save_interval: '@ckpt_interval'
206
+ save_final: true
207
+ epoch_level: true
208
+ - _target_: StatsHandler
209
+ name: null # use engine.logger as the Logger object to log to
210
+ tag_name: train_loss
211
+ output_transform: $monai.handlers.from_engine(['loss'], first=True) # log loss value
212
+ - _target_: LogfileHandler # log outputs from the training engine
213
+ output_dir: '@output_dir'
214
+
215
+ # engine for training, ties values defined above together into the main engine for the training process
216
+ trainer:
217
+ _target_: SupervisedTrainer
218
+ max_epochs: '@num_epochs'
219
+ device: '@device'
220
+ train_data_loader: '@train_dataloader'
221
+ network: '@network'
222
+ inferer: '@inferer' # unnecessary since SimpleInferer is the default if this isn't provided
223
+ loss_function: '@lossfn'
224
+ optimizer: '@optimizer'
225
+ # postprocessing: '@postprocessing' # uncomment if you have train metrics that need post-processing
226
+ key_train_metric: null
227
+ train_handlers: '@handlers'
228
+
229
+ initialize:
230
+ - "$monai.utils.set_determinism(seed=123)"
231
+ - "$setattr(torch.backends.cudnn, 'benchmark', True)"
232
+ run:
233
docs/README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Template Classification Bundle
2
+
3
+ This bundle is meant to be an example of classification in 2D which you can copy and modify to create your own bundle.
4
+ It is only roughly trained for the synthetic data you can generate with [this notebook](./generate_data.ipynb)
5
+ so doesn't do anything useful on its own. The purpose is to demonstrate the base line for classification network bundles.
6
+
7
+ To use this bundle, copy the contents of the whole directory and change the definitions for network, data, transforms,
8
+ or whatever else you want for your own new classification bundle.
9
+
10
+ ## Generating Demo Data
11
+
12
+ Run all the cells of [this notebook](./generate_data.ipynb) to generate training and test data. These will be 2D
13
+ nifti files containing volumes with randomly generated circle, triangle or rectangle. The classification task
14
+ is very easy so your network will train in minutes with the default configuration of values. A test
15
+ data directory will separately be created since the inference config is configured to apply the network to
16
+ every nifti file in a given directory with a certain pattern.
17
+
18
+ ## Training
19
+
20
+ To train a new network the `train.yaml` script can be used alone with no other arguments (assume `BUNDLE` is the root
21
+ directory of the bundle):
22
+
23
+ ```
24
+ python -m monai.bundle run --config_file configs/train.yaml
25
+ ```
26
+
27
+ The training config includes a number of hyperparameters like `learning_rate` and `num_workers`. These control aspects
28
+ of how training operates in terms of how many processes to use, when to perform validation, when to save checkpoints,
29
+ and other things. Other aspects of the script can be modified on the command line so these aren't exhaustive but are a
30
+ guide to the kind of parameterisation that make sense for a bundle.
31
+
32
+ ## Override the `train` config to execute multi-GPU training:
33
+
34
+ ```
35
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run --config_file "['configs/train.yaml','configs/multi_gpu_train.yaml']"
36
+ ```
37
+
38
+ Please note that the distributed training-related options depend on the actual running environment; thus, users may need to remove `--standalone`, modify `--nnodes`, or do some other necessary changes according to the machine used. For more details, please refer to [pytorch's official tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).
39
+
40
+ ## Override the `train` config to execute evaluation with the trained model:
41
+
42
+ ```
43
+ python -m monai.bundle run --config_file "['configs/train.yaml','configs/evaluate.yaml']"
44
+ ```
45
+
46
+ ## Execute inference:
47
+
48
+ ```
49
+ python -m monai.bundle run --config_file configs/inference.yaml
50
+ ```
51
+
52
+ ## Other Considerations
53
+
54
+ There is no `scripts` directory containing a valid Python module to be imported in your configs. This wasn't necessary
55
+ for this bundle but if you want to include custom code in a bundle please follow the bundle tutorials on how to do this.
docs/generate_data.ipynb ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "b1c9de9d-6777-4a1d-bb7c-c2413d01bd7d",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Generate Data\n",
9
+ "\n",
10
+ "This bundle uses simple synthetic data for training and testing. Using `create_test_image_3d` we'll create images of spheres with labels for each divided into 3 classes distinguished by intensity. The network will be able to train very quickly on this of course but it's for demonstration purposes and your specialised bundle will by modified for your data and its layout. \n",
11
+ "\n",
12
+ "Assuming this notebook is being run from the `docs` directory it will create two new directories in the root of the bundle, `train_data` and `test_data`.\n",
13
+ "\n",
14
+ "First imports:"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 7,
20
+ "id": "1e7cb4a8-f91a-4f15-a8aa-3136c2b954d6",
21
+ "metadata": {
22
+ "tags": []
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "import os\n",
27
+ "import json\n",
28
+ "import random\n",
29
+ "\n",
30
+ "import matplotlib.pyplot as plt\n",
31
+ "import nibabel as nib\n",
32
+ "import numpy as np\n",
33
+ "\n",
34
+ "plt.rcParams[\"image.interpolation\"] = \"none\""
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "id": "2b2c3de5-01e5-4578-832b-b24a75d095d5",
40
+ "metadata": {},
41
+ "source": [
42
+ "As shown here, the images are spheres in a 3D volume with associated labels:"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 8,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "def generate_images(image_size=128, border=20, shape_probabilities=None, shape_sizes=None):\n",
52
+ " image = np.zeros((image_size, image_size))\n",
53
+ "\n",
54
+ " if shape_probabilities is None:\n",
55
+ " shape_probabilities = [0.25, 0.2, 0.3, 0.25] # Default probabilities for circle, triangle, rectangle\n",
56
+ "\n",
57
+ " if shape_sizes is None:\n",
58
+ " shape_sizes = [(10, 30), (20, 40), (20, 40)] # Default size ranges for circle, triangle, rectangle\n",
59
+ "\n",
60
+ " def draw_zero(image):\n",
61
+ " return image\n",
62
+ "\n",
63
+ " def draw_circle(image):\n",
64
+ " center_x, center_y = np.random.randint(border, image_size - border), np.random.randint(border, image_size - border)\n",
65
+ " radius = np.random.randint(*shape_sizes[0])\n",
66
+ " y, x = np.ogrid[-center_x:image_size-center_x, -center_y:image_size-center_y]\n",
67
+ " mask = x ** 2 + y ** 2 <= radius ** 2\n",
68
+ " image[mask] = 1\n",
69
+ " return image\n",
70
+ "\n",
71
+ " def draw_triangle(image):\n",
72
+ " size = np.random.randint(*shape_sizes[1])\n",
73
+ " x1, y1 = np.random.randint(border, image_size - border), np.random.randint(border, image_size - border)\n",
74
+ " x2, y2 = x1 + size, y1\n",
75
+ " x3, y3 = x1 + size // 2, y1 - int(size * np.sqrt(3) / 2)\n",
76
+ " triangle = np.array([[x1, x2, x3], [y1, y2, y3]])\n",
77
+ " mask = plt.matplotlib.path.Path(np.transpose(triangle)).contains_points(\n",
78
+ " np.array([(i, j) for i in range(image_size) for j in range(image_size)])\n",
79
+ " )\n",
80
+ " image[mask.reshape(image_size, image_size)] = 1\n",
81
+ " return image\n",
82
+ "\n",
83
+ " def draw_rectangle(image):\n",
84
+ " x1, y1 = np.random.randint(border, image_size - border), np.random.randint(border, image_size - border)\n",
85
+ " x2, y2 = x1 + np.random.randint(*shape_sizes[2]), y1 + np.random.randint(*shape_sizes[2])\n",
86
+ " image[x1:x2, y1:y2] = 1\n",
87
+ " return image\n",
88
+ "\n",
89
+ " label, shape = random.choices([(0, draw_zero), (1, draw_circle), (2, draw_triangle), (3, draw_rectangle)], weights=shape_probabilities)[0]\n",
90
+ " image = shape(image)\n",
91
+ "\n",
92
+ " return image, label"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 9,
98
+ "metadata": {},
99
+ "outputs": [
100
+ {
101
+ "data": {
102
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAHqCAYAAAAgWrY5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAJ4ElEQVR4nO3dy27cOBRAQXHQ///LnMXAgTO2O51+8UiqWnlhwAQi5OCSeow559wAgJx/Vi8AAPieSANAlEgDQJRIA0CUSANAlEgDQJRIA0CUSANAlEgDQNTl1l8cY7xyHZzIvS+5cw3yLI+8aNF1yLPcch2apAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEg6rJ6AQBzzl8/jzEWroRn+/xvW1e89kQaWGZP/4HDCiINvJ04w21EGngbcYa/48YxAIgySQMvZ4KG+4g08DLiDI8RaeDpxBmeQ6SBpxFneC6RBh4mzvAaIg3cTZzhtUQa+GviDO/hOWkAiDJJAzcxPcP7maSBPxJoWMMkDfxInGEtkQa+EGdoEGlg2zZhhiJn0oBAQ5RIA0CUSAPbGGP1EoBvOJMGtm37PdS2v6HBJA18McYwXUOASAM/EmpYS6SBq0zVsI4zaeAmzqzh/UzSwF8zWcN7iDQARNnuBu7yMU3b+uYauy6PEWngIc6q4XVsdwNPY2qC5xJp4Kk8sgXPI9LAS4g1PE6kgZcSarifG8eAl3NzGdzHJA0AUSINvJWzaridSANLCDX8mUgDy5iq4To3jgHLCTV8zyQNAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBR3t292Jxz9RJ+5H3KAGuZpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEg6rJ6AWc3xli9BACiTNIAEDXmnHP1IgCAr0zSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQdbn1F8cYr1wHJ3Lv11FdgzzLI1/odR3yLLdchyZpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIi6rF4A7MWcc/USfjTGWL0E4AVM0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQdVm9ANiLMcbqJQAnY5IGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCivBf1kzvnrZ6+ABGA1kzQARJmkt98naACoOHWkxRmAstNudws0AHWnmqSFGYA9OUWkxRmAPTp0pMUZgD077Zk0ANQdNtKmaAD27lDb3cIMwJEcItLiDMAR7TrS4gzAke32TFqgATi63UbaV6oAOLrdRnrb/gu1WANwVLuONAAc2SEibZoG4IgOEelts/UNwPHs+hGs73wOtTvAAdizw0zS3zFZA7Bnh470ttkGB2C/Dh/pD0INwN4c7kz6GufVAOzJaSZpANib00baWTUAdaeN9AehBqDq9JHeNlM1AE2nunHsT4QagBKTNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESNOedcvQgA4CuTNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARF1u/cUxxivXwYnc+3VU1yDP8sgXel2HPMst16FJGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKIuqxcAAPeYc778b4wxXv43rjFJA0CUSANAlEgDQJRI32DO+ZazDwD4zI1jVwgzACuJ9DfEGYAC290AEGWS/sQEDUCJSG/iDEDTqSMtzgCUnTLS4gzAHpwq0uIMwJ6cItLiDMAeHTrS4gzAnnlOGgCiDhtpUzQAe3fYSK/+UDcAPOqwkd62/0It1gDs1aEj/UGoAdijQ9/d/dnnUDuvBmAPTjFJ/5/JGoA9OGWkAWAPThtpN5UBUHfaSH8QawCqTh/pD0INQM1p7u6+hTvAASgxSf/AZA3AaiJ9hfNqAFay3X0DoQZgBZM0AESJNABEiTQARIk0AESJNABEiTQARHkEC0JWvOnOI4bQZZIGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCjPSQOwS2d4xt8kDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUWPOOVcvAgD4yiQNAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUf8CeDXUWZWl5zcAAAAASUVORK5CYII=",
103
+ "text/plain": [
104
+ "<Figure size 500x500 with 9 Axes>"
105
+ ]
106
+ },
107
+ "metadata": {},
108
+ "output_type": "display_data"
109
+ }
110
+ ],
111
+ "source": [
112
+ "fig, axes = plt.subplots(3, 3, figsize=(5, 5))\n",
113
+ "for i, ax in enumerate(axes.flatten()):\n",
114
+ " for j in range(9):\n",
115
+ " images, label = generate_images(128)\n",
116
+ " ax.imshow(images, cmap='gray')\n",
117
+ " ax.axis('off')\n",
118
+ "plt.tight_layout()\n",
119
+ "plt.show()"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "markdown",
124
+ "id": "8e08c4a1-6630-4ab3-832b-e53face81e35",
125
+ "metadata": {},
126
+ "source": [
127
+ "50 image/label pairs are now generated into the directory `../data/train_data`, assuming this notebook is run from the `docs` directory this will be in the bundle root:"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 10,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "num_images = 50\n",
137
+ "out_dir = os.path.abspath(\"../data/train_data\")\n",
138
+ "os.makedirs(out_dir, exist_ok=True)\n",
139
+ "\n",
140
+ "train_data = []\n",
141
+ "for i in range(num_images):\n",
142
+ " data = {}\n",
143
+ " img, lbl = generate_images(128)\n",
144
+ " n = nib.Nifti1Image(img, np.eye(4))\n",
145
+ " train_file_path = os.path.join(out_dir, f\"img{i:02}.nii.gz\")\n",
146
+ " nib.save(n, train_file_path)\n",
147
+ "\n",
148
+ " data[\"image\"] = train_file_path\n",
149
+ " data[\"label\"] = lbl\n",
150
+ " train_data.append(data)\n",
151
+ "\n",
152
+ "with open(os.path.abspath(\"../data/train_samples.json\"), \"w\") as f:\n",
153
+ " json.dump(train_data, f, indent=2)"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "markdown",
158
+ "id": "7fe344f7-d01d-49d5-adca-a7071939ca53",
159
+ "metadata": {},
160
+ "source": [
161
+ "We'll also generate some test data in a separate folder:"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 11,
167
+ "id": "c3b8d8f3-8d73-4657-98f3-5605d4b1bad9",
168
+ "metadata": {
169
+ "tags": []
170
+ },
171
+ "outputs": [],
172
+ "source": [
173
+ "num_images = 10\n",
174
+ "out_dir = os.path.abspath(\"../data/test_data\")\n",
175
+ "os.makedirs(out_dir, exist_ok=True)\n",
176
+ "\n",
177
+ "train_data = []\n",
178
+ "for i in range(num_images):\n",
179
+ " data = {}\n",
180
+ " img, lbl = generate_images(128)\n",
181
+ " n = nib.Nifti1Image(img, np.eye(4))\n",
182
+ " train_file_path = os.path.join(out_dir, f\"img{i:02}.nii.gz\")\n",
183
+ " nib.save(n, train_file_path)\n",
184
+ "\n",
185
+ " data[\"image\"] = train_file_path\n",
186
+ " data[\"label\"] = lbl\n",
187
+ " train_data.append(data)\n",
188
+ "\n",
189
+ "with open(os.path.abspath(\"../data/test_samples.json\"), \"w\") as f:\n",
190
+ " json.dump(train_data, f, indent=2)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 12,
196
+ "id": "599cff25-4894-481b-aec3-6aedda327a09",
197
+ "metadata": {
198
+ "tags": []
199
+ },
200
+ "outputs": [
201
+ {
202
+ "name": "stdout",
203
+ "output_type": "stream",
204
+ "text": [
205
+ "img00.nii.gz img02.nii.gz img04.nii.gz img06.nii.gz\timg08.nii.gz\n",
206
+ "img01.nii.gz img03.nii.gz img05.nii.gz img07.nii.gz\timg09.nii.gz\n"
207
+ ]
208
+ }
209
+ ],
210
+ "source": [
211
+ "!ls {out_dir}"
212
+ ]
213
+ }
214
+ ],
215
+ "metadata": {
216
+ "kernelspec": {
217
+ "display_name": "Python 3",
218
+ "language": "python",
219
+ "name": "python3"
220
+ },
221
+ "language_info": {
222
+ "codemirror_mode": {
223
+ "name": "ipython",
224
+ "version": 3
225
+ },
226
+ "file_extension": ".py",
227
+ "mimetype": "text/x-python",
228
+ "name": "python",
229
+ "nbconvert_exporter": "python",
230
+ "pygments_lexer": "ipython3",
231
+ "version": "3.10.12"
232
+ }
233
+ },
234
+ "nbformat": 4,
235
+ "nbformat_minor": 5
236
+ }
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f5a81ff92b35a13e8ee60edc4499f4e441dcc7a988652718ccf3f94c049fef4
3
+ size 28434292