from pathlib import Path import click import pandas as pd import torch from rtnls_fundusprep.cli import _run_preprocessing from .inference import ( run_fovea_detection, run_quality_estimation, run_segmentation_disc, run_segmentation_vessels_and_av, ) from .utils import batch_create_overlays @click.group(name="vascx") def cli(): pass @cli.command() @click.argument("data_path", type=click.Path(exists=True)) @click.argument("output_path", type=click.Path()) @click.option( "--preprocess/--no-preprocess", default=True, help="Run preprocessing or use preprocessed images", ) @click.option( "--vessels/--no-vessels", default=True, help="Run vessels and AV segmentation" ) @click.option("--disc/--no-disc", default=True, help="Run optic disc segmentation") @click.option( "--quality/--no-quality", default=True, help="Run image quality estimation" ) @click.option("--fovea/--no-fovea", default=True, help="Run fovea detection") @click.option( "--overlay/--no-overlay", default=True, help="Create visualization overlays" ) @click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers") def run( data_path, output_path, preprocess, vessels, disc, quality, fovea, overlay, n_jobs ): """Run the complete inference pipeline on fundus images. DATA_PATH is either a directory containing images or a CSV file with 'path' column. OUTPUT_PATH is the directory where results will be stored. """ output_path = Path(output_path) output_path.mkdir(exist_ok=True, parents=True) # Setup output directories preprocess_rgb_path = output_path / "preprocessed_rgb" vessels_path = output_path / "vessels" av_path = output_path / "artery_vein" disc_path = output_path / "disc" overlay_path = output_path / "overlays" # Create required directories if preprocess: preprocess_rgb_path.mkdir(exist_ok=True, parents=True) if vessels: av_path.mkdir(exist_ok=True, parents=True) vessels_path.mkdir(exist_ok=True, parents=True) if disc: disc_path.mkdir(exist_ok=True, parents=True) if overlay: overlay_path.mkdir(exist_ok=True, parents=True) bounds_path = output_path / "bounds.csv" if preprocess else None quality_path = output_path / "quality.csv" if quality else None fovea_path = output_path / "fovea.csv" if fovea else None # Determine if input is a folder or CSV file data_path = Path(data_path) is_csv = data_path.suffix.lower() == ".csv" # Get files to process files = [] ids = None if is_csv: click.echo(f"Reading file paths from CSV: {data_path}") try: df = pd.read_csv(data_path) if "path" not in df.columns: click.echo("Error: CSV must contain a 'path' column") return # Get file paths and convert to Path objects files = [Path(p) for p in df["path"]] if "id" in df.columns: ids = df["id"].tolist() click.echo("Using IDs from CSV 'id' column") except Exception as e: click.echo(f"Error reading CSV file: {e}") return else: click.echo(f"Finding files in directory: {data_path}") files = list(data_path.glob("*")) ids = [f.stem for f in files] if not files: click.echo("No files found to process") return click.echo(f"Found {len(files)} files to process") # Step 1: Preprocess images if requested if preprocess: click.echo("Running preprocessing...") _run_preprocessing( files=files, ids=ids, rgb_path=preprocess_rgb_path, bounds_path=bounds_path, n_jobs=n_jobs, ) # Use the preprocessed images for subsequent steps preprocessed_files = list(preprocess_rgb_path.glob("*.png")) else: # Use the input files directly preprocessed_files = files ids = [f.stem for f in preprocessed_files] # Set up GPU device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") click.echo(f"Using device: {device}") # Step 2: Run quality estimation if requested if quality: click.echo("Running quality estimation...") df_quality = run_quality_estimation( fpaths=preprocessed_files, ids=ids, device=device ) df_quality.to_csv(quality_path) click.echo(f"Quality results saved to {quality_path}") # Step 3: Run vessels and AV segmentation if requested if vessels: click.echo("Running vessels and AV segmentation...") run_segmentation_vessels_and_av( rgb_paths=preprocessed_files, ids=ids, av_path=av_path, vessels_path=vessels_path, device=device, ) click.echo(f"Vessel segmentation saved to {vessels_path}") click.echo(f"AV segmentation saved to {av_path}") # Step 4: Run optic disc segmentation if requested if disc: click.echo("Running optic disc segmentation...") run_segmentation_disc( rgb_paths=preprocessed_files, ids=ids, output_path=disc_path, device=device ) click.echo(f"Disc segmentation saved to {disc_path}") # Step 5: Run fovea detection if requested df_fovea = None if fovea: click.echo("Running fovea detection...") df_fovea = run_fovea_detection( rgb_paths=preprocessed_files, ids=ids, device=device ) df_fovea.to_csv(fovea_path) click.echo(f"Fovea detection results saved to {fovea_path}") # Step 6: Create overlays if requested if overlay: click.echo("Creating visualization overlays...") # Prepare fovea data if available fovea_data = None if df_fovea is not None: fovea_data = { idx: (row["x_fovea"], row["y_fovea"]) for idx, row in df_fovea.iterrows() } # Create visualization overlays batch_create_overlays( rgb_dir=preprocess_rgb_path if preprocess else data_path, output_dir=overlay_path, av_dir=av_path, disc_dir=disc_path, fovea_data=fovea_data, ) click.echo(f"Visualization overlays saved to {overlay_path}") click.echo(f"All requested processing complete. Results saved to {output_path}")