import gradio as gr
import subprocess
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import nibabel as nib
import matplotlib as mpl
import matplotlib.pyplot as plt
import spaces # Import spaces for GPU decoration
import numpy as np
from scipy.ndimage import center_of_mass, zoom, label, generate_binary_structure
# Define paths
MODEL_DIR = "./model" # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"
# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID
import os
import subprocess
def setup_hd_bet(repo_dir="./HD-BET"):
"""
Clones the HD-BET repository and installs it in editable mode using pip.
Parameters:
repo_dir (str): Directory where HD-BET will be cloned and installed.
"""
if not os.path.exists(repo_dir):
print("Cloning HD-BET repository...")
subprocess.run(["git", "clone", "https://github.com/MIC-DKFZ/HD-BET", repo_dir], check=True)
else:
print("HD-BET repository already exists.")
# Install the HD-BET package from source
print("Installing HD-BET using pip...")
subprocess.run(["pip", "install", "-e", "."], cwd=repo_dir, check=True)
# Function to download the Dataset004_WML folder
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading Dataset004_WML.zip...")
zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
print("Dataset004_WML downloaded and extracted.")
def resample_to_isotropic(data, affine, target_spacing=1.0):
"""
Resamples a 3D NIfTI image to isotropic voxel size.
Parameters:
data (numpy.ndarray): The input 3D image data.
affine (numpy.ndarray): The affine transformation matrix.
target_spacing (float): Desired isotropic voxel spacing (in mm).
Returns:
resampled_data (numpy.ndarray): Resampled image data.
resampled_affine (numpy.ndarray): Updated affine matrix.
"""
# Extract current voxel dimensions from the affine matrix
current_spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0))
# Compute the scaling factors for resampling
scaling_factors = current_spacing / target_spacing
# Resample the data using zoom
resampled_data = zoom(data, zoom=scaling_factors, order=1) # Linear interpolation
# Update the affine matrix to reflect the new voxel dimensions
resampled_affine = affine.copy()
resampled_affine[:3, :3] /= scaling_factors[:, np.newaxis]
return resampled_data, resampled_affine
def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
"""
Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
"""
# Load NIfTI image
img = nib.load(nifti_path)
data = img.get_fdata()
affine = img.affine
# Resample the image to 1 mm isotropic
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
# Compute the center of mass of non-zero voxels
com = center_of_mass(resampled_data > 0)
center = np.round(com).astype(int)
# Define half the slice size
half_size = slice_size // 2
def extract_middle_slices(nifti_path, output_image_path, slice_size=180, center=None, label_components=False):
"""
Extracts slices from a 3D NIfTI image.
If label_components=True, it assigns different labels (colors) to each connected component (26-connectivity)
and returns the labeled 3D mask.
Returns:
labeled_data (np.ndarray): The 3D array (either labeled or original).
"""
# Load NIfTI image
img = nib.load(nifti_path)
data = img.get_fdata()
affine = img.affine
# Resample the image to 1 mm isotropic
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
# Optionally label connected components
if label_components:
structure = generate_binary_structure(3, 3) # 3D, 26-connectivity
labeled_data, num_features = label(data > 0, structure=structure)
labeled_data_resampled, num_features = label(resampled_data > 0, structure=structure)
else:
labeled_data = resampled_data
num_features = None # Not needed if we're not labeling
labeled_data_resampled = resampled_data
# Compute or reuse the center of mass
if center is None:
com = center_of_mass(labeled_data_resampled > 0)
center = np.round(com).astype(int)
# Define half the slice size
half_size = slice_size // 2
# Function to extract and pad slices
def extract_2d_slice(data, center, axis):
slices = [slice(None)] * 3
slices[axis] = center[axis]
extracted_slice = data[tuple(slices)]
remaining_axes = [i for i in range(3) if i != axis]
cropped_slice = extracted_slice[
max(center[remaining_axes[0]] - half_size, 0):min(center[remaining_axes[0]] + half_size, extracted_slice.shape[0]),
max(center[remaining_axes[1]] - half_size, 0):min(center[remaining_axes[1]] + half_size, extracted_slice.shape[1]),
]
pad_height = slice_size - cropped_slice.shape[0]
pad_width = slice_size - cropped_slice.shape[1]
padded_slice = np.pad(cropped_slice,
((pad_height // 2, pad_height - pad_height // 2),
(pad_width // 2, pad_width - pad_width // 2)),
mode='constant', constant_values=0)
return padded_slice
# Extract slices
axial_slice = extract_2d_slice(labeled_data_resampled, center, axis=2)
coronal_slice = extract_2d_slice(labeled_data_resampled, center, axis=1)
sagittal_slice = extract_2d_slice(labeled_data_resampled, center, axis=0)
# Apply rotations
axial_slice = np.rot90(axial_slice, k=-1)
coronal_slice = np.rot90(coronal_slice, k=1)
coronal_slice = np.rot90(coronal_slice, k=2)
sagittal_slice = np.rot90(sagittal_slice, k=1)
sagittal_slice = np.rot90(sagittal_slice, k=2)
# Create subplots
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
# Choose colormap
if label_components:
# Create 256 pastel colors
pastel = plt.cm.Pastel1(np.linspace(0, 1, 256))
np.random.seed(42) # For reproducibility
shuffled_colors = pastel[1:].copy()
np.random.shuffle(shuffled_colors)
final_colors = np.vstack([np.array([0, 0, 0, 1]), shuffled_colors])
custom_cmap = mpl.colors.ListedColormap(final_colors)
cmap = custom_cmap # Colorful
vmin = 0
vmax = num_features
else:
cmap = "gray" # Normal
vmin = None
vmax = None
# Plot slices
for idx, slice_data in enumerate([axial_slice, coronal_slice, sagittal_slice]):
ax = axes[idx]
im = ax.imshow(slice_data, cmap=cmap, origin="lower", vmin=vmin, vmax=vmax)
ax.axis("off")
# Save figure
plt.tight_layout()
plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
plt.close()
# Return the labeled mask
return labeled_data
# Function to run nnUNet inference
@spaces.GPU(duration=90) # Decorate the function to allocate GPU for its execution
def run_nnunet_predict(nifti_file,hd_bet=False):
# Prepare directories
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Extract the original filename without the extension
original_filename = os.path.basename(nifti_file.name)
base_filename = original_filename.replace(".nii.gz", "")
# Save the uploaded file to the input directory
input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location
if hd_bet:
# Apply skull-stripping with HD-BET
hd_bet_output_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
try:
subprocess.run([
"hd-bet",
"-i", input_path,
"-o", hd_bet_output_path,
"-device", "cuda", # or "cpu"
"--disable_tta" ], check=True)
print("Skull-stripping completed.")
input_path = hd_bet_output_path
except subprocess.CalledProcessError as e:
return f"HD-BET Error: {e}"
# Debugging: List files in the /tmp/input directory
print("Files in /tmp/input:")
print(os.listdir(INPUT_DIR))
# Set environment variables for nnUNet
os.environ["nnUNet_results"] = MODEL_DIR
# Construct and run the nnUNetv2_predict command
command = [
"nnUNetv2_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-d", "004", # Dataset ID
"-c", "3d_fullres", # Configuration
"-tr", "nnUNetTrainer_8000epochs",
"-device", "cuda", # Explicitly use GPU
]
print("Files in /tmp/output:")
print(os.listdir(OUTPUT_DIR))
try:
subprocess.run(command, check=True)
# Rename the output file to match the original input filename
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
if os.path.exists(output_file):
os.rename(output_file, new_output_file)
# Compute center of mass for the input image
img = nib.load(input_path)
data = img.get_fdata()
affine = img.affine
resampled_data, _ = resample_to_isotropic(data, affine, target_spacing=1.0)
com = center_of_mass(resampled_data > 0) # Center of mass
center = np.round(com).astype(int) # Round to integer
# Extract and save 2D slices
input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
image = extract_middle_slices(input_path, input_slice_path, center=center)
labeled_mask = extract_middle_slices(new_output_file, output_slice_path, center=center, label_components=True)
# Load the binary lesion mask to get its affine
output_img = nib.load(new_output_file)
labeled_mask_path = os.path.join(OUTPUT_DIR, f"{base_filename}_LabeledClusters.nii.gz")
nib.save(nib.Nifti1Image(labeled_mask.astype(np.int16), output_img.affine), labeled_mask_path)
# Return paths for the Gradio interface
return new_output_file, input_slice_path, output_slice_path, labeled_mask_path
else:
return "Error: Output file not found."
except subprocess.CalledProcessError as e:
return f"Error: {e}"
# Gradio interface with adjusted layout
with gr.Blocks() as demo:
gr.Markdown("""
# 🔥 FLAMeS: FLAIR Lesion Segmentation for Multiple Sclerosis
Upload a FLAIR brain MRI in NIfTI format (.nii.gz) to generate a binary segmentation of multiple sclerosis lesions.
FLAMeS is based on the nnUNet framework2 and was trained on 668 MRI scans acquired using Siemens, GE, and Philips 1.5T and 3T scanners1.
We suggest skull-stripping the image in advance using [SynthStrip](https://surfer.nmr.mgh.harvard.edu/docs/synthstrip/) with the `--no-csf` flag for optimal results. If that's not feasible, you can still upload your image as-is and enable the "Apply skull-stripping" option below.
Inference takes approximately 1 minute per MRI, with processing limited to one scan at a time due to Hugging Face's zero-GPU usage constraints. To process multiple cases simultaneously, install the [nnUNet v2](https://github.com/MIC-DKFZ/nnUNet), download [FLAMeS's model](https://huggingface.co/FrancescoLR/FLAMeS-model) and run it locally using your own GPU or CPU setup.
**Disclaimer:** Uploaded data is stored temporarily, no one has access to it, and it is deleted when the app is closed. For details, see [Gradio's file access guide](https://www.gradio.app/main/guides/file-access). Human subjects data should only be uploaded for processing if permitted by your institution's human subjects protection office.
This is a research tool and is not intended for clinical use. Clinical decisions should not be based on the outputs of this tool.
""")
with gr.Row():
with gr.Column(scale=1):
flair_input = gr.File(label="Upload a FLAIR Image (.nii.gz)")
hd_bet = gr.Checkbox(label="Apply skull-stripping", value=False)
submit_button = gr.Button("Submit")
with gr.Column(scale=2):
seg_output = gr.File(label="Download the Lesion Segmentation Mask")
clusters_output = gr.File(label="Download the Labeled Lesion Segmentation Mask")
input_img = gr.Image(label="Input: FLAIR image")
output_img = gr.Image(label="Output: Binary Lesion Mask")
gr.Markdown("""
**If you find this tool useful, please consider citing:**
1. FLAMeS: A Robust Deep Learning Model for Automated Multiple Sclerosis Lesion Segmentation
Dereskewicz, E., La Rosa, F., dos Santos Silva, J., Sizer, E., Kohli, A., Wynen, M., ... & Beck, E. S.
*medRxiv (2025)
DOI: [10.1177/13524585231169437](https://doi.org/10.1101/2025.05.19.25327707)
2. nnU-Net: A Self-Configuring Method for Deep Learning-Based Biomedical Image Segmentation
F. Isensee, P. F. Jaeger, S. A. Kohl, J. Petersen, & K. H. Maier-Hein.
*Nature Methods.* 2021;18(2):203-211.
DOI: [10.1038/s41592-020-01008-z](https://www.nature.com/articles/s41592-020-01008-z)
""")
submit_button.click(
fn=run_nnunet_predict,
inputs=[flair_input, hd_bet],
outputs=[seg_output, input_img, output_img, clusters_output]
)
# Debugging GPU environment
if torch.cuda.is_available():
print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
print("No GPU available. Falling back to CPU.")
os.system("nvidia-smi")
setup_hd_bet()
download_model()
if __name__ == "__main__":
demo.launch(share=True)