nmed2024 / adrd /nn /cnn_resnet3d.py
xf3227's picture
ok
6fc43ab
raw
history blame
2.46 kB
import torch
import torch.nn as nn
from typing import Any, Type
Tensor = Type[torch.Tensor]
from .resnet3d import r3d_18
class CNNResNet3D(nn.Module):
def __init__(self,
src_modalities: dict[str, dict[str, Any]],
tgt_modalities: dict[str, dict[str, Any]]
) -> None:
""" ... """
super().__init__()
# resnet
# embedding modules for source
self.modules_emb_src = nn.ModuleDict()
for k, info in src_modalities.items():
if info['type'] == 'imaging' and len(info['img_shape']) == 4:
self.modules_emb_src[k] = nn.Sequential(
r3d_18(),
nn.Dropout(0.5)
)
else:
# unrecognized
raise ValueError('{} is an unrecognized data modality'.format(k))
# classifiers (binary only)
self.modules_cls = nn.ModuleDict()
for k, info in tgt_modalities.items():
if info['type'] == 'categorical' and info['num_categories'] == 2:
# categorical
self.modules_cls[k] = nn.Linear(256, 1)
else:
# unrecognized
raise ValueError
def forward(self,
x: dict[str, Tensor],
) -> dict[str, Tensor]:
""" ... """
out_emb = self.forward_emb(x)
out_emb = out_emb[list(out_emb.keys())[0]]
out_cls = self.forward_cls(out_emb)
return out_cls
def forward_emb(self,
x: dict[str, Tensor],
) -> dict[str, Tensor]:
""" ... """
out_emb = dict()
for k in self.modules_emb_src.keys():
out_emb[k] = self.modules_emb_src[k](x[k])
return out_emb
def forward_cls(self,
out_emb: dict[str, Tensor]
) -> dict[str, Tensor]:
""" ... """
out_cls = dict()
for k in self.modules_cls.keys():
out_cls[k] = self.modules_cls[k](out_emb).squeeze(1)
return out_cls
# for testing purpose only
if __name__ == '__main__':
src_modalities = {
'img_MRI_T1': {'type': 'imaging', 'img_shape': [1, 182, 218, 182]}
}
tgt_modalities = {
'AD': {'type': 'categorical', 'num_categories': 2},
'PD': {'type': 'categorical', 'num_categories': 2}
}
net = CNNResNet3D(src_modalities, tgt_modalities)
net.eval()
x = {'img_MRI_T1': torch.zeros(2, 1, 182, 218, 182)}
print(net(x))