Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,588 Bytes
7801205 f01441d c02fad1 7801205 2cac7f2 7801205 2cac7f2 7801205 f01441d 2cac7f2 7801205 f01441d 7801205 2cac7f2 7801205 2cac7f2 7801205 2cac7f2 7801205 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import gradio as gr
import subprocess
import os
import shutil
import os
from huggingface_hub import hf_hub_download
# Define paths
MODEL_DIR = "./model" # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"
# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID
# Function to download the model files from Hugging Face Model Hub
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading Dataset004_WML...")
hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
# Unzip the dataset into the correct location
subprocess.run(["unzip", "-o", os.path.join(MODEL_DIR, "Dataset004_WML.zip"), "-d", DATASET_DIR])
os.remove(os.path.join(MODEL_DIR, "Dataset004_WML.zip"))
print("Dataset004_WML downloaded and extracted.")
# Function to run nnUNet inference
def run_nnunet_predict(nifti_file):
# Prepare directories
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Save the uploaded file to the input directory
input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
with open(input_path, "wb") as f:
f.write(nifti_file.read())
# Set environment variables for nnUNet
os.environ["nnUNet_results"] = MODEL_DIR
# Construct and run the nnUNetv2_predict command
command = [
"nnUNetv2_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-d", "004", # Dataset ID
"-c", "3d_fullres", # Configuration
"-tr", "nnUNetTrainer_8000epochs", # Trainer name
]
try:
subprocess.run(command, check=True)
# Get the output file
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
return output_file
except subprocess.CalledProcessError as e:
return f"Error: {e}"
# Gradio Interface
interface = gr.Interface(
fn=run_nnunet_predict,
inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
outputs=gr.File(label="Download Segmentation Mask"),
title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
)
# Download model files before launching the app
download_model()
# Launch the app
if __name__ == "__main__":
interface.launch()
|