import gradio as gr import subprocess import os import shutil from huggingface_hub import hf_hub_download import torch import nibabel as nib import matplotlib.pyplot as plt import spaces # Import spaces for GPU decoration import numpy as np from scipy.ndimage import center_of_mass, zoom # Define paths MODEL_DIR = "./model" # Local directory to store the downloaded model DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML INPUT_DIR = "/tmp/input" OUTPUT_DIR = "/tmp/output" # Hugging Face Model Repository REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID # Function to download the Dataset004_WML folder def download_model(): if not os.path.exists(DATASET_DIR): os.makedirs(DATASET_DIR, exist_ok=True) print("Downloading Dataset004_WML.zip...") zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR) subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR]) print("Dataset004_WML downloaded and extracted.") def resample_to_isotropic(data, affine, target_spacing=1.0): """ Resamples a 3D NIfTI image to isotropic voxel size. Parameters: data (numpy.ndarray): The input 3D image data. affine (numpy.ndarray): The affine transformation matrix. target_spacing (float): Desired isotropic voxel spacing (in mm). Returns: resampled_data (numpy.ndarray): Resampled image data. resampled_affine (numpy.ndarray): Updated affine matrix. """ # Extract current voxel dimensions from the affine matrix current_spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0)) # Compute the scaling factors for resampling scaling_factors = current_spacing / target_spacing # Resample the data using zoom resampled_data = zoom(data, zoom=scaling_factors, order=1) # Linear interpolation # Update the affine matrix to reflect the new voxel dimensions resampled_affine = affine.copy() resampled_affine[:3, :3] /= scaling_factors[:, np.newaxis] return resampled_data, resampled_affine def extract_middle_slices(nifti_path, output_image_path, slice_size=180): """ Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image. The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG. """ # Load NIfTI image img = nib.load(nifti_path) data = img.get_fdata() affine = img.affine # Resample the image to 1 mm isotropic resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0) # Compute the center of mass of non-zero voxels com = center_of_mass(resampled_data > 0) center = np.round(com).astype(int) # Define half the slice size half_size = slice_size // 2 def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None): """ Extracts slices from a 3D NIfTI image. If a center is provided, it uses it; otherwise, computes the center of mass of non-zero voxels. Slices are taken along axial, coronal, and sagittal planes and saved as a single PNG. """ # Load NIfTI image img = nib.load(nifti_path) data = img.get_fdata() affine = img.affine # Resample the image to 1 mm isotropic resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0) # Compute or reuse the center of mass if center is None: com = center_of_mass(resampled_data > 0) center = np.round(com).astype(int) # Define half the slice size half_size = slice_size // 2 # Safely extract and pad 2D slices def extract_2d_slice(data, center, axis): slices = [slice(None)] * 3 slices[axis] = center[axis] # Fix the axis to extract a single slice extracted_slice = data[tuple(slices)] # Crop the 2D slice around the center in the remaining dimensions remaining_axes = [i for i in range(3) if i != axis] cropped_slice = extracted_slice[ max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]), max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]), ] # Pad the slice to ensure 180x180 dimensions pad_height = slice_size - cropped_slice.shape[0] pad_width = slice_size - cropped_slice.shape[1] padded_slice = np.pad(cropped_slice, ((pad_height // 2, pad_height - pad_height // 2), (pad_width // 2, pad_width - pad_width // 2)), mode='constant', constant_values=0) return padded_slice # Extract slices in axial, coronal, and sagittal planes axial_slice = extract_2d_slice(resampled_data, center, axis=2) # Axial (z-axis) coronal_slice = extract_2d_slice(resampled_data, center, axis=1) # Coronal (y-axis) sagittal_slice = extract_2d_slice(resampled_data, center, axis=0) # Sagittal (x-axis) # Apply rotations to each slice axial_slice = np.rot90(axial_slice, k=-1) # 90 degrees clockwise coronal_slice = np.rot90(coronal_slice, k=1) # 90 degrees anticlockwise coronal_slice = np.rot90(coronal_slice, k=2) # Additional 180 degrees sagittal_slice = np.rot90(sagittal_slice, k=1) # 90 degrees anticlockwise sagittal_slice = np.rot90(sagittal_slice, k=2) # Additional 180 degrees # Create subplots fig, axes = plt.subplots(1, 3, figsize=(12, 4)) # Plot each padded and rotated slice axes[0].imshow(axial_slice, cmap="gray", origin="lower") axes[0].axis("off") axes[1].imshow(coronal_slice, cmap="gray", origin="lower") axes[1].axis("off") axes[2].imshow(sagittal_slice, cmap="gray", origin="lower") axes[2].axis("off") # Save the figure plt.tight_layout() plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0) plt.close() # Function to run nnUNet inference @spaces.GPU(duration=70) # Decorate the function to allocate GPU for its execution def run_nnunet_predict(nifti_file): # Prepare directories os.makedirs(INPUT_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # Extract the original filename without the extension original_filename = os.path.basename(nifti_file.name) base_filename = original_filename.replace(".nii.gz", "") # Save the uploaded file to the input directory input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz") os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location # Debugging: List files in the /tmp/input directory print("Files in /tmp/input:") print(os.listdir(INPUT_DIR)) # Set environment variables for nnUNet os.environ["nnUNet_results"] = MODEL_DIR # Construct and run the nnUNetv2_predict command command = [ "nnUNetv2_predict", "-i", INPUT_DIR, "-o", OUTPUT_DIR, "-d", "004", # Dataset ID "-c", "3d_fullres", # Configuration "-tr", "nnUNetTrainer_8000epochs", "-device", "cuda" # Explicitly use GPU ] print("Files in /tmp/output:") print(os.listdir(OUTPUT_DIR)) try: subprocess.run(command, check=True) # Rename the output file to match the original input filename output_file = os.path.join(OUTPUT_DIR, "image.nii.gz") new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz") if os.path.exists(output_file): os.rename(output_file, new_output_file) # Compute center of mass for the input image img = nib.load(input_path) data = img.get_fdata() affine = img.affine resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0) com = center_of_mass(resampled_data > 0) # Center of mass center = np.round(com).astype(int) # Round to integer # Extract and save 2D slices input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png") output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png") extract_middle_slices(input_path, input_slice_path, center=center) extract_middle_slices(new_output_file, output_slice_path, center=center) # Return paths for the Gradio interface return new_output_file, input_slice_path, output_slice_path else: return "Error: Output file not found." except subprocess.CalledProcessError as e: return f"Error: {e}" # Gradio interface with adjusted layout with gr.Blocks() as demo: gr.Markdown(""" # 🔥 FLAMeS: FLAIR Lesion Segmentation for Multiple Sclerosis Upload a skull-stripped FLAIR brain MRI in NIfTI (.nii.gz) format to generate a binary segmentation of multiple sclerosis lesions. FLAMeS is based on the nnUNet framework2 and was trained on 668 MRI scans acquired using Siemens, GE, and Philips 1.5T and 3T scanners1. For skull-stripping, we suggest using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results. Inference takes approximately 1 minute per MRI, with processing limited to one scan at a time due to Hugging Face's zero-GPU usage constraints. To process multiple cases simultaneously, download [FLAMeS's model](https://huggingface.co/FrancescoLR/FLAMeS-model) and run it locally using your own GPU or CPU setup. **Disclaimer:** Uploaded data is stored temporarily, no one has access to it, and it is deleted when the app is closed. For details, see [Gradio's file access guide](https://www.gradio.app/main/guides/file-access). Human subjects data should only be uploaded for processing if permitted by your institution's human subjects protection office. This is a research tool and is not intended for clinical use. Clinical decisions should not be based on the outputs of this tool. """) with gr.Row(): with gr.Column(scale=1): flair_input = gr.File(label="Upload a FLAIR Image (.nii.gz)") submit_button = gr.Button("Submit") with gr.Column(scale=2): seg_output = gr.File(label="Download the Lesion Segmentation Mask") input_img = gr.Image(label="Input: FLAIR image") output_img = gr.Image(label="Output: Lesion Mask") gr.Markdown(""" **If you find this tool useful, please consider citing:** 1. A Deep Learning-Based Pipeline for Longitudinal White Matter Lesion Segmentation Using Diverse FLAIR Images F. La Rosa, J. Dos Santos Silva, W. A. Mullins, H. Greenspan, J. F. Sumowski, D. S. Reich, & E. S. Beck. *ACTRIMS Forum 2023. Multiple Sclerosis Journal.* 2023;29(2_suppl):18-242. DOI: [10.1177/13524585231169437](https://doi.org/10.1177/13524585231169437) 2. nnU-Net: A Self-Configuring Method for Deep Learning-Based Biomedical Image Segmentation F. Isensee, P. F. Jaeger, S. A. Kohl, J. Petersen, & K. H. Maier-Hein. *Nature Methods.* 2021;18(2):203-211. DOI: [10.1038/s41592-020-01008-z](https://www.nature.com/articles/s41592-020-01008-z) """) submit_button.click( fn=run_nnunet_predict, inputs=[flair_input], outputs=[seg_output, input_img, output_img] ) # Debugging GPU environment if torch.cuda.is_available(): print(f"GPU is available: {torch.cuda.get_device_name(0)}") else: print("No GPU available. Falling back to CPU.") os.system("nvidia-smi") download_model() if __name__ == "__main__": demo.launch(share=True)