import torch.nn as nn import torch.nn.functional as F import torch from huggingface_hub import PyTorchModelHubMixin class conv3d(nn.Module, PyTorchModelHubMixin): def __init__(self, in_channels, out_channels): """ + Instantiate modules: conv-relu-norm + Assign them as member variables """ super(conv3d, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2) self.relu = nn.LeakyReLU(0.2) # with learnable parameters self.norm = nn.InstanceNorm3d(out_channels, affine=True) def forward(self, x): return self.relu(self.norm(self.conv(x))) class conv3d_x3(nn.Module, PyTorchModelHubMixin): """Three serial convs with a residual connection. Structure: inputs --> ① --> ② --> ③ --> outputs ↓ --> add--> ↑ """ def __init__(self, in_channels, out_channels): super(conv3d_x3, self).__init__() self.conv_1 = conv3d(in_channels, out_channels) self.conv_2 = conv3d(out_channels, out_channels) self.conv_3 = conv3d(out_channels, out_channels) self.skip_connection=nn.Conv3d(in_channels,out_channels,1) def forward(self, x): z_1 = self.conv_1(x) z_3 = self.conv_3(self.conv_2(z_1)) return z_3 + self.skip_connection(x) class conv3d_x2(nn.Module, PyTorchModelHubMixin): """Three serial convs with a residual connection. Structure: inputs --> ① --> ② --> ③ --> outputs ↓ --> add--> ↑ """ def __init__(self, in_channels, out_channels): super(conv3d_x2, self).__init__() self.conv_1 = conv3d(in_channels, out_channels) self.conv_2 = conv3d(out_channels, out_channels) self.skip_connection=nn.Conv3d(in_channels,out_channels,1) def forward(self, x): z_1 = self.conv_1(x) z_2 = self.conv_2(z_1) return z_2 + self.skip_connection(x) class conv3d_x1(nn.Module, PyTorchModelHubMixin): """Three serial convs with a residual connection. Structure: inputs --> ① --> ② --> ③ --> outputs ↓ --> add--> ↑ """ def __init__(self, in_channels, out_channels): super(conv3d_x1, self).__init__() self.conv_1 = conv3d(in_channels, out_channels) self.skip_connection=nn.Conv3d(in_channels,out_channels,1) def forward(self, x): z_1 = self.conv_1(x) return z_1 + self.skip_connection(x) class deconv3d_x3(nn.Module, PyTorchModelHubMixin): def __init__(self, in_channels, out_channels): super(deconv3d_x3, self).__init__() self.up = deconv3d_as_up(in_channels, out_channels, 2, 2) self.lhs_conv = conv3d(out_channels // 2, out_channels) self.conv_x3 = nn.Sequential( nn.Conv3d(2*out_channels, out_channels,5,1,2), nn.LeakyReLU(0.1), nn.Conv3d(out_channels, out_channels,5,1,2), nn.LeakyReLU(0.1), nn.Conv3d(out_channels, out_channels,5,1,2), nn.LeakyReLU(0.1), ) def forward(self, lhs, rhs): rhs_up = self.up(rhs) lhs_conv = self.lhs_conv(lhs) rhs_add = torch.cat((rhs_up, lhs_conv),dim=1) return self.conv_x3(rhs_add)+ rhs_up class deconv3d_x2(nn.Module, PyTorchModelHubMixin): def __init__(self, in_channels, out_channels): super(deconv3d_x2, self).__init__() self.up = deconv3d_as_up(in_channels, out_channels, 2, 2) self.lhs_conv = conv3d(out_channels // 2, out_channels) self.conv_x2= nn.Sequential( nn.Conv3d(2*out_channels, out_channels,5,1,2), nn.LeakyReLU(0.1), nn.Conv3d(out_channels, out_channels,5,1,2), nn.LeakyReLU(0.1), ) def forward(self, lhs, rhs): rhs_up = self.up(rhs) lhs_conv = self.lhs_conv(lhs) rhs_add = torch.cat((rhs_up, lhs_conv),dim=1) return self.conv_x2(rhs_add)+ rhs_up class deconv3d_x1(nn.Module, PyTorchModelHubMixin): def __init__(self, in_channels, out_channels): super(deconv3d_x1, self).__init__() self.up = deconv3d_as_up(in_channels, out_channels, 2, 2) self.lhs_conv = conv3d(out_channels // 2, out_channels) self.conv_x1 = nn.Sequential( nn.Conv3d(2*out_channels, out_channels,5,1,2), nn.LeakyReLU(0.2), ) def forward(self, lhs, rhs): rhs_up = self.up(rhs) lhs_conv = self.lhs_conv(lhs) rhs_add = torch.cat((rhs_up, lhs_conv),dim=1) return self.conv_x1(rhs_add)+ rhs_up def conv3d_as_pool(in_channels, out_channels, kernel_size=2, stride=2): return nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=0), nn.LeakyReLU(0.2)) def deconv3d_as_up(in_channels, out_channels, kernel_size=2, stride=2): return nn.Sequential( nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride), nn.PReLU() ) class softmax_out(nn.Module, PyTorchModelHubMixin): def __init__(self, in_channels, out_channels): super(softmax_out, self).__init__() self.conv_1 = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2) self.conv_2 = nn.Conv3d(out_channels, out_channels, kernel_size=1, padding=0) def forward(self, x): """Output with shape [batch_size, 1, depth, height, width].""" # Do NOT add normalize layer, or its values vanish. y_conv = self.conv_2(self.conv_1(x)) return y_conv class VNet(nn.Module, PyTorchModelHubMixin): def __init__(self): super(VNet, self).__init__() self.conv_1 = conv3d_x1(1, 16) self.pool_1 = conv3d_as_pool(16, 32) self.conv_2 = conv3d_x2(32, 32) self.pool_2 = conv3d_as_pool(32, 64) self.conv_3 = conv3d_x3(64, 64) self.pool_3 = conv3d_as_pool(64, 128) self.conv_4 = conv3d_x3(128, 128) self.pool_4 = conv3d_as_pool(128, 256) self.bottom = conv3d_x3(256, 256) self.deconv_4 = deconv3d_x3(256, 256) self.deconv_3 = deconv3d_x3(256, 128) self.deconv_2 = deconv3d_x2(128, 64) self.deconv_1 = deconv3d_x1(64, 32) self.out = softmax_out(32, 1) def forward(self, x): conv_1 = self.conv_1(x) pool = self.pool_1(conv_1) conv_2 = self.conv_2(pool) pool = self.pool_2(conv_2) conv_3 = self.conv_3(pool) pool = self.pool_3(conv_3) conv_4 = self.conv_4(pool) pool = self.pool_4(conv_4) bottom = self.bottom(pool) deconv = self.deconv_4(conv_4, bottom) deconv = self.deconv_3(conv_3, deconv) deconv = self.deconv_2(conv_2, deconv) deconv = self.deconv_1(conv_1, deconv) return self.out(deconv)