|
import torch |
|
import torch.nn as nn |
|
import sys |
|
from icecream import ic |
|
|
|
from .net_resnet3d import r3d_18 |
|
|
|
|
|
|
|
class ResNetModel(nn.Module): |
|
''' ... ''' |
|
def __init__( |
|
self, |
|
tgt_modalities, |
|
mri_feature = 'img_MRI_T1', |
|
): |
|
''' ... ''' |
|
super().__init__() |
|
|
|
self.mri_feature = mri_feature |
|
|
|
self.img_net_ = r3d_18() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.modules_cls = nn.ModuleDict() |
|
for k, info in tgt_modalities.items(): |
|
if info['type'] == 'categorical' and info['num_categories'] == 2: |
|
|
|
self.modules_cls[k] = nn.Linear(64, 1) |
|
|
|
else: |
|
|
|
raise ValueError |
|
|
|
def forward(self, x): |
|
''' ... ''' |
|
tgt_iter = self.modules_cls.keys() |
|
|
|
img_x_batch = x[self.mri_feature] |
|
img_out = self.img_net_(img_x_batch) |
|
|
|
|
|
|
|
|
|
out = [self.modules_cls[k](img_out).squeeze(1) for i, k in enumerate(tgt_iter)] |
|
out = torch.stack(out, dim=1) |
|
|
|
|
|
|
|
|
|
out = {k: out[:, i] for i, k in enumerate(tgt_iter)} |
|
|
|
return out |
|
|
|
|
|
if __name__ == '__main__': |
|
''' for testing purpose only ''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|