#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ @author: Francesco La Rosa """ import sys import pandas as pd import torch from torch.utils.data import DataLoader from monai.transforms import Compose, LoadImaged, ScaleIntensityd, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd from monai.data import CacheDataset import numpy as np import os import torchio import torch.nn as nn import matplotlib.pyplot as plt from nnunet_mednext import create_mednext_v1, create_mednext_encoder_v1 class MedNeXtEncReg(nn.Module): def __init__(self, *args, **kwargs): super(MedNeXtEncReg, self).__init__() self.mednextv1 = create_mednext_encoder_v1(num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True) self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.regression_fc = nn.Sequential( nn.Linear(512, 64), nn.ReLU(), nn.Dropout(0.0), nn.Linear(64, 1) ) def forward(self, x): mednext_out = self.mednextv1(x) x = mednext_out x = self.global_avg_pool(x) x = torch.flatten(x, start_dim=1) age_estimate = self.regression_fc(x) return age_estimate.squeeze() def prepare_transforms(): x, y, z = (160, 192, 160) p = 1.0 monai_transforms = [ LoadImaged(keys=["image"], ensure_channel_first=True), Spacingd(keys=["image"], pixdim=(p, p, p)), CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"), SpatialPadd(keys=["image"], spatial_size=(x, y, z)), CenterSpatialCropd(keys=["image"], roi_size=(x, y, z)) ] val_torchio_transforms = torchio.transforms.Compose( [torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"], include=['image'])] ) return Compose(monai_transforms + [val_torchio_transforms]) def load_data(csv_file): df = pd.read_csv(csv_file) df.dropna(subset=['Path'], inplace=True) df.dropna(subset=['Age'], inplace=True) data_dicts = [{'image': row['Path'], 'label': row['Age']} for index, row in df.iterrows()] return df, data_dicts def create_dataloader(data_dicts, transforms): dataset = CacheDataset(data=data_dicts, transform=transforms, cache_rate=0.2, num_workers=4) dataloader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False, pin_memory=torch.cuda.is_available()) return dataloader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def initialize_model(): torch.cuda.empty_cache() return MedNeXtEncReg().to(device) def run_predictions(model_path, dataloader): model = initialize_model() model.load_state_dict(torch.load(model_path)) model.eval() predictions = [] with torch.no_grad(): for batch_data in dataloader: images = batch_data['image'].to(device) pred = model(images) predictions.append(pred.cpu().numpy()) del model torch.cuda.empty_cache() return np.array(predictions) def main(csv_file): df, data_dicts = load_data(csv_file) transforms = prepare_transforms() dataloader = create_dataloader(data_dicts, transforms) model_paths = [ os.path.join(os.path.dirname(__file__), f'BrainAge_{i}.pth') for i in range(1, 6) ] predictions_list = [run_predictions(model_path, dataloader) for model_path in model_paths] average_predictions = np.median(np.stack(predictions_list), axis=0) CA = df['Age'].values BA = average_predictions.flatten() BA_corr = np.where(CA > 18, BA + (CA * 0.062) - 2.96, BA) BAD_corr = BA_corr - CA df['Predicted_Brain_Age'] = BA_corr df['Brain_Age_Difference'] = BAD_corr df.to_csv(csv_file.replace('.csv', '_with_predictions.csv'), index=False) print('Updated CSV file saved.') if __name__ == '__main__': if len(sys.argv) != 2: print("Error: No .csv file provided.") print("Usage: python script.py ") print("Please provide the path to a .csv file as the argument. This file should contain columns for 'Path' and 'Age' for all subjects.") sys.exit(1) csv_file = sys.argv[1] main(csv_file)