FLAMeS / app.py
FrancescoLR's picture
Updated app.py
d21c058
raw
history blame
5.29 kB
import gradio as gr
import subprocess
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import nibabel as nib
import matplotlib.pyplot as plt
import spaces # Import spaces for GPU decoration
import numpy as np
# Define paths
MODEL_DIR = "./model" # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"
# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID
# Function to download the Dataset004_WML folder
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading Dataset004_WML.zip...")
zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
print("Dataset004_WML downloaded and extracted.")
def extract_middle_slice(nifti_path, output_image_path):
"""
Extracts a middle slice from a 3D NIfTI image and saves it as a PNG file.
The figure size is adjusted dynamically based on the slice's aspect ratio
and scaled to be 50% smaller.
"""
import nibabel as nib
import matplotlib.pyplot as plt
# Load NIfTI image and get the data
img = nib.load(nifti_path)
data = img.get_fdata()
# Get the middle slice along the z-axis
middle_slice_index = data.shape[2] // 2
slice_data = data[:, :, middle_slice_index]
# Rotate the slice 90 degrees clockwise
slice_data = np.rot90(slice_data, k=-1)
# Calculate aspect ratio
height, width = slice_data.shape
aspect_ratio = width / height
# Dynamically adjust figure size based on aspect ratio and scale down by 0.5
plt.figure(figsize=(4 * aspect_ratio, 4)) # Height scaled to 3, width scaled proportionally
plt.imshow(slice_data, cmap="gray")
plt.axis("off")
plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
plt.close()
# Function to run nnUNet inference
@spaces.GPU # Decorate the function to allocate GPU for its execution
def run_nnunet_predict(nifti_file):
# Prepare directories
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Extract the original filename without the extension
original_filename = os.path.basename(nifti_file.name)
base_filename = original_filename.replace(".nii.gz", "")
# Save the uploaded file to the input directory
input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location
# Debugging: List files in the /tmp/input directory
print("Files in /tmp/input:")
print(os.listdir(INPUT_DIR))
# Set environment variables for nnUNet
os.environ["nnUNet_results"] = MODEL_DIR
# Construct and run the nnUNetv2_predict command
command = [
"nnUNetv2_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-d", "004", # Dataset ID
"-c", "3d_fullres", # Configuration
"-tr", "nnUNetTrainer_8000epochs",
"-device", "cuda" # Explicitly use GPU
]
print("Files in /tmp/output:")
print(os.listdir(OUTPUT_DIR))
try:
subprocess.run(command, check=True)
# Rename the output file to match the original input filename
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
if os.path.exists(output_file):
os.rename(output_file, new_output_file)
# Extract and save 2D slices
input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
extract_middle_slice(input_path, input_slice_path)
extract_middle_slice(new_output_file, output_slice_path)
# Return paths for the Gradio interface
return new_output_file, input_slice_path, output_slice_path
else:
return "Error: Output file not found."
except subprocess.CalledProcessError as e:
return f"Error: {e}"
# Gradio Interfaceinterface = gr.Interface(
interface = gr.Interface(
fn=run_nnunet_predict,
inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
outputs=[
gr.File(label="Download Segmentation Mask"),
gr.Image(label="Input Middle Slice"),
gr.Image(label="Output Middle Slice")
],
title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
)
# Debugging GPU environment
if torch.cuda.is_available():
print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
print("No GPU available. Falling back to CPU.")
os.system("nvidia-smi") # Check if NVIDIA tools are available
# Download model files before launching the app
download_model()
# Launch the app
if __name__ == "__main__":
interface.launch(share=True)