nmed2024 / adrd /nn /resnet_img_model.py
xf3227's picture
ok
6fc43ab
raw
history blame
2.06 kB
import torch
import torch.nn as nn
import sys
from icecream import ic
# sys.path.append('/home/skowshik/ADRD_repo/adrd_tool/adrd/')
from .net_resnet3d import r3d_18
# from dev.data.dataset_csv import CSVDataset
class ResNetModel(nn.Module):
''' ... '''
def __init__(
self,
tgt_modalities,
mri_feature = 'img_MRI_T1',
):
''' ... '''
super().__init__()
self.mri_feature = mri_feature
self.img_net_ = r3d_18()
# self.modules_emb_src = nn.Sequential(
# nn.BatchNorm1d(9),
# nn.Linear(9, d_model)
# )
# classifiers (binary only)
self.modules_cls = nn.ModuleDict()
for k, info in tgt_modalities.items():
if info['type'] == 'categorical' and info['num_categories'] == 2:
# categorical
self.modules_cls[k] = nn.Linear(64, 1)
else:
# unrecognized
raise ValueError
def forward(self, x):
''' ... '''
tgt_iter = self.modules_cls.keys()
img_x_batch = x[self.mri_feature]
img_out = self.img_net_(img_x_batch)
# ic(img_out.shape)
# run linear classifiers
out = [self.modules_cls[k](img_out).squeeze(1) for i, k in enumerate(tgt_iter)]
out = torch.stack(out, dim=1)
# ic(out.shape)
# out to dict
out = {k: out[:, i] for i, k in enumerate(tgt_iter)}
return out
if __name__ == '__main__':
''' for testing purpose only '''
# import torch
# import numpy as np
# seed = 0
# print('Loading training dataset ... ')
# dat_trn = CSVDataset(mode=0, split=[1, 700], seed=seed)
# print(len(dat_trn))
# tgt_modalities = dat_trn.label_modalities
# net = ResNetModel(tgt_modalities).to('cuda')
# x = dat_trn.features
# x = {k: torch.as_tensor(np.array([x[i][k] for i in range(len(x))])).to('cuda') for k in x[0]}
# ic(x)
# # print(net(x).shape)
# print(net(x))