File size: 2,921 Bytes
7801205
 
 
 
f01441d
c02fad1
18f25b9
7801205
2cac7f2
 
 
 
7801205
 
 
2cac7f2
 
 
60723f6
2cac7f2
 
 
e388d15
 
 
 
2cac7f2
 
7801205
 
 
 
 
 
7a2ca4b
4d847b9
7a2ca4b
 
 
 
7801205
 
ed9fa70
7801205
2cac7f2
7801205
 
 
 
 
 
c2107b1
7801205
c2107b1
 
7801205
 
 
 
 
 
 
 
4d847b9
7801205
 
 
 
 
2cac7f2
 
7801205
 
18f25b9
 
 
 
 
 
 
 
2cac7f2
 
 
7801205
 
db7f58a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import subprocess
import os
import shutil
import os
from huggingface_hub import hf_hub_download
import torch


# Define paths
MODEL_DIR = "./model"  # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML")  # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"

# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model"  # Replace with your actual model repository ID

# Function to download the Dataset004_WML folder
def download_model():
    if not os.path.exists(DATASET_DIR):
        os.makedirs(DATASET_DIR, exist_ok=True)
        print("Downloading Dataset004_WML.zip...")
        zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
        subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
        print("Dataset004_WML downloaded and extracted.")

# Function to run nnUNet inference
def run_nnunet_predict(nifti_file):
    # Prepare directories
    os.makedirs(INPUT_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Save the uploaded file to the input directory
    input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
    os.rename(nifti_file.name, input_path)  # Move the uploaded file to the expected input location
    
    # Debugging: List files in the /tmp/input directory
    print("Files in /tmp/input:")
    print(os.listdir(INPUT_DIR))

    # Set environment variables for nnUNet
    os.environ["nnUNet_results"] = MODEL_DIR

    # Construct and run the nnUNetv2_predict command
    command = [
        "nnUNetv2_predict",
        "-i", INPUT_DIR,
        "-o", OUTPUT_DIR,
        "-d", "004",                  # Dataset ID
        "-c", "3d_fullres",           # Configuration
        "-tr", "nnUNetTrainer_8000epochs",			
    ]
    print("Files in /tmp/output:")
    print(os.listdir(OUTPUT_DIR))
    try:
        subprocess.run(command, check=True)
        # Get the output file
        output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
        return output_file
    except subprocess.CalledProcessError as e:
        return f"Error: {e}"


# Gradio Interface
interface = gr.Interface(
    fn=run_nnunet_predict,
    inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
    outputs=gr.File(label="Download Segmentation Mask"),
    title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
    description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
)

# Force GPU initialization
if torch.cuda.is_available():
    print("CUDA is available. Initializing GPU...")
    device = torch.device("cuda:0")
    torch.tensor([1.0]).to(device)  # Trigger GPU initialization
else:
    print("No GPU available. Falling back to CPU.")

# Download model files before launching the app
download_model()

# Launch the app
if __name__ == "__main__":
    interface.launch()