File size: 2,458 Bytes
6fc43ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
from typing import Any, Type
Tensor = Type[torch.Tensor]

from .resnet3d import r3d_18


class CNNResNet3D(nn.Module):

    def __init__(self,
        src_modalities: dict[str, dict[str, Any]],
        tgt_modalities: dict[str, dict[str, Any]]
    ) -> None:
        """ ... """
        super().__init__()

        # resnet
        # embedding modules for source
        self.modules_emb_src = nn.ModuleDict()
        for k, info in src_modalities.items():
            if info['type'] == 'imaging' and len(info['img_shape']) == 4:
                self.modules_emb_src[k] = nn.Sequential(
                    r3d_18(),
                    nn.Dropout(0.5)
                )
            else:
                # unrecognized
                raise ValueError('{} is an unrecognized data modality'.format(k))

        # classifiers (binary only)
        self.modules_cls = nn.ModuleDict()
        for k, info in tgt_modalities.items():
            if info['type'] == 'categorical' and info['num_categories'] == 2:
                # categorical
                self.modules_cls[k] = nn.Linear(256, 1)
            else:
                # unrecognized
                raise ValueError
            
    def forward(self,
        x: dict[str, Tensor],
    ) -> dict[str, Tensor]:
        """ ... """
        out_emb = self.forward_emb(x)
        out_emb = out_emb[list(out_emb.keys())[0]]
        out_cls = self.forward_cls(out_emb)
        return out_cls

    def forward_emb(self,
        x: dict[str, Tensor],
    ) -> dict[str, Tensor]:
        """ ... """
        out_emb = dict()
        for k in self.modules_emb_src.keys():
            out_emb[k] = self.modules_emb_src[k](x[k])
        return out_emb
    
    def forward_cls(self,
        out_emb: dict[str, Tensor]
    ) -> dict[str, Tensor]:
        """ ... """
        out_cls = dict()
        for k in self.modules_cls.keys():
            out_cls[k] = self.modules_cls[k](out_emb).squeeze(1)
        return out_cls
    

# for testing purpose only
if __name__ == '__main__':
    src_modalities = {
        'img_MRI_T1': {'type': 'imaging', 'img_shape': [1, 182, 218, 182]}
    }
    tgt_modalities = {
        'AD': {'type': 'categorical', 'num_categories': 2},
        'PD': {'type': 'categorical', 'num_categories': 2}
    }
    net = CNNResNet3D(src_modalities, tgt_modalities)
    net.eval()
    x = {'img_MRI_T1': torch.zeros(2, 1, 182, 218, 182)}
    print(net(x))