Image Segmentation
medical
biology
Jose commited on
Commit
1b052a1
·
1 Parent(s): b90f95b

new inference utilities

Browse files
README.md CHANGED
@@ -6,7 +6,7 @@ tags:
6
  - biology
7
  ---
8
 
9
- ## VascX models
10
 
11
  This repository contains the instructions for using the VascX models from the paper [VascX Models: Model Ensembles for Retinal Vascular Analysis from Color Fundus Images](https://arxiv.org/abs/2409.16016).
12
 
@@ -18,7 +18,7 @@ The model weights are in [huggingface](https://huggingface.co/Eyened/vascx).
18
 
19
  <img src="imgs/HRF_04_g_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/HRF_04_g.png" width="240" height="240" style="display:inline">
20
 
21
- ### Installation
22
 
23
  To install the entire fundus analysis pipeline including fundus preprocessing, model inference code and vascular biomarker extraction:
24
 
@@ -26,8 +26,116 @@ To install the entire fundus analysis pipeline including fundus preprocessing, m
26
 
27
  2. Install the [rtnls_inference package](https://github.com/Eyened/retinalysis-inference).
28
 
 
 
 
 
 
29
  ### Usage
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  To speed up re-execution of vascx we recommend to run the preprocessing and segmentation steps separately:
32
 
33
  1. Preprocessing. See [this notebook](./notebooks/0_preprocess.ipynb). This step is CPU-heavy and benefits from parallelization (see notebook).
 
6
  - biology
7
  ---
8
 
9
+ # VascX models
10
 
11
  This repository contains the instructions for using the VascX models from the paper [VascX Models: Model Ensembles for Retinal Vascular Analysis from Color Fundus Images](https://arxiv.org/abs/2409.16016).
12
 
 
18
 
19
  <img src="imgs/HRF_04_g_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/HRF_04_g.png" width="240" height="240" style="display:inline">
20
 
21
+ ## Installation
22
 
23
  To install the entire fundus analysis pipeline including fundus preprocessing, model inference code and vascular biomarker extraction:
24
 
 
26
 
27
  2. Install the [rtnls_inference package](https://github.com/Eyened/retinalysis-inference).
28
 
29
+
30
+ ## `vascx run` Command
31
+
32
+ The `run` command provides a comprehensive pipeline for processing fundus images, performing various analyses, and creating visualizations.
33
+
34
  ### Usage
35
 
36
+ ```bash
37
+ vascx run DATA_PATH OUTPUT_PATH [OPTIONS]
38
+ ```
39
+
40
+ ### Arguments
41
+
42
+ - `DATA_PATH`: Path to input data. Can be either:
43
+ - A directory containing fundus images
44
+ - A CSV file with a 'path' column containing paths to images
45
+
46
+ - `OUTPUT_PATH`: Directory where processed results will be stored
47
+
48
+ ### Options
49
+
50
+ | Option | Default | Description |
51
+ |--------|---------|-------------|
52
+ | `--preprocess/--no-preprocess` | `--preprocess` | Run preprocessing to standardize images for model input |
53
+ | `--vessels/--no-vessels` | `--vessels` | Run vessel segmentation and artery-vein classification |
54
+ | `--disc/--no-disc` | `--disc` | Run optic disc segmentation |
55
+ | `--quality/--no-quality` | `--quality` | Run image quality assessment |
56
+ | `--fovea/--no-fovea` | `--fovea` | Run fovea detection |
57
+ | `--overlay/--no-overlay` | `--overlay` | Create visualization overlays combining all results |
58
+ | `--n_jobs` | `4` | Number of preprocessing workers for parallel processing |
59
+
60
+ ### Output Structure
61
+
62
+ When run with default options, the command creates the following structure in `OUTPUT_PATH`:
63
+
64
+ ```
65
+ OUTPUT_PATH/
66
+ ├── preprocessed_rgb/ # Standardized fundus images
67
+ ├── vessels/ # Vessel segmentation results
68
+ ├── artery_vein/ # Artery-vein classification
69
+ ├── disc/ # Optic disc segmentation
70
+ ├── overlays/ # Visualization images
71
+ ├── bounds.csv # Image boundary information
72
+ ├── quality.csv # Image quality scores
73
+ └── fovea.csv # Fovea coordinates
74
+ ```
75
+
76
+ ### Processing Stages
77
+
78
+ 1. **Preprocessing**:
79
+ - Standardizes input images for consistent analysis
80
+ - Outputs preprocessed images and boundary information
81
+
82
+ 2. **Quality Assessment**:
83
+ - Evaluates image quality with three quality metrics (q1, q2, q3)
84
+ - Higher scores indicate better image quality
85
+
86
+ 3. **Vessel Segmentation and Artery-Vein Classification**:
87
+ - Identifies blood vessels in the retina
88
+ - Classifies vessels as arteries (1) or veins (2) with intersections (3)
89
+
90
+ 4. **Optic Disc Segmentation**:
91
+ - Identifies the optic disc location and boundaries
92
+
93
+ 5. **Fovea Detection**:
94
+ - Determines the coordinates of the fovea (center of vision)
95
+
96
+ 6. **Visualization Overlays**:
97
+ - Creates color-coded images showing:
98
+ - Arteries in red
99
+ - Veins in blue
100
+ - Optic disc in white
101
+ - Fovea marked with yellow X
102
+
103
+ ### Examples
104
+
105
+ **Process a directory of images with all analyses:**
106
+ ```bash
107
+ vascx run /path/to/images /path/to/output
108
+ ```
109
+
110
+ **Process specific images listed in a CSV:**
111
+ ```bash
112
+ vascx run /path/to/image_list.csv /path/to/output
113
+ ```
114
+
115
+ **Only run preprocessing and vessel segmentation:**
116
+ ```bash
117
+ vascx run /path/to/images /path/to/output --no-disc --no-quality --no-fovea --no-overlay
118
+ ```
119
+
120
+ **Skip preprocessing on already preprocessed images:**
121
+ ```bash
122
+ vascx run /path/to/preprocessed/images /path/to/output --no-preprocess
123
+ ```
124
+
125
+ **Increase parallel processing workers:**
126
+ ```bash
127
+ vascx run /path/to/images /path/to/output --n_jobs 8
128
+ ```
129
+
130
+ ### Notes
131
+
132
+ - The CSV input must contain a 'path' column with image file paths
133
+ - If the CSV includes an 'id' column, these IDs will be used instead of filenames
134
+ - When `--no-preprocess` is used, input images must already be in the proper format
135
+ - The overlay visualization requires at least one analysis component to be enabled
136
+
137
+ ###
138
+
139
  To speed up re-execution of vascx we recommend to run the preprocessing and segmentation steps separately:
140
 
141
  1. Preprocessing. See [this notebook](./notebooks/0_preprocess.ipynb). This step is CPU-heavy and benefits from parallelization (see notebook).
notebooks/0_preprocess.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  "\n",
11
  "import pandas as pd\n",
12
  "\n",
13
- "from rtnls_fundusprep.utils import preprocess_for_inference"
14
  ]
15
  },
16
  {
@@ -58,16 +58,30 @@
58
  "output_type": "stream",
59
  "text": [
60
  "0it [00:00, ?it/s][Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n",
61
- "6it [00:00, 143.58it/s]\n",
62
- "[Parallel(n_jobs=4)]: Done 2 out of 6 | elapsed: 2.1s remaining: 4.2s\n",
63
- "[Parallel(n_jobs=4)]: Done 3 out of 6 | elapsed: 2.1s remaining: 2.1s\n",
64
- "[Parallel(n_jobs=4)]: Done 4 out of 6 | elapsed: 2.9s remaining: 1.4s\n",
65
- "[Parallel(n_jobs=4)]: Done 6 out of 6 | elapsed: 4.3s finished\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ]
67
  }
68
  ],
69
  "source": [
70
- "bounds = preprocess_for_inference(\n",
71
  " files, # List of image files\n",
72
  " rgb_path=ds_path / \"rgb\", # Output path for RGB images\n",
73
  " ce_path=ds_path / \"ce\", # Output path for Contrast Enhanced images\n",
@@ -102,7 +116,7 @@
102
  ],
103
  "metadata": {
104
  "kernelspec": {
105
- "display_name": "base",
106
  "language": "python",
107
  "name": "python3"
108
  },
 
10
  "\n",
11
  "import pandas as pd\n",
12
  "\n",
13
+ "from rtnls_fundusprep.preprocessor import parallel_preprocess"
14
  ]
15
  },
16
  {
 
58
  "output_type": "stream",
59
  "text": [
60
  "0it [00:00, ?it/s][Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n",
61
+ "6it [00:00, 154.80it/s]\n"
62
+ ]
63
+ },
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "Error with image ../samples/fundus/original/HRF_07_dr.jpg\n",
69
+ "Error with image ../samples/fundus/original/HRF_04_g.jpg\n"
70
+ ]
71
+ },
72
+ {
73
+ "name": "stderr",
74
+ "output_type": "stream",
75
+ "text": [
76
+ "[Parallel(n_jobs=4)]: Done 2 out of 6 | elapsed: 0.9s remaining: 1.8s\n",
77
+ "[Parallel(n_jobs=4)]: Done 3 out of 6 | elapsed: 1.5s remaining: 1.5s\n",
78
+ "[Parallel(n_jobs=4)]: Done 4 out of 6 | elapsed: 1.5s remaining: 0.8s\n",
79
+ "[Parallel(n_jobs=4)]: Done 6 out of 6 | elapsed: 1.6s finished\n"
80
  ]
81
  }
