import gradio as gr import torch import numpy as np import pandas as pd import os import torchio import torch.nn as nn from huggingface_hub import hf_hub_download from monai.transforms import Compose, LoadImaged, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd from monai.data import Dataset from nnunet_mednext import create_mednext_encoder_v1 # Device selection device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Model definition class MedNeXtEncReg(nn.Module): def __init__(self): super(MedNeXtEncReg, self).__init__() self.mednextv1 = create_mednext_encoder_v1( num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True ) self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.regression_fc = nn.Sequential( nn.Linear(512, 64), nn.ReLU(), nn.Dropout(0.0), nn.Linear(64, 1) ) def forward(self, x): x = self.mednextv1(x) x = self.global_avg_pool(x) x = torch.flatten(x, start_dim=1) age_estimate = self.regression_fc(x) return age_estimate.squeeze() # Download the model from Hugging Face Hub def initialize_model(): model_paths = [ hf_hub_download(repo_id="FrancescoLR/BrainAgeNeXt", filename=f"BrainAge_{i}.pth") for i in range(1, 6) ] models = [] for model_path in model_paths: model = MedNeXtEncReg().to(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() models.append(model) return models # Define preprocessing transforms def prepare_transforms(): return Compose([ LoadImaged(keys=["image"], ensure_channel_first=True), Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0)), CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"), SpatialPadd(keys=["image"], spatial_size=(160, 192, 160)), CenterSpatialCropd(keys=["image"], roi_size=(160, 192, 160)), torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"]) ]) # Process uploaded MRI scan def preprocess_mri(mri_path): transforms = prepare_transforms() data_dict = {"image": mri_path} dataset = Dataset([data_dict], transform=transforms) dataloader = DataLoader(dataset, batch_size=1, num_workers=0) return next(iter(dataloader))["image"].to(device) # Predict brain age def predict_brain_age(mri_path, actual_age, sex): if not os.path.exists(mri_path): return "Error: MRI file not found" # Load the model models = initialize_model() # Preprocess MRI image = preprocess_mri(mri_path) # Run predictions predictions = [] with torch.no_grad(): for model in models: pred = model(image) predictions.append(pred.cpu().numpy()) # Compute median brain age prediction predicted_brain_age = np.median(np.stack(predictions)) # Apply correction based on actual age predicted_brain_age_corrected = ( predicted_brain_age + (actual_age * 0.062) - 2.96 if actual_age > 18 else predicted_brain_age ) brain_age_difference = predicted_brain_age_corrected - actual_age # Output results return f"Predicted Brain Age: {predicted_brain_age_corrected:.2f} years", \ f"Brain Age Difference (BAD): {brain_age_difference:.2f} years" # Gradio UI iface = gr.Interface( fn=predict_brain_age, inputs=[ gr.File(label="Upload MRI (NIfTI .nii.gz)"), gr.Number(label="Enter Age"), gr.Radio(["Male", "Female"], label="Select Sex") ], outputs=[ gr.Textbox(label="Predicted Brain Age"), gr.Textbox(label="Brain Age Difference (BAD)") ], title="Brain Age Prediction with MedNeXt", description="Upload an MRI scan (.nii.gz), enter your age and sex, and get a brain age prediction.", theme="default" ) # Launch the Gradio app if __name__ == "__main__": iface.launch(share=True)