project-monai commited on
Commit
b8597df
·
verified ·
1 Parent(s): 1e093c4

Upload retinalOCT_RPD_segmentation version 0.0.1

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/Figure1.jpg filter=lfs diff=lfs merge=lfs -text
37
+ sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_026.png filter=lfs diff=lfs merge=lfs -text
38
+ sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_051.png filter=lfs diff=lfs merge=lfs -text
39
+ sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_060.png filter=lfs diff=lfs merge=lfs -text
40
+ sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_027.png filter=lfs diff=lfs merge=lfs -text
41
+ sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_033.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 2-Clause License
2
+
3
+ Copyright (c) 2022, uw-biomedical-ml
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
configs/inference.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imports:
2
+ - $import scripts
3
+ - $import scripts.inference
4
+
5
+ args:
6
+ run_extract : False
7
+ input_dir : "/path/to/data"
8
+ extracted_dir : "/path/to/extracted/data"
9
+ input_format : "dicom"
10
+ create_dataset : True
11
+ dataset_name : "my_dataset_name"
12
+
13
+ output_dir : "/path/to/model/output"
14
+ run_inference : True
15
+ create_tables : True
16
+
17
+ # create visuals
18
+ binary_mask : False
19
+ binary_mask_overlay : True
20
+ instance_mask_overlay : False
21
+
22
+ inference:
23
+ - $scripts.inference.main(@args)
configs/metadata.json ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
+ "version": "0.0.1",
4
+ "changelog": {
5
+ "0.0.1": "Initial version"
6
+ },
7
+ "monai_version": "1.5.0",
8
+ "pytorch_version": "2.6.0",
9
+ "numpy_version": "1.26.4",
10
+ "optional_packages_version": {},
11
+ "required_packages_version": {
12
+ "setuptools": "75.8.0",
13
+ "opencv-python-headless": "4.11.0.86",
14
+ "pandas": "2.3.0",
15
+ "seaborn": "0.13.2",
16
+ "scikit-learn": "1.6.1",
17
+ "progressbar": "2.5",
18
+ "pydicom": "3.0.1",
19
+ "fire": "0.7.0",
20
+ "torchvision": "0.21.0",
21
+ "detectron2": "0.6",
22
+ "lxml": "5.4.0",
23
+ "pillow": "11.2.1"
24
+ },
25
+ "name": "retinalOCT_RPD_segmentation",
26
+ "task": "Reticular Pseudodrusen (RPD) instance segmentation.",
27
+ "description": "This network detects and segments Reticular Pseudodrusen (RPD) instances in Optical Coherence Tomography (OCT) B-scans which can be presented in a vol or dicom format.",
28
+ "authors": "Yelena Bagdasarova, Scott Song",
29
+ "copyright": "Copyright (c) 2022, uw-biomedical-ml",
30
+ "network_data_format": {
31
+ "inputs": {
32
+ "image": {
33
+ "type": "image",
34
+ "format": "magnitude",
35
+ "modality": "OCT",
36
+ "num_channels": 1,
37
+ "spatial_shape": [
38
+ 496,
39
+ 1024
40
+ ],
41
+ "dtype": "int16",
42
+ "value_range": [
43
+ 0,
44
+ 256
45
+ ],
46
+ "is_patch_data": false,
47
+ "channel_def": {
48
+ "0": "image"
49
+ }
50
+ }
51
+ },
52
+ "preprocessed_data_sources": {
53
+ "vol_file": {
54
+ "type": "image",
55
+ "format": "magnitude",
56
+ "modality": "OCT",
57
+ "num_channels": 1,
58
+ "spatial_shape": [
59
+ 496,
60
+ 1024,
61
+ "D"
62
+ ],
63
+ "dtype": "int16",
64
+ "value_range": [
65
+ 0,
66
+ 256
67
+ ],
68
+ "description": "The pixel array of each OCT slice is extracted with volreader and the png files saved to <extracted_dir>/<some>/<file>/<name>/<some_file_name>_oct_<DDD>.png on disk, where <DDD> is the slice number and a nested hierarchy of folders is created using the underscores in the original filename. "
69
+ },
70
+ "dicom_series": {
71
+ "type": "image",
72
+ "format": "magnitude",
73
+ "modality": "OCT",
74
+ "SOP class UID": "1.2.840.10008.5.1.4.1.1.77.1.5.4",
75
+ "num_channels": 1,
76
+ "spatial_shape": [
77
+ 496,
78
+ 1024,
79
+ "D"
80
+ ],
81
+ "dtype": "int16",
82
+ "value_range": [
83
+ 0,
84
+ 256
85
+ ],
86
+ "description": "The pixel array of each OCT slice is extracted with pydicom and the png files saved to <extracted_dir>/<SOPInstanceUID>/<SOPInstanceUID>_oct_<DDD>.png on disk, where <DDD> is the slice number. "
87
+ }
88
+ },
89
+ "outputs": {
90
+ "pred": {
91
+ "dtype": "dictionary",
92
+ "type": "dictionary",
93
+ "format": "COCO",
94
+ "modality": "n/a",
95
+ "value_range": [
96
+ 0,
97
+ 1
98
+ ],
99
+ "num_channels": 1,
100
+ "spatial_shape": [
101
+ 496,
102
+ 1024
103
+ ],
104
+ "channel_def": {
105
+ "0": "RPD"
106
+ },
107
+ "description": "This output is a JSON file in COCO Instance Segmentation format, containing bounding boxes, segmentation masks, and output probabilities for detected instances."
108
+ }
109
+ },
110
+ "post_processed_outputs": {
111
+ "binary segmentation": {
112
+ "type": "image",
113
+ "format": "TIFF",
114
+ "modality": "OCT",
115
+ "num_channels": 3,
116
+ "spatial_shape": [
117
+ 496,
118
+ 1024
119
+ ],
120
+ "description": "This output is a multi-page TIFF file. Each page of the TIFF image corresponds to a binary segmentation mask for a single OCT slice from the input volume. The segmentation masks are stacked in the same order as the original OCT slices."
121
+ },
122
+ "binary segmentation overlay": {
123
+ "type": "image",
124
+ "format": "TIFF",
125
+ "modality": "OCT",
126
+ "num_channels": 3,
127
+ "spatial_shape": [
128
+ 496,
129
+ 1024
130
+ ],
131
+ "description": "This output is a multi-page TIFF file. Each page of the TIFF image corresponds to a single OCT slice from the input volume overlayed with the detected binary segmentation mask."
132
+ },
133
+ "instance segmentation overlay": {
134
+ "type": "image",
135
+ "format": "TIFF",
136
+ "modality": "OCT",
137
+ "num_channels": 3,
138
+ "spatial_shape": [
139
+ 496,
140
+ 1024
141
+ ],
142
+ "description": "This output is a multi-page TIFF file. Each page of the TIFF image corresponds to a single OCT slice from the input volume overlayed with the detected binary segmentation mask."
143
+ }
144
+ }
145
+ }
146
+ }
docs/Figure1.jpg ADDED

Git LFS Details

  • SHA256: cbc11cb99a99519b47ca390ba7734436ecc8c5377d8d4cb54a740285b8eca36d
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
docs/README.md ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # RPD OCT Segmentation
3
+ ### **Authors**
4
+ Himeesh Kumar, Yelena Bagdasarova, Scott Song, Doron G. Hickey, Amy C. Cohn, Mali Okada, Robert P. Finger, Jan H. Terheyden, Ruth E. Hogg, Pierre-Henry Gabrielle, Louis Arnould, Maxime Jannaud, Xavier Hadoux, Peter van Wijngaarden, Carla J. Abbott, Lauren A.B. Hodgson, Roy Schwartz, Adnan Tufail, Emily Y. Chew, Cecilia S. Lee, Erica L. Fletcher, Melanie Bahlo, Brendan R.E. Ansell, Alice Pébay, Robyn H. Guymer, Aaron Y. Lee, Zhichao Wu
5
+
6
+ ### **Tags**
7
+ Reticular Pseudodrusen, AMD, OCT, Segmentation
8
+
9
+ ## **Model Description**
10
+ This model detects and segments Reticular Pseudodrusen (RPD) instances in Optical Coherence Tomography (OCT) B-scans. The instance segmentation model used a Mask-RCNN [1] head with the ResNeXt-101-32x8d-FPN [2] backbone (pretrained on ImageNet) implemented via the Detectron2 framework [3]. The model produces outputs that consist of bounding boxes and segmentation masks that delineate the coordinates and pixels of each instance detected, which are assigned a corresponding output probability. A tuneable probability threshold can then be applied to finalise the binary detection of an RPD instance.
11
+
12
+ Five segmentation models using these RPD instance labels on the OCT B-scans were trained based on five-fold cross-validation which were used to form a final ensemble model using soft voting (see supplementary material of paper for more information on model training.)
13
+
14
+ ## **Data**
15
+ The model was trained using the prospectively-collected, baseline OCT scans (prior to any treatments) of individuals enrolled in the LEAD study [4] imaged using Heidelberg Spectralis HRA+OCT. OCT B-scans from 200 eyes from 100 individuals in the LEAD study were randomly selected to undergo manual annotations of RPD by a single grader (HK) at the pixel level, following training from two senior investigators (RHG and ZW). Only definite RPD lesions, defined as subretinal hyperreflective accumulations that altered the contour of, or broke through, the overlying photoreceptor ellipsoid zone on the OCT B-scans were annotated.
16
+
17
+ The model was then internally tested in a different set of OCT scans from 125 eyes from 92 individuals from the LEAD study, and externally tested on five independent datasets: the MACUSTAR study [5], the Northern Ireland Cohort for Longitudinal Study of Ageing (NICOLA) study [6], the Montrachet study [7], AMD observational studies at the University of Bonn, Germany (UB), and a routine clinical care cohort seen at the University of Washington (UW). The presence of RPD was graded either as part of each study (MACUSTAR and UB datasets) or graded by one of the study investigators (HK; in the NICOLA, UW, and Montrachet datasets). All these studies defined RPD based on the presence of five or more definite lesions on more than one OCT B-scan that corresponded to hyporeflective lesions seen on near-infrared reflectance imaging.
18
+
19
+ #### **Preprocessing**
20
+ Scans were kept at native resolution (1024 x 496 pixels).
21
+
22
+ ## **Performance**
23
+ In the external test datasets, the overall performance for detecting RPD in a volume scan was (AUC = 0·94; 95% CI = 0·92–0·97). In the internal test dataset, the Dice coefficient (DSC) between the model and manual annotations by retinal specialists for each B-scan was caculated and the average over the dataset is listed in the table below. Note that the DSC was assigned a value of 1·0 to all pairwise comparisons where no pixels on a B-scan were labelled as having RPD.
24
+
25
+ ![](Table2.gif)
26
+
27
+ ![](Figure1.jpg)
28
+
29
+ For more details regarding evaluation results, please see Results section of paper.
30
+
31
+ <!-- ## INSTALLATION
32
+ This bundle can be installed using docker by navigating to the RPDBundle directory and running
33
+ ```
34
+ docker build -t <image_name>:<tag> .
35
+ ``` -->
36
+ ## INSTALL
37
+ This bundle has been installed and tested using python 3.10. From the bundle directory, install the required packages using
38
+ ```
39
+ pip install -r ./docs/requirements.txt
40
+ ```
41
+ Install detectron2 using
42
+ ```
43
+ python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
44
+ ```
45
+
46
+ ## USAGE
47
+ The expected image data is in PNG format at the scan level, VOL format at the volume level, or DICOM format at the volume level. To run inference, modify the parameters of the inference.yaml config file in the configs folder which looks like:
48
+
49
+ ```
50
+ imports:
51
+ - $import scripts
52
+ - $import scripts.inference
53
+
54
+ args:
55
+ run_extract : False
56
+ input_dir : "/path/to/data"
57
+ extracted_dir : "/path/to/extracted/data"
58
+ input_format : "dicom"
59
+ create_dataset : True
60
+ dataset_name : "my_dataset_name"
61
+
62
+ output_dir : "/path/to/model/output"
63
+ run_inference : True
64
+ create_tables : True
65
+
66
+ # create visuals
67
+ binary_mask : False
68
+ binary_mask_overlay : True
69
+ instance_mask_overlay : False
70
+
71
+ inference:
72
+ - $scripts.inference.main(@args)
73
+ ```
74
+ Then in your bash shell run
75
+ ```
76
+ BUNDLE="/path/to/budle/RPDBundle"
77
+
78
+ python -m monai.bundle run inference \
79
+ --bundle_root "$BUNDLE" \
80
+ --config_file "$BUNDLE/configs/inference.yaml" \
81
+ --meta_file "$BUNDLE/configs/metadata.json"
82
+ ```
83
+ ### VOL/DICOM EXTRACTION
84
+ If extracting DICOM or VOL files:
85
+ * set `run_extract` to `True`
86
+ * specify `input_dir`, the path to the directory that contains the VOL or DICOM files
87
+ * specify `extracted_dir`, the path to the directory where extracted images will be stored
88
+ * set `input_format` to "dicom" or "vol"
89
+
90
+ The VOL or DICOM files can be in a nested hierarchy of folders, and all files in that directory with a VOL or DICOM extension will be extracted.
91
+
92
+ For DICOM files, each OCT slice will be saved as a png file to `<extracted_dir>/<SOPInstanceUID>/<SOPInstanceUID>_oct_<DDD>.png` on disk, where `<DDD>` is the slice number.
93
+
94
+ For VOL files, each OCT slice will be saved as a png file to `<extracted_dir>/<some>/<file>/<name>/<some_file_name>_oct_<DDD>.png` on disk, where `<DDD>` is the slice number and a nested hierarchy of folders is created using the underscores in the original filename. "
95
+
96
+ ### DATASET PACKAGING
97
+ Once you have the scans in PNG format, you can create a "dataset" in Detectron2 dictionary format for model consumption:
98
+ * specify `extracted_dir`, the path to the directory where the PNG files are stored
99
+ * set `create_dataset` to `True`
100
+ * set `dataset_name` to the chosen name of your dataset
101
+
102
+ The summary tables and visual output is organized around OCT volumes, so please make sure that the basename of the PNG files looks like `<volumeid>_<sliceid>.` The dataset dictionary will be saved as pickle file in `/<path>/<to>/<bundle>/RPDBundle/datasets/<your_dataset_name>.pk`
103
+
104
+ ### INFERENCE
105
+ To run inference on your dataset:
106
+ * set `dataset_name` to the name of your dataset which you create with the previous step and resides in `/<path>/<to>/<bundle>/RPDBundle/datasets/<your_dataset_name>.pk`
107
+ * set `output_dir`, the path to the directory where model predictions and other data will be stored.
108
+ * set `run_inference` to `True`
109
+
110
+ The final ensembled predictions will be saved in COCO Instance Segmentation format in `coco_instances_results.json` in the output directory. The output directory will also be populated with five folders with the preffix 'fold' which contain predictions from the individual models of the ensemble.
111
+
112
+ ### SUMMARY TABLES and VISUAL OUTPUT
113
+ Tables and images can be created from the predictions and written to the output directory. A confidence threshold of 0.5 is applied to the scored predictions by default. To change the threshold, set the `prob_thresh` value between 0.0 and 1.0.
114
+
115
+ The tables can be created by setting `create_tables` to `True`:
116
+ * HTML table called `dfimg_<dataset_name>.html` indexed by OCT-B scan with columns listing the detected number of RPD <em>instances</em> (dt_instances), <em>pixels</em> (dt_pixels), and <em>horizontal pixels</em> (dt_xpxs) in that B-scan.
117
+ * HTML table called `dfvol_<dataset_name>.html` indexed by OCT volume with columns listing the detected number of RPD <em>instances</em> (dt_instances), <em>pixels</em> (dt_pixels), and <em>horizontal pixels</em> (dt_xpxs) in that volume.
118
+
119
+ The predicted segmentations can be output as multi-page TIFFs, where each TIFF file corresponds to an input volume of the dataset, and each page to an OCT slice from the volume in original order. The output images can be binary masks, binary masks overlaying the original B-scan, and instance masks overlaying the original B-scan. Set the `binary_mask`, `binary_mask_overlay` and `instance_mask_overlay` flags in the yaml file to `True` accordingly.
120
+
121
+ ### SAMPLE DATA
122
+ As a reference, sample OCT-B scans are provided in PNG format under the sample_data directory. Set `extracted_dir` in `inference.yaml` to `sample_data` to run inference on these few set of images.
123
+
124
+ ## **System Configuration**
125
+ Inference on one Nvidia A100 gpu takes about 0.041 s/batch of 14 images, about 3G of gpu memory, and 6G of RAM.
126
+
127
+ ## **Limitations**
128
+ This model has not been tested for robustness of performance on OCTs imaged with other devices and with different scan parameters.
129
+
130
+ ## **Citation Info**
131
+
132
+ ```
133
+ @article {Kumar2024.09.11.24312817,
134
+ author = {Kumar, Himeesh and Bagdasarova, Yelena and Song, Scott and Hickey, Doron G. and Cohn, Amy C. and Okada, Mali and Finger, Robert P. and Terheyden, Jan H. and Hogg, Ruth E. and Gabrielle, Pierre-Henry and Arnould, Louis and Jannaud, Maxime and Hadoux, Xavier and van Wijngaarden, Peter and Abbott, Carla J. and Hodgson, Lauren A.B. and Schwartz, Roy and Tufail, Adnan and Chew, Emily Y. and Lee, Cecilia S. and Fletcher, Erica L. and Bahlo, Melanie and Ansell, Brendan R.E. and P{\'e}bay, Alice and Guymer, Robyn H. and Lee, Aaron Y. and Wu, Zhichao},
135
+ title = {Deep Learning-Based Detection of Reticular Pseudodrusen in Age-Related Macular Degeneration on Optical Coherence Tomography},
136
+ elocation-id = {2024.09.11.24312817},
137
+ year = {2024},
138
+ doi = {10.1101/2024.09.11.24312817},
139
+ publisher = {Cold Spring Harbor Laboratory Press},
140
+ URL = {https://www.medrxiv.org/content/early/2024/09/12/2024.09.11.24312817},
141
+ eprint = {https://www.medrxiv.org/content/early/2024/09/12/2024.09.11.24312817.full.pdf},
142
+ journal = {medRxiv}
143
+ }
144
+ ```
145
+
146
+ ## **References**
147
+ [1]: He, Kaiming, Georgia Gkioxari, Piotr Dollár, and Ross Girshick. "Mask R-CNN." In Proceedings of the IEEE international conference on computer vision (ICCV), pp. 2961-2969. 2017.
148
+
149
+ [2]: Xie, Saining, Ross Girshick, Piotr Dollár, Zhuowen Tu, and Kaiming He. "Aggregated residual transformations for deep neural networks." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1492-1500. 2017.
150
+
151
+ [3]: Wu, Yuxin, Alexander Kirillov, Francisco Massa, Wan-Yen Lo, and Ross Girshick. "Detectron2." arXiv preprint arXiv:1902.09615 (2019).
152
+
153
+ [4]: Liu X, Faes L, Kale AU, et al. A comparison of deep learning performance against health-care professionals in detecting diseases from medical imaging: a systematic review and meta-analysis. The Lancet Digital Health. 2019;1(6):e271–e97.
154
+
155
+ [5]: Finger RP, Schmitz-Valckenberg S, Schmid M, et al. MACUSTAR: Development and Clinical Validation of Functional, Structural, and Patient-Reported Endpoints in Intermediate Age-Related Macular Degeneration. Ophthalmologica. 2019;241(2):61–72.
156
+
157
+ [6]: Hogg RE, Wright DM, Quinn NB, et al. Prevalence and risk factors for age-related macular degeneration in a population-based cohort study of older adults in Northern Ireland using multimodal imaging: NICOLA Study. Br J Ophthalmol. 2022:bjophthalmol-2021-320469.
158
+
159
+ [7]: Gabrielle P-H, Seydou A, Arnould L, et al. Subretinal Drusenoid Deposits in the Elderly in a Population-Based Study (the Montrachet Study). Invest Ophthalmol Vis Sci. 2019;60(14):4838–48.
docs/Table2.gif ADDED
docs/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ setuptools==75.8.0
2
+ monai==1.5.0
3
+ torch==2.6.0
4
+ numpy==1.26.4
5
+ opencv-python-headless==4.11.0.86
6
+ pandas==2.3.0
7
+ seaborn==0.13.2
8
+ scikit-learn==1.6.1
9
+ progressbar==2.5
10
+ pydicom==3.0.1
11
+ fire==0.7.0
12
+ torchvision==0.21.0
13
+ lxml==5.4.0
14
+ pillow==11.2.1
models/fold1_model_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1819f5200d72764e768377de19724c19689b087d98f77597c59e7fecf737943d
3
+ size 856220125
models/fold2_model_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:342ae47daf092232606a0fae7246d712071d73178623103ebe1ca72d2b9d26d5
3
+ size 856220125
models/fold3_model_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1820347e67097bd90459bef7b9d790d935f56ca787b61234ff69a49a4d134ea
3
+ size 856220125
models/fold4_model_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48cd826885972cb9eec8e058b2c28872c096477edae78ac9757a7a19e535aafc
3
+ size 856220125
models/fold5_model_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29c2fa54889382704115fa6f67eac30e5f3c05d789f7c29e67f7290bd3acd975
3
+ size 856220125
sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_026.png ADDED

