Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,252 Bytes
7801205 c02fad1 18f25b9 2fa3177 5dd0ea6 d21c058 0b88f39 2cac7f2 7801205 2cac7f2 60723f6 2cac7f2 e388d15 2cac7f2 a74b6ce 2fa3177 a74b6ce 2fa3177 0496106 2fa3177 79fd983 a74b6ce 79fd983 a74b6ce 44da582 0496106 a74b6ce 0496106 44da582 0496106 79fd983 a74b6ce 79fd983 a74b6ce 79fd983 a74b6ce 79fd983 a74b6ce 79fd983 a74b6ce 79fd983 a74b6ce 79fd983 a74b6ce 79fd983 2fa3177 1d7bf30 2cac7f2 5dd0ea6 7801205 6c748cb 9c5c250 6c748cb 7801205 7a2ca4b 4d847b9 7a2ca4b 7801205 ed9fa70 7801205 2cac7f2 7801205 5dd0ea6 6c748cb 7801205 c5b67f9 7801205 6c748cb 7801205 6c748cb 5dd0ea6 6c748cb 2fa3177 d9fd063 2fa3177 c5b67f9 9c5c250 5dd0ea6 7801205 c5b67f9 6670942 7801205 2fa3177 1d7bf30 2fa3177 1d7bf30 2fa3177 2cac7f2 7801205 812aa8d 18f25b9 812aa8d 18f25b9 812aa8d 18f25b9 2cac7f2 7801205 94ba9ed 5dd0ea6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import subprocess
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import nibabel as nib
import matplotlib.pyplot as plt
import spaces # Import spaces for GPU decoration
import numpy as np
from scipy.ndimage import center_of_mass
# Define paths
MODEL_DIR = "./model" # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"
# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID
# Function to download the Dataset004_WML folder
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading Dataset004_WML.zip...")
zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
print("Dataset004_WML downloaded and extracted.")
def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
"""
Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
"""
# Load NIfTI image and get the data
img = nib.load(nifti_path)
data = img.get_fdata()
# Compute the center of mass of non-zero voxels
com = center_of_mass(data > 0)
center = np.round(com).astype(int)
# Define half the slice size to extract regions around the center of mass
half_size = slice_size // 2
# Safely extract 2D slices
def extract_2d_slice(data, center, axis):
slices = [slice(None)] * 3
slices[axis] = center[axis]
extracted_slice = data[tuple(slices)]
# Crop around the center for the other two dimensions
other_axes = [i for i in range(3) if i != axis]
for i in other_axes:
start = max(center[i] - half_size, 0)
end = min(center[i] + half_size, data.shape[i])
extracted_slice = np.take(extracted_slice, range(start, end), axis=i)
return extracted_slice
axial_slice = extract_2d_slice(data, center, axis=2) # Axial (z-axis)
coronal_slice = extract_2d_slice(data, center, axis=1) # Coronal (y-axis)
sagittal_slice = extract_2d_slice(data, center, axis=0) # Sagittal (x-axis)
# Create subplots
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
# Plot each slice
axes[0].imshow(axial_slice, cmap="gray", origin="lower")
axes[0].axis("off")
axes[0].set_title("Axial")
axes[1].imshow(coronal_slice, cmap="gray", origin="lower")
axes[1].axis("off")
axes[1].set_title("Coronal")
axes[2].imshow(sagittal_slice, cmap="gray", origin="lower")
axes[2].axis("off")
axes[2].set_title("Sagittal")
# Save the figure
plt.tight_layout()
plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
plt.close()
# Function to run nnUNet inference
@spaces.GPU # Decorate the function to allocate GPU for its execution
def run_nnunet_predict(nifti_file):
# Prepare directories
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Extract the original filename without the extension
original_filename = os.path.basename(nifti_file.name)
base_filename = original_filename.replace(".nii.gz", "")
# Save the uploaded file to the input directory
input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location
# Debugging: List files in the /tmp/input directory
print("Files in /tmp/input:")
print(os.listdir(INPUT_DIR))
# Set environment variables for nnUNet
os.environ["nnUNet_results"] = MODEL_DIR
# Construct and run the nnUNetv2_predict command
command = [
"nnUNetv2_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-d", "004", # Dataset ID
"-c", "3d_fullres", # Configuration
"-tr", "nnUNetTrainer_8000epochs",
"-device", "cuda" # Explicitly use GPU
]
print("Files in /tmp/output:")
print(os.listdir(OUTPUT_DIR))
try:
subprocess.run(command, check=True)
# Rename the output file to match the original input filename
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
if os.path.exists(output_file):
os.rename(output_file, new_output_file)
# Extract and save 2D slices
input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
extract_middle_slices(input_path, input_slice_path)
extract_middle_slices(new_output_file, output_slice_path)
# Return paths for the Gradio interface
return new_output_file, input_slice_path, output_slice_path
else:
return "Error: Output file not found."
except subprocess.CalledProcessError as e:
return f"Error: {e}"
# Gradio Interfaceinterface = gr.Interface(
interface = gr.Interface(
fn=run_nnunet_predict,
inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
outputs=[
gr.File(label="Download Segmentation Mask"),
gr.Image(label="Input Middle Slice"),
gr.Image(label="Output Middle Slice")
],
title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
)
# Debugging GPU environment
if torch.cuda.is_available():
print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
print("No GPU available. Falling back to CPU.")
os.system("nvidia-smi") # Check if NVIDIA tools are available
# Download model files before launching the app
download_model()
# Launch the app
if __name__ == "__main__":
interface.launch(share=True)
|