Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import subprocess | |
import os | |
import shutil | |
import os | |
from huggingface_hub import hf_hub_download | |
# Define paths | |
MODEL_DIR = "./model" # Local directory to store the downloaded model | |
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML | |
INPUT_DIR = "/tmp/input" | |
OUTPUT_DIR = "/tmp/output" | |
# Hugging Face Model Repository | |
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID | |
# Function to download the Dataset004_WML folder | |
def download_model(): | |
if not os.path.exists(DATASET_DIR): | |
os.makedirs(DATASET_DIR, exist_ok=True) | |
print("Downloading Dataset004_WML.zip...") | |
zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR) | |
subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR]) | |
print("Dataset004_WML downloaded and extracted.") | |
# Function to run nnUNet inference | |
def run_nnunet_predict(nifti_file): | |
# Prepare directories | |
os.makedirs(INPUT_DIR, exist_ok=True) | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# Save the uploaded file to the input directory | |
input_path = os.path.join(INPUT_DIR, "image.nii.gz") | |
os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location | |
# Set environment variables for nnUNet | |
os.environ["nnUNet_raw"] = DATASET_DIR | |
os.environ["nnUNet_preprocessed"] = DATASET_DIR | |
os.environ["nnUNet_results"] = DATASET_DIR | |
# Construct and run the nnUNetv2_predict command | |
command = [ | |
"nnUNetv2_predict", | |
"-i", INPUT_DIR, | |
"-o", OUTPUT_DIR, | |
"-d", "004", # Dataset ID | |
"-c", "3d_fullres", # Configuration | |
"-tr", "nnUNetTrainer_8000epochs", # Trainer name | |
] | |
try: | |
subprocess.run(command, check=True) | |
# Get the output file | |
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz") | |
return output_file | |
except subprocess.CalledProcessError as e: | |
return f"Error: {e}" | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=run_nnunet_predict, | |
inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"), | |
outputs=gr.File(label="Download Segmentation Mask"), | |
title="FLAMeS: Multiple Sclerosis Lesion Segmentation", | |
description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions." | |
) | |
# Download model files before launching the app | |
download_model() | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch(share=True) | |