|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import sys |
|
|
|
import math |
|
|
|
class C3D(torch.nn.Module): |
|
|
|
def __init__(self, tgt_modalities, in_channels=1, load_from_ckpt=None): |
|
|
|
super(C3D, self).__init__() |
|
self.conv_group1 = nn.Sequential( |
|
nn.Conv3d(in_channels, 64, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(64), |
|
nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2))) |
|
self.conv_group2 = nn.Sequential( |
|
nn.Conv3d(64, 128, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(128), |
|
nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))) |
|
self.conv_group3 = nn.Sequential( |
|
nn.Conv3d(128, 256, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(256), |
|
nn.ReLU(), |
|
nn.Conv3d(256, 256, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(256), |
|
nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) |
|
) |
|
self.conv_group4 = nn.Sequential( |
|
nn.Conv3d(256, 512, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(512), |
|
nn.ReLU(), |
|
nn.Conv3d(512, 512, kernel_size=3, padding=1), |
|
nn.BatchNorm3d(512), |
|
nn.ReLU(), |
|
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)) |
|
) |
|
|
|
|
|
|
|
self.fc1 = nn.Sequential( |
|
nn.Linear((512 * 15 * 9 * 9) , 512), |
|
nn.ReLU(), |
|
nn.Dropout(0.5)) |
|
self.fc2 = nn.Sequential( |
|
nn.Linear(512, 256), |
|
nn.ReLU(), |
|
nn.Dropout(0.5)) |
|
|
|
|
|
|
|
self.fc = torch.nn.ModuleDict() |
|
for k in tgt_modalities: |
|
self.fc[k] = torch.nn.Linear(256, 1) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
out = self.conv_group1(x) |
|
out = self.conv_group2(out) |
|
out = self.conv_group3(out) |
|
out = self.conv_group4(out) |
|
out = out.view(out.size(0), -1) |
|
|
|
out = self.fc1(out) |
|
out = self.fc2(out) |
|
|
|
|
|
tgt_iter = self.fc.keys() |
|
out_tgt = {k: self.fc[k](out).squeeze(1) for k in tgt_iter} |
|
return out_tgt |
|
|
|
|
|
if __name__ == "__main__": |
|
model = C3D(tgt_modalities=['NC', 'MCI', 'DE']) |
|
print(model) |
|
x = torch.rand((1, 1, 128, 128, 128)) |
|
|
|
|
|
|
|
print(sum(p.numel() for p in model.parameters())) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|