Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,107 Bytes
7801205 c02fad1 18f25b9 5dd0ea6 2cac7f2 7801205 2cac7f2 60723f6 2cac7f2 e388d15 2cac7f2 5dd0ea6 7801205 7a2ca4b 4d847b9 7a2ca4b 7801205 ed9fa70 7801205 2cac7f2 7801205 5dd0ea6 2bce9bd 7801205 5dd0ea6 7801205 2cac7f2 7801205 812aa8d 18f25b9 812aa8d 18f25b9 812aa8d 18f25b9 2cac7f2 7801205 94ba9ed 5dd0ea6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
import subprocess
import os
import shutil
from huggingface_hub import hf_hub_download
import torch
import spaces # Import spaces for GPU decoration
# Define paths
MODEL_DIR = "./model" # Local directory to store the downloaded model
DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML
INPUT_DIR = "/tmp/input"
OUTPUT_DIR = "/tmp/output"
# Hugging Face Model Repository
REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID
# Function to download the Dataset004_WML folder
def download_model():
if not os.path.exists(DATASET_DIR):
os.makedirs(DATASET_DIR, exist_ok=True)
print("Downloading Dataset004_WML.zip...")
zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR)
subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR])
print("Dataset004_WML downloaded and extracted.")
# Function to run nnUNet inference
@spaces.GPU # Decorate the function to allocate GPU for its execution
def run_nnunet_predict(nifti_file):
# Prepare directories
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Save the uploaded file to the input directory
input_path = os.path.join(INPUT_DIR, "image_0000.nii.gz")
os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location
# Debugging: List files in the /tmp/input directory
print("Files in /tmp/input:")
print(os.listdir(INPUT_DIR))
# Set environment variables for nnUNet
os.environ["nnUNet_results"] = MODEL_DIR
# Construct and run the nnUNetv2_predict command
command = [
"nnUNetv2_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-d", "004", # Dataset ID
"-c", "3d_fullres", # Configuration
"-tr", "nnUNetTrainer_8000epochs",
"-device", "cuda" # Explicitly use GPU 0
]
try:
subprocess.run(command, check=True)
# Get the output file
output_file = os.path.join(OUTPUT_DIR, "image.nii.gz")
if os.path.exists(output_file):
return output_file
else:
return "Error: Output file not found."
except subprocess.CalledProcessError as e:
return f"Error: {e}"
# Gradio Interface
interface = gr.Interface(
fn=run_nnunet_predict,
inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"),
outputs=gr.File(label="Download Segmentation Mask"),
title="FLAMeS: Multiple Sclerosis Lesion Segmentation",
description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions."
)
# Debugging GPU environment
if torch.cuda.is_available():
print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
print("No GPU available. Falling back to CPU.")
os.system("nvidia-smi") # Check if NVIDIA tools are available
# Download model files before launching the app
download_model()
# Launch the app
if __name__ == "__main__":
interface.launch(share=True)
|