nmed2024 / adrd /nn /c3d.py
xf3227's picture
ok
6fc43ab
# From https://github.com/xmuyzz/3D-CNN-PyTorch/blob/master/models/C3DNet.py
import torch
import torch.nn as nn
import sys
# from icecream import ic
import math
class C3D(torch.nn.Module):
def __init__(self, tgt_modalities, in_channels=1, load_from_ckpt=None):
super(C3D, self).__init__()
self.conv_group1 = nn.Sequential(
nn.Conv3d(in_channels, 64, kernel_size=3, padding=1),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2)))
self.conv_group2 = nn.Sequential(
nn.Conv3d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm3d(128),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
self.conv_group3 = nn.Sequential(
nn.Conv3d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm3d(256),
nn.ReLU(),
nn.Conv3d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm3d(256),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
)
self.conv_group4 = nn.Sequential(
nn.Conv3d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm3d(512),
nn.ReLU(),
nn.Conv3d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm3d(512),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1))
)
# last_duration = int(math.floor(128 / 16))
# last_size = int(math.ceil(128 / 32))
self.fc1 = nn.Sequential(
nn.Linear((512 * 15 * 9 * 9) , 512),
nn.ReLU(),
nn.Dropout(0.5))
self.fc2 = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5))
# self.fc = nn.Sequential(
# nn.Linear(4096, num_classes))
self.fc = torch.nn.ModuleDict()
for k in tgt_modalities:
self.fc[k] = torch.nn.Linear(256, 1)
def forward(self, x):
# for k in x.keys():
# x[k] = x[k].to(torch.float32)
# x = torch.stack([o for o in x.values()], dim=0)[0]
# print(x.shape)
out = self.conv_group1(x)
out = self.conv_group2(out)
out = self.conv_group3(out)
out = self.conv_group4(out)
out = out.view(out.size(0), -1)
# print(out.shape)
out = self.fc1(out)
out = self.fc2(out)
# out = self.fc(out)
tgt_iter = self.fc.keys()
out_tgt = {k: self.fc[k](out).squeeze(1) for k in tgt_iter}
return out_tgt
if __name__ == "__main__":
model = C3D(tgt_modalities=['NC', 'MCI', 'DE'])
print(model)
x = torch.rand((1, 1, 128, 128, 128))
# layers = list(model.features.named_children())
# features = nn.Sequential(*list(model.features.children()))(x)
# print(features.shape)
print(sum(p.numel() for p in model.parameters()))
# layer_found = False
# features = None
# desired_layer_name = 'transition3'
# for name, layer in layers:
# if name == desired_layer_name:
# x = layer(x)
# print(x)
# model(x)
# print(features)