FLAMeS / app.py
FrancescoLR's picture
First commit
7801205
raw
history blame
1.74 kB
import gradio as gr
import subprocess
import os
import shutil
# Paths
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"
MODEL_DIR = "./model"
def run_nnunet_predict(nifti_file):
# Prepare directories
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Save the uploaded file to the input directory
input_path = os.path.join(INPUT_DIR, "image.nii.gz")
shutil.copy(nifti_file.name, input_path)
# Set environment variables for nnUNet
os.environ["nnUNet_raw"] = MODEL_DIR
os.environ["nnUNet_preprocessed"] = MODEL_DIR
os.environ["nnUNet_results"] = MODEL_DIR
# Construct the nnUNetv2_predict command
command = [
"nnUNetv2_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-d", "004", # Dataset ID
"-c", "3d_fullres", # Configuration
"-tr", "nnUNetTrainer_8000epochs", # Trainer name
]
# Run the command
try:
subprocess.run(command, check=True)
# Get the output file
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
return output_file
except subprocess.CalledProcessError as e:
return f"Error: {e}"
# Gradio Interface
interface = gr.Interface(
fn=run_nnunet_predict,
inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
outputs=gr.File(label="Download Segmentation Mask"),
title="FLAMeS: FLAIR Lesion Analysis in Multiple Sclerosis",
description=(
"Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions. "
"This model uses nnUNetv2 for inference with ensemble predictions."
),
)
# Launch the app
if __name__ == "__main__":
interface.launch()