82
  ],
83
  "source": [
84
+ "bounds = parallel_preprocess(\n",
85
  " files, # List of image files\n",
86
  " rgb_path=ds_path / \"rgb\", # Output path for RGB images\n",
87
  " ce_path=ds_path / \"ce\", # Output path for Contrast Enhanced images\n",
 
116
  ],
117
  "metadata": {
118
  "kernelspec": {
119
+ "display_name": "retinalysis",
120
  "language": "python",
121
  "name": "python3"
122
  },
setup.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ with open("README.md", "r") as fh:
4
+ long_description = fh.read()
5
+
6
+ setup(
7
+ name="vascx_models",
8
+ # using versioneer for versioning using git tags
9
+ # https://github.com/python-versioneer/python-versioneer/blob/master/INSTALL.md
10
+ # version=versioneer.get_version(),
11
+ # cmdclass=versioneer.get_cmdclass(),
12
+ author="Jose Vargas",
13
+ author_email="[email protected]",
14
+ description="Retinal analysis toolbox for Python",
15
+ long_description=long_description,
16
+ long_description_content_type="text/markdown",
17
+ packages=find_packages(),
18
+ include_package_data=True,
19
+ zip_safe=False,
20
+ entry_points={
21
+ "console_scripts": [
22
+ "vascx = vascx_models.cli:cli",
23
+ ]
24
+ },
25
+ install_requires=[
26
+ "numpy == 1.*",
27
+ "pandas == 2.*",
28
+ "tqdm == 4.*",
29
+ "Pillow == 9.*",
30
+ "click==8.*",
31
+ ],
32
+ python_requires=">=3.10, <3.11",
33
+ )
vascx_models/cli.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import click
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from rtnls_fundusprep.cli import _run_preprocessing
8
+
9
+ from .inference import (
10
+ run_fovea_detection,
11
+ run_quality_estimation,
12
+ run_segmentation_disc,
13
+ run_segmentation_vessels_and_av,
14
+ )
15
+ from .utils import batch_create_overlays
16
+
17
+
18
+ @click.group(name="vascx")
19
+ def cli():
20
+ pass
21
+
22
+
23
+ @cli.command()
24
+ @click.argument("data_path", type=click.Path(exists=True))
25
+ @click.argument("output_path", type=click.Path())
26
+ @click.option(
27
+ "--preprocess/--no-preprocess",
28
+ default=True,
29
+ help="Run preprocessing or use preprocessed images",
30
+ )
31
+ @click.option(
32
+ "--vessels/--no-vessels", default=True, help="Run vessels and AV segmentation"
33
+ )
34
+ @click.option("--disc/--no-disc", default=True, help="Run optic disc segmentation")
35
+ @click.option(
36
+ "--quality/--no-quality", default=True, help="Run image quality estimation"
37
+ )
38
+ @click.option("--fovea/--no-fovea", default=True, help="Run fovea detection")
39
+ @click.option(
40
+ "--overlay/--no-overlay", default=True, help="Create visualization overlays"
41
+ )
42
+ @click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers")
43
+ def run(
44
+ data_path, output_path, preprocess, vessels, disc, quality, fovea, overlay, n_jobs
45
+ ):
46
+ """Run the complete inference pipeline on fundus images.
47
+
48
+ DATA_PATH is either a directory containing images or a CSV file with 'path' column.
49
+ OUTPUT_PATH is the directory where results will be stored.
50
+ """
51
+
52
+ output_path = Path(output_path)
53
+ output_path.mkdir(exist_ok=True, parents=True)
54
+
55
+ # Setup output directories
56
+ preprocess_rgb_path = output_path / "preprocessed_rgb"
57
+ vessels_path = output_path / "vessels"
58
+ av_path = output_path / "artery_vein"
59
+ disc_path = output_path / "disc"
60
+ overlay_path = output_path / "overlays"
61
+
62
+ # Create required directories
63
+ if preprocess:
64
+ preprocess_rgb_path.mkdir(exist_ok=True, parents=True)
65
+ if vessels:
66
+ av_path.mkdir(exist_ok=True, parents=True)
67
+ vessels_path.mkdir(exist_ok=True, parents=True)
68
+ if disc:
69
+ disc_path.mkdir(exist_ok=True, parents=True)
70
+ if overlay:
71
+ overlay_path.mkdir(exist_ok=True, parents=True)
72
+
73
+ bounds_path = output_path / "bounds.csv" if preprocess else None
74
+ quality_path = output_path / "quality.csv" if quality else None
75
+ fovea_path = output_path / "fovea.csv" if fovea else None
76
+
77
+ # Determine if input is a folder or CSV file
78
+ data_path = Path(data_path)
79
+ is_csv = data_path.suffix.lower() == ".csv"
80
+
81
+ # Get files to process
82
+ files = []
83
+ ids = None
84
+
85
+ if is_csv:
86
+ click.echo(f"Reading file paths from CSV: {data_path}")
87
+ try:
88
+ df = pd.read_csv(data_path)
89
+ if "path" not in df.columns:
90
+ click.echo("Error: CSV must contain a 'path' column")
91
+ return
92
+
93
+ # Get file paths and convert to Path objects
94
+ files = [Path(p) for p in df["path"]]
95
+
96
+ if "id" in df.columns:
97
+ ids = df["id"].tolist()
98
+ click.echo("Using IDs from CSV 'id' column")
99
+
100
+ except Exception as e:
101
+ click.echo(f"Error reading CSV file: {e}")
102
+ return
103
+ else:
104
+ click.echo(f"Finding files in directory: {data_path}")
105
+ files = list(data_path.glob("*"))
106
+ ids = [f.stem for f in files]
107
+
108
+ if not files:
109
+ click.echo("No files found to process")
110
+ return
111
+
112
+ click.echo(f"Found {len(files)} files to process")
113
+
114
+ # Step 1: Preprocess images if requested
115
+ if preprocess:
116
+ click.echo("Running preprocessing...")
117
+ _run_preprocessing(
118
+ files=files,
119
+ ids=ids,
120
+ rgb_path=preprocess_rgb_path,
121
+ bounds_path=bounds_path,
122
+ n_jobs=n_jobs,
123
+ )
124
+ # Use the preprocessed images for subsequent steps
125
+ preprocessed_files = list(preprocess_rgb_path.glob("*.png"))
126
+ else:
127
+ # Use the input files directly
128
+ preprocessed_files = files
129
+ ids = [f.stem for f in preprocessed_files]
130
+
131
+ # Set up GPU device
132
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
133
+ click.echo(f"Using device: {device}")
134
+
135
+ # Step 2: Run quality estimation if requested
136
+ if quality:
137
+ click.echo("Running quality estimation...")
138
+ df_quality = run_quality_estimation(
139
+ fpaths=preprocessed_files, ids=ids, device=device
140
+ )
141
+ df_quality.to_csv(quality_path)
142
+ click.echo(f"Quality results saved to {quality_path}")
143
+
144
+ # Step 3: Run vessels and AV segmentation if requested
145
+ if vessels:
146
+ click.echo("Running vessels and AV segmentation...")
147
+ run_segmentation_vessels_and_av(
148
+ rgb_paths=preprocessed_files,
149
+ ids=ids,
150
+ av_path=av_path,
151
+ vessels_path=vessels_path,
152
+ device=device,
153
+ )
154
+ click.echo(f"Vessel segmentation saved to {vessels_path}")
155
+ click.echo(f"AV segmentation saved to {av_path}")
156
+
157
+ # Step 4: Run optic disc segmentation if requested
158
+ if disc:
159
+ click.echo("Running optic disc segmentation...")
160
+ run_segmentation_disc(
161
+ rgb_paths=preprocessed_files, ids=ids, output_path=disc_path, device=device
162
+ )
163
+ click.echo(f"Disc segmentation saved to {disc_path}")
164
+
165
+ # Step 5: Run fovea detection if requested
166
+ df_fovea = None
167
+ if fovea:
168
+ click.echo("Running fovea detection...")
169
+ df_fovea = run_fovea_detection(
170
+ rgb_paths=preprocessed_files, ids=ids, device=device
171
+ )
172
+ df_fovea.to_csv(fovea_path)
173
+ click.echo(f"Fovea detection results saved to {fovea_path}")
174
+
175
+ # Step 6: Create overlays if requested
176
+ if overlay:
177
+ click.echo("Creating visualization overlays...")
178
+
179
+ # Prepare fovea data if available
180
+ fovea_data = None
181
+ if df_fovea is not None:
182
+ fovea_data = {
183
+ idx: (row["x_fovea"], row["y_fovea"])
184
+ for idx, row in df_fovea.iterrows()
185
+ }
186
+
187
+ # Create visualization overlays
188
+ batch_create_overlays(
189
+ rgb_dir=preprocess_rgb_path if preprocess else data_path,
190
+ output_dir=overlay_path,
191
+ av_dir=av_path,
192
+ disc_dir=disc_path,
193
+ fovea_data=fovea_data,
194
+ )
195
+
196
+ click.echo(f"Visualization overlays saved to {overlay_path}")
197
+
198
+ click.echo(f"All requested processing complete. Results saved to {output_path}")
vascx_models/inference.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+ from rtnls_inference.ensembles.ensemble_classification import ClassificationEnsemble
12
+ from rtnls_inference.ensembles.ensemble_heatmap_regression import (
13
+ HeatmapRegressionEnsemble,
14
+ )
15
+ from rtnls_inference.ensembles.ensemble_segmentation import SegmentationEnsemble
16
+ from rtnls_inference.utils import decollate_batch, extract_keypoints_from_heatmaps
17
+
18
+
19
+ def run_quality_estimation(fpaths, ids, device: torch.device):
20
+ ensemble_quality = ClassificationEnsemble.from_release("quality.pt").to(device)
21
+ dataloader = ensemble_quality._make_inference_dataloader(
22
+ fpaths,
23
+ ids=ids,
24
+ num_workers=8,
25
+ preprocess=False,
26
+ batch_size=16,
27
+ )
28
+
29
+ output_ids, outputs = [], []
30
+ with torch.no_grad():
31
+ for batch in tqdm(dataloader):
32
+ if len(batch) == 0:
33
+ continue
34
+
35
+ im = batch["image"].to(device)
36
+
37
+ # QUALITY
38
+ quality = ensemble_quality.predict_step(im)
39
+ quality = torch.mean(quality, dim=0)
40
+
41
+ items = {"id": batch["id"], "quality": quality}
42
+ items = decollate_batch(items)
43
+
44
+ for item in items:
45
+ output_ids.append(item["id"])
46
+ outputs.append(item["quality"].tolist())
47
+
48
+ return pd.DataFrame(
49
+ outputs,
50
+ index=output_ids,
51
+ columns=["q1", "q2", "q3"],
52
+ )
53
+
54
+
55
+ def run_segmentation_vessels_and_av(
56
+ rgb_paths: List[Path],
57
+ ce_paths: Optional[List[Path]] = None,
58
+ ids: Optional[List[str]] = None,
59
+ av_path: Optional[Path] = None,
60
+ vessels_path: Optional[Path] = None,
61
+ device: torch.device = torch.device(
62
+ "cuda:0" if torch.cuda.is_available() else "cpu"
63
+ ),
64
+ ) -> None:
65
+ """
66
+ Run AV and vessel segmentation on the provided images.
67
+
68
+ Args:
69
+ rgb_paths: List of paths to RGB fundus images
70
+ ce_paths: Optional list of paths to contrast enhanced images
71
+ ids: Optional list of ids to pass to _make_inference_dataloader
72
+ av_path: Folder where to store output AV segmentations
73
+ vessels_path: Folder where to store output vessel segmentations
74
+ device: Device to run inference on
75
+ """
76
+ # Create output directories if they don't exist
77
+ if av_path is not None:
78
+ av_path.mkdir(exist_ok=True, parents=True)
79
+ if vessels_path is not None:
80
+ vessels_path.mkdir(exist_ok=True, parents=True)
81
+
82
+ # Load models
83
+ ensemble_av = SegmentationEnsemble.from_release("av_july24.pt").to(device).eval()
84
+ ensemble_vessels = (
85
+ SegmentationEnsemble.from_release("vessels_july24.pt").to(device).eval()
86
+ )
87
+
88
+ # Prepare input paths
89
+ if ce_paths is None:
90
+ # If CE paths are not provided, use RGB paths for both inputs
91
+ fpaths = rgb_paths
92
+ else:
93
+ # If CE paths are provided, pair them with RGB paths
94
+ if len(rgb_paths) != len(ce_paths):
95
+ raise ValueError("rgb_paths and ce_paths must have the same length")
96
+ fpaths = list(zip(rgb_paths, ce_paths))
97
+
98
+ # Create dataloader
99
+ dataloader = ensemble_av._make_inference_dataloader(
100
+ fpaths,
101
+ ids=ids,
102
+ num_workers=8,
103
+ preprocess=False,
104
+ batch_size=8,
105
+ )
106
+
107
+ # Run inference
108
+ with torch.no_grad():
109
+ for batch in tqdm(dataloader):
110
+ # AV segmentation
111
+ if av_path is not None:
112
+ with torch.autocast(device_type=device.type):
113
+ proba = ensemble_av.forward(batch["image"].to(device))
114
+ proba = torch.mean(proba, dim=0) # average over models
115
+ proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
116
+ proba = torch.nn.functional.softmax(proba, dim=-1)
117
+
118
+ items = {
119
+ "id": batch["id"],
120
+ "image": proba,
121
+ }
122
+
123
+ items = decollate_batch(items)
124
+ for i, item in enumerate(items):
125
+ fpath = os.path.join(av_path, f"{item['id']}.png")
126
+ mask = np.argmax(item["image"], -1)
127
+ Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
128
+
129
+ # Vessel segmentation
130
+ if vessels_path is not None:
131
+ with torch.autocast(device_type=device.type):
132
+ proba = ensemble_vessels.forward(batch["image"].to(device))
133
+ proba = torch.mean(proba, dim=0) # average over models
134
+ proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
135
+ proba = torch.nn.functional.softmax(proba, dim=-1)
136
+
137
+ items = {
138
+ "id": batch["id"],
139
+ "image": proba,
140
+ }
141
+
142
+ items = decollate_batch(items)
143
+ for i, item in enumerate(items):
144
+ fpath = os.path.join(vessels_path, f"{item['id']}.png")
145
+ mask = np.argmax(item["image"], -1)
146
+ Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
147
+
148
+
149
+ def run_segmentation_disc(
150
+ rgb_paths: List[Path],
151
+ ce_paths: Optional[List[Path]] = None,
152
+ ids: Optional[List[str]] = None,
153
+ output_path: Optional[Path] = None,
154
+ device: torch.device = torch.device(
155
+ "cuda:0" if torch.cuda.is_available() else "cpu"
156
+ ),
157
+ ) -> None:
158
+ ensemble_disc = (
159
+ SegmentationEnsemble.from_release("disc_july24.pt").to(device).eval()
160
+ )
161
+
162
+ # Prepare input paths
163
+ if ce_paths is None:
164
+ # If CE paths are not provided, use RGB paths for both inputs
165
+ fpaths = rgb_paths
166
+ else:
167
+ # If CE paths are provided, pair them with RGB paths
168
+ if len(rgb_paths) != len(ce_paths):
169
+ raise ValueError("rgb_paths and ce_paths must have the same length")
170
+ fpaths = list(zip(rgb_paths, ce_paths))
171
+
172
+ dataloader = ensemble_disc._make_inference_dataloader(
173
+ fpaths,
174
+ ids=ids,
175
+ num_workers=8,
176
+ preprocess=False,
177
+ batch_size=8,
178
+ )
179
+
180
+ with torch.no_grad():
181
+ for batch in tqdm(dataloader):
182
+ # AV
183
+ with torch.autocast(device_type=device.type):
184
+ proba = ensemble_disc.forward(batch["image"].to(device))
185
+ proba = torch.mean(proba, dim=0) # average over models
186
+ proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
187
+ proba = torch.nn.functional.softmax(proba, dim=-1)
188
+
189
+ items = {
190
+ "id": batch["id"],
191
+ "image": proba,
192
+ }
193
+
194
+ items = decollate_batch(items)
195
+ items = [dataloader.dataset.transform.undo_item(item) for item in items]
196
+ for i, item in enumerate(items):
197
+ fpath = os.path.join(output_path, f"{item['id']}.png")
198
+
199
+ mask = np.argmax(item["image"], -1)
200
+ Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
201
+
202
+
203
+ def run_fovea_detection(
204
+ rgb_paths: List[Path],
205
+ ce_paths: Optional[List[Path]] = None,
206
+ ids: Optional[List[str]] = None,
207
+ device: torch.device = torch.device(
208
+ "cuda:0" if torch.cuda.is_available() else "cpu"
209
+ ),
210
+ ) -> None:
211
+ # def run_fovea_detection(fpaths, ids, device: torch.device):
212
+ ensemble_fovea = HeatmapRegressionEnsemble.from_release("fovea_july24.pt").to(
213
+ device
214
+ )
215
+
216
+ # Prepare input paths
217
+ if ce_paths is None:
218
+ # If CE paths are not provided, use RGB paths for both inputs
219
+ fpaths = rgb_paths
220
+ else:
221
+ # If CE paths are provided, pair them with RGB paths
222
+ if len(rgb_paths) != len(ce_paths):
223
+ raise ValueError("rgb_paths and ce_paths must have the same length")
224
+ fpaths = list(zip(rgb_paths, ce_paths))
225
+
226
+ dataloader = ensemble_fovea._make_inference_dataloader(
227
+ fpaths,
228
+ ids=ids,
229
+ num_workers=8,
230
+ preprocess=False,
231
+ batch_size=8,
232
+ )
233
+
234
+ output_ids, outputs = [], []
235
+ with torch.no_grad():
236
+ for batch in tqdm(dataloader):
237
+ if len(batch) == 0:
238
+ continue
239
+
240
+ im = batch["image"].to(device)
241
+
242
+ # FOVEA DETECTION
243
+ with torch.autocast(device_type=device.type):
244
+ heatmap = ensemble_fovea.forward(im)
245
+ keypoints = extract_keypoints_from_heatmaps(heatmap)
246
+
247
+ kp_fovea = torch.mean(keypoints, dim=0) # average over models
248
+
249
+ items = {
250
+ "id": batch["id"],
251
+ "keypoints": kp_fovea,
252
+ "metadata": batch["metadata"],
253
+ }
254
+ items = decollate_batch(items)
255
+
256
+ items = [dataloader.dataset.transform.undo_item(item) for item in items]
257
+
258
+ for item in items:
259
+ output_ids.append(item["id"])
260
+ outputs.append(
261
+ [
262
+ *item["keypoints"][0].tolist(),
263
+ ]
264
+ )
265
+ return pd.DataFrame(
266
+ outputs,
267
+ index=output_ids,
268
+ columns=["x_fovea", "y_fovea"],
269
+ )
vascx_models/utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw
6
+
7
+
8
+ def create_fundus_overlay(
9
+ rgb_path: str,
10
+ av_path: Optional[str] = None,
11
+ disc_path: Optional[str] = None,
12
+ fovea_location: Optional[Tuple[int, int]] = None,
13
+ output_path: Optional[str] = None,
14
+ ) -> np.ndarray:
15
+ """
16
+ Create a visualization of a fundus image with overlaid segmentations and markers.
17
+
18
+ Args:
19
+ rgb_path: Path to the RGB fundus image
20
+ av_path: Optional path to artery-vein segmentation (1=artery, 2=vein, 3=intersection)
21
+ disc_path: Optional path to binary disc segmentation
22
+ fovea_location: Optional (x,y) tuple indicating the location of the fovea
23
+ output_path: Optional path to save the visualization image
24
+
25
+ Returns:
26
+ Numpy array containing the visualization image
27
+ """
28
+ print(rgb_path, av_path, disc_path, fovea_location, output_path)
29
+ # Load RGB image
30
+ rgb_img = np.array(Image.open(rgb_path))
31
+
32
+ # Create output image starting with the RGB image
33
+ output_img = rgb_img.copy()
34
+
35
+ # Load and overlay AV segmentation if provided
36
+ if av_path:
37
+ av_mask = np.array(Image.open(av_path))
38
+
39
+ # Create masks for arteries (1), veins (2) and intersections (3)
40
+ artery_mask = av_mask == 1
41
+ vein_mask = av_mask == 2
42
+ intersection_mask = av_mask == 3
43
+
44
+ # Combine artery and intersection for visualization
45
+ artery_combined = np.logical_or(artery_mask, intersection_mask)
46
+ vein_combined = np.logical_or(vein_mask, intersection_mask)
47
+
48
+ # Apply colors: red for arteries, blue for veins
49
+ # Red channel - increase for arteries
50
+ output_img[artery_combined, 0] = 255
51
+ output_img[artery_combined, 1] = 0
52
+ output_img[artery_combined, 2] = 0
53
+
54
+ # Blue channel - increase for veins
55
+ output_img[vein_combined, 0] = 0
56
+ output_img[vein_combined, 1] = 0
57
+ output_img[vein_combined, 2] = 255
58
+
59
+ # Load and overlay optic disc segmentation if provided
60
+ if disc_path:
61
+ disc_mask = np.array(Image.open(disc_path)) > 0
62
+
63
+ # Apply white color for disc
64
+ output_img[disc_mask, :] = [255, 255, 255] # White
65
+
66
+ # Convert to PIL image for drawing the fovea marker
67
+ pil_img = Image.fromarray(output_img)
68
+
69
+ # Add fovea marker if provided
70
+ if fovea_location:
71
+ draw = ImageDraw.Draw(pil_img)
72
+ x, y = fovea_location
73
+ marker_size = (
74
+ min(pil_img.width, pil_img.height) // 50
75
+ ) # Scale marker with image
76
+
77
+ # Draw yellow X at fovea location
78
+ draw.line(
79
+ [(x - marker_size, y - marker_size), (x + marker_size, y + marker_size)],
80
+ fill=(255, 255, 0),
81
+ width=2,
82
+ )
83
+ draw.line(
84
+ [(x - marker_size, y + marker_size), (x + marker_size, y - marker_size)],
85
+ fill=(255, 255, 0),
86
+ width=2,
87
+ )
88
+
89
+ # Convert back to numpy array
90
+ output_img = np.array(pil_img)
91
+
92
+ # Save output if path provided
93
+ if output_path:
94
+ Image.fromarray(output_img).save(output_path)
95
+
96
+ return output_img
97
+
98
+
99
+ def batch_create_overlays(
100
+ rgb_dir: Path,
101
+ output_dir: Path,
102
+ av_dir: Optional[Path] = None,
103
+ disc_dir: Optional[Path] = None,
104
+ fovea_data: Optional[Dict[str, Tuple[int, int]]] = None,
105
+ ) -> None:
106
+ """
107
+ Create visualization overlays for a batch of images.
108
+
109
+ Args:
110
+ rgb_dir: Directory containing RGB fundus images
111
+ output_dir: Directory to save visualization images
112
+ av_dir: Optional directory containing AV segmentations
113
+ disc_dir: Optional directory containing disc segmentations
114
+ fovea_data: Optional dictionary mapping image IDs to fovea coordinates
115
+
116
+ Returns:
117
+ List of paths to created visualization images
118
+ """
119
+ # Create output directory if it doesn't exist
120
+ output_dir.mkdir(exist_ok=True, parents=True)
121
+
122
+ # Get all RGB images
123
+ rgb_files = list(rgb_dir.glob("*.png"))
124
+ if not rgb_files:
125
+ return []
126
+
127
+ # Process each image
128
+ for rgb_file in rgb_files:
129
+ image_id = rgb_file.stem
130
+
131
+ # Check for corresponding AV segmentation
132
+ av_file = None
133
+ if av_dir:
134
+ av_file_path = av_dir / f"{image_id}.png"
135
+ if av_file_path.exists():
136
+ av_file = str(av_file_path)
137
+
138
+ # Check for corresponding disc segmentation
139
+ disc_file = None
140
+ if disc_dir:
141
+ disc_file_path = disc_dir / f"{image_id}.png"
142
+ if disc_file_path.exists():
143
+ disc_file = str(disc_file_path)
144
+
145
+ # Get fovea location if available
146
+ fovea_location = None
147
+ if fovea_data and image_id in fovea_data:
148
+ fovea_location = fovea_data[image_id]
149
+
150
+ # Create output path
151
+ output_file = output_dir / f"{image_id}.png"
152
+
153
+ # Create and save overlay
154
+ create_fundus_overlay(
155
+ rgb_path=str(rgb_file),
156
+ av_path=av_file,
157
+ disc_path=disc_file,
158
+ fovea_location=fovea_location,
159
+ output_path=str(output_file),
160
+ )