File size: 2,588 Bytes
7801205
 
 
 
f01441d
c02fad1
7801205
2cac7f2
 
 
 
7801205
 
 
2cac7f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7801205
 
 
 
 
 
f01441d
2cac7f2
 
7801205
 
f01441d
7801205
2cac7f2
7801205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cac7f2
 
7801205
 
2cac7f2
 
 
7801205
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import subprocess
import os
import shutil
import os
from huggingface_hub import hf_hub_download


# Define paths
MODEL_DIR = "./model"  # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML")  # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"

# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model"  # Replace with your actual model repository ID

# Function to download the model files from Hugging Face Model Hub
def download_model():
    if not os.path.exists(DATASET_DIR):
        os.makedirs(DATASET_DIR, exist_ok=True)
        print("Downloading Dataset004_WML...")
        hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
        # Unzip the dataset into the correct location
        subprocess.run(["unzip", "-o", os.path.join(MODEL_DIR, "Dataset004_WML.zip"), "-d", DATASET_DIR])
        os.remove(os.path.join(MODEL_DIR, "Dataset004_WML.zip"))
        print("Dataset004_WML downloaded and extracted.")

# Function to run nnUNet inference
def run_nnunet_predict(nifti_file):
    # Prepare directories
    os.makedirs(INPUT_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Save the uploaded file to the input directory
    input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
    with open(input_path, "wb") as f:
        f.write(nifti_file.read())

    # Set environment variables for nnUNet
    os.environ["nnUNet_results"] = MODEL_DIR

    # Construct and run the nnUNetv2_predict command
    command = [
        "nnUNetv2_predict",
        "-i", INPUT_DIR,
        "-o", OUTPUT_DIR,
        "-d", "004",                  # Dataset ID
        "-c", "3d_fullres",           # Configuration
        "-tr", "nnUNetTrainer_8000epochs",  # Trainer name
    ]
    try:
        subprocess.run(command, check=True)
        # Get the output file
        output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
        return output_file
    except subprocess.CalledProcessError as e:
        return f"Error: {e}"

# Gradio Interface
interface = gr.Interface(
    fn=run_nnunet_predict,
    inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
    outputs=gr.File(label="Download Segmentation Mask"),
    title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
    description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
)

# Download model files before launching the app
download_model()

# Launch the app
if __name__ == "__main__":
    interface.launch()