File size: 4,241 Bytes
6760104 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: Francesco La Rosa
"""
import sys
import pandas as pd
import torch
from torch.utils.data import DataLoader
from monai.transforms import Compose, LoadImaged, ScaleIntensityd, Spacingd, CropForegroundd, SpatialPadd, CenterSpatialCropd
from monai.data import CacheDataset
import numpy as np
import os
import torchio
import torch.nn as nn
import matplotlib.pyplot as plt
from nnunet_mednext import create_mednext_v1, create_mednext_encoder_v1
class MedNeXtEncReg(nn.Module):
def __init__(self, *args, **kwargs):
super(MedNeXtEncReg, self).__init__()
self.mednextv1 = create_mednext_encoder_v1(num_input_channels=1, num_classes=1, model_id='B', kernel_size=3, deep_supervision=True)
self.global_avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.regression_fc = nn.Sequential(
nn.Linear(512, 64),
nn.ReLU(),
nn.Dropout(0.0),
nn.Linear(64, 1)
)
def forward(self, x):
mednext_out = self.mednextv1(x)
x = mednext_out
x = self.global_avg_pool(x)
x = torch.flatten(x, start_dim=1)
age_estimate = self.regression_fc(x)
return age_estimate.squeeze()
def prepare_transforms():
x, y, z = (160, 192, 160)
p = 1.0
monai_transforms = [
LoadImaged(keys=["image"], ensure_channel_first=True),
Spacingd(keys=["image"], pixdim=(p, p, p)),
CropForegroundd(keys=["image"], allow_smaller=True, source_key="image"),
SpatialPadd(keys=["image"], spatial_size=(x, y, z)),
CenterSpatialCropd(keys=["image"], roi_size=(x, y, z))
]
val_torchio_transforms = torchio.transforms.Compose(
[torchio.transforms.ZNormalization(masking_method=lambda x: x > 0, keys=["image"], include=['image'])]
)
return Compose(monai_transforms + [val_torchio_transforms])
def load_data(csv_file):
df = pd.read_csv(csv_file)
df.dropna(subset=['Path'], inplace=True)
df.dropna(subset=['Age'], inplace=True)
data_dicts = [{'image': row['Path'], 'label': row['Age']} for index, row in df.iterrows()]
return df, data_dicts
def create_dataloader(data_dicts, transforms):
dataset = CacheDataset(data=data_dicts, transform=transforms, cache_rate=0.2, num_workers=4)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False, pin_memory=torch.cuda.is_available())
return dataloader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def initialize_model():
torch.cuda.empty_cache()
return MedNeXtEncReg().to(device)
def run_predictions(model_path, dataloader):
model = initialize_model()
model.load_state_dict(torch.load(model_path))
model.eval()
predictions = []
with torch.no_grad():
for batch_data in dataloader:
images = batch_data['image'].to(device)
pred = model(images)
predictions.append(pred.cpu().numpy())
del model
torch.cuda.empty_cache()
return np.array(predictions)
def main(csv_file):
df, data_dicts = load_data(csv_file)
transforms = prepare_transforms()
dataloader = create_dataloader(data_dicts, transforms)
model_paths = [
os.path.join(os.path.dirname(__file__), f'BrainAge_{i}.pth') for i in range(1, 6)
]
predictions_list = [run_predictions(model_path, dataloader) for model_path in model_paths]
average_predictions = np.median(np.stack(predictions_list), axis=0)
CA = df['Age'].values
BA = average_predictions.flatten()
BA_corr = np.where(CA > 18, BA + (CA * 0.062) - 2.96, BA)
BAD_corr = BA_corr - CA
df['Predicted_Brain_Age'] = BA_corr
df['Brain_Age_Difference'] = BAD_corr
df.to_csv(csv_file.replace('.csv', '_with_predictions.csv'), index=False)
print('Updated CSV file saved.')
if __name__ == '__main__':
if len(sys.argv) != 2:
print("Error: No .csv file provided.")
print("Usage: python script.py <path_to_csv_file>")
print("Please provide the path to a .csv file as the argument. This file should contain columns for 'Path' and 'Age' for all subjects.")
sys.exit(1)
csv_file = sys.argv[1]
main(csv_file)
|