leaky_vnet / model /leaky_vnet.py
doggywastaken's picture
init commit
d26715f
import torch.nn as nn
import torch.nn.functional as F
import torch
from huggingface_hub import PyTorchModelHubMixin
class conv3d(nn.Module, PyTorchModelHubMixin):
def __init__(self, in_channels, out_channels):
"""
+ Instantiate modules: conv-relu-norm
+ Assign them as member variables
"""
super(conv3d, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2)
self.relu = nn.LeakyReLU(0.2)
# with learnable parameters
self.norm = nn.InstanceNorm3d(out_channels, affine=True)
def forward(self, x):
return self.relu(self.norm(self.conv(x)))
class conv3d_x3(nn.Module, PyTorchModelHubMixin):
"""Three serial convs with a residual connection.
Structure:
inputs --> β‘  --> β‘‘ --> β‘’ --> outputs
↓ --> add--> ↑
"""
def __init__(self, in_channels, out_channels):
super(conv3d_x3, self).__init__()
self.conv_1 = conv3d(in_channels, out_channels)
self.conv_2 = conv3d(out_channels, out_channels)
self.conv_3 = conv3d(out_channels, out_channels)
self.skip_connection=nn.Conv3d(in_channels,out_channels,1)
def forward(self, x):
z_1 = self.conv_1(x)
z_3 = self.conv_3(self.conv_2(z_1))
return z_3 + self.skip_connection(x)
class conv3d_x2(nn.Module, PyTorchModelHubMixin):
"""Three serial convs with a residual connection.
Structure:
inputs --> β‘  --> β‘‘ --> β‘’ --> outputs
↓ --> add--> ↑
"""
def __init__(self, in_channels, out_channels):
super(conv3d_x2, self).__init__()
self.conv_1 = conv3d(in_channels, out_channels)
self.conv_2 = conv3d(out_channels, out_channels)
self.skip_connection=nn.Conv3d(in_channels,out_channels,1)
def forward(self, x):
z_1 = self.conv_1(x)
z_2 = self.conv_2(z_1)
return z_2 + self.skip_connection(x)
class conv3d_x1(nn.Module, PyTorchModelHubMixin):
"""Three serial convs with a residual connection.
Structure:
inputs --> β‘  --> β‘‘ --> β‘’ --> outputs
↓ --> add--> ↑
"""
def __init__(self, in_channels, out_channels):
super(conv3d_x1, self).__init__()
self.conv_1 = conv3d(in_channels, out_channels)
self.skip_connection=nn.Conv3d(in_channels,out_channels,1)
def forward(self, x):
z_1 = self.conv_1(x)
return z_1 + self.skip_connection(x)
class deconv3d_x3(nn.Module, PyTorchModelHubMixin):
def __init__(self, in_channels, out_channels):
super(deconv3d_x3, self).__init__()
self.up = deconv3d_as_up(in_channels, out_channels, 2, 2)
self.lhs_conv = conv3d(out_channels // 2, out_channels)
self.conv_x3 = nn.Sequential(
nn.Conv3d(2*out_channels, out_channels,5,1,2),
nn.LeakyReLU(0.1),
nn.Conv3d(out_channels, out_channels,5,1,2),
nn.LeakyReLU(0.1),
nn.Conv3d(out_channels, out_channels,5,1,2),
nn.LeakyReLU(0.1),
)
def forward(self, lhs, rhs):
rhs_up = self.up(rhs)
lhs_conv = self.lhs_conv(lhs)
rhs_add = torch.cat((rhs_up, lhs_conv),dim=1)
return self.conv_x3(rhs_add)+ rhs_up
class deconv3d_x2(nn.Module, PyTorchModelHubMixin):
def __init__(self, in_channels, out_channels):
super(deconv3d_x2, self).__init__()
self.up = deconv3d_as_up(in_channels, out_channels, 2, 2)
self.lhs_conv = conv3d(out_channels // 2, out_channels)
self.conv_x2= nn.Sequential(
nn.Conv3d(2*out_channels, out_channels,5,1,2),
nn.LeakyReLU(0.1),
nn.Conv3d(out_channels, out_channels,5,1,2),
nn.LeakyReLU(0.1),
)
def forward(self, lhs, rhs):
rhs_up = self.up(rhs)
lhs_conv = self.lhs_conv(lhs)
rhs_add = torch.cat((rhs_up, lhs_conv),dim=1)
return self.conv_x2(rhs_add)+ rhs_up
class deconv3d_x1(nn.Module, PyTorchModelHubMixin):
def __init__(self, in_channels, out_channels):
super(deconv3d_x1, self).__init__()
self.up = deconv3d_as_up(in_channels, out_channels, 2, 2)
self.lhs_conv = conv3d(out_channels // 2, out_channels)
self.conv_x1 = nn.Sequential(
nn.Conv3d(2*out_channels, out_channels,5,1,2),
nn.LeakyReLU(0.2),
)
def forward(self, lhs, rhs):
rhs_up = self.up(rhs)
lhs_conv = self.lhs_conv(lhs)
rhs_add = torch.cat((rhs_up, lhs_conv),dim=1)
return self.conv_x1(rhs_add)+ rhs_up
def conv3d_as_pool(in_channels, out_channels, kernel_size=2, stride=2):
return nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=0),
nn.LeakyReLU(0.2))
def deconv3d_as_up(in_channels, out_channels, kernel_size=2, stride=2):
return nn.Sequential(
nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride),
nn.PReLU()
)
class softmax_out(nn.Module, PyTorchModelHubMixin):
def __init__(self, in_channels, out_channels):
super(softmax_out, self).__init__()
self.conv_1 = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2)
self.conv_2 = nn.Conv3d(out_channels, out_channels, kernel_size=1, padding=0)
def forward(self, x):
"""Output with shape [batch_size, 1, depth, height, width]."""
# Do NOT add normalize layer, or its values vanish.
y_conv = self.conv_2(self.conv_1(x))
return y_conv
class VNet(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super(VNet, self).__init__()
self.conv_1 = conv3d_x1(1, 16)
self.pool_1 = conv3d_as_pool(16, 32)
self.conv_2 = conv3d_x2(32, 32)
self.pool_2 = conv3d_as_pool(32, 64)
self.conv_3 = conv3d_x3(64, 64)
self.pool_3 = conv3d_as_pool(64, 128)
self.conv_4 = conv3d_x3(128, 128)
self.pool_4 = conv3d_as_pool(128, 256)
self.bottom = conv3d_x3(256, 256)
self.deconv_4 = deconv3d_x3(256, 256)
self.deconv_3 = deconv3d_x3(256, 128)
self.deconv_2 = deconv3d_x2(128, 64)
self.deconv_1 = deconv3d_x1(64, 32)
self.out = softmax_out(32, 1)
def forward(self, x):
conv_1 = self.conv_1(x)
pool = self.pool_1(conv_1)
conv_2 = self.conv_2(pool)
pool = self.pool_2(conv_2)
conv_3 = self.conv_3(pool)
pool = self.pool_3(conv_3)
conv_4 = self.conv_4(pool)
pool = self.pool_4(conv_4)
bottom = self.bottom(pool)
deconv = self.deconv_4(conv_4, bottom)
deconv = self.deconv_3(conv_3, deconv)
deconv = self.deconv_2(conv_2, deconv)
deconv = self.deconv_1(conv_1, deconv)
return self.out(deconv)