File size: 6,252 Bytes
7801205
 
 
 
c02fad1
18f25b9
2fa3177
 
5dd0ea6
d21c058
0b88f39
2cac7f2
 
 
 
7801205
 
 
2cac7f2
 
 
60723f6
2cac7f2
 
 
e388d15
 
 
 
2cac7f2
a74b6ce
2fa3177
a74b6ce
 
2fa3177
0496106
2fa3177
 
79fd983
a74b6ce
 
 
79fd983
a74b6ce
 
44da582
0496106
 
a74b6ce
0496106
 
44da582
0496106
 
 
 
 
 
 
 
 
 
 
79fd983
 
a74b6ce
79fd983
 
a74b6ce
79fd983
a74b6ce
79fd983
a74b6ce
79fd983
a74b6ce
79fd983
a74b6ce
79fd983
a74b6ce
79fd983
 
 
2fa3177
 
1d7bf30
2cac7f2
5dd0ea6
7801205
 
 
 
 
6c748cb
 
9c5c250
6c748cb
7801205
7a2ca4b
4d847b9
7a2ca4b
 
 
 
7801205
 
ed9fa70
7801205
2cac7f2
7801205
 
 
 
 
 
5dd0ea6
6c748cb
7801205
c5b67f9
 
7801205
 
6c748cb
 
7801205
6c748cb
5dd0ea6
6c748cb
2fa3177
 
 
 
d9fd063
 
2fa3177
c5b67f9
9c5c250
5dd0ea6
 
7801205
 
 
c5b67f9
6670942
7801205
 
2fa3177
1d7bf30
2fa3177
1d7bf30
2fa3177
2cac7f2
 
7801205
 
812aa8d
18f25b9
812aa8d
18f25b9
 
812aa8d
18f25b9
2cac7f2
 
 
7801205
 
94ba9ed
5dd0ea6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
import subprocess
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import nibabel as nib
import matplotlib.pyplot as plt
import spaces  # Import spaces for GPU decoration
import numpy as np
from scipy.ndimage import center_of_mass

# Define paths
MODEL_DIR = "./model"  # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML")  # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"

# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model"  # Replace with your actual model repository ID

# Function to download the Dataset004_WML folder
def download_model():
    if not os.path.exists(DATASET_DIR):
        os.makedirs(DATASET_DIR, exist_ok=True)
        print("Downloading Dataset004_WML.zip...")
        zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
        subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
        print("Dataset004_WML downloaded and extracted.")

def extract_middle_slices(nifti_path, output_image_path, slice_size=180):
    """
    Extracts slices centered around the center of mass of non-zero voxels in a 3D NIfTI image.
    The slices are taken along axial, coronal, and sagittal planes and saved as a single PNG.
    """
      # Load NIfTI image and get the data
    img = nib.load(nifti_path)
    data = img.get_fdata()

    # Compute the center of mass of non-zero voxels
    com = center_of_mass(data > 0)
    center = np.round(com).astype(int)

    # Define half the slice size to extract regions around the center of mass
    half_size = slice_size // 2

    # Safely extract 2D slices
    def extract_2d_slice(data, center, axis):
        slices = [slice(None)] * 3
        slices[axis] = center[axis]
        extracted_slice = data[tuple(slices)]

        # Crop around the center for the other two dimensions
        other_axes = [i for i in range(3) if i != axis]
        for i in other_axes:
            start = max(center[i] - half_size, 0)
            end = min(center[i] + half_size, data.shape[i])
            extracted_slice = np.take(extracted_slice, range(start, end), axis=i)
        return extracted_slice

    axial_slice = extract_2d_slice(data, center, axis=2)  # Axial (z-axis)
    coronal_slice = extract_2d_slice(data, center, axis=1)  # Coronal (y-axis)
    sagittal_slice = extract_2d_slice(data, center, axis=0)  # Sagittal (x-axis)

    # Create subplots
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))

    # Plot each slice
    axes[0].imshow(axial_slice, cmap="gray", origin="lower")
    axes[0].axis("off")
    axes[0].set_title("Axial")

    axes[1].imshow(coronal_slice, cmap="gray", origin="lower")
    axes[1].axis("off")
    axes[1].set_title("Coronal")

    axes[2].imshow(sagittal_slice, cmap="gray", origin="lower")
    axes[2].axis("off")
    axes[2].set_title("Sagittal")

    # Save the figure
    plt.tight_layout()
    plt.savefig(output_image_path, bbox_inches="tight", pad_inches=0)
    plt.close()

# Function to run nnUNet inference
@spaces.GPU  # Decorate the function to allocate GPU for its execution
def run_nnunet_predict(nifti_file):
    # Prepare directories
    os.makedirs(INPUT_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Extract the original filename without the extension
    original_filename = os.path.basename(nifti_file.name)
    base_filename = original_filename.replace(".nii.gz", "")
    
    # Save the uploaded file to the input directory
    input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
    os.rename(nifti_file.name, input_path)  # Move the uploaded file to the expected input location
    
    # Debugging: List files in the /tmp/input directory
    print("Files in /tmp/input:")
    print(os.listdir(INPUT_DIR))

    # Set environment variables for nnUNet
    os.environ["nnUNet_results"] = MODEL_DIR

    # Construct and run the nnUNetv2_predict command
    command = [
        "nnUNetv2_predict",
        "-i", INPUT_DIR,
        "-o", OUTPUT_DIR,
        "-d", "004",                  # Dataset ID
        "-c", "3d_fullres",           # Configuration
        "-tr", "nnUNetTrainer_8000epochs",
        "-device", "cuda"  # Explicitly use GPU
    ]
    print("Files in /tmp/output:")
    print(os.listdir(OUTPUT_DIR))
    try:
        subprocess.run(command, check=True)

        # Rename the output file to match the original input filename
        output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
        new_output_file = os.path.join(OUTPUT_DIR, f"{base_filename}_LesionMask.nii.gz")
        if os.path.exists(output_file):
            os.rename(output_file, new_output_file)

            # Extract and save 2D slices
            input_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_input_slice.png")
            output_slice_path = os.path.join(OUTPUT_DIR, f"{base_filename}_output_slice.png")
            extract_middle_slices(input_path, input_slice_path)
            extract_middle_slices(new_output_file, output_slice_path)

            # Return paths for the Gradio interface
            return new_output_file, input_slice_path, output_slice_path
        else:
            return "Error: Output file not found."
    except subprocess.CalledProcessError as e:
        return f"Error: {e}"

# Gradio Interfaceinterface = gr.Interface(
interface = gr.Interface(
    fn=run_nnunet_predict,
    inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
    outputs=[
        gr.File(label="Download Segmentation Mask"),
        gr.Image(label="Input Middle Slice"),
        gr.Image(label="Output Middle Slice")
    ],
    title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
    description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
)

# Debugging GPU environment
if torch.cuda.is_available():
    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
    print("No GPU available. Falling back to CPU.")
    os.system("nvidia-smi")  # Check if NVIDIA tools are available

# Download model files before launching the app
download_model()

# Launch the app
if __name__ == "__main__":
    interface.launch(share=True)