import gradio as gr import subprocess import os import shutil from huggingface_hub import hf_hub_download import torch import nibabel as nib import matplotlib.pyplot as plt import spaces # Import spaces for GPU decoration import numpy as np from scipy.ndimage import center_of_mass # Define paths MODEL_DIR = "./model" # Local directory to store the downloaded model DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML INPUT_DIR = "/tmp/input" OUTPUT_DIR = "/tmp/output" # Hugging Face Model Repository REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID # Function to download the Dataset004_WML folder def download_model(): if not os.path.exists(DATASET_DIR): os.makedirs(DATASET_DIR, exist_ok=True) print("Downloading Dataset004_WML.zip...") zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR) subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR]) print("Dataset004_WML downloaded and extracted.") def extract_middle_slices(nifti_path, output_image_path, slice_size=180): """ Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image. The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG. """ # Load NIfTI image and get the data img = nib.load(nifti_path) data = img.get_fdata() # Compute the center of mass of non-zero voxels com = center_of_mass(data > 0) center = np.round(com).astype(int) # Define half the slice size to extract regions around the center of mass half_size = slice_size // 2 # Safely extract and crop 2D slices def extract_2d_slice(data, center, axis): slices = [slice(None)] * 3 slices[axis] = center[axis] extracted_slice = data[tuple(slices)] # Crop around the center for the remaining dimensions remaining_axes = [i for i in range(3) if i != axis] for dim in remaining_axes: start = max(center[dim] - half_size, 0) end = min(center[dim] + half_size, extracted_slice.shape[dim]) extracted_slice = extracted_slice.take(range(start, end), axis=dim - (dim > axis)) return extracted_slice axial_slice = extract_2d_slice(data, center, axis=2) # Axial (z-axis) coronal_slice = extract_2d_slice(data, center, axis=1) # Coronal (y-axis) sagittal_slice = extract_2d_slice(data, center, axis=0) # Sagittal (x-axis) # Create subplots fig, axes = plt.subplots(1, 3, figsize=(12, 4)) # Plot each slice axes[0].imshow(axial_slice, cmap="gray", origin="lower") axes[0].axis("off") axes[0].set_title("Axial") axes[1].imshow(coronal_slice, cmap="gray", origin="lower") axes[1].axis("off") axes[1].set_title("Coronal") axes[2].imshow(sagittal_slice, cmap="gray", origin="lower") axes[2].axis("off") axes[2].set_title("Sagittal") # Save the figure plt.tight_layout() plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0) plt.close() # Function to run nnUNet inference @spaces.GPU # Decorate the function to allocate GPU for its execution def run_nnunet_predict(nifti_file): # Prepare directories os.makedirs(INPUT_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # Extract the original filename without the extension original_filename = os.path.basename(nifti_file.name) base_filename = original_filename.replace(".nii.gz", "") # Save the uploaded file to the input directory input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz") os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location # Debugging: List files in the /tmp/input directory print("Files in /tmp/input:") print(os.listdir(INPUT_DIR)) # Set environment variables for nnUNet os.environ["nnUNet_results"] = MODEL_DIR # Construct and run the nnUNetv2_predict command command = [ "nnUNetv2_predict", "-i", INPUT_DIR, "-o", OUTPUT_DIR, "-d", "004", # Dataset ID "-c", "3d_fullres", # Configuration "-tr", "nnUNetTrainer_8000epochs", "-device", "cuda" # Explicitly use GPU ] print("Files in /tmp/output:") print(os.listdir(OUTPUT_DIR)) try: subprocess.run(command, check=True) # Rename the output file to match the original input filename output_file = os.path.join(OUTPUT_DIR, "image.nii.gz") new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz") if os.path.exists(output_file): os.rename(output_file, new_output_file) # Extract and save 2D slices input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png") output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png") extract_middle_slices(input_path, input_slice_path) extract_middle_slices(new_output_file, output_slice_path) # Return paths for the Gradio interface return new_output_file, input_slice_path, output_slice_path else: return "Error: Output file not found." except subprocess.CalledProcessError as e: return f"Error: {e}" # Gradio Interfaceinterface = gr.Interface( interface = gr.Interface( fn=run_nnunet_predict, inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"), outputs=[ gr.File(label="Download Segmentation Mask"), gr.Image(label="Input Middle Slice"), gr.Image(label="Output Middle Slice") ], title="FLAMeS: Multiple Sclerosis Lesion Segmentation", description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions." ) # Debugging GPU environment if torch.cuda.is_available(): print(f"GPU is available: {torch.cuda.get_device_name(0)}") else: print("No GPU available. Falling back to CPU.") os.system("nvidia-smi") # Check if NVIDIA tools are available # Download model files before launching the app download_model() # Launch the app if __name__ == "__main__": interface.launch(share=True)