import gradio as gr import subprocess import os import shutil import os from huggingface_hub import hf_hub_download # Define paths MODEL_DIR = "./model" # Local directory to store the downloaded model DATASET_DIR = os.path.join(MODEL_DIR, "Dataset004_WML") # Directory for Dataset004_WML INPUT_DIR = "/tmp/input" OUTPUT_DIR = "/tmp/output" # Hugging Face Model Repository REPO_ID = "FrancescoLR/FLAMeS-model" # Replace with your actual model repository ID # Function to download the Dataset004_WML folder def download_model(): if not os.path.exists(DATASET_DIR): os.makedirs(DATASET_DIR, exist_ok=True) print("Downloading Dataset004_WML.zip...") zip_path = hf_hub_download(repo_id=REPO_ID, filename="Dataset004_WML.zip", cache_dir=MODEL_DIR) subprocess.run(["unzip", "-o", zip_path, "-d", MODEL_DIR]) print("Dataset004_WML downloaded and extracted.") # Function to run nnUNet inference def run_nnunet_predict(nifti_file): # Prepare directories os.makedirs(INPUT_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # Save the uploaded file to the input directory input_path = os.path.join(INPUT_DIR, "image.nii.gz") os.rename(nifti_file.name, input_path) # Move the uploaded file to the expected input location # Set environment variables for nnUNet os.environ["nnUNet_raw"] = DATASET_DIR os.environ["nnUNet_preprocessed"] = DATASET_DIR os.environ["nnUNet_results"] = DATASET_DIR # Construct and run the nnUNetv2_predict command command = [ "nnUNetv2_predict", "-i", INPUT_DIR, "-o", OUTPUT_DIR, "-d", "004", # Dataset ID "-c", "3d_fullres", # Configuration "-tr", "nnUNetTrainer_8000epochs", # Trainer name ] try: subprocess.run(command, check=True) # Get the output file output_file = os.path.join(OUTPUT_DIR, "image.nii.gz") return output_file except subprocess.CalledProcessError as e: return f"Error: {e}" # Gradio Interface interface = gr.Interface( fn=run_nnunet_predict, inputs=gr.File(label="Upload FLAIR Image (.nii.gz)"), outputs=gr.File(label="Download Segmentation Mask"), title="FLAMeS: Multiple Sclerosis Lesion Segmentation", description="Upload a skull-stripped FLAIR image (.nii.gz) to generate a binary segmentation of MS lesions." ) # Download model files before launching the app download_model() # Launch the app if __name__ == "__main__": interface.launch(share=True)