Git LFS Details

  • SHA256: 796c78bce2b580f5b2e095dde7be3cf18c79897cc9b82be7f4eba0f23d019c44
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_051.png ADDED

Git LFS Details

  • SHA256: a5971d87830fdd151c9569afafa036da270d67b0b6876e374734201ff5d821ca
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
sample_data/37721c93df11e35a8caa9b15616841ae985a985fb301988fce780dcaad37e71a_oct_060.png ADDED

Git LFS Details

  • SHA256: c342090939967c12f23681aab91fe654654e543b11ad4b7a1a0b7ad29c3dd6a6
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_027.png ADDED

Git LFS Details

  • SHA256: b55cfb4d95b0cc144f9504a4589a13df83a35af99d67ac29f98f2725ebd21cc1
  • Pointer size: 131 Bytes
  • Size of remote file: 218 kB
sample_data/8c85a17e87eef485a975566dab6b54cafbabf1e4c558ab7b7637b88d962920af_oct_033.png ADDED

Git LFS Details

  • SHA256: ac75667d6278eb0ed130ba5c84018c843b808681030a49e051057f678639109a
  • Pointer size: 131 Bytes
  • Size of remote file: 201 kB
scripts/Base-RCNN-FPN.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GeneralizedRCNN"
3
+ BACKBONE:
4
+ NAME: "build_resnet_fpn_backbone"
5
+ RESNETS:
6
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
7
+ FPN:
8
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
9
+ ANCHOR_GENERATOR:
10
+ SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
11
+ ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
12
+ RPN:
13
+ IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
14
+ PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
15
+ PRE_NMS_TOPK_TEST: 1000 # Per FPN level
16
+ # Detectron1 uses 2000 proposals per-batch,
17
+ # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
18
+ # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
19
+ POST_NMS_TOPK_TRAIN: 1000
20
+ POST_NMS_TOPK_TEST: 1000
21
+ ROI_HEADS:
22
+ NAME: "StandardROIHeads"
23
+ IN_FEATURES: ["p2", "p3", "p4", "p5"]
24
+ ROI_BOX_HEAD:
25
+ NAME: "FastRCNNConvFCHead"
26
+ NUM_FC: 2
27
+ POOLER_RESOLUTION: 7
28
+ ROI_MASK_HEAD:
29
+ NAME: "MaskRCNNConvUpsampleHead"
30
+ NUM_CONV: 4
31
+ POOLER_RESOLUTION: 14
32
+ SOLVER:
33
+ IMS_PER_BATCH: 14
34
+ BASE_LR: 0.02
35
+ STEPS: (60000, 80000)
36
+ MAX_ITER: 90000
37
+ INPUT:
38
+ MIN_SIZE_TRAIN: (496,)
39
+ VERSION: 2
40
+ DATALOADER:
41
+ FILTER_EMPTY_ANNOTATIONS: False
scripts/Ensembler.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from pycocotools.coco import COCO
8
+ from torchvision.ops.boxes import box_convert, box_iou
9
+ from tqdm import tqdm
10
+
11
+
12
+ class NpEncoder(json.JSONEncoder):
13
+ """Custom JSON encoder for NumPy data types.
14
+
15
+ This encoder handles NumPy-specific types that are not serializable by
16
+ the default JSON library by converting them into standard Python types.
17
+ """
18
+
19
+ def default(self, obj):
20
+ """Converts NumPy objects to their native Python equivalents.
21
+
22
+ Args:
23
+ obj (any): The object to encode.
24
+
25
+ Returns:
26
+ any: The JSON-serializable representation of the object.
27
+ """
28
+ if isinstance(obj, np.integer):
29
+ return int(obj)
30
+ elif isinstance(obj, np.floating):
31
+ return float(obj)
32
+ elif isinstance(obj, np.ndarray):
33
+ return obj.tolist()
34
+ else:
35
+ return super(NpEncoder, self).default(obj)
36
+
37
+
38
+ class Ensembler:
39
+ """A class to ensemble predictions from multiple object detection models.
40
+
41
+ This class loads ground truth data and predictions from several models,
42
+ performs non-maximum suppression (NMS) to merge overlapping detections,
43
+ and saves the final ensembled results in COCO format.
44
+ """
45
+
46
+ def __init__(
47
+ self, output_dir, dataset_name, grplist, iou_thresh, coco_gt_path=None, coco_instances_results_fname=None
48
+ ):
49
+ """Initializes the Ensembler.
50
+
51
+ Args:
52
+ output_dir (str): The base directory where model outputs and
53
+ ensembled results are stored.
54
+ dataset_name (str): The name of the dataset being evaluated.
55
+ grplist (list[str]): A list of subdirectory names, where each
56
+ subdirectory contains the prediction file from one model.
57
+ iou_thresh (float): The IoU threshold for considering two bounding
58
+ boxes as overlapping during NMS.
59
+ coco_gt_path (str, optional): The full path to the ground truth
60
+ COCO JSON file. If None, it's assumed to be in `output_dir`.
61
+ Defaults to None.
62
+ coco_instances_results_fname (str, optional): The filename for the
63
+ COCO prediction files within each model's subdirectory.
64
+ Defaults to "coco_instances_results.json".
65
+ """
66
+ self.output_dir = output_dir
67
+ self.dataset_name = dataset_name
68
+ self.grplist = grplist
69
+ self.iou_thresh = iou_thresh
70
+ self.n_detectors = len(grplist)
71
+
72
+ if coco_gt_path is None:
73
+ fname_gt = os.path.join(output_dir, dataset_name + "_coco_format.json")
74
+ else:
75
+ fname_gt = coco_gt_path
76
+
77
+ if coco_instances_results_fname is None:
78
+ fname_dt = "coco_instances_results.json"
79
+ else:
80
+ fname_dt = coco_instances_results_fname
81
+
82
+ # load in ground truth (form image lists)
83
+ coco_gt = COCO(fname_gt)
84
+ # populate detector truths
85
+ dtlist = []
86
+ for grp in grplist:
87
+ fname = os.path.join(output_dir, grp, fname_dt)
88
+ dtlist.append(coco_gt.loadRes(fname))
89
+ print("Successfully loaded {} into memory. {} instance detected.\n".format(fname, len(dtlist[-1].anns)))
90
+
91
+ self.coco_gt = coco_gt
92
+ self.cats = [cat["id"] for cat in self.coco_gt.dataset["categories"]]
93
+ self.dtlist = dtlist
94
+ self.results = []
95
+
96
+ print(
97
+ "Working with {} models, {} categories, and {} images.".format(
98
+ self.n_detectors, len(self.cats), len(self.coco_gt.imgs.keys())
99
+ )
100
+ )
101
+
102
+ def mean_score_nms(self):
103
+ """Performs non-maximum suppression by merging overlapping boxes.
104
+
105
+ This method iterates through all images and categories, merging sets of
106
+ overlapping bounding boxes from different detectors based on the IoU
107
+ threshold. For each merged set, it calculates a mean score and selects
108
+ the single box with the highest original score as the representative
109
+ detection for the ensembled output.
110
+
111
+ Returns:
112
+ Ensembler: The instance itself, with the `self.results` attribute
113
+ populated with the ensembled predictions.
114
+ """
115
+
116
+ def nik_merge(lsts):
117
+ """Niklas B. https://github.com/rikpg/IntersectionMerge/blob/master/core.py"""
118
+ sets = [set(lst) for lst in lsts if lst]
119
+ merged = 1
120
+ while merged:
121
+ merged = 0
122
+ results = []
123
+ while sets:
124
+ common, rest = sets[0], sets[1:]
125
+ sets = []
126
+ for x in rest:
127
+ if x.isdisjoint(common):
128
+ sets.append(x)
129
+ else:
130
+ merged = 1
131
+ common |= x
132
+ results.append(common)
133
+ sets = results
134
+ return sets
135
+
136
+ winning_list = []
137
+ print(
138
+ "Computing mean score non-max suppression ensembling for {} images.".format(len(self.coco_gt.imgs.keys()))
139
+ )
140
+ for img in tqdm(self.coco_gt.imgs.keys()):
141
+ # print(img)
142
+ dflist = [] # a dataframe of detections
143
+ obj_set = set() # a set of objects (frozensets)
144
+ for i, coco_dt in enumerate(self.dtlist): # for each detector append predictions to df
145
+ dflist.append(pd.DataFrame(coco_dt.imgToAnns[img]).assign(det=i))
146
+ df = pd.concat(dflist, ignore_index=True)
147
+ if not df.empty:
148
+ for cat in self.cats: # for each category
149
+ dfcat = df[df["category_id"] == cat]
150
+ ts = box_convert(
151
+ torch.tensor(dfcat["bbox"]), in_fmt="xywh", out_fmt="xyxy"
152
+ ) # list of tensor boxes for cateogory
153
+ iou_bool = np.array((box_iou(ts, ts) > self.iou_thresh)) # compute IoU matrix and threshold
154
+ for i in range(len(dfcat)): # for each detection in that category
155
+ fset = frozenset(dfcat.index[iou_bool[i]])
156
+ obj_set.add(fset) # compute set of sets representing objects
157
+ # find overlapping sets
158
+
159
+ # for fs in obj_set: #for existing sets
160
+ # if fs&fset: #check for
161
+ # fsnew = fs.union(fset)
162
+ # obj_set.remove(fs)
163
+ # obj_set.add(fsnew)
164
+ obj_set = nik_merge(obj_set)
165
+ for s in obj_set: # for each detected objects, find winning box and assign score as mean of scores
166
+ dfset = dfcat.loc[list(s)]
167
+ mean_score = dfset["score"].sum() / max(
168
+ self.n_detectors, len(s)
169
+ ) # allows for more detections than detectors
170
+ winning_box = dfset.iloc[dfset["score"].argmax()].to_dict()
171
+ winning_box["score"] = mean_score
172
+ winning_list.append(winning_box)
173
+ print("{} resulting instances from NMS".format(len(winning_list)))
174
+ self.results = winning_list
175
+ return self
176
+
177
+ def save_coco_instances(self, fname="coco_instances_results.json"):
178
+ """Saves the ensembled prediction results to a JSON file.
179
+
180
+ The output file follows the COCO instance format and can be used for
181
+ further evaluation.
182
+
183
+ Args:
184
+ fname (str, optional): The filename for the output JSON file.
185
+ Defaults to "coco_instances_results.json".
186
+ """
187
+ if self.results:
188
+ with open(os.path.join(self.output_dir, fname), "w") as f:
189
+ f.write(json.dumps(self.results, cls=NpEncoder))
190
+ f.flush()
191
+
192
+
193
+ if __name__ == "__main__":
194
+ # Example usage:
195
+ # This assumes an 'output' directory with subdirectories 'fold1', 'fold2', etc.,
196
+ # each containing a 'coco_instances_results.json' file.
197
+ ens = Ensembler("dev", ["fold1", "fold2", "fold3", "fold4", "fold5"], 0.2)
198
+ ens.mean_score_nms()
scripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .inference import main # Import main from inference.py
scripts/analysis_lib.py ADDED
@@ -0,0 +1,1243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utiltites for analyizing and visualizing model segmentations on dataset.
3
+ Yelena Bagdasarova, Scott Song
4
+ """
5
+
6
+ import json
7
+ import os
8
+ import pickle
9
+ import sys
10
+ import warnings
11
+
12
+ import cv2
13
+ import detectron2
14
+ import detectron2.utils.comm as comm
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import pandas as pd
18
+ import seaborn as sns
19
+ import torch
20
+ from detectron2.data import DatasetCatalog, MetadataCatalog
21
+ from detectron2.engine import DefaultPredictor
22
+ from detectron2.evaluation import COCOEvaluator
23
+ from detectron2.utils.visualizer import Visualizer
24
+ from matplotlib.backends.backend_pdf import PdfPages
25
+ from PIL import Image
26
+ from pycocotools.coco import COCO
27
+ from pycocotools.cocoeval import COCOeval
28
+ from pycocotools.mask import decode
29
+ from sklearn.metrics import average_precision_score, precision_recall_curve
30
+ from tqdm import tqdm
31
+
32
+ # current_directory = os.getcwd()
33
+ # print(current_directory)
34
+ plt.style.use("./scripts/ybpres.mplstyle")
35
+
36
+
37
+ def grab_dataset(name):
38
+ """Creates a function to load a pickled dataset by name.
39
+
40
+ This function returns another function that, when called, loads a dataset
41
+ from a pickle file located in the "datasets/" directory.
42
+
43
+ Args:
44
+ name (str): The base name of the dataset file (without extension).
45
+
46
+ Returns:
47
+ function: A zero-argument function that loads and returns the dataset.
48
+ """
49
+
50
+ def f():
51
+ return pickle.load(open("datasets/" + name + ".pk", "rb"))
52
+
53
+ return f
54
+
55
+
56
+ class OutputVis:
57
+ """A class to visualize model outputs and ground truth annotations."""
58
+
59
+ def __init__(
60
+ self,
61
+ dataset_name,
62
+ cfg=None,
63
+ prob_thresh=0.5,
64
+ pred_mode="model",
65
+ pred_file=None,
66
+ has_annotations=True,
67
+ draw_mode="default",
68
+ ):
69
+ """Initializes the OutputVis class.
70
+
71
+ Args:
72
+ dataset_name (str): The name of the registered Detectron2 dataset.
73
+ cfg (CfgNode, optional): The Detectron2 configuration object.
74
+ Required if `pred_mode` is "model". Defaults to None.
75
+ prob_thresh (float, optional): The probability threshold to apply
76
+ to model predictions for visualization. Defaults to 0.5.
77
+ pred_mode (str, optional): The mode for getting predictions. Must be
78
+ either "model" (to use a live predictor) or "file" (to load
79
+ from a COCO results file). Defaults to "model".
80
+ pred_file (str, optional): The path to the COCO JSON results file.
81
+ Required if `pred_mode` is "file". Defaults to None.
82
+ has_annotations (bool, optional): Whether the dataset has ground
83
+ truth annotations to visualize. Defaults to True.
84
+ draw_mode (str, optional): The drawing style for visualizations.
85
+ Can be "default" (color) or "bw" (monochrome). Defaults to "default".
86
+ """
87
+ self.dataset_name = dataset_name
88
+ self.cfg = cfg
89
+ self.prob_thresh = prob_thresh
90
+ self.data = DatasetCatalog.get(dataset_name)
91
+ if pred_mode == "model":
92
+ self.predictor = DefaultPredictor(cfg)
93
+ self._mode = "model"
94
+ elif pred_mode == "file":
95
+ with open(pred_file, "r") as f:
96
+ self.pred_instances = json.load(f)
97
+ self.instance_img_list = [p["image_id"] for p in self.pred_instances]
98
+ self._mode = "file"
99
+ else:
100
+ sys.exit('Invalid mode. Only "model" or "file" permitted.')
101
+ self.has_annotations = has_annotations
102
+ self.permitted_draw_modes = ["default", "bw"]
103
+ self.set_draw_mode(draw_mode)
104
+ self.font_size = 16 # 28 for ARVO
105
+ self.annotation_color = "r"
106
+ self.scale = 3.0
107
+
108
+ def set_draw_mode(self, draw_mode):
109
+ """Sets the drawing mode for visualizations.
110
+
111
+ Args:
112
+ draw_mode (str): The drawing style. Must be one of the permitted
113
+ modes (e.g., "default", "bw").
114
+ """
115
+ if draw_mode not in self.permitted_draw_modes:
116
+ sys.exit("draw_mode must be one of the following: {}".format(self.permitted_draw_modes))
117
+ self.draw_mode = draw_mode
118
+
119
+ def get_ori_image(self, imgid):
120
+ """Retrieves the original image for a given image ID.
121
+
122
+ The image is scaled up by a factor of 3 for better visualization.
123
+
124
+ Args:
125
+ imgid (str): The 'image_id' from the dataset dictionary.
126
+
127
+ Returns:
128
+ PIL.Image: The original image.
129
+ """
130
+ dat = self.get_gt_image_data(imgid) # gt
131
+ im = cv2.imread(dat["file_name"]) # input to model
132
+ v_gt = Visualizer(im, MetadataCatalog.get(self.dataset_name), scale=self.scale)
133
+ result_image = v_gt.output.get_image() # get original image
134
+ img = Image.fromarray(result_image)
135
+ return img
136
+
137
+ def get_gt_image_data(self, imgid):
138
+ """Returns the ground truth data dictionary for a given image ID.
139
+
140
+ Args:
141
+ imgid (str): The 'image_id' from the dataset dictionary.
142
+
143
+ Returns:
144
+ dict: The dataset dictionary for the specified image.
145
+ """
146
+ gt_data = next(item for item in self.data if (item["image_id"] == imgid))
147
+ return gt_data
148
+
149
+ def produce_gt_image(self, dat, im):
150
+ """Creates an image with ground truth annotations overlaid.
151
+
152
+ The visualization can be in color or monochrome depending on the draw mode.
153
+
154
+ Args:
155
+ dat (dict): The dataset dictionary containing ground truth annotations.
156
+ im (np.ndarray): The input image in RGB format (H, W, C) as a NumPy array.
157
+
158
+ Returns:
159
+ PIL.Image: The image with ground truth instances overlaid.
160
+ """
161
+ v_gt = Visualizer(im, MetadataCatalog.get(self.dataset_name), scale=self.scale)
162
+ if self.has_annotations: # ground truth boxes and masks
163
+ segs = [ddict["segmentation"] for ddict in dat["annotations"]]
164
+ if self.draw_mode == "bw":
165
+ _bboxes = None
166
+ assigned_colors = [self.annotation_color] * len(segs)
167
+ else: # default behavior
168
+ bboxes = [ddict["bbox"] for ddict in dat["annotations"]]
169
+ _bboxes = detectron2.structures.Boxes(bboxes)
170
+ _bboxes = detectron2.structures.BoxMode.convert(
171
+ _bboxes.tensor, from_mode=1, to_mode=0
172
+ ) # 0= XYXY, 1 = XYWH
173
+ assigned_colors = None
174
+
175
+ result_image = v_gt.overlay_instances(
176
+ boxes=_bboxes, masks=segs, assigned_colors=assigned_colors, alpha=1.0
177
+ ).get_image()
178
+ else:
179
+ result_image = v_gt.output.get_image() # get original image if no annotations
180
+ img = Image.fromarray(result_image)
181
+ return img
182
+
183
+ def produce_model_image(self, imgid, dat, im):
184
+ """Creates an image with model-predicted instances overlaid.
185
+
186
+ Predictions are either generated by the model or loaded from a file,
187
+ based on the configured `pred_mode`.
188
+
189
+ Args:
190
+ imgid (str): The 'image_id' from the dataset dictionary.
191
+ dat (dict): The dataset dictionary for the image (used for height/width).
192
+ im (np.ndarray): The input image in RGB format (H, W, C) as a NumPy array.
193
+
194
+ Returns:
195
+ PIL.Image: The image with model-predicted instances overlaid.
196
+ """
197
+ v_dt = Visualizer(im, MetadataCatalog.get(self.dataset_name), scale=self.scale)
198
+ v_dt._default_font_size = self.font_size
199
+
200
+ # get predictions from model or file
201
+ if self._mode == "model":
202
+ outputs = self.predictor(im)["instances"].to("cpu")
203
+ elif self._mode == "file":
204
+ outputs = self.get_outputs_from_file(imgid, (dat["height"], dat["width"]))
205
+ outputs = outputs[outputs.scores > self.prob_thresh] # apply probability threshold to instances
206
+ if self.draw_mode == "bw":
207
+ result_model = v_dt.overlay_instances(
208
+ masks=outputs.pred_masks, assigned_colors=[self.annotation_color] * len(outputs), alpha=1.0
209
+ ).get_image()
210
+ else: # default behavior
211
+ result_model = v_dt.draw_instance_predictions(outputs).get_image()
212
+ img_model = Image.fromarray(result_model)
213
+ return img_model
214
+
215
+ def get_image(self, imgid):
216
+ """Generates both ground truth and model prediction overlay images.
217
+
218
+ Args:
219
+ imgid (str): The 'image_id' from the dataset dictionary.
220
+
221
+ Returns:
222
+ tuple[PIL.Image, PIL.Image]: A tuple containing the ground truth
223
+ image and the model prediction image.
224
+ """
225
+ dat = self.get_gt_image_data(imgid) # gt
226
+ im = cv2.imread(dat["file_name"]) # input to model
227
+ img = self.produce_gt_image(dat, im)
228
+ img_model = self.produce_model_image(imgid, dat, im)
229
+ return img, img_model
230
+
231
+ def get_outputs_from_file(self, imgid, imgsize):
232
+ """Loads and formats model predictions from a COCO results file.
233
+
234
+ Converts COCO-formatted instances into a Detectron2 `Instances` object
235
+ suitable for the visualizer.
236
+
237
+ Args:
238
+ imgid (str): The 'image_id' of the desired image.
239
+ imgsize (tuple[int, int]): The (height, width) of the image.
240
+
241
+ Returns:
242
+ detectron2.structures.Instances: An `Instances` object containing
243
+ the predictions.
244
+ """
245
+
246
+ pred_boxes = []
247
+ scores = []
248
+ pred_classes = []
249
+ pred_masks = []
250
+ for i, img in enumerate(self.instance_img_list):
251
+ if img == imgid:
252
+ pred_boxes.append(self.pred_instances[i]["bbox"])
253
+ scores.append(self.pred_instances[i]["score"])
254
+ pred_classes.append(int(self.pred_instances[i]["category_id"]))
255
+ # pred_masks_rle.append(self.pred_instances[i]['segmentation'])
256
+ pred_masks.append(decode(self.pred_instances[i]["segmentation"]))
257
+ _bboxes = detectron2.structures.Boxes(pred_boxes)
258
+ pred_boxes = detectron2.structures.BoxMode.convert(_bboxes.tensor, from_mode=1, to_mode=0) # 0= XYXY, 1 = XYWH
259
+ inst_dict = dict(
260
+ pred_boxes=pred_boxes,
261
+ scores=torch.tensor(np.array(scores)),
262
+ pred_classes=torch.tensor(np.array(pred_classes)),
263
+ pred_masks=torch.tensor(np.array(pred_masks)).to(torch.bool),
264
+ ) # pred_masks_rle=pred_masks_rle)
265
+ outputs = detectron2.structures.Instances(imgsize, **inst_dict)
266
+ return outputs
267
+
268
+ @staticmethod
269
+ def height_crop_range(im, height_target=256):
270
+ """Calculates a vertical crop range centered on the brightest part of an image.
271
+
272
+ Args:
273
+ im (np.ndarray): The input image as a NumPy array (H, W, C).
274
+ height_target (int, optional): The desired height of the crop.
275
+ Defaults to 256.
276
+
277
+ Returns:
278
+ range: A range object representing the start and end pixel rows for the crop.
279
+ """
280
+ yhist = im.sum(axis=1) # integrate over width of image
281
+ mu = np.average(np.arange(yhist.shape[0]), weights=yhist)
282
+ h1 = int(np.floor(mu - height_target / 2)) # inclusive
283
+ h2 = int(np.ceil(mu + height_target / 2)) # exclusive
284
+ if h1 < 0:
285
+ h1 = 0
286
+ h2 = height_target
287
+ if h2 > yhist.shape[0]:
288
+ h2 = yhist.shape[0]
289
+ h1 = h2 - height_target
290
+ return range(h1, h2)
291
+
292
+ def output_to_pdf(self, imgids, outname, dfimg=None):
293
+ """Exports visualizations of ground truth and model predictions to a PDF file.
294
+
295
+ Each page of the PDF contains the ground truth and model prediction for one image.
296
+
297
+ Args:
298
+ imgids (list[str]): A list of 'image_id' values to include in the PDF.
299
+ outname (str): The path and filename for the output PDF.
300
+ dfimg (pd.DataFrame, optional): A DataFrame with image statistics
301
+ to display on each page. Index should be `imgid`. Defaults to None.
302
+ """
303
+
304
+ gtstr = ""
305
+ dtstr = ""
306
+
307
+ if dfimg is not None:
308
+ gtcols = dfimg.columns[["gt_" in col for col in dfimg.columns]]
309
+ dtcols = dfimg.columns[["dt_" in col for col in dfimg.columns]]
310
+
311
+ with PdfPages(outname) as pdf:
312
+ for imgid in tqdm(imgids):
313
+ img, img_model = self.get_image(imgid)
314
+ # pdb.set_trace()
315
+ crop_range = self.height_crop_range(np.array(img.convert("L")), height_target=256 * self.scale)
316
+ img = np.array(img)[crop_range]
317
+ img_model = np.array(img_model)[crop_range]
318
+
319
+ fig, ax = plt.subplots(2, 1, figsize=[22, 10], dpi=200)
320
+ ax[0].imshow(img)
321
+ ax[0].set_title(imgid + " Ground Truth")
322
+ ax[0].set_axis_off()
323
+ ax[1].imshow(img_model)
324
+ ax[1].set_title(imgid + " Model Prediction")
325
+ ax[1].set_axis_off()
326
+ if dfimg is not None: # annotate with provided stats
327
+ gtstr = ["{:s}={:.2f}".format(col, dfimg.loc[imgid, col]) for col in gtcols]
328
+ ax[0].text(0, 0.05 * (ax[0].get_ylim()[0]), gtstr, color="white", fontsize=14)
329
+ dtstr = ["{:s}={:.2f}".format(col, dfimg.loc[imgid, col]) for col in dtcols]
330
+ ax[1].text(0, 0.05 * (ax[1].get_ylim()[0]), dtstr, color="white", fontsize=14)
331
+ pdf.savefig(fig)
332
+ plt.close(fig)
333
+
334
+ def save_imgarr_to_tiff(self, imgs, outname):
335
+ """Saves a list of PIL images to a multi-page TIFF file.
336
+
337
+ Args:
338
+ imgs (list[PIL.Image]): A list of images to save.
339
+ outname (str): The path and filename for the output TIFF.
340
+ """
341
+ if len(imgs) > 1:
342
+ imgs[0].save(outname, dpi=(400, 400), tags="", compression=None, save_all=True, append_images=imgs[1:])
343
+ else:
344
+ imgs[0].save(outname)
345
+
346
+ def output_ori_to_tiff(self, imgids, outname):
347
+ """Saves the original images for a list of IDs to a multi-page TIFF.
348
+
349
+ Args:
350
+ imgids (list[str]): A list of 'image_id' values.
351
+ outname (str): The path and filename for the output TIFF.
352
+ """
353
+ imgs = []
354
+ for imgid in tqdm(imgids):
355
+ img_ori = self.get_ori_image(imgid) # PIL Image
356
+ imgs.append(img_ori)
357
+ self.save_imgarr_to_tiff(imgs, outname)
358
+
359
+ def output_pred_to_tiff(self, imgids, outname, pred_only=False):
360
+ """Saves model prediction overlays for a list of IDs to a multi-page TIFF.
361
+
362
+ Args:
363
+ imgids (list[str]): A list of 'image_id' values.
364
+ outname (str): The path and filename for the output TIFF.
365
+ pred_only (bool, optional): If True, overlays predictions on a
366
+ black background instead of the original image. Defaults to False.
367
+ """
368
+ imgs = self.output_pred_to_list(imgids, pred_only)
369
+ self.save_imgarr_to_tiff(imgs, outname)
370
+
371
+ def output_pred_to_list(self, imgids, pred_only=False):
372
+ """Generates a list of images with model predictions overlaid.
373
+
374
+ Args:
375
+ imgids (list[str]): A list of 'image_id' values.
376
+ pred_only (bool, optional): If True, overlays predictions on a
377
+ black background. Defaults to False.
378
+
379
+ Returns:
380
+ list[PIL.Image]: A list of the generated visualization images.
381
+ """
382
+ imgs = []
383
+ for imgid in tqdm(imgids):
384
+ dat = self.get_gt_image_data(imgid) # gt
385
+ if pred_only:
386
+ im = np.zeros((dat["height"], dat["width"], 3)) # blank image for overlay
387
+ assert (
388
+ self._mode == "file"
389
+ ), 'pred_mode must be "file" when pred_only flage is set to True.' # fix this later
390
+ else:
391
+ im = cv2.imread(dat["file_name"]) # input to model
392
+ img_dt = self.produce_model_image(imgid, dat, im)
393
+ imgs.append(img_dt)
394
+ return imgs
395
+
396
+ def output_all_to_tiff(self, imgids, outname):
397
+ """Saves a combined visualization (original, GT, prediction) to a TIFF.
398
+
399
+ For each image ID, it creates a single composite image by concatenating
400
+ the original, ground truth overlay, and model prediction overlay, then
401
+ saves them to a multi-page TIFF.
402
+
403
+ Args:
404
+ imgids (list[str]): A list of 'image_id' values.
405
+ outname (str): The path and filename for the output TIFF.
406
+ """
407
+ imgs = []
408
+ for imgid in tqdm(imgids):
409
+ img_gt, img_dt = self.get_image(imgid)
410
+ img_ori = self.get_ori_image(imgid)
411
+ hcrange = list(self.height_crop_range(np.array(img_ori.convert("L")), height_target=256 * self.scale))
412
+ img_result = Image.fromarray(
413
+ np.concatenate(
414
+ (
415
+ np.array(img_ori.convert("RGB"))[hcrange, :],
416
+ np.array(img_gt)[hcrange, :],
417
+ np.array(img_dt)[hcrange],
418
+ )
419
+ )
420
+ )
421
+ imgs.append(img_result)
422
+ self.save_imgarr_to_tiff(imgs, outname)
423
+
424
+ def get_enface_dt(self, grp, scan_height, scan_width, scan_spacing):
425
+ """Generates an en-face view of model predictions for a scan volume.
426
+
427
+ Args:
428
+ grp (pd.DataFrame): DataFrame for a single scan volume, indexed by imgid.
429
+ scan_height (int): The height of a single scan image in pixels.
430
+ scan_width (int): The width of a single scan image in pixels.
431
+ scan_spacing (float): The spacing between scan centers in pixels.
432
+
433
+ Returns:
434
+ np.ndarray: An en-face image of the model predictions.
435
+ """
436
+ grp = grp.sort_index()
437
+ nscans = len(grp)
438
+ enface_height = int(np.ceil((nscans - 1) * scan_spacing))
439
+ enface = np.zeros((enface_height, scan_width, 3), dtype=int)
440
+ for i, imgid in enumerate(grp.index):
441
+ pos = int(np.clip(np.floor(scan_spacing * i), 0, scan_width - 1)) # vertical enface position
442
+
443
+ outputs = self.get_outputs_from_file(imgid, (scan_height, scan_width))
444
+ outputs = outputs[outputs.scores > self.prob_thresh]
445
+ instances = outputs.pred_boxes[:, (0, 2)].round().clip(0, scan_width - 1).to(np.int)
446
+
447
+ for inst in instances:
448
+ try:
449
+ enface[max(pos - 4, 0) : min(pos + 4, scan_width - 1), inst[0] : inst[1]] = np.array(
450
+ [255, 255, 255]
451
+ ) # random_color(rgb = True)
452
+ except IndexError:
453
+ print(pos, inst[0], inst[1])
454
+ return enface
455
+
456
+ def get_enface_gt(self, grp, scan_height, scan_width, scan_spacing):
457
+ """Generates an en-face view of ground truth annotations for a scan volume.
458
+
459
+ Args:
460
+ grp (pd.DataFrame): DataFrame for a single scan volume, indexed by imgid.
461
+ scan_height (int): The height of a single scan image in pixels.
462
+ scan_width (int): The width of a single scan image in pixels.
463
+ scan_spacing (float): The spacing between scan centers in pixels.
464
+
465
+ Returns:
466
+ np.ndarray: An en-face image of the ground truth annotations.
467
+ """
468
+ grp = grp.sort_index()
469
+ nscans = len(grp)
470
+ enface_height = int(np.ceil((nscans - 1) * scan_spacing))
471
+ enface = np.zeros((enface_height, scan_width, 3), dtype=int)
472
+ if not self.has_annotations:
473
+ enface[:, :] = np.array([100, 100, 100])
474
+
475
+ else:
476
+ # minx = scan_width
477
+ for i, imgid in enumerate(grp.index):
478
+ pos = int(np.clip(np.floor(scan_spacing * i), 0, scan_width - 1))
479
+ instances = self.get_gt_image_data(imgid)["annotations"]
480
+ for inst in instances:
481
+ x1 = inst["bbox"][0]
482
+ # minx = min(minx,x1)
483
+ x2 = x1 + inst["bbox"][2]
484
+ try:
485
+ enface[max(pos - 4, 0) : min(pos + 4, scan_width - 1), x1:x2] = np.array(
486
+ [255, 255, 255]
487
+ ) # random_color(rgb = True)
488
+ except IndexError:
489
+ print(pos, x1, x2)
490
+ return enface
491
+
492
+ def compare_enface(self, grp, name, scan_height, scan_width, scan_spacing):
493
+ """Creates a figure comparing the en-face views of predictions and ground truth.
494
+
495
+ Args:
496
+ grp (pd.DataFrame): DataFrame for a single scan volume, indexed by imgid.
497
+ name (str): The name/ID of the scan volume for the plot title.
498
+ scan_height (int): The height of a single scan image in pixels.
499
+ scan_width (int): The width of a single scan image in pixels.
500
+ scan_spacing (float): The spacing between scan centers in pixels.
501
+
502
+ Returns:
503
+ tuple[plt.Figure, np.ndarray]: A tuple containing the figure and axes objects.
504
+ """
505
+ fig, ax = plt.subplots(1, 2, figsize=[18, 9], dpi=120)
506
+
507
+ enface = self.get_enface_dt(grp, scan_height, scan_width, scan_spacing)
508
+ ax[0].imshow(enface)
509
+ ax[0].set_title(str(name) + " DT")
510
+ ax[0].set_aspect("equal")
511
+
512
+ enface = self.get_enface_gt(grp, scan_height, scan_width, scan_spacing)
513
+ ax[1].imshow(enface)
514
+ ax[1].set_title(str(name) + " GT")
515
+ ax[1].set_aspect("equal")
516
+ return fig, ax
517
+
518
+
519
+ def wilson_ci(p, n, z):
520
+ """Calculates the Wilson score interval for a binomial proportion.
521
+
522
+ Args:
523
+ p (float): The observed proportion of successes.
524
+ n (int): The total number of trials.
525
+ z (float): The z-score for the desired confidence level (e.g., 1.96 for 95%).
526
+
527
+ Returns:
528
+ tuple[float, float]: A tuple containing the lower and upper bounds of the confidence interval.
529
+ """
530
+ if p < 0 or p > 1 or n == 0:
531
+ if p < 0 or p > 1:
532
+ warnings.warn(f"The value of proportion {p} must be in the range [0,1]. Returning identity for CIs.")
533
+ else:
534
+ warnings.warn(f"The number of counts {n} must be above zero. Returning identity for CIs.")
535
+ return (p, p)
536
+ sym = z * (p * (1 - p) / n + z * z / 4 / n / n) ** 0.5
537
+ asym = p + z * z / 2 / n
538
+ fact = 1 / (1 + z * z / n)
539
+ upper = fact * (asym + sym)
540
+ lower = fact * (asym - sym)
541
+ return (lower, upper)
542
+
543
+
544
+ class EvaluateClass(COCOEvaluator):
545
+ """A custom evaluation class extending COCOEvaluator for detailed analysis."""
546
+
547
+ def __init__(self, dataset_name, output_dir, prob_thresh=0.5, iou_thresh=0.1, evalsuper=True):
548
+ """Initializes the custom evaluator.
549
+
550
+ Args:
551
+ dataset_name (str): The name of the registered Detectron2 dataset.
552
+ output_dir (str): Directory to store temporary evaluation files.
553
+ prob_thresh (float, optional): Probability threshold for calculating
554
+ precision, recall, and FPR. Defaults to 0.5.
555
+ iou_thresh (float, optional): IoU threshold for defining a true positive.
556
+ Defaults to 0.1.
557
+ evalsuper (bool, optional): If True, run the parent COCOEvaluator's
558
+ evaluate method to generate standard COCO metrics. Defaults to True.
559
+ """
560
+ super().__init__(dataset_name, tasks={"bbox", "segm"}, output_dir=output_dir)
561
+ self.dataset_name = dataset_name
562
+ self.mycoco = None # pycocotools.cocoEval instance
563
+ self.cocoDt = None
564
+ self.cocoGt = None
565
+ self.evalsuper = evalsuper # if True, run COCOEvaluator.evaluate() when self.evaluate is run
566
+ self.prob_thresh = prob_thresh # instance probabilty threshold for scalars (precision,recall,fpr for scans)
567
+ self.iou_thresh = iou_thresh # iou threshold for defining precision,recall
568
+ self.pr = None
569
+ self.rc = None
570
+ self.fpr = None
571
+
572
+ def reset(self):
573
+ """Resets the evaluator's state for a new evaluation run."""
574
+ super().reset()
575
+ self.mycoco = None
576
+
577
+ def process(self, inputs, outputs):
578
+ """Processes a batch of inputs and outputs from the model.
579
+
580
+ This method is called by the evaluation loop for each batch.
581
+
582
+ Args:
583
+ inputs (list[dict]): A list of dataset dictionaries.
584
+ outputs (list[dict]): A list of model output dictionaries.
585
+ """
586
+ super().process(inputs, outputs)
587
+
588
+ def evaluate(self):
589
+ """Runs the evaluation and calculates detailed performance metrics.
590
+
591
+ This method orchestrates the COCO evaluation, calculates precision-recall
592
+ curves, and other custom metrics.
593
+
594
+ Returns:
595
+ tuple[float, float]: The precision and recall at the specified
596
+ `prob_thresh` and `iou_thresh`.
597
+ """
598
+ if self.evalsuper:
599
+ _ = super().evaluate() # this call populates coco_instances_results.json
600
+ comm.synchronize()
601
+ if not comm.is_main_process():
602
+ return ()
603
+ self.cocoGt = COCO(
604
+ os.path.join(self._output_dir, self.dataset_name + "_coco_format.json")
605
+ ) # produced when super is initialized
606
+ self.cocoDt = self.cocoGt.loadRes(
607
+ os.path.join(self._output_dir, "coco_instances_results.json")
608
+ ) # load detector results
609
+ self.mycoco = COCOeval(self.cocoGt, self.cocoDt, iouType="segm")
610
+ self.num_images = len(self.mycoco.params.imgIds)
611
+ print("Calculated metrics for {} images".format(self.num_images))
612
+ self.mycoco.params.iouThrs = np.arange(0.10, 0.6, 0.1)
613
+ self.mycoco.params.maxDets = [100]
614
+ self.mycoco.params.areaRng = [[0, 10000000000.0]]
615
+
616
+ self.mycoco.evaluate()
617
+ self.mycoco.accumulate()
618
+
619
+ self.pr = self.mycoco.eval["precision"][
620
+ :, :, 0, 0, 0 # iouthresh # recall level # catagory # area range
621
+ ] # max detections per image
622
+ self.rc = self.mycoco.params.recThrs
623
+ self.iou = self.mycoco.params.iouThrs
624
+ self.scores = self.mycoco.eval["scores"][:, :, 0, 0, 0] # unreliable if GT has no instances
625
+ p, r = self.get_precision_recall()
626
+ return p, r
627
+
628
+ def plot_pr_curve(self, ax=None):
629
+ """Plots precision-recall curves for various IoU thresholds.
630
+
631
+ Args:
632
+ ax (plt.Axes, optional): A matplotlib axes object to plot on. If None,
633
+ a new figure and axes are created.
634
+ """
635
+ if ax is None:
636
+ fig, ax = plt.subplots(1, 1)
637
+ for i in range(len(self.iou)):
638
+ ax.plot(self.rc, self.pr[i], label="{:.2}".format(self.iou[i]))
639
+ ax.set_xlabel("Recall")
640
+ ax.set_ylabel("Precision")
641
+ ax.set_title("")
642
+ ax.legend(title="IoU")
643
+
644
+ def plot_recall_vs_prob(self):
645
+ """Plots model score thresholds versus recall for various IoU thresholds."""
646
+ plt.figure()
647
+ for i in range(len(self.iou)):
648
+ plt.plot(self.rc, self.scores[i], label="{:.2}".format(self.iou[i]))
649
+ plt.ylabel("Model probability")
650
+ plt.xlabel("Recall")
651
+ plt.legend(title="IoU")
652
+
653
+ def get_precision_recall(self):
654
+ """Gets the precision and recall for the configured IoU and probability thresholds.
655
+
656
+ Returns:
657
+ tuple[float, float]: The calculated precision and recall.
658
+ """
659
+ iou_idx, rc_idx = self._find_iou_rc_inds()
660
+ precision = self.pr[iou_idx, rc_idx]
661
+ recall = self.rc[rc_idx]
662
+ return precision, recall
663
+
664
+ def _calculate_fpr_matrix(self):
665
+ """(Private) Calculates the false positive rate matrix across all IoU and recall thresholds."""
666
+
667
+ # FP rate, 1 RPD in image = FP
668
+ if (self.scores.min() == -1) and (self.scores.max() == -1):
669
+ print(
670
+ "WARNING: Scores for all iou thresholds and all recall levels are not defined. "
671
+ "This can arise if ground truth annotations contain no instances. Leaving fpr matrix as None"
672
+ )
673
+ self.fpr = None
674
+ return
675
+
676
+ fpr = np.zeros((len(self.iou), len(self.rc)))
677
+ for i in range(len(self.iou)):
678
+ for j, s in enumerate(self.scores[i]): # j -> recall level, s -> corresponding score
679
+ ng = 0 # number of negative images
680
+ fp = 0 # number of false positives images
681
+ for el in self.mycoco.evalImgs:
682
+ if el is None: # no predictions, no gts
683
+ ng = ng + 1
684
+ elif len(el["gtIds"]) == 0: # some predictions and no gts
685
+ ng = ng + 1
686
+ if (
687
+ np.array(el["dtScores"]) > s
688
+ ).sum() > 0: # if at least one score over threshold for recall level
689
+ fp = fp + 1 # count as FP
690
+ else:
691
+ continue
692
+ fpr[i, j] = fp / ng
693
+ self.fpr = fpr
694
+
695
+ def _calculate_fpr(self):
696
+ """(Private) Calculates FPR for a single probability threshold.
697
+
698
+ This is an alternate calculation used when the main FPR matrix cannot
699
+ be computed (e.g., no positive ground truth instances).
700
+
701
+ Returns:
702
+ float: The calculated false positive rate.
703
+ """
704
+ print("Using alternate calculation for fpr at instance score threshold of {}".format(self.prob_thresh))
705
+ ng = 0 # number of negative images
706
+ fp = 0 # number of false positives images
707
+ for el in self.mycoco.evalImgs:
708
+ if el is None: # no predictions, no gts
709
+ ng = ng + 1
710
+ elif len(el["gtIds"]) == 0: # some predictions and no gts
711
+ ng = ng + 1
712
+ if (
713
+ np.array(el["dtScores"]) > self.prob_thresh
714
+ ).sum() > 0: # if at least one score over threshold for recall level
715
+ fp = fp + 1 # count as FP
716
+ else: # gt has instances
717
+ continue
718
+ return fp / (ng + 1e-5)
719
+
720
+ def _find_iou_rc_inds(self):
721
+ """(Private) Finds the indices corresponding to the configured IoU and probability thresholds.
722
+
723
+ Returns:
724
+ tuple[int, int]: The index for the IoU threshold and the index for the recall level.
725
+ """
726
+ try:
727
+ iou_idx = np.argwhere(self.iou == self.iou_thresh)[0][0] # first instance of
728
+ except IndexError:
729
+ print(
730
+ "iou threshold {} not found in mycoco.params.iouThrs {}".format(
731
+ self.iou_thresh, self.mycoco.params.iouThrs
732
+ )
733
+ )
734
+ exit(1)
735
+ # test above for out of bounds
736
+ inds = np.argwhere(self.scores[iou_idx] >= self.prob_thresh)
737
+ if len(inds) > 0:
738
+ rc_idx = inds[-1][0] # get recall index corresponding to prob_thresh
739
+ else:
740
+ rc_idx = 0
741
+ return iou_idx, rc_idx
742
+
743
+ def get_fpr(self):
744
+ """Gets the false positive rate for the configured thresholds.
745
+
746
+ Returns:
747
+ float: The calculated false positive rate. Returns -1 if it cannot be computed.
748
+ """
749
+ if self.fpr is None:
750
+ self._calculate_fpr_matrix()
751
+
752
+ if self.fpr is not None:
753
+ iou_idx, rc_idx = self._find_iou_rc_inds()
754
+ fpr = self.fpr[iou_idx, rc_idx]
755
+ elif len(self.mycoco.cocoGt.anns) == 0:
756
+ fpr = self._calculate_fpr()
757
+ else:
758
+ fpr = -1
759
+ return fpr
760
+
761
+ def summarize_scalars(self): # for pretty printing
762
+ """Generates a dictionary summarizing key performance metrics with confidence intervals.
763
+
764
+ Returns:
765
+ dict: A dictionary containing precision, recall, F1-score, FPR,
766
+ and their confidence intervals.
767
+ """
768
+ p, r = self.get_precision_recall()
769
+ f1 = 2 * (p * r) / (p + r)
770
+ fpr = self.get_fpr()
771
+
772
+ # Confidence intervals
773
+ z = 1.96 # 95% Gaussian
774
+ # instance count
775
+ inst_cnt = self.count_instances()
776
+ n_r = inst_cnt["gt_instances"]
777
+ n_p = inst_cnt["dt_instances"]
778
+ n_fpr = inst_cnt["gt_neg_scans"]
779
+
780
+ def stat_ci(p, n, z):
781
+ return z * np.sqrt(p * (1 - p) / n)
782
+
783
+ r_ci = wilson_ci(r, n_r, z)
784
+ p_ci = wilson_ci(p, n_p, z)
785
+ fpr_ci = wilson_ci(fpr, n_fpr, z)
786
+
787
+ # propogate errors for f1
788
+ int_r = stat_ci(r, n_r, z)
789
+ int_p = stat_ci(p, n_p, z)
790
+ int_f1 = (f1) * np.sqrt(int_r**2 * (1 / r - 1 / (p + r)) ** 2 + int_p**2 * (1 / p - 1 / (p + r)) ** 2)
791
+ f1_ci = (f1 - int_f1, f1 + int_f1)
792
+
793
+ dd = dict(
794
+ dataset=self.dataset_name,
795
+ precision=float(p),
796
+ precision_ci=p_ci,
797
+ recall=float(r),
798
+ recall_ci=r_ci,
799
+ f1=float(f1),
800
+ f1_ci=f1_ci,
801
+ fpr=float(fpr),
802
+ fpr_ci=fpr_ci,
803
+ iou=self.iou_thresh,
804
+ probability=self.prob_thresh,
805
+ )
806
+ return dd
807
+
808
+ def count_instances(self):
809
+ """Counts ground truth and detected instances across the dataset.
810
+
811
+ Returns:
812
+ dict: A dictionary with counts for 'gt_instances', 'dt_instances',
813
+ and 'gt_neg_scans' (images with no GT instances).
814
+ """
815
+ gt_inst = 0
816
+ dt_inst = 0
817
+ gt_neg_scans = 0
818
+ for _, val in self.cocoGt.imgs.items():
819
+ imgid = val["id"]
820
+ # Gt instances
821
+ annids_gt = self.cocoGt.getAnnIds([imgid])
822
+ anns_gt = self.cocoGt.loadAnns(annids_gt)
823
+ gt_inst += len(anns_gt)
824
+ if len(anns_gt) == 0:
825
+ gt_neg_scans += 1
826
+
827
+ # Dt instances
828
+ annids_dt = self.cocoDt.getAnnIds([imgid])
829
+ anns_dt = self.cocoDt.loadAnns(annids_dt)
830
+ anns_dt = [ann for ann in anns_dt if ann["score"] > self.prob_thresh]
831
+ dt_inst += len(anns_dt)
832
+
833
+ return dict(gt_instances=gt_inst, dt_instances=dt_inst, gt_neg_scans=gt_neg_scans)
834
+
835
+
836
+ class CreatePlotsRPD:
837
+ """A class to create various plots for analyzing RPD (Reticular Pseudodrusen) data."""
838
+
839
+ def __init__(self, dfimg):
840
+ """Initializes the plotting class with image-level data.
841
+
842
+ Args:
843
+ dfimg (pd.DataFrame): A DataFrame where each row corresponds to an
844
+ image, containing counts for ground truth and detected instances
845
+ and pixels. Must include a 'volID' column.
846
+ """
847
+ self.dfimg = dfimg
848
+ self.dfvol = self.dfimg.groupby(["volID"])[
849
+ ["gt_instances", "gt_pxs", "gt_xpxs", "dt_instances", "dt_pxs", "dt_xpxs"]
850
+ ].sum()
851
+
852
+ @classmethod
853
+ def initfromcoco(cls, mycoco, prob_thresh):
854
+ """Initializes the class from a COCOeval object.
855
+
856
+ Args:
857
+ mycoco (COCOeval): An evaluated COCOeval object.
858
+ prob_thresh (float): The probability threshold to apply to detections.
859
+
860
+ Returns:
861
+ CreatePlotsRPD: An instance of the class.
862
+ """
863
+ df = pd.DataFrame(
864
+ index=mycoco.cocoGt.imgs.keys(),
865
+ columns=["gt_instances", "gt_pxs", "gt_xpxs", "dt_instances", "dt_pxs", "dt_xpxs"],
866
+ dtype=np.uint64,
867
+ )
868
+
869
+ for key, val in mycoco.cocoGt.imgs.items():
870
+ imgid = val["id"]
871
+ # Gt instances
872
+ annids_gt = mycoco.cocoGt.getAnnIds([imgid])
873
+ anns_gt = mycoco.cocoGt.loadAnns(annids_gt)
874
+ inst_gt = [mycoco.cocoGt.annToMask(ann).sum() for ann in anns_gt]
875
+ xproj_gt = [(mycoco.cocoGt.annToMask(ann).sum(axis=0) > 0).astype("uint8").sum() for ann in anns_gt]
876
+ # Dt instances
877
+ annids_dt = mycoco.cocoDt.getAnnIds([imgid])
878
+ anns_dt = mycoco.cocoDt.loadAnns(annids_dt)
879
+ anns_dt = [ann for ann in anns_dt if ann["score"] > prob_thresh]
880
+ inst_dt = [mycoco.cocoDt.annToMask(ann).sum() for ann in anns_dt]
881
+ xproj_dt = [(mycoco.cocoDt.annToMask(ann).sum(axis=0) > 0).astype("uint8").sum() for ann in anns_dt]
882
+
883
+ dat = [
884
+ len(inst_gt),
885
+ np.array(inst_gt).sum(),
886
+ np.array(xproj_gt).sum(),
887
+ len(inst_dt),
888
+ np.array(inst_dt).sum(),
889
+ np.array(xproj_dt).sum(),
890
+ ]
891
+ df.loc[key] = dat
892
+
893
+ newdf = pd.DataFrame(
894
+ [idx.rsplit(".", 1)[0].rsplit("_", 1) for idx in df.index], columns=["volID", "scan"], index=df.index
895
+ )
896
+ df = df.merge(newdf, how="inner", left_index=True, right_index=True)
897
+ return cls(df)
898
+
899
+ @classmethod
900
+ def initfromcsv(cls, fname):
901
+ """Initializes the class from a CSV file.
902
+
903
+ Args:
904
+ fname (str): The path to the CSV file.
905
+
906
+ Returns:
907
+ CreatePlotsRPD: An instance of the class.
908
+ """
909
+ df = pd.read_csv(fname)
910
+ return cls(df)
911
+
912
+ def get_max_limits(self, df):
913
+ """Calculates the maximum values for plotting limits.
914
+
915
+ Args:
916
+ df (pd.DataFrame): The DataFrame to analyze.
917
+
918
+ Returns:
919
+ tuple[int, int, int]: Max values for instances, x-pixels, and total pixels.
920
+ """
921
+ max_inst = np.max([df.gt_instances.max(), df.dt_instances.max()])
922
+ max_xpxs = np.max([df.gt_xpxs.max(), df.dt_xpxs.max()])
923
+ max_pxs = np.max([df.gt_pxs.max(), df.dt_pxs.max()])
924
+ # print('Max instances:',max_inst)
925
+ # print('Max xpxs:',max_xpxs)
926
+ # print('Max pxs:',max_pxs)
927
+ return max_inst, max_xpxs, max_pxs
928
+
929
+ def vol_level_prc(self, df, gt_thresh=5, ax=None):
930
+ """Plots a volume-level precision-recall curve.
931
+
932
+ Args:
933
+ df (pd.DataFrame): DataFrame with volume-level statistics.
934
+ gt_thresh (int, optional): The minimum number of ground truth
935
+ instances for a volume to be considered positive. Defaults to 5.
936
+ ax (plt.Axes, optional): Axes to plot on. Defaults to None.
937
+
938
+ Returns:
939
+ tuple[float, tuple]: The average precision and the PR curve data.
940
+ """
941
+ prc = precision_recall_curve(df.gt_instances >= gt_thresh, df.dt_instances)
942
+ if ax is None:
943
+ fig, ax = plt.subplots(1, 1)
944
+ ax.plot(prc[1], prc[0])
945
+ ax.set_xlabel("RPD Volume Recall")
946
+ ax.set_ylabel("RPD Volume Precision")
947
+
948
+ ap = average_precision_score(df.gt_instances >= gt_thresh, df.dt_instances)
949
+ return ap, prc
950
+
951
+ def plot_img_level_instance_thresholding(self, df, inst):
952
+ """Plots P/R/FPR as a function of the instance count threshold.
953
+
954
+ Args:
955
+ df (pd.DataFrame): DataFrame with image-level statistics.
956
+ inst (list[int]): A list of instance count thresholds to evaluate.
957
+
958
+ Returns:
959
+ tuple[np.ndarray, np.ndarray, np.ndarray]: Arrays for precision,
960
+ recall, and FPR at each threshold.
961
+ """
962
+ rc = np.zeros((len(inst),))
963
+ pr = np.zeros((len(inst),))
964
+ fpr = np.zeros((len(inst),))
965
+
966
+ fig, ax = plt.subplots(1, 3, figsize=[15, 5])
967
+ for i, dt_thresh in enumerate(inst):
968
+ gt = df.gt_instances > dt_thresh
969
+ dt = df.dt_instances > dt_thresh
970
+ rc[i] = (gt & dt).sum() / gt.sum()
971
+ pr[i] = (gt & dt).sum() / dt.sum()
972
+ fpr[i] = ((~gt) & (dt)).sum() / ((~gt).sum())
973
+
974
+ ax[1].plot(inst, pr)
975
+ ax[1].set_ylim(0.45, 1.01)
976
+ ax[1].set_xlabel("instance threshold")
977
+ ax[1].set_ylabel("Precision")
978
+
979
+ ax[0].plot(inst, rc)
980
+ ax[0].set_ylim(0.45, 1.01)
981
+ ax[0].set_ylabel("Recall")
982
+ ax[0].set_xlabel("instance threshold")
983
+
984
+ ax[2].plot(inst, fpr)
985
+ ax[2].set_ylim(0, 0.80)
986
+ ax[2].set_xlabel("instance threshold")
987
+ ax[2].set_ylabel("FPR")
988
+
989
+ plt.tight_layout()
990
+ return pr, rc, fpr
991
+
992
+ def plot_img_level_instance_thresholding2(self, df, inst, gt_thresh, plot=True):
993
+ """Plots P/R/FPR vs. instance threshold with confidence intervals.
994
+
995
+ Args:
996
+ df (pd.DataFrame): DataFrame with image-level statistics.
997
+ inst (list[int]): A list of instance count thresholds to evaluate.
998
+ gt_thresh (int): The ground truth instance threshold.
999
+ plot (bool, optional): Whether to generate a plot. Defaults to True.
1000
+
1001
+ Returns:
1002
+ dict: A dictionary containing arrays for P/R/FPR and their CIs.
1003
+ """
1004
+
1005
+ rc = np.zeros((len(inst),))
1006
+ pr = np.zeros((len(inst),))
1007
+ fpr = np.zeros((len(inst),))
1008
+ rc_ci = np.zeros((len(inst), 2))
1009
+ pr_ci = np.zeros((len(inst), 2))
1010
+ fpr_ci = np.zeros((len(inst), 2))
1011
+
1012
+ for i, dt_thresh in enumerate(inst):
1013
+ gt = df.gt_instances >= gt_thresh
1014
+ dt = df.dt_instances >= dt_thresh
1015
+ rc[i] = (gt & dt).sum() / gt.sum()
1016
+ pr[i] = (gt & dt).sum() / dt.sum()
1017
+ fpr[i] = ((~gt) & (dt)).sum() / ((~gt).sum())
1018
+ rc_ci[i, :] = wilson_ci(rc[i], gt.sum(), 1.96)
1019
+ pr_ci[i, :] = wilson_ci(pr[i], dt.sum(), 1.96)
1020
+ fpr_ci[i, :] = wilson_ci(fpr[i], ((~gt).sum()), 1.96)
1021
+
1022
+ if plot:
1023
+ fig, ax = plt.subplots(1, 3, figsize=[15, 5])
1024
+ # ax[0].plot(rc,pr)
1025
+ # ax[0].set_xlabel('Recall')
1026
+ # ax[0].set_ylabel('Precision')
1027
+
1028
+ ax[1].plot(inst, pr)
1029
+ ax[1].fill_between(inst, pr_ci[:, 0], pr_ci[:, 1], alpha=0.25)
1030
+ # ax[1].set_ylim(0.45,1.01)
1031
+ ax[1].set_xlabel("instance threshold")
1032
+ ax[1].set_ylabel("Precision")
1033
+
1034
+ ax[0].plot(inst, rc)
1035
+ ax[0].fill_between(inst, rc_ci[:, 0], rc_ci[:, 1], alpha=0.25)
1036
+ # ax[0].set_ylim(0.45,1.01)
1037
+ ax[0].set_ylabel("Recall")
1038
+ ax[0].set_xlabel("instance threshold")
1039
+
1040
+ ax[2].plot(inst, fpr)
1041
+ ax[2].fill_between(inst, fpr_ci[:, 0], fpr_ci[:, 1], alpha=0.25)
1042
+ # ax[2].set_ylim(0,0.80)
1043
+ ax[2].set_xlabel("instance threshold")
1044
+ ax[2].set_ylabel("FPR")
1045
+
1046
+ plt.tight_layout()
1047
+ return dict(precision=pr, precision_ci=pr_ci, recall=rc, recall_ci=rc_ci, fpr=fpr, fpr_ci=fpr_ci)
1048
+
1049
+ def gt_vs_dt_instances(self, ax=None):
1050
+ """Plots mean detected instances vs. ground truth instances with error bars.
1051
+
1052
+ Args:
1053
+ ax (plt.Axes, optional): Axes to plot on. Defaults to None.
1054
+
1055
+ Returns:
1056
+ plt.Axes: The axes object with the plot.
1057
+ """
1058
+ df = self.dfimg
1059
+ max_inst, max_xpxs, max_pxs = self.get_max_limits(df)
1060
+ idx = (df.gt_instances > 0) & (df.dt_instances > 0)
1061
+
1062
+ if ax is None:
1063
+ fig = plt.figure(dpi=100)
1064
+ ax = fig.add_subplot(111)
1065
+
1066
+ y = df[idx].groupby("gt_instances")["dt_instances"].mean()
1067
+ yerr = df[idx].groupby("gt_instances")["dt_instances"].std()
1068
+ ax.errorbar(y.index, y.values, yerr.values, fmt="*")
1069
+ plt.plot([0, max_inst], [0, max_inst], alpha=0.5)
1070
+ plt.xlim(0, max_inst + 1)
1071
+ plt.ylim(0, max_inst + 1)
1072
+ ax.set_aspect(1)
1073
+ plt.xlabel("gt_instances")
1074
+ plt.ylabel("dt_instances")
1075
+ plt.tight_layout()
1076
+ return ax
1077
+
1078
+ def gt_vs_dt_instances_boxplot(self, ax=None):
1079
+ """Creates a boxplot of detected instances for each ground truth instance count.
1080
+
1081
+ Args:
1082
+ ax (plt.Axes, optional): Axes to plot on. Defaults to None.
1083
+
1084
+ Returns:
1085
+ plt.Axes: The axes object with the plot.
1086
+ """
1087
+ df = self.dfimg
1088
+ max_inst, max_xpxs, max_pxs = self.get_max_limits(df)
1089
+ max_inst = int(max_inst)
1090
+ if ax is None:
1091
+ fig = plt.figure(dpi=100)
1092
+ ax = fig.add_subplot(111)
1093
+
1094
+ ax.plot([0, max_inst + 1], [0, max_inst + 1], alpha=0.5)
1095
+ x = df["gt_instances"].values.astype(int)
1096
+ y = df["dt_instances"].values.astype(int)
1097
+ sns.boxplot(x, y, ax=ax, width=0.5)
1098
+ ax.set_xbound(0, max_inst + 1)
1099
+ ax.set_ybound(0, max_inst + 1)
1100
+ ax.set_aspect("equal")
1101
+
1102
+ ax.set_title("")
1103
+ ax.set_xlabel("gt_instances")
1104
+ ax.set_ylabel("dt_instances")
1105
+
1106
+ import matplotlib.ticker as pltticker
1107
+
1108
+ loc = pltticker.MultipleLocator(base=2.0)
1109
+ ax.xaxis.set_major_locator(loc)
1110
+ ax.yaxis.set_major_locator(loc)
1111
+
1112
+ return ax
1113
+
1114
+ def gt_vs_dt_xpxs(self):
1115
+ """Creates scatter plots comparing ground truth and detected x-pixels.
1116
+
1117
+ Returns:
1118
+ tuple[plt.Figure, plt.Figure, plt.Figure]: Figure handles for the three generated plots.
1119
+ """
1120
+ df = self.dfimg
1121
+ max_inst, max_xpxs, max_pxs = self.get_max_limits(df)
1122
+ idx = (df.gt_instances > 0) & (df.dt_instances > 0)
1123
+ dfsub = df[idx]
1124
+
1125
+ fig1 = plt.figure(figsize=[10, 10], dpi=100)
1126
+ ax = fig1.add_subplot(111)
1127
+ sc = ax.scatter(dfsub["gt_xpxs"], dfsub["dt_xpxs"], c=dfsub["gt_instances"], cmap="viridis")
1128
+ ax.set_aspect(1)
1129
+ # ax = dfsub.plot(kind = 'scatter',x=,y=,c='gt_instances')
1130
+ plt.plot([0, max_xpxs], [0, max_xpxs], alpha=0.5)
1131
+ plt.xlim(0, max_xpxs)
1132
+ plt.ylim(0, max_xpxs)
1133
+ plt.xlabel("gt_xpxs")
1134
+ plt.ylabel("dt_xpxs")
1135
+ cbar = plt.colorbar(sc)
1136
+ cbar.ax.set_ylabel("gt_instances")
1137
+ plt.tight_layout()
1138
+
1139
+ fig2 = plt.figure(figsize=[10, 10], dpi=100)
1140
+ ax = fig2.add_subplot(111)
1141
+ sc = ax.scatter(dfsub["gt_xpxs"], dfsub["gt_xpxs"] - dfsub["dt_xpxs"], c=dfsub["gt_instances"], cmap="viridis")
1142
+ # ax = dfsub.plot(kind = 'scatter',x=,y=,c='gt_instances')
1143
+ plt.plot([0, max_xpxs], [0, 0], alpha=0.5)
1144
+ plt.xlabel("gt_xpxs")
1145
+ plt.ylabel("gt_xpxs-dt_xpxs")
1146
+ cbar = plt.colorbar(sc)
1147
+ cbar.ax.set_ylabel("gt_instances")
1148
+ plt.tight_layout()
1149
+
1150
+ fig3 = plt.figure(dpi=100)
1151
+ plt.hist(dfsub["gt_xpxs"] - dfsub["dt_xpxs"])
1152
+ plt.xlabel("gt_xpxs - dt_xpxs")
1153
+ plt.ylabel("B-scans")
1154
+
1155
+ return fig1, fig2, fig3
1156
+
1157
+ def gt_vs_dt_xpxs_mu(self):
1158
+ """Plots binned means of detected vs. ground truth x-pixels.
1159
+
1160
+ Returns:
1161
+ plt.Figure: The figure handle for the plot.
1162
+ """
1163
+ df = self.dfimg
1164
+ max_inst, max_xpxs, max_pxs = self.get_max_limits(df)
1165
+ idx = (df.gt_instances > 0) & (df.dt_instances > 0)
1166
+ dfsub = df[idx]
1167
+
1168
+ from scipy import stats
1169
+
1170
+ mu_dt, bins, bnum = stats.binned_statistic(dfsub["gt_xpxs"], dfsub["dt_xpxs"], statistic="mean", bins=10)
1171
+ std_dt, _, _ = stats.binned_statistic(dfsub["gt_xpxs"], dfsub["dt_xpxs"], statistic="std", bins=bins)
1172
+ mu_gt, _, _ = stats.binned_statistic(dfsub["gt_xpxs"], dfsub["gt_xpxs"], statistic="mean", bins=bins)
1173
+ std_gt, _, _ = stats.binned_statistic(dfsub["gt_xpxs"], dfsub["gt_xpxs"], statistic="std", bins=bins)
1174
+ fig = plt.figure(dpi=100)
1175
+ plt.errorbar(mu_gt, mu_dt, yerr=std_dt, xerr=std_gt, fmt="*")
1176
+ plt.xlabel("gt_xpxs")
1177
+ plt.ylabel("dt_xpxs")
1178
+ plt.plot([0, max_xpxs], [0, max_xpxs], alpha=0.5)
1179
+ plt.xlim(0, max_xpxs)
1180
+ plt.ylim(0, max_xpxs)
1181
+ plt.gca().set_aspect(1)
1182
+ plt.tight_layout()
1183
+ return fig
1184
+
1185
+ def gt_dt_fp_fn_count(self):
1186
+ """Plots histograms of false positive and false negative instance counts.
1187
+
1188
+ Returns:
1189
+ plt.Figure: The figure handle for the plot.
1190
+ """
1191
+ df = self.dfimg
1192
+ fig, ax = plt.subplots(1, 2, figsize=[10, 5])
1193
+
1194
+ idx = (df.gt_instances == 0) & (df.dt_instances > 0)
1195
+ ax[0].hist(df[idx]["dt_instances"], bins=range(1, 10))
1196
+ ax[0].set_xlabel("dt instances")
1197
+ ax[0].set_ylabel("B-scans")
1198
+ ax[0].set_title("FP dt instance count per B-scan")
1199
+
1200
+ idx = (df.gt_instances > 0) & (df.dt_instances == 0)
1201
+ ax[1].hist(df[idx]["gt_instances"], bins=range(1, 10))
1202
+ ax[1].set_xlabel("gt instances")
1203
+ ax[1].set_ylabel("B-scans")
1204
+ ax[1].set_title("FN gt instance count per B-scan")
1205
+
1206
+ plt.tight_layout()
1207
+ return fig
1208
+
1209
+ def avg_inst_size(self):
1210
+ """Plots histograms of the average instance size in pixels.
1211
+
1212
+ Compares the average size (in both total pixels and x-axis projection)
1213
+ between ground truth and detected instances.
1214
+
1215
+ Returns:
1216
+ plt.Figure: The figure handle for the plot.
1217
+ """
1218
+ df = self.dfimg
1219
+ max_inst, max_xpxs, max_pxs = self.get_max_limits(df)
1220
+ idx = (df.gt_instances > 0) & (df.dt_instances > 0)
1221
+ dfsub = df[idx]
1222
+
1223
+ fig = plt.figure(figsize=[10, 5])
1224
+ plt.subplot(121)
1225
+ bins = np.arange(0, 120, 10)
1226
+ ax = (dfsub.gt_xpxs / dfsub.gt_instances).hist(bins=bins, alpha=0.5, label="gt")
1227
+ ax = (dfsub.dt_xpxs / dfsub.dt_instances).hist(bins=bins, alpha=0.5, label="dt")
1228
+ ax.set_xlabel("xpxs")
1229
+ ax.set_ylabel("B-scans")
1230
+ ax.set_title("Average size of instance")
1231
+ ax.legend()
1232
+
1233
+ plt.subplot(122)
1234
+ bins = np.arange(0, 600, 40)
1235
+ ax = (dfsub.gt_pxs / dfsub.gt_instances).hist(bins=bins, alpha=0.5, label="gt")
1236
+ ax = (dfsub.dt_pxs / dfsub.dt_instances).hist(bins=bins, alpha=0.5, label="dt")
1237
+ ax.set_xlabel("pxs")
1238
+ ax.set_ylabel("B-scans")
1239
+ ax.set_title("Average size of instance")
1240
+ ax.legend()
1241
+
1242
+ plt.tight_layout()
1243
+ return fig
scripts/datasets/__init__.py ADDED
File without changes
scripts/datasets/data.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import distutils.util
2
+ import glob
3
+ import os
4
+ import shutil
5
+
6
+ import cv2
7
+ import pandas as pd
8
+ from PIL import Image
9
+ from pydicom import dcmread
10
+ from pydicom.fileset import FileSet
11
+ from tqdm import tqdm
12
+
13
+ from .volReader import VolFile
14
+
15
+ script_dir = os.path.dirname(__file__)
16
+
17
+
18
+ class Error(Exception):
19
+ """Base class for exceptions in this module."""
20
+
21
+ pass
22
+
23
+
24
+ def extract_files(dirtoextract, extracted_path, input_format):
25
+ """Extracts individual image frames from .vol or DICOM files.
26
+
27
+ This function scans a directory for source files of a specified format
28
+ and extracts them into a structured output directory as PNG images.
29
+ It handles both .vol files and standard DICOM files. If the
30
+ output directory already contains files, it will prompt the user
31
+ before proceeding to overwrite them.
32
+
33
+ Args:
34
+ dirtoextract (str): The root directory to search for source files.
35
+ extracted_path (str): The destination directory where the extracted
36
+ PNG images will be saved.
37
+ input_format (str): The format of the input files. Must be either
38
+ "vol" or "dicom".
39
+ """
40
+ assert input_format in ["vol", "dicom"], 'Error: input_format must be "vol" or "dicom".'
41
+ proceed = True
42
+ if (os.path.isdir(extracted_path)) and (len(os.listdir(extracted_path)) != 0):
43
+ val = input(
44
+ f"{extracted_path} exists and is not empty. Files may be overwritten. Proceed with extraction? (Y/N)"
45
+ )
46
+ proceed = bool(distutils.util.strtobool(val))
47
+ if proceed:
48
+ print(f"Extracting files from {dirtoextract} into {extracted_path}...")
49
+ if input_format == "vol":
50
+ files_to_extract = glob.glob(os.path.join(dirtoextract, "**/*.vol"), recursive=True)
51
+ for _, line in enumerate(tqdm(files_to_extract)):
52
+ fpath = line.strip("\n")
53
+ vol = VolFile(fpath)
54
+ fpath = fpath.replace("\\", "/")
55
+ path, scan_str = fpath.strip(".vol").rsplit("/", 1)
56
+ extractpath = os.path.join(extracted_path, scan_str.replace("_", "/"))
57
+ os.makedirs(extractpath, exist_ok=True)
58
+ preffix = os.path.join(extractpath, scan_str + "_oct")
59
+ vol.render_oct_scans(preffix)
60
+ elif input_format == "dicom":
61
+ keywords = ["SOPInstanceUID", "PatientID", "ImageLaterality", "SeriesDate"]
62
+ list_of_dicts = []
63
+ dirgen = glob.iglob(os.path.join(dirtoextract, "**/DICOMDIR"), recursive=True)
64
+
65
+ for dsstr in dirgen:
66
+ fs = FileSet(dcmread(dsstr))
67
+ fsgenopt = gen_opt_fs(fs)
68
+ for fi in tqdm(fsgenopt):
69
+ dd = dict()
70
+ # top level keywords
71
+ for key in keywords:
72
+ dd[key] = fi.get(key)
73
+
74
+ volpath = os.path.join(extracted_path, f"{fi.SOPInstanceUID}")
75
+ shutil.rmtree(volpath, ignore_errors=True)
76
+ os.mkdir(volpath)
77
+ n = fi.NumberOfFrames
78
+ for i in range(n):
79
+ fname = os.path.join(volpath, f"{fi.SOPInstanceUID}_oct_{i:03d}.png")
80
+ Image.fromarray(fi.pixel_array[i]).save(fname)
81
+ list_of_dicts.append(dd.copy())
82
+ dfoct = pd.DataFrame(list_of_dicts, columns=keywords)
83
+ dfoct.to_csv(os.path.join(extracted_path, "basic_meta.csv"))
84
+ else:
85
+ pass
86
+
87
+
88
+ def rpd_data(extracted_path):
89
+ """Generates a dataset list from a directory of extracted image files.
90
+
91
+ Scans a directory recursively for PNG images and creates a list of
92
+ dictionaries, one for each image. This format is designed to be compatible
93
+ with Detectron2's `DatasetCatalog` and can be adapted to hold ground truth instances for evaluation.
94
+
95
+ Args:
96
+ extracted_path (str): The root directory containing the extracted
97
+ .png image files to be included in the dataset.
98
+
99
+ Returns:
100
+ list[dict]: A list where each dictionary represents an image and
101
+ contains its file path, dimensions, and a unique ID.
102
+ """
103
+ dataset = []
104
+ extracted_files = glob.glob(os.path.join(extracted_path, "**/*.[Pp][Nn][Gg]"), recursive=True)
105
+ print("Generating dataset of images...")
106
+ for fn in tqdm(extracted_files):
107
+ fn_adjusted = fn.replace("\\", "/")
108
+ imageid = fn_adjusted.split("/")[-1]
109
+ im = cv2.imread(fn)
110
+ dat = dict(file_name=fn_adjusted, height=im.shape[0], width=im.shape[1], image_id=imageid)
111
+ dataset.append(dat)
112
+ print(f"Found {len(dataset)} images")
113
+ return dataset
114
+
115
+
116
+ def gen_opt_fs(fs):
117
+ """A generator for finding and loading OPT modality DICOM datasets.
118
+
119
+ This function filters a pydicom `FileSet` object for instances that have
120
+ the modality set to "OPT" (Ophthalmic Tomography) and yields each one
121
+ as a fully loaded pydicom dataset.
122
+
123
+ Args:
124
+ fs (pydicom.fileset.FileSet): The pydicom FileSet to search through.
125
+
126
+ Yields:
127
+ pydicom.dataset.FileDataset: A loaded DICOM dataset for each instance
128
+ with the "OPT" modality found in the FileSet.
129
+ """
130
+ for instance in fs.find(Modality="OPT"):
131
+ ds = instance.load()
132
+ yield ds
scripts/datasets/volReader.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Aaron Y. Lee MD MSCI (University of Washington) Copyright 2019
2
+ #
3
+ # Code ported from Markus Mayer's excellent work (https://www5.cs.fau.de/research/software/octseg/)
4
+ #
5
+ # Also thanks to who contributed to the original openVol.m in Markus's project
6
+ # Radim Kolar, Brno University, Czech Republic
7
+ # Kris Sheets, Retinal Cell Biology Lab, Neuroscience Center of Excellence, LSU Health Sciences Center, New Orleans
8
+
9
+
10
+ import array
11
+ import codecs
12
+ import datetime
13
+ import struct
14
+ from collections import OrderedDict
15
+
16
+ import numpy as np
17
+
18
+
19
+ class VolFile:
20
+ def __init__(self, filename):
21
+ """
22
+ Parses Heyex Spectralis *.vol files.
23
+
24
+ Args:
25
+ filename (str): Path to vol file
26
+
27
+ Returns:
28
+ volFile class
29
+
30
+ """
31
+ self.__parse_volfile(filename)
32
+
33
+ @property
34
+ def oct(self):
35
+ """
36
+ Retrieve OCT volume as a 3D numpy array.
37
+
38
+ Returns:
39
+ 3D numpy array with OCT intensities as 'uint8' array
40
+
41
+ """
42
+ return self.wholefile["cScan"]
43
+
44
+ @property
45
+ def irslo(self):
46
+ """
47
+ Retrieve IR SLO image as 2D numpy array
48
+
49
+ Returns:
50
+ 2D numpy array with IR reflectance SLO image as 'uint8' array.
51
+
52
+ """
53
+ return self.wholefile["sloImage"]
54
+
55
+ @property
56
+ def grid(self):
57
+ """
58
+ Retrieve the IR SLO pixel coordinates for the B scan OCT slices
59
+
60
+ Returns:
61
+ 2D numpy array with the number of b scan images in the first dimension
62
+ and x_0, y_0, x_1, y_1 defining the line of the B scan on the pixel
63
+ coordinates of the IR SLO image.
64
+
65
+ """
66
+ wf = self.wholefile
67
+ grid = []
68
+ for bi in range(len(wf["slice-headers"])):
69
+ bscan_head = wf["slice-headers"][bi]
70
+ x_0 = int(bscan_head["startX"] / wf["header"]["scaleXSlo"])
71
+ x_1 = int(bscan_head["endX"] / wf["header"]["scaleXSlo"])
72
+ y_0 = int(bscan_head["startY"] / wf["header"]["scaleYSlo"])
73
+ y_1 = int(bscan_head["endY"] / wf["header"]["scaleYSlo"])
74
+ grid.append([x_0, y_0, x_1, y_1])
75
+ return grid
76
+
77
+ def render_ir_slo(self, filename, render_grid=False):
78
+ """
79
+ Renders IR SLO image as a PNG file and optionally overlays grid of B scans
80
+
81
+ Args:
82
+ filename (str): filename to save IR SLO image
83
+ renderGrid (bool): True will render red lines for the location of the B scans.
84
+
85
+ Returns:
86
+ None
87
+
88
+ """
89
+ from PIL import Image, ImageDraw
90
+
91
+ wf = self.wholefile
92
+ a = np.copy(wf["sloImage"])
93
+ if render_grid:
94
+ a = np.stack((a,) * 3, axis=-1)
95
+ a = Image.fromarray(a)
96
+ draw = ImageDraw.Draw(a)
97
+ grid = self.grid
98
+ for x_0, y_0, x_1, y_1 in grid:
99
+ draw.line((x_0, y_0, x_1, y_1), fill=(255, 0, 0), width=3)
100
+ a.save(filename)
101
+ else:
102
+ Image.fromarray(a).save(filename)
103
+
104
+ def render_oct_scans(self, filepre="oct", render_seg=False):
105
+ """
106
+ Renders OCT images a PNG file and optionally overlays segmentation lines
107
+ Also creates a CSV file of vol file features.
108
+
109
+ Args:
110
+ filepre (str): filename prefix. OCT Images will be named as "<prefix>_001.png"
111
+ renderSeg (bool): True will render colored lines for the segmentation of the RPE, ILM, and NFL on the B scans.
112
+
113
+ Returns:
114
+ None
115
+
116
+ """
117
+ from PIL import Image
118
+
119
+ wf = self.wholefile
120
+ for i in range(wf["cScan"].shape[0]):
121
+ a = np.copy(wf["cScan"][i])
122
+ if render_seg:
123
+ a = np.stack((a,) * 3, axis=-1)
124
+ for li in range(wf["segmentations"].shape[0]):
125
+ for x in range(wf["segmentations"].shape[2]):
126
+ a[int(wf["segmentations"][li, i, x]), x, li] = 255
127
+
128
+ Image.fromarray(a).save("%s_%03d.png" % (filepre, i))
129
+
130
+ def __parse_volfile(self, fn, parse_seg=False):
131
+ print(fn)
132
+ wholefile = OrderedDict()
133
+ decode_hex = codecs.getdecoder("hex_codec")
134
+ with open(fn, "rb") as fin:
135
+ header = OrderedDict()
136
+ header["version"] = fin.read(12)
137
+ header["octSizeX"] = struct.unpack("I", fin.read(4))[0] # lateral resolution
138
+ header["numBscan"] = struct.unpack("I", fin.read(4))[0]
139
+ header["octSizeZ"] = struct.unpack("I", fin.read(4))[0] # OCT depth
140
+ header["scaleX"] = struct.unpack("d", fin.read(8))[0]
141
+ header["distance"] = struct.unpack("d", fin.read(8))[0]
142
+ header["scaleZ"] = struct.unpack("d", fin.read(8))[0]
143
+ header["sizeXSlo"] = struct.unpack("I", fin.read(4))[0]
144
+ header["sizeYSlo"] = struct.unpack("I", fin.read(4))[0]
145
+ header["scaleXSlo"] = struct.unpack("d", fin.read(8))[0]
146
+ header["scaleYSlo"] = struct.unpack("d", fin.read(8))[0]
147
+ header["fieldSizeSlo"] = struct.unpack("I", fin.read(4))[0] # FOV in degrees
148
+ header["scanFocus"] = struct.unpack("d", fin.read(8))[0]
149
+ header["scanPos"] = fin.read(4)
150
+ header["examTime"] = struct.unpack("=q", fin.read(8))[0] / 1e7
151
+ header["examTime"] = datetime.datetime.utcfromtimestamp(
152
+ header["examTime"] - (369 * 365.25 + 4) * 24 * 60 * 60
153
+ ) # needs to be checked
154
+ header["scanPattern"] = struct.unpack("I", fin.read(4))[0]
155
+ header["BscanHdrSize"] = struct.unpack("I", fin.read(4))[0]
156
+ header["ID"] = fin.read(16)
157
+ header["ReferenceID"] = fin.read(16)
158
+ header["PID"] = struct.unpack("I", fin.read(4))[0]
159
+ header["PatientID"] = fin.read(21)
160
+ header["unknown2"] = fin.read(3)
161
+ header["DOB"] = struct.unpack("d", fin.read(8))[0] - 25569
162
+ header["DOB"] = datetime.datetime.utcfromtimestamp(0) + datetime.timedelta(
163
+ seconds=header["DOB"] * 24 * 60 * 60
164
+ ) # needs to be checked
165
+ header["VID"] = struct.unpack("I", fin.read(4))[0]
166
+ header["VisitID"] = fin.read(24)
167
+ header["VisitDate"] = struct.unpack("d", fin.read(8))[0] - 25569
168
+ header["VisitDate"] = datetime.datetime.utcfromtimestamp(0) + datetime.timedelta(
169
+ seconds=header["VisitDate"] * 24 * 60 * 60
170
+ ) # needs to be checked
171
+ header["GridType"] = struct.unpack("I", fin.read(4))[0]
172
+ header["GridOffset"] = struct.unpack("I", fin.read(4))[0]
173
+
174
+ wholefile["header"] = header
175
+ fin.seek(2048)
176
+ u = array.array("B")
177
+ u.frombytes(fin.read(header["sizeXSlo"] * header["sizeYSlo"]))
178
+ u = np.array(u).astype("uint8").reshape((header["sizeXSlo"], header["sizeYSlo"]))
179
+ wholefile["sloImage"] = u
180
+
181
+ slo_offset = 2048 + header["sizeXSlo"] * header["sizeYSlo"]
182
+ oct_offset = header["BscanHdrSize"] + header["octSizeX"] * header["octSizeZ"] * 4
183
+ bscans = []
184
+ bscanheaders = []
185
+ bscanqualities = []
186
+ if parse_seg:
187
+ segmentations = None
188
+ for i in range(header["numBscan"]):
189
+ fin.seek(16 + slo_offset + i * oct_offset)
190
+ bscan_head = OrderedDict()
191
+ bscan_head["startX"] = struct.unpack("d", fin.read(8))[0]
192
+ bscan_head["startY"] = struct.unpack("d", fin.read(8))[0]
193
+ bscan_head["endX"] = struct.unpack("d", fin.read(8))[0]
194
+ bscan_head["endY"] = struct.unpack("d", fin.read(8))[0]
195
+ bscan_head["numSeg"] = struct.unpack("I", fin.read(4))[0]
196
+ bscan_head["offSeg"] = struct.unpack("I", fin.read(4))[0]
197
+ bscan_head["quality"] = struct.unpack("f", fin.read(4))[0]
198
+ bscan_head["shift"] = struct.unpack("I", fin.read(4))[0]
199
+ bscanheaders.append(bscan_head)
200
+ bscanqualities.append(bscan_head["quality"])
201
+
202
+ # extract OCT B scan data
203
+ fin.seek(header["BscanHdrSize"] + slo_offset + i * oct_offset)
204
+ u = array.array("f")
205
+ u.frombytes(fin.read(4 * header["octSizeX"] * header["octSizeZ"]))
206
+ u = np.array(u).reshape((header["octSizeZ"], header["octSizeX"]))
207
+ # remove out of boundary
208
+ v = struct.unpack("f", decode_hex("FFFF7F7F")[0])
209
+ u[u == v] = 0
210
+ # log normalize
211
+ u = np.log(10000 * u + 1)
212
+ u = (255.0 * (np.clip(u, 0, np.max(u)) / np.max(u))).astype("uint8")
213
+ bscans.append(u)
214
+ if parse_seg:
215
+ # extract OCT segmentations data
216
+ fin.seek(256 + slo_offset + i * oct_offset)
217
+ u = array.array("f")
218
+ u.frombytes(fin.read(4 * header["octSizeX"] * bscan_head["numSeg"]))
219
+ u = np.array(u)
220
+ print(u.shape)
221
+ u[u == v] = 0.0
222
+ if segmentations is None:
223
+ segmentations = []
224
+ for _ in range(bscan_head["numSeg"]):
225
+ segmentations.append([])
226
+
227
+ for j in range(bscan_head["numSeg"]):
228
+ segmentations[j].append(u[j * header["octSizeX"] : (j + 1) * header["octSizeX"]].tolist())
229
+ wholefile["cScan"] = np.array(bscans)
230
+ if parse_seg:
231
+ wholefile["segmentations"] = np.array(segmentations)
232
+ wholefile["slice-headers"] = bscanheaders
233
+ wholefile["average-quality"] = np.mean(bscanqualities)
234
+ self.wholefile = wholefile
235
+ import csv
236
+ from pathlib import Path, PurePath
237
+
238
+ vol_features = [
239
+ PurePath(fn).name,
240
+ wholefile["header"]["version"].decode("utf-8").rstrip("\x00"),
241
+ wholefile["header"]["numBscan"],
242
+ wholefile["header"]["octSizeX"],
243
+ wholefile["header"]["octSizeZ"],
244
+ wholefile["header"]["distance"],
245
+ wholefile["header"]["scaleX"],
246
+ wholefile["header"]["scaleZ"],
247
+ wholefile["header"]["sizeXSlo"],
248
+ wholefile["header"]["sizeYSlo"],
249
+ wholefile["header"]["scaleXSlo"],
250
+ wholefile["header"]["scaleYSlo"],
251
+ wholefile["header"]["fieldSizeSlo"],
252
+ wholefile["header"]["scanFocus"],
253
+ wholefile["header"]["scanPos"].decode("utf-8").rstrip("\x00"),
254
+ wholefile["header"]["examTime"],
255
+ wholefile["header"]["scanPattern"],
256
+ wholefile["header"]["BscanHdrSize"],
257
+ wholefile["header"]["ID"].decode("utf-8").rstrip("\x00"),
258
+ wholefile["header"]["ReferenceID"].decode("utf-8").rstrip("\x00"),
259
+ wholefile["header"]["PID"],
260
+ wholefile["header"]["PatientID"].decode("utf-8").rstrip("\x00"),
261
+ wholefile["header"]["DOB"],
262
+ wholefile["header"]["VID"],
263
+ wholefile["header"]["VisitID"].decode("utf-8").rstrip("\x00"),
264
+ wholefile["header"]["VisitDate"],
265
+ wholefile["header"]["GridType"],
266
+ wholefile["header"]["GridOffset"],
267
+ wholefile["average-quality"],
268
+ ]
269
+ output_dir = PurePath(fn).parent
270
+ output_csv = output_dir.joinpath("vols.csv")
271
+ if not Path(output_csv).exists():
272
+ print("Creating vols.csv as it does not exist.")
273
+ with open(output_csv, "w", newline="") as file:
274
+ writer = csv.writer(file)
275
+ writer.writerow(
276
+ [
277
+ "filename",
278
+ "version",
279
+ "numBscan",
280
+ "octSizeX",
281
+ "octSizeZ",
282
+ "distance",
283
+ "scaleX",
284
+ "scaleZ",
285
+ "sizeXSlo",
286
+ "sizeYSlo",
287
+ "scaleXSlo",
288
+ "scaleYSlo",
289
+ "fieldSizeSlo",
290
+ "scanFocus",
291
+ "scanPos",
292
+ "examTime",
293
+ "scanPattern",
294
+ "BscanHdrSize",
295
+ "ID",
296
+ "ReferenceID",
297
+ "PID",
298
+ "PatientID",
299
+ "DOB",
300
+ "VID",
301
+ "VisitID",
302
+ "VisitDate",
303
+ "GridType",
304
+ "GridOffset",
305
+ "Average Quality",
306
+ ]
307
+ )
308
+ with open(output_csv, "r", newline="") as file:
309
+ existing_vols = csv.reader(file)
310
+ for vol in existing_vols:
311
+ if vol[0] == PurePath(fn).name:
312
+ print("Skipping,", PurePath(fn).name, "already present in vols.csv.")
313
+ return
314
+ with open(output_csv, "a", newline="") as file:
315
+ print("Adding", PurePath(fn).name, "to vols.csv.")
316
+ writer = csv.writer(file)
317
+ writer.writerow(vol_features)
318
+
319
+ @property
320
+ def file_header(self):
321
+ """
322
+ Retrieve vol header fields
323
+
324
+ Returns:
325
+ Dictionary with the following keys
326
+ - version: version number of vol file definition
327
+ - numBscan: number of B scan images in the volume
328
+ - octSizeX: number of pixels in the width of the OCT B scan
329
+ - octSizeZ: number of pixels in the height of the OCT B scan
330
+ - distance: unknown
331
+ - scaleX: resolution scaling factor of the width of the OCT B scan
332
+ - scaleZ: resolution scaling factor of the height of the OCT B scan
333
+ - sizeXSlo: number of pixels in the width of the IR SLO image
334
+ - sizeYSlo: number of pixels in the height of the IR SLO image
335
+ - scaleXSlo: resolution scaling factor of the width of the IR SLO image
336
+ - scaleYSlo: resolution scaling factor of the height of the IR SLO image
337
+ - fieldSizeSlo: field of view (FOV) of the retina in degrees
338
+ - scanFocus: unknown
339
+ - scanPos: Left or Right eye scanned
340
+ - examTime: Datetime of the scan (needs to be checked)
341
+ - scanPattern: unknown
342
+ - BscanHdrSize: size of B scan header in bytes
343
+ - ID: unknown
344
+ - ReferenceID
345
+ - PID: unknown
346
+ - PatientID: Patient ID string
347
+ - DOB: Date of birth
348
+ - VID: unknown
349
+ - VisitID: Visit ID string
350
+ - VisitDate: Datetime of visit (needs to be checked)
351
+ - GridType: unknown
352
+ - GridOffset: unknown
353
+
354
+ """
355
+ return self.wholefile["header"]
356
+
357
+ def bscan_header(self, slicei):
358
+ """
359
+ Retrieve the B Scan header information per slice.
360
+
361
+ Args:
362
+ slicei (int): index of B scan
363
+
364
+ Returns:
365
+ Dictionary with the following keys
366
+ - startX: x-coordinate for B scan on IR. (see getGrid)
367
+ - startY: y-coordinate for B scan on IR. (see getGrid)
368
+ - endX: x-coordinate for B scan on IR. (see getGrid)
369
+ - endY: y-coordinate for B scan on IR. (see getGrid)
370
+ - numSeg: 2 or 3 segmentation lines for the B scan
371
+ - quality: OCT signal quality
372
+ - shift: unknown
373
+
374
+ """
375
+ return self.wholefile["slice-headers"][slicei]
376
+
377
+ def save_grid(self, outfn):
378
+ """
379
+ Saves the grid coordinates mapping OCT Bscans to the IR SLO image to a text file. The text file
380
+ will be a tab-delimited file with 5 columns: The bscan number, x_0, y_0, x_1, y_1 in pixel space
381
+ scaled to the resolution of the IR SLO image.
382
+
383
+ Args:
384
+ outfn (str): location of where to output the file
385
+
386
+ Returns:
387
+ None
388
+
389
+ """
390
+ grid = self.grid
391
+ with open(outfn, "w") as fout:
392
+ fout.write("bscan\tx_0\ty_0\tx_1\ty_1\n")
393
+ ri = 0
394
+ for r in grid:
395
+ r = [ri] + r
396
+ fout.write("%s\n" % "\t".join(map(str, r)))
397
+ ri += 1
scripts/inference.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pickle
5
+
6
+ import pandas as pd
7
+ import progressbar
8
+ from detectron2.checkpoint import DetectionCheckpointer
9
+ from detectron2.config import get_cfg
10
+ from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader
11
+ from detectron2.evaluation import COCOEvaluator, inference_on_dataset
12
+ from detectron2.modeling import build_model
13
+
14
+ from .analysis_lib import CreatePlotsRPD, EvaluateClass, OutputVis, grab_dataset
15
+ from .datasets import data
16
+ from .Ensembler import Ensembler
17
+ from .table_styles import styles
18
+
19
+ # Change directory to the script's location to ensure relative paths work correctly.
20
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
21
+
22
+
23
+ logging.basicConfig(level=logging.INFO)
24
+
25
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
26
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
27
+
28
+ dpi = 120
29
+
30
+
31
+ class MyProgressBar:
32
+ # https://stackoverflow.com/a/53643011/3826929
33
+ # George C
34
+ def __init__(self):
35
+ self.pbar = None
36
+
37
+ def __call__(self, block_num, block_size, total_size):
38
+ if not self.pbar:
39
+ self.pbar = progressbar.ProgressBar(maxval=total_size)
40
+ self.pbar.start()
41
+
42
+ downloaded = block_num * block_size
43
+ if downloaded < total_size:
44
+ self.pbar.update(downloaded)
45
+ else:
46
+ self.pbar.finish()
47
+
48
+
49
+ def create_dataset(dataset_name, extracted_path):
50
+ """Creates a pickled dataset file from a directory of extracted images.
51
+
52
+ This function scans the `extracted_path` for images, formats them into a
53
+ list of dictionaries compatible with Detectron2, and saves the list as a
54
+ pickle file.
55
+
56
+ Args:
57
+ dataset_name (str): The name for the dataset, used for the output .pk file.
58
+ extracted_path (str): The directory containing the extracted image files.
59
+ """
60
+ stored_data = data.rpd_data(extracted_path)
61
+ pickle.dump(stored_data, open(os.path.join(data.script_dir, f"{dataset_name}.pk"), "wb"))
62
+
63
+
64
+ def configure_model():
65
+ """Loads and returns the model configuration from a YAML file.
66
+
67
+ It reads a 'working.yaml' file located in the same directory as the script
68
+ to set up the Detectron2 configuration.
69
+
70
+ Returns:
71
+ detectron2.config.CfgNode: The configuration object for the model.
72
+ """
73
+ cfg = get_cfg()
74
+ moddir = os.path.dirname(os.path.realpath(__file__))
75
+ name = "working.yaml"
76
+ cfg_path = os.path.join(moddir, name)
77
+ cfg.merge_from_file(cfg_path)
78
+ return cfg
79
+
80
+
81
+ def register_dataset(dataset_name):
82
+ """Registers a dataset with Detectron2's DatasetCatalog.
83
+
84
+ This makes the dataset available to be loaded by Detectron2's data loaders.
85
+ It sets the class metadata to 'rpd'.
86
+
87
+ Args:
88
+ dataset_name (str): The name under which to register the dataset.
89
+ """
90
+ for name in [dataset_name]:
91
+ try:
92
+ DatasetCatalog.register(name, grab_dataset(name))
93
+ except AssertionError as e:
94
+ print(f"Assertion failed: {e}. Already registered.")
95
+ MetadataCatalog.get(name).thing_classes = ["rpd"]
96
+
97
+
98
+ def run_prediction(cfg, dataset_name, output_path):
99
+ """Runs inference on a dataset using a cross-validation ensemble of models.
100
+
101
+ It loads five different model weight files (fold1 to fold5), runs inference
102
+ for each model on the specified dataset, and saves the predictions in
103
+ separate subdirectories within `output_path`.
104
+
105
+ Args:
106
+ cfg (CfgNode): The model configuration object.
107
+ dataset_name (str): The name of the registered dataset to run inference on.
108
+ output_path (str): The base directory to save prediction outputs.
109
+ """
110
+ model = build_model(cfg) # returns a torch.nn.Module
111
+ myloader = build_detection_test_loader(cfg, dataset_name)
112
+ myeval = COCOEvaluator(
113
+ dataset_name, tasks={"bbox", "segm"}, output_dir=output_path
114
+ ) # produces _coco_format.json when initialized
115
+ for mdl in ("fold1", "fold2", "fold3", "fold4", "fold5"):
116
+ extract_directory = "../models"
117
+ file_name = mdl + "_model_final.pth"
118
+ model_weights_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), extract_directory, file_name)
119
+ print(model_weights_path)
120
+ DetectionCheckpointer(model).load(model_weights_path) # load a file, usually from cfg.MODEL.WEIGHTS
121
+ model.eval() # set model in evaluation mode
122
+ myeval.reset()
123
+ output_dir = os.path.join(output_path, mdl)
124
+ myeval._output_dir = output_dir
125
+ print("Running inference with model ", mdl)
126
+ _ = inference_on_dataset(
127
+ model, myloader, myeval
128
+ ) # produces coco_instance_results.json when myeval.evaluate is called
129
+ print("Done with predictions!")
130
+
131
+
132
+ def run_ensemble(dataset_name, output_path, iou_thresh=0.2):
133
+ """Ensembles predictions from multiple models using NMS.
134
+
135
+ It initializes an `Ensembler`, runs the non-maximum suppression logic, and
136
+ saves the final combined predictions to a single COCO results file.
137
+
138
+ Args:
139
+ dataset_name (str): The name of the dataset.
140
+ output_path (str): The base directory containing the individual model
141
+ prediction subdirectories.
142
+ iou_thresh (float, optional): The IoU threshold for ensembling. Defaults to 0.2.
143
+
144
+ Returns:
145
+ Ensembler: The ensembler instance after running NMS.
146
+ """
147
+ ens = Ensembler(output_path, dataset_name, ["fold1", "fold2", "fold3", "fold4", "fold5"], iou_thresh=iou_thresh)
148
+ ens.mean_score_nms()
149
+ ens.save_coco_instances()
150
+ return ens
151
+
152
+
153
+ def evaluate_dataset(dataset_name, output_path, iou_thresh=0.2, prob_thresh=0.5):
154
+ """Evaluates the final ensembled predictions against ground truth.
155
+
156
+ It uses the custom `EvaluateClass` to calculate performance metrics and saves
157
+ a summary to a JSON file.
158
+
159
+ Args:
160
+ dataset_name (str): The name of the dataset.
161
+ output_path (str): The directory containing the ensembled predictions file.
162
+ iou_thresh (float, optional): The IoU threshold for evaluation. Defaults to 0.2.
163
+ prob_thresh (float, optional): The probability threshold for evaluation. Defaults to 0.5.
164
+
165
+ Returns:
166
+ EvaluateClass: The evaluation object containing detailed metrics.
167
+ """
168
+ myeval = EvaluateClass(dataset_name, output_path, iou_thresh=iou_thresh, prob_thresh=prob_thresh, evalsuper=False)
169
+ myeval.evaluate()
170
+ with open(os.path.join(output_path, "scalar_dict.json"), "w") as outfile:
171
+ json.dump(obj=myeval.summarize_scalars(), fp=outfile)
172
+ return myeval
173
+
174
+
175
+ def create_table(myeval):
176
+ """Creates a DataFrame of per-image statistics from evaluation results.
177
+
178
+ Args:
179
+ myeval (EvaluateClass): The evaluation object containing COCO results.
180
+
181
+ Returns:
182
+ CreatePlotsRPD: An object containing DataFrames for image and volume stats.
183
+ """
184
+ dataset_table = CreatePlotsRPD.initfromcoco(myeval.mycoco, myeval.prob_thresh)
185
+ dataset_table.dfimg.sort_index(inplace=True)
186
+ return dataset_table
187
+ # dataset_table.dfimg['scan'] = dataset_table.dfimg['scan'].astype('int') #depends on what we want scan field to be
188
+
189
+
190
+ def output_vol_predictions(dataset_table, vis, volid, output_path, output_mode="pred_overlay"):
191
+ """Generates and saves visualization TIFFs for a single scan volume.
192
+
193
+ Args:
194
+ dataset_table (CreatePlotsRPD): Object containing the image/volume stats.
195
+ vis (OutputVis): The visualization object.
196
+ volid (str): The ID of the volume to visualize.
197
+ output_path (str): The directory to save the output TIFF file.
198
+ output_mode (str, optional): The type of visualization to create.
199
+ Options: "pred_overlay", "pred_only", "originals", "all".
200
+ Defaults to "pred_overlay".
201
+ """
202
+ dfimg = dataset_table.dfimg
203
+ imgids = dfimg[dfimg["volID"] == volid].sort_index().index.values
204
+ outname = os.path.join(output_path, f"{volid}.tiff")
205
+ if output_mode == "pred_overlay":
206
+ vis.output_pred_to_tiff(imgids, outname, pred_only=False)
207
+ elif output_mode == "pred_only":
208
+ vis.output_pred_to_tiff(imgids, outname, pred_only=True)
209
+ elif output_mode == "originals":
210
+ vis.output_ori_to_tiff(imgids, outname)
211
+ elif output_mode == "all":
212
+ vis.output_all_to_tiff(imgids, outname)
213
+ else:
214
+ print(f"Invalid mode {output_mode} for function output_vol_predictions.")
215
+
216
+
217
+ def output_dataset_predictions(dataset_table, vis, output_path, output_mode="pred_overlay", draw_mode="default"):
218
+ """Generates and saves visualization TIFFs for all volumes in a dataset.
219
+
220
+ Args:
221
+ dataset_table (CreatePlotsRPD): Object containing the image/volume stats.
222
+ vis (OutputVis): The visualization object.
223
+ output_path (str): The base directory to save the output TIFF files.
224
+ output_mode (str, optional): The type of visualization to create.
225
+ Defaults to "pred_overlay".
226
+ draw_mode (str, optional): The drawing style ("default" or "bw").
227
+ Defaults to "default".
228
+ """
229
+ vis.set_draw_mode(draw_mode)
230
+ os.makedirs(output_path, exist_ok=True)
231
+ for volid in dataset_table.dfvol.index:
232
+ output_vol_predictions(dataset_table, vis, volid, output_path, output_mode)
233
+
234
+
235
+ def create_dfvol(dataset_name, output_path, dataset_table):
236
+ """Creates and saves a styled HTML table of volume-level statistics.
237
+
238
+ Args:
239
+ dataset_name (str): The name of the dataset.
240
+ output_path (str): The directory to save the HTML file.
241
+ dataset_table (CreatePlotsRPD): Object containing the volume DataFrame.
242
+ """
243
+ dfvol = dataset_table.dfvol.sort_values(by=["dt_instances"], ascending=False)
244
+ with pd.option_context("styler.render.max_elements", int(dfvol.size) + 1):
245
+ html_str = dfvol.style.format("{:.0f}").set_table_styles(styles).to_html()
246
+ html_file = open(os.path.join(output_path, "dfvol_" + dataset_name + ".html"), "w")
247
+ html_file.write(html_str)
248
+ html_file.close()
249
+
250
+
251
+ def create_dfimg(dataset_name, output_path, dataset_table):
252
+ """Creates and saves a styled HTML table of image-level statistics.
253
+
254
+ Args:
255
+ dataset_name (str): The name of the dataset.
256
+ output_path (str): The directory to save the HTML file.
257
+ dataset_table (CreatePlotsRPD): Object containing the image DataFrame.
258
+ """
259
+ dfimg = dataset_table.dfimg.sort_index()
260
+ with pd.option_context("styler.render.max_elements", int(dfimg.size) + 1):
261
+ html_str = dfimg.style.set_table_styles(styles).to_html()
262
+ html_file = open(os.path.join(output_path, "dfimg_" + dataset_name + ".html"), "w")
263
+ html_file.write(html_str)
264
+ html_file.close()
265
+
266
+
267
+ def main(args):
268
+ """Main function to orchestrate the end-to-end analysis pipeline.
269
+
270
+ This function controls the flow from data extraction to evaluation and
271
+ visualization based on the provided arguments.
272
+
273
+ Args:
274
+ args (dict): A dictionary of command-line arguments and flags
275
+ controlling the pipeline execution.
276
+ """
277
+ print(f"Received arguments: {args}")
278
+
279
+ # Unpack arguments from the dictionary with default values
280
+ dataset_name = args.get("dataset_name")
281
+ input_dir = args.get("input_dir")
282
+ extracted_dir = args.get("extracted_dir")
283
+ input_format = args.get("input_format")
284
+ output_dir = args.get("output_dir")
285
+ run_extract = args.get("run_extract", True)
286
+ make_dataset = args.get("create_dataset", True)
287
+ run_inference = args.get("run_inference", True)
288
+ prob_thresh = args.get("prob_thresh", 0.5)
289
+ iou_thresh = args.get("iou_thresh", 0.2)
290
+ create_tables = args.get("create_tables", True)
291
+
292
+ # Visualization flags
293
+ bm = args.get("binary_mask", False)
294
+ bmo = args.get("binary_mask_overlay", False)
295
+ imo = args.get("instance_mask_overlay", False)
296
+ make_visuals = bm or bmo or imo
297
+
298
+ # --- Pipeline Steps ---
299
+ if run_extract:
300
+ os.makedirs(extracted_dir, exist_ok=True)
301
+ print("Starting file extraction...")
302
+ data.extract_files(input_dir, extracted_dir, input_format)
303
+ print("Image extraction complete!")
304
+ if make_dataset:
305
+ print("Creating dataset from extracted images...")
306
+ create_dataset(dataset_name, extracted_dir)
307
+ if run_inference:
308
+ print("Configuring model...")
309
+ cfg = configure_model()
310
+ print("Registering dataset...")
311
+ register_dataset(dataset_name)
312
+ os.makedirs(output_dir, exist_ok=True)
313
+ print("Running inference...")
314
+ run_prediction(cfg, dataset_name, output_dir)
315
+ print("Inference complete, running ensemble...")
316
+ run_ensemble(dataset_name, output_dir, iou_thresh)
317
+ print("Ensemble complete!")
318
+ if create_tables or make_visuals:
319
+ print("Registering dataset for evaluation...")
320
+ register_dataset(dataset_name)
321
+ print("Evaluating dataset...")
322
+ eval_obj = evaluate_dataset(dataset_name, output_dir, iou_thresh, prob_thresh)
323
+ print("Creating dataset table...")
324
+ table = create_table(eval_obj)
325
+ if create_tables:
326
+ create_dfvol(dataset_name, output_dir, table)
327
+ create_dfimg(dataset_name, output_dir, table)
328
+ print("Dataset HTML tables complete!")
329
+ if make_visuals:
330
+ print("Initializing visualizer...")
331
+ vis = OutputVis(
332
+ dataset_name,
333
+ prob_thresh=eval_obj.prob_thresh,
334
+ pred_mode="file",
335
+ pred_file=os.path.join(output_dir, "coco_instances_results.json"),
336
+ has_annotations=False, # Assuming we are visualizing on test data without GT
337
+ )
338
+ vis.scale = 1.0 # Use original scale for output visuals
339
+ if bm:
340
+ print("Creating binary masks TIFF (no overlay)...")
341
+ vis.annotation_color = "w"
342
+ output_dataset_predictions(
343
+ table, vis, os.path.join(output_dir, "predicted_binary_masks"), "pred_only", "bw"
344
+ )
345
+ if bmo:
346
+ print("Creating binary masks TIFF (with overlay)...")
347
+ output_dataset_predictions(
348
+ table, vis, os.path.join(output_dir, "predicted_binary_overlays"), "pred_overlay", "bw"
349
+ )
350
+ if imo:
351
+ print("Creating instance masks TIFF (with overlay)...")
352
+ output_dataset_predictions(
353
+ table, vis, os.path.join(output_dir, "predicted_instance_overlays"), "pred_overlay", "default"
354
+ )
355
+ print("Visualizations complete!")
scripts/mask_rcnn_X_101_32x8d_FPN_1x.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-RCNN-FPN.yaml"
2
+ MODEL:
3
+ WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl"
4
+ PIXEL_STD: [57.375, 57.120, 58.395]
5
+ MASK_ON: True
6
+ RESNETS:
7
+ STRIDE_IN_1X1: False # this is a C2 model
8
+ NUM_GROUPS: 32
9
+ WIDTH_PER_GROUP: 8
10
+ DEPTH: 101
11
+ ROI_HEADS:
12
+ NUM_CLASSES: 1
13
+ SCORE_THRESH_TEST: 0.001
14
+ NMS_THRESH_TEST: .01
15
+ INPUT:
16
+ MIN_SIZE_TRAIN: (496,)
17
+ MIN_SIZE_TEST: 496
18
+ SOLVER:
19
+ BASE_LR: 0.02
20
+ #GAMMA: 0.05
21
+ #STEPS: (3000, 7000, 11000, 15000)
22
+ #MAX_ITER: 18000
23
+ GAMMA: 0.1
24
+ STEPS: (3000, 4500)
25
+ MAX_ITER: 6000
26
+ CHECKPOINT_PERIOD: 300
27
+ IMS_PER_BATCH: 14
28
+ TEST:
29
+ DETECTIONS_PER_IMAGE: 30 # LVIS allows up to 300
30
+ EVAL_PERIOD: 300
31
+ DATALOADER:
32
+ SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
33
+ REPEAT_THRESHOLD: 0.001
34
+ NUM_WORKERS: 4
35
+ # DATASETS:
36
+ # TRAIN: ("fold1","fold2","fold3","fold4",)
37
+ # TEST: ("fold5",)
38
+ # OUTPUT_DIR: "./output_valid_fold5"
39
+ # DATASETS:
40
+ # TRAIN: ("fold2","fold3","fold4","fold5",)
41
+ # TEST: ("fold1",)
42
+ # OUTPUT_DIR: "./output_valid_fold1"
43
+ # DATASETS:
44
+ # TRAIN: ("fold3","fold4","fold5","fold1",)
45
+ # TEST: ("fold2",)
46
+ # OUTPUT_DIR: "./output_valid_fold2"
47
+ # DATASETS:
48
+ # TRAIN: ("fold4","fold5","fold1","fold2",)
49
+ # TEST: ("fold3",)
50
+ # OUTPUT_DIR: "./output_valid_fold3"
51
+ # DATASETS:
52
+ # TRAIN: ("fold5","fold1","fold2","fold3",)
53
+ # TEST: ("fold4",)
54
+ # OUTPUT_DIR: "./output_valid_fold4"
55
+
56
+ #modifiying to commit again
scripts/table_styles.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def hover(hover_color="#add8e6"):
2
+ return dict(selector="tbody tr:hover", props=[("background-color", "%s" % hover_color)])
3
+
4
+
5
+ styles = [
6
+ # table properties
7
+ dict(
8
+ selector=" ",
9
+ props=[
10
+ ("margin", "0"),
11
+ ("font-family", '"Helvetica", "Arial", sans-serif'),
12
+ ("border-collapse", "collapse"),
13
+ ("border", "none"),
14
+ ("border", "2px solid #ccf"),
15
+ ],
16
+ ),
17
+ # #header color - optional
18
+ # dict(selector="thead",
19
+ # props=[("background-color","#cc8484")
20
+ # ]),
21
+ # background shading
22
+ dict(selector="tbody tr:nth-child(even)", props=[("background-color", "#fff")]),
23
+ dict(selector="tbody tr:nth-child(odd)", props=[("background-color", "#eee")]),
24
+ # cell spacing
25
+ dict(selector="td", props=[("padding", ".5em"), ("text-align", "center")]),
26
+ # header cell properties
27
+ dict(selector="th", props=[("font-size", "125%"), ("text-align", "center")]),
28
+ # caption placement
29
+ dict(selector="caption", props=[("caption-side", "bottom")]),
30
+ # render hover last to override background-color
31
+ hover(),
32
+ ]
scripts/working.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-RCNN-FPN.yaml"
2
+ MODEL:
3
+ WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl"
4
+ PIXEL_STD: [57.375, 57.120, 58.395]
5
+ MASK_ON: True
6
+ RESNETS:
7
+ STRIDE_IN_1X1: False # this is a C2 model
8
+ NUM_GROUPS: 32
9
+ WIDTH_PER_GROUP: 8
10
+ DEPTH: 101
11
+ ROI_HEADS:
12
+ NUM_CLASSES: 1
13
+ SCORE_THRESH_TEST: 0.001
14
+ NMS_THRESH_TEST: .01
15
+ INPUT:
16
+ MIN_SIZE_TRAIN: (496,)
17
+ MIN_SIZE_TEST: 496
18
+ SOLVER:
19
+ BASE_LR: 0.02
20
+ #GAMMA: 0.05
21
+ #STEPS: (3000, 7000, 11000, 15000)
22
+ #MAX_ITER: 18000
23
+ GAMMA: 0.1
24
+ STEPS: (3000, 4500)
25
+ MAX_ITER: 6000
26
+ CHECKPOINT_PERIOD: 300
27
+ IMS_PER_BATCH: 14
28
+ TEST:
29
+ DETECTIONS_PER_IMAGE: 30 # LVIS allows up to 300
30
+ EVAL_PERIOD: 300
31
+ DATALOADER:
32
+ SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
33
+ REPEAT_THRESHOLD: 0.001
34
+ NUM_WORKERS: 4
35
+ # DATASETS:
36
+ # TRAIN: ("fold1","fold2","fold3","fold4",)
37
+ # TEST: ("fold5",)
38
+ # OUTPUT_DIR: "./output_valid_fold5"
39
+ # DATASETS:
40
+ # TRAIN: ("fold2","fold3","fold4","fold5",)
41
+ # TEST: ("fold1",)
42
+ # OUTPUT_DIR: "./output_valid_fold1"
43
+ # DATASETS:
44
+ # TRAIN: ("fold3","fold4","fold5","fold1",)
45
+ # TEST: ("fold2",)
46
+ # OUTPUT_DIR: "./output_valid_fold2"
47
+ # DATASETS:
48
+ # TRAIN: ("fold4","fold5","fold1","fold2",)
49
+ # TEST: ("fold3",)
50
+ # OUTPUT_DIR: "./output_valid_fold3"
51
+ # DATASETS:
52
+ # TRAIN: ("fold5","fold1","fold2","fold3",)
53
+ # TEST: ("fold4",)
54
+ # OUTPUT_DIR: "./output_valid_fold4"
55
+
56
+ #modifiying to commit again
scripts/ybpres.mplstyle ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ axes.titlesize : 16
2
+ axes.labelsize : 16
3
+ lines.linewidth : 2
4
+ lines.markersize : 6
5
+ xtick.labelsize : 15
6
+ ytick.labelsize : 15