Jose
commited on
Commit
·
1b052a1
1
Parent(s):
b90f95b
new inference utilities
Browse files- README.md +110 -2
- notebooks/0_preprocess.ipynb +22 -8
- setup.py +33 -0
- vascx_models/cli.py +198 -0
- vascx_models/inference.py +269 -0
- vascx_models/utils.py +160 -0
README.md
CHANGED
@@ -6,7 +6,7 @@ tags:
|
|
6 |
- biology
|
7 |
---
|
8 |
|
9 |
-
|
10 |
|
11 |
This repository contains the instructions for using the VascX models from the paper [VascX Models: Model Ensembles for Retinal Vascular Analysis from Color Fundus Images](https://arxiv.org/abs/2409.16016).
|
12 |
|
@@ -18,7 +18,7 @@ The model weights are in [huggingface](https://huggingface.co/Eyened/vascx).
|
|
18 |
|
19 |
<img src="imgs/HRF_04_g_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/HRF_04_g.png" width="240" height="240" style="display:inline">
|
20 |
|
21 |
-
|
22 |
|
23 |
To install the entire fundus analysis pipeline including fundus preprocessing, model inference code and vascular biomarker extraction:
|
24 |
|
@@ -26,8 +26,116 @@ To install the entire fundus analysis pipeline including fundus preprocessing, m
|
|
26 |
|
27 |
2. Install the [rtnls_inference package](https://github.com/Eyened/retinalysis-inference).
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
### Usage
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
To speed up re-execution of vascx we recommend to run the preprocessing and segmentation steps separately:
|
32 |
|
33 |
1. Preprocessing. See [this notebook](./notebooks/0_preprocess.ipynb). This step is CPU-heavy and benefits from parallelization (see notebook).
|
|
|
6 |
- biology
|
7 |
---
|
8 |
|
9 |
+
# VascX models
|
10 |
|
11 |
This repository contains the instructions for using the VascX models from the paper [VascX Models: Model Ensembles for Retinal Vascular Analysis from Color Fundus Images](https://arxiv.org/abs/2409.16016).
|
12 |
|
|
|
18 |
|
19 |
<img src="imgs/HRF_04_g_rgb.png" width="240" height="240" style="display:inline"><img src="imgs/HRF_04_g.png" width="240" height="240" style="display:inline">
|
20 |
|
21 |
+
## Installation
|
22 |
|
23 |
To install the entire fundus analysis pipeline including fundus preprocessing, model inference code and vascular biomarker extraction:
|
24 |
|
|
|
26 |
|
27 |
2. Install the [rtnls_inference package](https://github.com/Eyened/retinalysis-inference).
|
28 |
|
29 |
+
|
30 |
+
## `vascx run` Command
|
31 |
+
|
32 |
+
The `run` command provides a comprehensive pipeline for processing fundus images, performing various analyses, and creating visualizations.
|
33 |
+
|
34 |
### Usage
|
35 |
|
36 |
+
```bash
|
37 |
+
vascx run DATA_PATH OUTPUT_PATH [OPTIONS]
|
38 |
+
```
|
39 |
+
|
40 |
+
### Arguments
|
41 |
+
|
42 |
+
- `DATA_PATH`: Path to input data. Can be either:
|
43 |
+
- A directory containing fundus images
|
44 |
+
- A CSV file with a 'path' column containing paths to images
|
45 |
+
|
46 |
+
- `OUTPUT_PATH`: Directory where processed results will be stored
|
47 |
+
|
48 |
+
### Options
|
49 |
+
|
50 |
+
| Option | Default | Description |
|
51 |
+
|--------|---------|-------------|
|
52 |
+
| `--preprocess/--no-preprocess` | `--preprocess` | Run preprocessing to standardize images for model input |
|
53 |
+
| `--vessels/--no-vessels` | `--vessels` | Run vessel segmentation and artery-vein classification |
|
54 |
+
| `--disc/--no-disc` | `--disc` | Run optic disc segmentation |
|
55 |
+
| `--quality/--no-quality` | `--quality` | Run image quality assessment |
|
56 |
+
| `--fovea/--no-fovea` | `--fovea` | Run fovea detection |
|
57 |
+
| `--overlay/--no-overlay` | `--overlay` | Create visualization overlays combining all results |
|
58 |
+
| `--n_jobs` | `4` | Number of preprocessing workers for parallel processing |
|
59 |
+
|
60 |
+
### Output Structure
|
61 |
+
|
62 |
+
When run with default options, the command creates the following structure in `OUTPUT_PATH`:
|
63 |
+
|
64 |
+
```
|
65 |
+
OUTPUT_PATH/
|
66 |
+
├── preprocessed_rgb/ # Standardized fundus images
|
67 |
+
├── vessels/ # Vessel segmentation results
|
68 |
+
├── artery_vein/ # Artery-vein classification
|
69 |
+
├── disc/ # Optic disc segmentation
|
70 |
+
├── overlays/ # Visualization images
|
71 |
+
├── bounds.csv # Image boundary information
|
72 |
+
├── quality.csv # Image quality scores
|
73 |
+
└── fovea.csv # Fovea coordinates
|
74 |
+
```
|
75 |
+
|
76 |
+
### Processing Stages
|
77 |
+
|
78 |
+
1. **Preprocessing**:
|
79 |
+
- Standardizes input images for consistent analysis
|
80 |
+
- Outputs preprocessed images and boundary information
|
81 |
+
|
82 |
+
2. **Quality Assessment**:
|
83 |
+
- Evaluates image quality with three quality metrics (q1, q2, q3)
|
84 |
+
- Higher scores indicate better image quality
|
85 |
+
|
86 |
+
3. **Vessel Segmentation and Artery-Vein Classification**:
|
87 |
+
- Identifies blood vessels in the retina
|
88 |
+
- Classifies vessels as arteries (1) or veins (2) with intersections (3)
|
89 |
+
|
90 |
+
4. **Optic Disc Segmentation**:
|
91 |
+
- Identifies the optic disc location and boundaries
|
92 |
+
|
93 |
+
5. **Fovea Detection**:
|
94 |
+
- Determines the coordinates of the fovea (center of vision)
|
95 |
+
|
96 |
+
6. **Visualization Overlays**:
|
97 |
+
- Creates color-coded images showing:
|
98 |
+
- Arteries in red
|
99 |
+
- Veins in blue
|
100 |
+
- Optic disc in white
|
101 |
+
- Fovea marked with yellow X
|
102 |
+
|
103 |
+
### Examples
|
104 |
+
|
105 |
+
**Process a directory of images with all analyses:**
|
106 |
+
```bash
|
107 |
+
vascx run /path/to/images /path/to/output
|
108 |
+
```
|
109 |
+
|
110 |
+
**Process specific images listed in a CSV:**
|
111 |
+
```bash
|
112 |
+
vascx run /path/to/image_list.csv /path/to/output
|
113 |
+
```
|
114 |
+
|
115 |
+
**Only run preprocessing and vessel segmentation:**
|
116 |
+
```bash
|
117 |
+
vascx run /path/to/images /path/to/output --no-disc --no-quality --no-fovea --no-overlay
|
118 |
+
```
|
119 |
+
|
120 |
+
**Skip preprocessing on already preprocessed images:**
|
121 |
+
```bash
|
122 |
+
vascx run /path/to/preprocessed/images /path/to/output --no-preprocess
|
123 |
+
```
|
124 |
+
|
125 |
+
**Increase parallel processing workers:**
|
126 |
+
```bash
|
127 |
+
vascx run /path/to/images /path/to/output --n_jobs 8
|
128 |
+
```
|
129 |
+
|
130 |
+
### Notes
|
131 |
+
|
132 |
+
- The CSV input must contain a 'path' column with image file paths
|
133 |
+
- If the CSV includes an 'id' column, these IDs will be used instead of filenames
|
134 |
+
- When `--no-preprocess` is used, input images must already be in the proper format
|
135 |
+
- The overlay visualization requires at least one analysis component to be enabled
|
136 |
+
|
137 |
+
###
|
138 |
+
|
139 |
To speed up re-execution of vascx we recommend to run the preprocessing and segmentation steps separately:
|
140 |
|
141 |
1. Preprocessing. See [this notebook](./notebooks/0_preprocess.ipynb). This step is CPU-heavy and benefits from parallelization (see notebook).
|
notebooks/0_preprocess.ipynb
CHANGED
@@ -10,7 +10,7 @@
|
|
10 |
"\n",
|
11 |
"import pandas as pd\n",
|
12 |
"\n",
|
13 |
-
"from rtnls_fundusprep.
|
14 |
]
|
15 |
},
|
16 |
{
|
@@ -58,16 +58,30 @@
|
|
58 |
"output_type": "stream",
|
59 |
"text": [
|
60 |
"0it [00:00, ?it/s][Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n",
|
61 |
-
"6it [00:00,
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
]
|
67 |
}
|
68 |
],
|
69 |
"source": [
|
70 |
-
"bounds =
|
71 |
" files, # List of image files\n",
|
72 |
" rgb_path=ds_path / \"rgb\", # Output path for RGB images\n",
|
73 |
" ce_path=ds_path / \"ce\", # Output path for Contrast Enhanced images\n",
|
@@ -102,7 +116,7 @@
|
|
102 |
],
|
103 |
"metadata": {
|
104 |
"kernelspec": {
|
105 |
-
"display_name": "
|
106 |
"language": "python",
|
107 |
"name": "python3"
|
108 |
},
|
|
|
10 |
"\n",
|
11 |
"import pandas as pd\n",
|
12 |
"\n",
|
13 |
+
"from rtnls_fundusprep.preprocessor import parallel_preprocess"
|
14 |
]
|
15 |
},
|
16 |
{
|
|
|
58 |
"output_type": "stream",
|
59 |
"text": [
|
60 |
"0it [00:00, ?it/s][Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n",
|
61 |
+
"6it [00:00, 154.80it/s]\n"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"name": "stdout",
|
66 |
+
"output_type": "stream",
|
67 |
+
"text": [
|
68 |
+
"Error with image ../samples/fundus/original/HRF_07_dr.jpg\n",
|
69 |
+
"Error with image ../samples/fundus/original/HRF_04_g.jpg\n"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"name": "stderr",
|
74 |
+
"output_type": "stream",
|
75 |
+
"text": [
|
76 |
+
"[Parallel(n_jobs=4)]: Done 2 out of 6 | elapsed: 0.9s remaining: 1.8s\n",
|
77 |
+
"[Parallel(n_jobs=4)]: Done 3 out of 6 | elapsed: 1.5s remaining: 1.5s\n",
|
78 |
+
"[Parallel(n_jobs=4)]: Done 4 out of 6 | elapsed: 1.5s remaining: 0.8s\n",
|
79 |
+
"[Parallel(n_jobs=4)]: Done 6 out of 6 | elapsed: 1.6s finished\n"
|
80 |
]
|
81 |
}
|
82 |
],
|
83 |
"source": [
|
84 |
+
"bounds = parallel_preprocess(\n",
|
85 |
" files, # List of image files\n",
|
86 |
" rgb_path=ds_path / \"rgb\", # Output path for RGB images\n",
|
87 |
" ce_path=ds_path / \"ce\", # Output path for Contrast Enhanced images\n",
|
|
|
116 |
],
|
117 |
"metadata": {
|
118 |
"kernelspec": {
|
119 |
+
"display_name": "retinalysis",
|
120 |
"language": "python",
|
121 |
"name": "python3"
|
122 |
},
|
setup.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import find_packages, setup
|
2 |
+
|
3 |
+
with open("README.md", "r") as fh:
|
4 |
+
long_description = fh.read()
|
5 |
+
|
6 |
+
setup(
|
7 |
+
name="vascx_models",
|
8 |
+
# using versioneer for versioning using git tags
|
9 |
+
# https://github.com/python-versioneer/python-versioneer/blob/master/INSTALL.md
|
10 |
+
# version=versioneer.get_version(),
|
11 |
+
# cmdclass=versioneer.get_cmdclass(),
|
12 |
+
author="Jose Vargas",
|
13 |
+
author_email="[email protected]",
|
14 |
+
description="Retinal analysis toolbox for Python",
|
15 |
+
long_description=long_description,
|
16 |
+
long_description_content_type="text/markdown",
|
17 |
+
packages=find_packages(),
|
18 |
+
include_package_data=True,
|
19 |
+
zip_safe=False,
|
20 |
+
entry_points={
|
21 |
+
"console_scripts": [
|
22 |
+
"vascx = vascx_models.cli:cli",
|
23 |
+
]
|
24 |
+
},
|
25 |
+
install_requires=[
|
26 |
+
"numpy == 1.*",
|
27 |
+
"pandas == 2.*",
|
28 |
+
"tqdm == 4.*",
|
29 |
+
"Pillow == 9.*",
|
30 |
+
"click==8.*",
|
31 |
+
],
|
32 |
+
python_requires=">=3.10, <3.11",
|
33 |
+
)
|
vascx_models/cli.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import click
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from rtnls_fundusprep.cli import _run_preprocessing
|
8 |
+
|
9 |
+
from .inference import (
|
10 |
+
run_fovea_detection,
|
11 |
+
run_quality_estimation,
|
12 |
+
run_segmentation_disc,
|
13 |
+
run_segmentation_vessels_and_av,
|
14 |
+
)
|
15 |
+
from .utils import batch_create_overlays
|
16 |
+
|
17 |
+
|
18 |
+
@click.group(name="vascx")
|
19 |
+
def cli():
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
@cli.command()
|
24 |
+
@click.argument("data_path", type=click.Path(exists=True))
|
25 |
+
@click.argument("output_path", type=click.Path())
|
26 |
+
@click.option(
|
27 |
+
"--preprocess/--no-preprocess",
|
28 |
+
default=True,
|
29 |
+
help="Run preprocessing or use preprocessed images",
|
30 |
+
)
|
31 |
+
@click.option(
|
32 |
+
"--vessels/--no-vessels", default=True, help="Run vessels and AV segmentation"
|
33 |
+
)
|
34 |
+
@click.option("--disc/--no-disc", default=True, help="Run optic disc segmentation")
|
35 |
+
@click.option(
|
36 |
+
"--quality/--no-quality", default=True, help="Run image quality estimation"
|
37 |
+
)
|
38 |
+
@click.option("--fovea/--no-fovea", default=True, help="Run fovea detection")
|
39 |
+
@click.option(
|
40 |
+
"--overlay/--no-overlay", default=True, help="Create visualization overlays"
|
41 |
+
)
|
42 |
+
@click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers")
|
43 |
+
def run(
|
44 |
+
data_path, output_path, preprocess, vessels, disc, quality, fovea, overlay, n_jobs
|
45 |
+
):
|
46 |
+
"""Run the complete inference pipeline on fundus images.
|
47 |
+
|
48 |
+
DATA_PATH is either a directory containing images or a CSV file with 'path' column.
|
49 |
+
OUTPUT_PATH is the directory where results will be stored.
|
50 |
+
"""
|
51 |
+
|
52 |
+
output_path = Path(output_path)
|
53 |
+
output_path.mkdir(exist_ok=True, parents=True)
|
54 |
+
|
55 |
+
# Setup output directories
|
56 |
+
preprocess_rgb_path = output_path / "preprocessed_rgb"
|
57 |
+
vessels_path = output_path / "vessels"
|
58 |
+
av_path = output_path / "artery_vein"
|
59 |
+
disc_path = output_path / "disc"
|
60 |
+
overlay_path = output_path / "overlays"
|
61 |
+
|
62 |
+
# Create required directories
|
63 |
+
if preprocess:
|
64 |
+
preprocess_rgb_path.mkdir(exist_ok=True, parents=True)
|
65 |
+
if vessels:
|
66 |
+
av_path.mkdir(exist_ok=True, parents=True)
|
67 |
+
vessels_path.mkdir(exist_ok=True, parents=True)
|
68 |
+
if disc:
|
69 |
+
disc_path.mkdir(exist_ok=True, parents=True)
|
70 |
+
if overlay:
|
71 |
+
overlay_path.mkdir(exist_ok=True, parents=True)
|
72 |
+
|
73 |
+
bounds_path = output_path / "bounds.csv" if preprocess else None
|
74 |
+
quality_path = output_path / "quality.csv" if quality else None
|
75 |
+
fovea_path = output_path / "fovea.csv" if fovea else None
|
76 |
+
|
77 |
+
# Determine if input is a folder or CSV file
|
78 |
+
data_path = Path(data_path)
|
79 |
+
is_csv = data_path.suffix.lower() == ".csv"
|
80 |
+
|
81 |
+
# Get files to process
|
82 |
+
files = []
|
83 |
+
ids = None
|
84 |
+
|
85 |
+
if is_csv:
|
86 |
+
click.echo(f"Reading file paths from CSV: {data_path}")
|
87 |
+
try:
|
88 |
+
df = pd.read_csv(data_path)
|
89 |
+
if "path" not in df.columns:
|
90 |
+
click.echo("Error: CSV must contain a 'path' column")
|
91 |
+
return
|
92 |
+
|
93 |
+
# Get file paths and convert to Path objects
|
94 |
+
files = [Path(p) for p in df["path"]]
|
95 |
+
|
96 |
+
if "id" in df.columns:
|
97 |
+
ids = df["id"].tolist()
|
98 |
+
click.echo("Using IDs from CSV 'id' column")
|
99 |
+
|
100 |
+
except Exception as e:
|
101 |
+
click.echo(f"Error reading CSV file: {e}")
|
102 |
+
return
|
103 |
+
else:
|
104 |
+
click.echo(f"Finding files in directory: {data_path}")
|
105 |
+
files = list(data_path.glob("*"))
|
106 |
+
ids = [f.stem for f in files]
|
107 |
+
|
108 |
+
if not files:
|
109 |
+
click.echo("No files found to process")
|
110 |
+
return
|
111 |
+
|
112 |
+
click.echo(f"Found {len(files)} files to process")
|
113 |
+
|
114 |
+
# Step 1: Preprocess images if requested
|
115 |
+
if preprocess:
|
116 |
+
click.echo("Running preprocessing...")
|
117 |
+
_run_preprocessing(
|
118 |
+
files=files,
|
119 |
+
ids=ids,
|
120 |
+
rgb_path=preprocess_rgb_path,
|
121 |
+
bounds_path=bounds_path,
|
122 |
+
n_jobs=n_jobs,
|
123 |
+
)
|
124 |
+
# Use the preprocessed images for subsequent steps
|
125 |
+
preprocessed_files = list(preprocess_rgb_path.glob("*.png"))
|
126 |
+
else:
|
127 |
+
# Use the input files directly
|
128 |
+
preprocessed_files = files
|
129 |
+
ids = [f.stem for f in preprocessed_files]
|
130 |
+
|
131 |
+
# Set up GPU device
|
132 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
133 |
+
click.echo(f"Using device: {device}")
|
134 |
+
|
135 |
+
# Step 2: Run quality estimation if requested
|
136 |
+
if quality:
|
137 |
+
click.echo("Running quality estimation...")
|
138 |
+
df_quality = run_quality_estimation(
|
139 |
+
fpaths=preprocessed_files, ids=ids, device=device
|
140 |
+
)
|
141 |
+
df_quality.to_csv(quality_path)
|
142 |
+
click.echo(f"Quality results saved to {quality_path}")
|
143 |
+
|
144 |
+
# Step 3: Run vessels and AV segmentation if requested
|
145 |
+
if vessels:
|
146 |
+
click.echo("Running vessels and AV segmentation...")
|
147 |
+
run_segmentation_vessels_and_av(
|
148 |
+
rgb_paths=preprocessed_files,
|
149 |
+
ids=ids,
|
150 |
+
av_path=av_path,
|
151 |
+
vessels_path=vessels_path,
|
152 |
+
device=device,
|
153 |
+
)
|
154 |
+
click.echo(f"Vessel segmentation saved to {vessels_path}")
|
155 |
+
click.echo(f"AV segmentation saved to {av_path}")
|
156 |
+
|
157 |
+
# Step 4: Run optic disc segmentation if requested
|
158 |
+
if disc:
|
159 |
+
click.echo("Running optic disc segmentation...")
|
160 |
+
run_segmentation_disc(
|
161 |
+
rgb_paths=preprocessed_files, ids=ids, output_path=disc_path, device=device
|
162 |
+
)
|
163 |
+
click.echo(f"Disc segmentation saved to {disc_path}")
|
164 |
+
|
165 |
+
# Step 5: Run fovea detection if requested
|
166 |
+
df_fovea = None
|
167 |
+
if fovea:
|
168 |
+
click.echo("Running fovea detection...")
|
169 |
+
df_fovea = run_fovea_detection(
|
170 |
+
rgb_paths=preprocessed_files, ids=ids, device=device
|
171 |
+
)
|
172 |
+
df_fovea.to_csv(fovea_path)
|
173 |
+
click.echo(f"Fovea detection results saved to {fovea_path}")
|
174 |
+
|
175 |
+
# Step 6: Create overlays if requested
|
176 |
+
if overlay:
|
177 |
+
click.echo("Creating visualization overlays...")
|
178 |
+
|
179 |
+
# Prepare fovea data if available
|
180 |
+
fovea_data = None
|
181 |
+
if df_fovea is not None:
|
182 |
+
fovea_data = {
|
183 |
+
idx: (row["x_fovea"], row["y_fovea"])
|
184 |
+
for idx, row in df_fovea.iterrows()
|
185 |
+
}
|
186 |
+
|
187 |
+
# Create visualization overlays
|
188 |
+
batch_create_overlays(
|
189 |
+
rgb_dir=preprocess_rgb_path if preprocess else data_path,
|
190 |
+
output_dir=overlay_path,
|
191 |
+
av_dir=av_path,
|
192 |
+
disc_dir=disc_path,
|
193 |
+
fovea_data=fovea_data,
|
194 |
+
)
|
195 |
+
|
196 |
+
click.echo(f"Visualization overlays saved to {overlay_path}")
|
197 |
+
|
198 |
+
click.echo(f"All requested processing complete. Results saved to {output_path}")
|
vascx_models/inference.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from rtnls_inference.ensembles.ensemble_classification import ClassificationEnsemble
|
12 |
+
from rtnls_inference.ensembles.ensemble_heatmap_regression import (
|
13 |
+
HeatmapRegressionEnsemble,
|
14 |
+
)
|
15 |
+
from rtnls_inference.ensembles.ensemble_segmentation import SegmentationEnsemble
|
16 |
+
from rtnls_inference.utils import decollate_batch, extract_keypoints_from_heatmaps
|
17 |
+
|
18 |
+
|
19 |
+
def run_quality_estimation(fpaths, ids, device: torch.device):
|
20 |
+
ensemble_quality = ClassificationEnsemble.from_release("quality.pt").to(device)
|
21 |
+
dataloader = ensemble_quality._make_inference_dataloader(
|
22 |
+
fpaths,
|
23 |
+
ids=ids,
|
24 |
+
num_workers=8,
|
25 |
+
preprocess=False,
|
26 |
+
batch_size=16,
|
27 |
+
)
|
28 |
+
|
29 |
+
output_ids, outputs = [], []
|
30 |
+
with torch.no_grad():
|
31 |
+
for batch in tqdm(dataloader):
|
32 |
+
if len(batch) == 0:
|
33 |
+
continue
|
34 |
+
|
35 |
+
im = batch["image"].to(device)
|
36 |
+
|
37 |
+
# QUALITY
|
38 |
+
quality = ensemble_quality.predict_step(im)
|
39 |
+
quality = torch.mean(quality, dim=0)
|
40 |
+
|
41 |
+
items = {"id": batch["id"], "quality": quality}
|
42 |
+
items = decollate_batch(items)
|
43 |
+
|
44 |
+
for item in items:
|
45 |
+
output_ids.append(item["id"])
|
46 |
+
outputs.append(item["quality"].tolist())
|
47 |
+
|
48 |
+
return pd.DataFrame(
|
49 |
+
outputs,
|
50 |
+
index=output_ids,
|
51 |
+
columns=["q1", "q2", "q3"],
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def run_segmentation_vessels_and_av(
|
56 |
+
rgb_paths: List[Path],
|
57 |
+
ce_paths: Optional[List[Path]] = None,
|
58 |
+
ids: Optional[List[str]] = None,
|
59 |
+
av_path: Optional[Path] = None,
|
60 |
+
vessels_path: Optional[Path] = None,
|
61 |
+
device: torch.device = torch.device(
|
62 |
+
"cuda:0" if torch.cuda.is_available() else "cpu"
|
63 |
+
),
|
64 |
+
) -> None:
|
65 |
+
"""
|
66 |
+
Run AV and vessel segmentation on the provided images.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
rgb_paths: List of paths to RGB fundus images
|
70 |
+
ce_paths: Optional list of paths to contrast enhanced images
|
71 |
+
ids: Optional list of ids to pass to _make_inference_dataloader
|
72 |
+
av_path: Folder where to store output AV segmentations
|
73 |
+
vessels_path: Folder where to store output vessel segmentations
|
74 |
+
device: Device to run inference on
|
75 |
+
"""
|
76 |
+
# Create output directories if they don't exist
|
77 |
+
if av_path is not None:
|
78 |
+
av_path.mkdir(exist_ok=True, parents=True)
|
79 |
+
if vessels_path is not None:
|
80 |
+
vessels_path.mkdir(exist_ok=True, parents=True)
|
81 |
+
|
82 |
+
# Load models
|
83 |
+
ensemble_av = SegmentationEnsemble.from_release("av_july24.pt").to(device).eval()
|
84 |
+
ensemble_vessels = (
|
85 |
+
SegmentationEnsemble.from_release("vessels_july24.pt").to(device).eval()
|
86 |
+
)
|
87 |
+
|
88 |
+
# Prepare input paths
|
89 |
+
if ce_paths is None:
|
90 |
+
# If CE paths are not provided, use RGB paths for both inputs
|
91 |
+
fpaths = rgb_paths
|
92 |
+
else:
|
93 |
+
# If CE paths are provided, pair them with RGB paths
|
94 |
+
if len(rgb_paths) != len(ce_paths):
|
95 |
+
raise ValueError("rgb_paths and ce_paths must have the same length")
|
96 |
+
fpaths = list(zip(rgb_paths, ce_paths))
|
97 |
+
|
98 |
+
# Create dataloader
|
99 |
+
dataloader = ensemble_av._make_inference_dataloader(
|
100 |
+
fpaths,
|
101 |
+
ids=ids,
|
102 |
+
num_workers=8,
|
103 |
+
preprocess=False,
|
104 |
+
batch_size=8,
|
105 |
+
)
|
106 |
+
|
107 |
+
# Run inference
|
108 |
+
with torch.no_grad():
|
109 |
+
for batch in tqdm(dataloader):
|
110 |
+
# AV segmentation
|
111 |
+
if av_path is not None:
|
112 |
+
with torch.autocast(device_type=device.type):
|
113 |
+
proba = ensemble_av.forward(batch["image"].to(device))
|
114 |
+
proba = torch.mean(proba, dim=0) # average over models
|
115 |
+
proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
|
116 |
+
proba = torch.nn.functional.softmax(proba, dim=-1)
|
117 |
+
|
118 |
+
items = {
|
119 |
+
"id": batch["id"],
|
120 |
+
"image": proba,
|
121 |
+
}
|
122 |
+
|
123 |
+
items = decollate_batch(items)
|
124 |
+
for i, item in enumerate(items):
|
125 |
+
fpath = os.path.join(av_path, f"{item['id']}.png")
|
126 |
+
mask = np.argmax(item["image"], -1)
|
127 |
+
Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
|
128 |
+
|
129 |
+
# Vessel segmentation
|
130 |
+
if vessels_path is not None:
|
131 |
+
with torch.autocast(device_type=device.type):
|
132 |
+
proba = ensemble_vessels.forward(batch["image"].to(device))
|
133 |
+
proba = torch.mean(proba, dim=0) # average over models
|
134 |
+
proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
|
135 |
+
proba = torch.nn.functional.softmax(proba, dim=-1)
|
136 |
+
|
137 |
+
items = {
|
138 |
+
"id": batch["id"],
|
139 |
+
"image": proba,
|
140 |
+
}
|
141 |
+
|
142 |
+
items = decollate_batch(items)
|
143 |
+
for i, item in enumerate(items):
|
144 |
+
fpath = os.path.join(vessels_path, f"{item['id']}.png")
|
145 |
+
mask = np.argmax(item["image"], -1)
|
146 |
+
Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
|
147 |
+
|
148 |
+
|
149 |
+
def run_segmentation_disc(
|
150 |
+
rgb_paths: List[Path],
|
151 |
+
ce_paths: Optional[List[Path]] = None,
|
152 |
+
ids: Optional[List[str]] = None,
|
153 |
+
output_path: Optional[Path] = None,
|
154 |
+
device: torch.device = torch.device(
|
155 |
+
"cuda:0" if torch.cuda.is_available() else "cpu"
|
156 |
+
),
|
157 |
+
) -> None:
|
158 |
+
ensemble_disc = (
|
159 |
+
SegmentationEnsemble.from_release("disc_july24.pt").to(device).eval()
|
160 |
+
)
|
161 |
+
|
162 |
+
# Prepare input paths
|
163 |
+
if ce_paths is None:
|
164 |
+
# If CE paths are not provided, use RGB paths for both inputs
|
165 |
+
fpaths = rgb_paths
|
166 |
+
else:
|
167 |
+
# If CE paths are provided, pair them with RGB paths
|
168 |
+
if len(rgb_paths) != len(ce_paths):
|
169 |
+
raise ValueError("rgb_paths and ce_paths must have the same length")
|
170 |
+
fpaths = list(zip(rgb_paths, ce_paths))
|
171 |
+
|
172 |
+
dataloader = ensemble_disc._make_inference_dataloader(
|
173 |
+
fpaths,
|
174 |
+
ids=ids,
|
175 |
+
num_workers=8,
|
176 |
+
preprocess=False,
|
177 |
+
batch_size=8,
|
178 |
+
)
|
179 |
+
|
180 |
+
with torch.no_grad():
|
181 |
+
for batch in tqdm(dataloader):
|
182 |
+
# AV
|
183 |
+
with torch.autocast(device_type=device.type):
|
184 |
+
proba = ensemble_disc.forward(batch["image"].to(device))
|
185 |
+
proba = torch.mean(proba, dim=0) # average over models
|
186 |
+
proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
|
187 |
+
proba = torch.nn.functional.softmax(proba, dim=-1)
|
188 |
+
|
189 |
+
items = {
|
190 |
+
"id": batch["id"],
|
191 |
+
"image": proba,
|
192 |
+
}
|
193 |
+
|
194 |
+
items = decollate_batch(items)
|
195 |
+
items = [dataloader.dataset.transform.undo_item(item) for item in items]
|
196 |
+
for i, item in enumerate(items):
|
197 |
+
fpath = os.path.join(output_path, f"{item['id']}.png")
|
198 |
+
|
199 |
+
mask = np.argmax(item["image"], -1)
|
200 |
+
Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
|
201 |
+
|
202 |
+
|
203 |
+
def run_fovea_detection(
|
204 |
+
rgb_paths: List[Path],
|
205 |
+
ce_paths: Optional[List[Path]] = None,
|
206 |
+
ids: Optional[List[str]] = None,
|
207 |
+
device: torch.device = torch.device(
|
208 |
+
"cuda:0" if torch.cuda.is_available() else "cpu"
|
209 |
+
),
|
210 |
+
) -> None:
|
211 |
+
# def run_fovea_detection(fpaths, ids, device: torch.device):
|
212 |
+
ensemble_fovea = HeatmapRegressionEnsemble.from_release("fovea_july24.pt").to(
|
213 |
+
device
|
214 |
+
)
|
215 |
+
|
216 |
+
# Prepare input paths
|
217 |
+
if ce_paths is None:
|
218 |
+
# If CE paths are not provided, use RGB paths for both inputs
|
219 |
+
fpaths = rgb_paths
|
220 |
+
else:
|
221 |
+
# If CE paths are provided, pair them with RGB paths
|
222 |
+
if len(rgb_paths) != len(ce_paths):
|
223 |
+
raise ValueError("rgb_paths and ce_paths must have the same length")
|
224 |
+
fpaths = list(zip(rgb_paths, ce_paths))
|
225 |
+
|
226 |
+
dataloader = ensemble_fovea._make_inference_dataloader(
|
227 |
+
fpaths,
|
228 |
+
ids=ids,
|
229 |
+
num_workers=8,
|
230 |
+
preprocess=False,
|
231 |
+
batch_size=8,
|
232 |
+
)
|
233 |
+
|
234 |
+
output_ids, outputs = [], []
|
235 |
+
with torch.no_grad():
|
236 |
+
for batch in tqdm(dataloader):
|
237 |
+
if len(batch) == 0:
|
238 |
+
continue
|
239 |
+
|
240 |
+
im = batch["image"].to(device)
|
241 |
+
|
242 |
+
# FOVEA DETECTION
|
243 |
+
with torch.autocast(device_type=device.type):
|
244 |
+
heatmap = ensemble_fovea.forward(im)
|
245 |
+
keypoints = extract_keypoints_from_heatmaps(heatmap)
|
246 |
+
|
247 |
+
kp_fovea = torch.mean(keypoints, dim=0) # average over models
|
248 |
+
|
249 |
+
items = {
|
250 |
+
"id": batch["id"],
|
251 |
+
"keypoints": kp_fovea,
|
252 |
+
"metadata": batch["metadata"],
|
253 |
+
}
|
254 |
+
items = decollate_batch(items)
|
255 |
+
|
256 |
+
items = [dataloader.dataset.transform.undo_item(item) for item in items]
|
257 |
+
|
258 |
+
for item in items:
|
259 |
+
output_ids.append(item["id"])
|
260 |
+
outputs.append(
|
261 |
+
[
|
262 |
+
*item["keypoints"][0].tolist(),
|
263 |
+
]
|
264 |
+
)
|
265 |
+
return pd.DataFrame(
|
266 |
+
outputs,
|
267 |
+
index=output_ids,
|
268 |
+
columns=["x_fovea", "y_fovea"],
|
269 |
+
)
|
vascx_models/utils.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Dict, Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
|
7 |
+
|
8 |
+
def create_fundus_overlay(
|
9 |
+
rgb_path: str,
|
10 |
+
av_path: Optional[str] = None,
|
11 |
+
disc_path: Optional[str] = None,
|
12 |
+
fovea_location: Optional[Tuple[int, int]] = None,
|
13 |
+
output_path: Optional[str] = None,
|
14 |
+
) -> np.ndarray:
|
15 |
+
"""
|
16 |
+
Create a visualization of a fundus image with overlaid segmentations and markers.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
rgb_path: Path to the RGB fundus image
|
20 |
+
av_path: Optional path to artery-vein segmentation (1=artery, 2=vein, 3=intersection)
|
21 |
+
disc_path: Optional path to binary disc segmentation
|
22 |
+
fovea_location: Optional (x,y) tuple indicating the location of the fovea
|
23 |
+
output_path: Optional path to save the visualization image
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
Numpy array containing the visualization image
|
27 |
+
"""
|
28 |
+
print(rgb_path, av_path, disc_path, fovea_location, output_path)
|
29 |
+
# Load RGB image
|
30 |
+
rgb_img = np.array(Image.open(rgb_path))
|
31 |
+
|
32 |
+
# Create output image starting with the RGB image
|
33 |
+
output_img = rgb_img.copy()
|
34 |
+
|
35 |
+
# Load and overlay AV segmentation if provided
|
36 |
+
if av_path:
|
37 |
+
av_mask = np.array(Image.open(av_path))
|
38 |
+
|
39 |
+
# Create masks for arteries (1), veins (2) and intersections (3)
|
40 |
+
artery_mask = av_mask == 1
|
41 |
+
vein_mask = av_mask == 2
|
42 |
+
intersection_mask = av_mask == 3
|
43 |
+
|
44 |
+
# Combine artery and intersection for visualization
|
45 |
+
artery_combined = np.logical_or(artery_mask, intersection_mask)
|
46 |
+
vein_combined = np.logical_or(vein_mask, intersection_mask)
|
47 |
+
|
48 |
+
# Apply colors: red for arteries, blue for veins
|
49 |
+
# Red channel - increase for arteries
|
50 |
+
output_img[artery_combined, 0] = 255
|
51 |
+
output_img[artery_combined, 1] = 0
|
52 |
+
output_img[artery_combined, 2] = 0
|
53 |
+
|
54 |
+
# Blue channel - increase for veins
|
55 |
+
output_img[vein_combined, 0] = 0
|
56 |
+
output_img[vein_combined, 1] = 0
|
57 |
+
output_img[vein_combined, 2] = 255
|
58 |
+
|
59 |
+
# Load and overlay optic disc segmentation if provided
|
60 |
+
if disc_path:
|
61 |
+
disc_mask = np.array(Image.open(disc_path)) > 0
|
62 |
+
|
63 |
+
# Apply white color for disc
|
64 |
+
output_img[disc_mask, :] = [255, 255, 255] # White
|
65 |
+
|
66 |
+
# Convert to PIL image for drawing the fovea marker
|
67 |
+
pil_img = Image.fromarray(output_img)
|
68 |
+
|
69 |
+
# Add fovea marker if provided
|
70 |
+
if fovea_location:
|
71 |
+
draw = ImageDraw.Draw(pil_img)
|
72 |
+
x, y = fovea_location
|
73 |
+
marker_size = (
|
74 |
+
min(pil_img.width, pil_img.height) // 50
|
75 |
+
) # Scale marker with image
|
76 |
+
|
77 |
+
# Draw yellow X at fovea location
|
78 |
+
draw.line(
|
79 |
+
[(x - marker_size, y - marker_size), (x + marker_size, y + marker_size)],
|
80 |
+
fill=(255, 255, 0),
|
81 |
+
width=2,
|
82 |
+
)
|
83 |
+
draw.line(
|
84 |
+
[(x - marker_size, y + marker_size), (x + marker_size, y - marker_size)],
|
85 |
+
fill=(255, 255, 0),
|
86 |
+
width=2,
|
87 |
+
)
|
88 |
+
|
89 |
+
# Convert back to numpy array
|
90 |
+
output_img = np.array(pil_img)
|
91 |
+
|
92 |
+
# Save output if path provided
|
93 |
+
if output_path:
|
94 |
+
Image.fromarray(output_img).save(output_path)
|
95 |
+
|
96 |
+
return output_img
|
97 |
+
|
98 |
+
|
99 |
+
def batch_create_overlays(
|
100 |
+
rgb_dir: Path,
|
101 |
+
output_dir: Path,
|
102 |
+
av_dir: Optional[Path] = None,
|
103 |
+
disc_dir: Optional[Path] = None,
|
104 |
+
fovea_data: Optional[Dict[str, Tuple[int, int]]] = None,
|
105 |
+
) -> None:
|
106 |
+
"""
|
107 |
+
Create visualization overlays for a batch of images.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
rgb_dir: Directory containing RGB fundus images
|
111 |
+
output_dir: Directory to save visualization images
|
112 |
+
av_dir: Optional directory containing AV segmentations
|
113 |
+
disc_dir: Optional directory containing disc segmentations
|
114 |
+
fovea_data: Optional dictionary mapping image IDs to fovea coordinates
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
List of paths to created visualization images
|
118 |
+
"""
|
119 |
+
# Create output directory if it doesn't exist
|
120 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
121 |
+
|
122 |
+
# Get all RGB images
|
123 |
+
rgb_files = list(rgb_dir.glob("*.png"))
|
124 |
+
if not rgb_files:
|
125 |
+
return []
|
126 |
+
|
127 |
+
# Process each image
|
128 |
+
for rgb_file in rgb_files:
|
129 |
+
image_id = rgb_file.stem
|
130 |
+
|
131 |
+
# Check for corresponding AV segmentation
|
132 |
+
av_file = None
|
133 |
+
if av_dir:
|
134 |
+
av_file_path = av_dir / f"{image_id}.png"
|
135 |
+
if av_file_path.exists():
|
136 |
+
av_file = str(av_file_path)
|
137 |
+
|
138 |
+
# Check for corresponding disc segmentation
|
139 |
+
disc_file = None
|
140 |
+
if disc_dir:
|
141 |
+
disc_file_path = disc_dir / f"{image_id}.png"
|
142 |
+
if disc_file_path.exists():
|
143 |
+
disc_file = str(disc_file_path)
|
144 |
+
|
145 |
+
# Get fovea location if available
|
146 |
+
fovea_location = None
|
147 |
+
if fovea_data and image_id in fovea_data:
|
148 |
+
fovea_location = fovea_data[image_id]
|
149 |
+
|
150 |
+
# Create output path
|
151 |
+
output_file = output_dir / f"{image_id}.png"
|
152 |
+
|
153 |
+
# Create and save overlay
|
154 |
+
create_fundus_overlay(
|
155 |
+
rgb_path=str(rgb_file),
|
156 |
+
av_path=av_file,
|
157 |
+
disc_path=disc_file,
|
158 |
+
fovea_location=fovea_location,
|
159 |
+
output_path=str(output_file),
|
160 |
+
)
|