Image Segmentation
medical
biology
File size: 6,487 Bytes
1b052a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from pathlib import Path

import click
import pandas as pd
import torch

from rtnls_fundusprep.cli import _run_preprocessing

from .inference import (
    run_fovea_detection,
    run_quality_estimation,
    run_segmentation_disc,
    run_segmentation_vessels_and_av,
)
from .utils import batch_create_overlays


@click.group(name="vascx")
def cli():
    pass


@cli.command()
@click.argument("data_path", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@click.option(
    "--preprocess/--no-preprocess",
    default=True,
    help="Run preprocessing or use preprocessed images",
)
@click.option(
    "--vessels/--no-vessels", default=True, help="Run vessels and AV segmentation"
)
@click.option("--disc/--no-disc", default=True, help="Run optic disc segmentation")
@click.option(
    "--quality/--no-quality", default=True, help="Run image quality estimation"
)
@click.option("--fovea/--no-fovea", default=True, help="Run fovea detection")
@click.option(
    "--overlay/--no-overlay", default=True, help="Create visualization overlays"
)
@click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers")
def run(
    data_path, output_path, preprocess, vessels, disc, quality, fovea, overlay, n_jobs
):
    """Run the complete inference pipeline on fundus images.

    DATA_PATH is either a directory containing images or a CSV file with 'path' column.
    OUTPUT_PATH is the directory where results will be stored.
    """

    output_path = Path(output_path)
    output_path.mkdir(exist_ok=True, parents=True)

    # Setup output directories
    preprocess_rgb_path = output_path / "preprocessed_rgb"
    vessels_path = output_path / "vessels"
    av_path = output_path / "artery_vein"
    disc_path = output_path / "disc"
    overlay_path = output_path / "overlays"

    # Create required directories
    if preprocess:
        preprocess_rgb_path.mkdir(exist_ok=True, parents=True)
    if vessels:
        av_path.mkdir(exist_ok=True, parents=True)
        vessels_path.mkdir(exist_ok=True, parents=True)
    if disc:
        disc_path.mkdir(exist_ok=True, parents=True)
    if overlay:
        overlay_path.mkdir(exist_ok=True, parents=True)

    bounds_path = output_path / "bounds.csv" if preprocess else None
    quality_path = output_path / "quality.csv" if quality else None
    fovea_path = output_path / "fovea.csv" if fovea else None

    # Determine if input is a folder or CSV file
    data_path = Path(data_path)
    is_csv = data_path.suffix.lower() == ".csv"

    # Get files to process
    files = []
    ids = None

    if is_csv:
        click.echo(f"Reading file paths from CSV: {data_path}")
        try:
            df = pd.read_csv(data_path)
            if "path" not in df.columns:
                click.echo("Error: CSV must contain a 'path' column")
                return

            # Get file paths and convert to Path objects
            files = [Path(p) for p in df["path"]]

            if "id" in df.columns:
                ids = df["id"].tolist()
                click.echo("Using IDs from CSV 'id' column")

        except Exception as e:
            click.echo(f"Error reading CSV file: {e}")
            return
    else:
        click.echo(f"Finding files in directory: {data_path}")
        files = list(data_path.glob("*"))
        ids = [f.stem for f in files]

    if not files:
        click.echo("No files found to process")
        return

    click.echo(f"Found {len(files)} files to process")

    # Step 1: Preprocess images if requested
    if preprocess:
        click.echo("Running preprocessing...")
        _run_preprocessing(
            files=files,
            ids=ids,
            rgb_path=preprocess_rgb_path,
            bounds_path=bounds_path,
            n_jobs=n_jobs,
        )
        # Use the preprocessed images for subsequent steps
        preprocessed_files = list(preprocess_rgb_path.glob("*.png"))
    else:
        # Use the input files directly
        preprocessed_files = files
    ids = [f.stem for f in preprocessed_files]

    # Set up GPU device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    click.echo(f"Using device: {device}")

    # Step 2: Run quality estimation if requested
    if quality:
        click.echo("Running quality estimation...")
        df_quality = run_quality_estimation(
            fpaths=preprocessed_files, ids=ids, device=device
        )
        df_quality.to_csv(quality_path)
        click.echo(f"Quality results saved to {quality_path}")

    # Step 3: Run vessels and AV segmentation if requested
    if vessels:
        click.echo("Running vessels and AV segmentation...")
        run_segmentation_vessels_and_av(
            rgb_paths=preprocessed_files,
            ids=ids,
            av_path=av_path,
            vessels_path=vessels_path,
            device=device,
        )
        click.echo(f"Vessel segmentation saved to {vessels_path}")
        click.echo(f"AV segmentation saved to {av_path}")

    # Step 4: Run optic disc segmentation if requested
    if disc:
        click.echo("Running optic disc segmentation...")
        run_segmentation_disc(
            rgb_paths=preprocessed_files, ids=ids, output_path=disc_path, device=device
        )
        click.echo(f"Disc segmentation saved to {disc_path}")

    # Step 5: Run fovea detection if requested
    df_fovea = None
    if fovea:
        click.echo("Running fovea detection...")
        df_fovea = run_fovea_detection(
            rgb_paths=preprocessed_files, ids=ids, device=device
        )
        df_fovea.to_csv(fovea_path)
        click.echo(f"Fovea detection results saved to {fovea_path}")

    # Step 6: Create overlays if requested
    if overlay:
        click.echo("Creating visualization overlays...")

        # Prepare fovea data if available
        fovea_data = None
        if df_fovea is not None:
            fovea_data = {
                idx: (row["x_fovea"], row["y_fovea"])
                for idx, row in df_fovea.iterrows()
            }

        # Create visualization overlays
        batch_create_overlays(
            rgb_dir=preprocess_rgb_path if preprocess else data_path,
            output_dir=overlay_path,
            av_dir=av_path,
            disc_dir=disc_path,
            fovea_data=fovea_data,
        )

        click.echo(f"Visualization overlays saved to {overlay_path}")

    click.echo(f"All requested processing complete. Results saved to {output_path}")