Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import subprocess | |
import os | |
import shutil | |
from huggingface_hub import hf_hub_download | |
import torch | |
import nibabel as nib | |
import matplotlib.pyplot as plt | |
import spaces # Import spaces for GPU decoration | |
import numpy as np | |
from scipy.ndimage import center_of_mass | |
# Define paths | |
MODEL_DIR = "./model" # Local directory to store the downloaded model | |
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML | |
INPUT_DIR = "/tmp/input" | |
OUTPUT_DIR = "/tmp/output" | |
# Hugging Face Model Repository | |
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID | |
# Function to download the Dataset004_WML folder | |
def download_model(): | |
if not os.path.exists(DATASET_DIR): | |
os.makedirs(DATASET_DIR, exist_ok=True) | |
print("Downloading Dataset004_WML.zip...") | |
zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR) | |
subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR]) | |
print("Dataset004_WML downloaded and extracted.") | |
def extract_middle_slices(nifti_path, output_image_path, slice_size=180): | |
""" | |
Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image. | |
The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG. | |
""" | |
# Load NIfTI image and get the data | |
img = nib.load(nifti_path) | |
data = img.get_fdata() | |
# Compute the center of mass of non-zero voxels | |
com = center_of_mass(data > 0) | |
center = np.round(com).astype(int) | |
# Define half the slice size to extract regions around the center of mass | |
half_size = slice_size // 2 | |
# Safely extract 2D slices | |
def extract_2d_slice(data, center, axis): | |
slices = [slice(None)] * 3 | |
slices[axis] = center[axis] | |
extracted_slice = data[tuple(slices)] | |
# Crop around the center for the other two dimensions | |
other_axes = [i for i in range(3) if i != axis] | |
for i in other_axes: | |
start = max(center[i] - half_size, 0) | |
end = min(center[i] + half_size, data.shape[i]) | |
extracted_slice = np.take(extracted_slice, range(start, end), axis=i) | |
return extracted_slice | |
axial_slice = extract_2d_slice(data, center, axis=2) # Axial (z-axis) | |
coronal_slice = extract_2d_slice(data, center, axis=1) # Coronal (y-axis) | |
sagittal_slice = extract_2d_slice(data, center, axis=0) # Sagittal (x-axis) | |
# Create subplots | |
fig, axes = plt.subplots(1, 3, figsize=(12, 4)) | |
# Plot each slice | |
axes[0].imshow(axial_slice, cmap="gray", origin="lower") | |
axes[0].axis("off") | |
axes[0].set_title("Axial") | |
axes[1].imshow(coronal_slice, cmap="gray", origin="lower") | |
axes[1].axis("off") | |
axes[1].set_title("Coronal") | |
axes[2].imshow(sagittal_slice, cmap="gray", origin="lower") | |
axes[2].axis("off") | |
axes[2].set_title("Sagittal") | |
# Save the figure | |
plt.tight_layout() | |
plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0) | |
plt.close() | |
# Function to run nnUNet inference | |
# Decorate the function to allocate GPU for its execution | |
def run_nnunet_predict(nifti_file): | |
# Prepare directories | |
os.makedirs(INPUT_DIR, exist_ok=True) | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# Extract the original filename without the extension | |
original_filename = os.path.basename(nifti_file.name) | |
base_filename = original_filename.replace(".nii.gz", "") | |
# Save the uploaded file to the input directory | |
input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz") | |
os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location | |
# Debugging: List files in the /tmp/input directory | |
print("Files in /tmp/input:") | |
print(os.listdir(INPUT_DIR)) | |
# Set environment variables for nnUNet | |
os.environ["nnUNet_results"] = MODEL_DIR | |
# Construct and run the nnUNetv2_predict command | |
command = [ | |
"nnUNetv2_predict", | |
"-i", INPUT_DIR, | |
"-o", OUTPUT_DIR, | |
"-d", "004", # Dataset ID | |
"-c", "3d_fullres", # Configuration | |
"-tr", "nnUNetTrainer_8000epochs", | |
"-device", "cuda" # Explicitly use GPU | |
] | |
print("Files in /tmp/output:") | |
print(os.listdir(OUTPUT_DIR)) | |
try: | |
subprocess.run(command, check=True) | |
# Rename the output file to match the original input filename | |
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz") | |
new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz") | |
if os.path.exists(output_file): | |
os.rename(output_file, new_output_file) | |
# Extract and save 2D slices | |
input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png") | |
output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png") | |
extract_middle_slices(input_path, input_slice_path) | |
extract_middle_slices(new_output_file, output_slice_path) | |
# Return paths for the Gradio interface | |
return new_output_file, input_slice_path, output_slice_path | |
else: | |
return "Error: Output file not found." | |
except subprocess.CalledProcessError as e: | |
return f"Error: {e}" | |
# Gradio Interfaceinterface = gr.Interface( | |
interface = gr.Interface( | |
fn=run_nnunet_predict, | |
inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"), | |
outputs=[ | |
gr.File(label="Download Segmentation Mask"), | |
gr.Image(label="Input Middle Slice"), | |
gr.Image(label="Output Middle Slice") | |
], | |
title="FLAMeS: Multiple Sclerosis Lesion Segmentation", | |
description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions." | |
) | |
# Debugging GPU environment | |
if torch.cuda.is_available(): | |
print(f"GPU is available: {torch.cuda.get_device_name(0)}") | |
else: | |
print("No GPU available. Falling back to CPU.") | |
os.system("nvidia-smi") # Check if NVIDIA tools are available | |
# Download model files before launching the app | |
download_model() | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch(share=True) | |