|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
class conv3d(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, in_channels, out_channels): |
|
""" |
|
+ Instantiate modules: conv-relu-norm |
|
+ Assign them as member variables |
|
""" |
|
super(conv3d, self).__init__() |
|
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2) |
|
self.relu = nn.LeakyReLU(0.2) |
|
|
|
self.norm = nn.InstanceNorm3d(out_channels, affine=True) |
|
|
|
def forward(self, x): |
|
return self.relu(self.norm(self.conv(x))) |
|
|
|
|
|
class conv3d_x3(nn.Module, PyTorchModelHubMixin): |
|
"""Three serial convs with a residual connection. |
|
Structure: |
|
inputs --> β --> β‘ --> β’ --> outputs |
|
β --> add--> β |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super(conv3d_x3, self).__init__() |
|
self.conv_1 = conv3d(in_channels, out_channels) |
|
self.conv_2 = conv3d(out_channels, out_channels) |
|
self.conv_3 = conv3d(out_channels, out_channels) |
|
self.skip_connection=nn.Conv3d(in_channels,out_channels,1) |
|
|
|
def forward(self, x): |
|
z_1 = self.conv_1(x) |
|
z_3 = self.conv_3(self.conv_2(z_1)) |
|
return z_3 + self.skip_connection(x) |
|
|
|
class conv3d_x2(nn.Module, PyTorchModelHubMixin): |
|
"""Three serial convs with a residual connection. |
|
Structure: |
|
inputs --> β --> β‘ --> β’ --> outputs |
|
β --> add--> β |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super(conv3d_x2, self).__init__() |
|
self.conv_1 = conv3d(in_channels, out_channels) |
|
self.conv_2 = conv3d(out_channels, out_channels) |
|
self.skip_connection=nn.Conv3d(in_channels,out_channels,1) |
|
|
|
def forward(self, x): |
|
z_1 = self.conv_1(x) |
|
z_2 = self.conv_2(z_1) |
|
return z_2 + self.skip_connection(x) |
|
|
|
|
|
class conv3d_x1(nn.Module, PyTorchModelHubMixin): |
|
"""Three serial convs with a residual connection. |
|
Structure: |
|
inputs --> β --> β‘ --> β’ --> outputs |
|
β --> add--> β |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super(conv3d_x1, self).__init__() |
|
self.conv_1 = conv3d(in_channels, out_channels) |
|
self.skip_connection=nn.Conv3d(in_channels,out_channels,1) |
|
|
|
def forward(self, x): |
|
z_1 = self.conv_1(x) |
|
return z_1 + self.skip_connection(x) |
|
|
|
class deconv3d_x3(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, in_channels, out_channels): |
|
super(deconv3d_x3, self).__init__() |
|
self.up = deconv3d_as_up(in_channels, out_channels, 2, 2) |
|
self.lhs_conv = conv3d(out_channels // 2, out_channels) |
|
self.conv_x3 = nn.Sequential( |
|
nn.Conv3d(2*out_channels, out_channels,5,1,2), |
|
nn.LeakyReLU(0.1), |
|
nn.Conv3d(out_channels, out_channels,5,1,2), |
|
nn.LeakyReLU(0.1), |
|
nn.Conv3d(out_channels, out_channels,5,1,2), |
|
nn.LeakyReLU(0.1), |
|
) |
|
|
|
def forward(self, lhs, rhs): |
|
rhs_up = self.up(rhs) |
|
lhs_conv = self.lhs_conv(lhs) |
|
rhs_add = torch.cat((rhs_up, lhs_conv),dim=1) |
|
return self.conv_x3(rhs_add)+ rhs_up |
|
|
|
class deconv3d_x2(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, in_channels, out_channels): |
|
super(deconv3d_x2, self).__init__() |
|
self.up = deconv3d_as_up(in_channels, out_channels, 2, 2) |
|
self.lhs_conv = conv3d(out_channels // 2, out_channels) |
|
self.conv_x2= nn.Sequential( |
|
nn.Conv3d(2*out_channels, out_channels,5,1,2), |
|
nn.LeakyReLU(0.1), |
|
nn.Conv3d(out_channels, out_channels,5,1,2), |
|
nn.LeakyReLU(0.1), |
|
) |
|
|
|
def forward(self, lhs, rhs): |
|
rhs_up = self.up(rhs) |
|
lhs_conv = self.lhs_conv(lhs) |
|
rhs_add = torch.cat((rhs_up, lhs_conv),dim=1) |
|
return self.conv_x2(rhs_add)+ rhs_up |
|
|
|
class deconv3d_x1(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, in_channels, out_channels): |
|
super(deconv3d_x1, self).__init__() |
|
self.up = deconv3d_as_up(in_channels, out_channels, 2, 2) |
|
self.lhs_conv = conv3d(out_channels // 2, out_channels) |
|
self.conv_x1 = nn.Sequential( |
|
nn.Conv3d(2*out_channels, out_channels,5,1,2), |
|
nn.LeakyReLU(0.2), |
|
) |
|
|
|
def forward(self, lhs, rhs): |
|
rhs_up = self.up(rhs) |
|
lhs_conv = self.lhs_conv(lhs) |
|
rhs_add = torch.cat((rhs_up, lhs_conv),dim=1) |
|
return self.conv_x1(rhs_add)+ rhs_up |
|
|
|
|
|
def conv3d_as_pool(in_channels, out_channels, kernel_size=2, stride=2): |
|
return nn.Sequential( |
|
nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding=0), |
|
nn.LeakyReLU(0.2)) |
|
|
|
|
|
def deconv3d_as_up(in_channels, out_channels, kernel_size=2, stride=2): |
|
return nn.Sequential( |
|
nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride), |
|
nn.PReLU() |
|
) |
|
|
|
|
|
class softmax_out(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, in_channels, out_channels): |
|
super(softmax_out, self).__init__() |
|
self.conv_1 = nn.Conv3d(in_channels, out_channels, kernel_size=5, padding=2) |
|
self.conv_2 = nn.Conv3d(out_channels, out_channels, kernel_size=1, padding=0) |
|
|
|
def forward(self, x): |
|
"""Output with shape [batch_size, 1, depth, height, width].""" |
|
|
|
y_conv = self.conv_2(self.conv_1(x)) |
|
return y_conv |
|
|
|
|
|
class VNet(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self): |
|
super(VNet, self).__init__() |
|
self.conv_1 = conv3d_x1(1, 16) |
|
self.pool_1 = conv3d_as_pool(16, 32) |
|
self.conv_2 = conv3d_x2(32, 32) |
|
self.pool_2 = conv3d_as_pool(32, 64) |
|
self.conv_3 = conv3d_x3(64, 64) |
|
self.pool_3 = conv3d_as_pool(64, 128) |
|
self.conv_4 = conv3d_x3(128, 128) |
|
self.pool_4 = conv3d_as_pool(128, 256) |
|
|
|
self.bottom = conv3d_x3(256, 256) |
|
|
|
self.deconv_4 = deconv3d_x3(256, 256) |
|
self.deconv_3 = deconv3d_x3(256, 128) |
|
self.deconv_2 = deconv3d_x2(128, 64) |
|
self.deconv_1 = deconv3d_x1(64, 32) |
|
|
|
self.out = softmax_out(32, 1) |
|
|
|
def forward(self, x): |
|
conv_1 = self.conv_1(x) |
|
pool = self.pool_1(conv_1) |
|
conv_2 = self.conv_2(pool) |
|
pool = self.pool_2(conv_2) |
|
conv_3 = self.conv_3(pool) |
|
pool = self.pool_3(conv_3) |
|
conv_4 = self.conv_4(pool) |
|
pool = self.pool_4(conv_4) |
|
bottom = self.bottom(pool) |
|
deconv = self.deconv_4(conv_4, bottom) |
|
deconv = self.deconv_3(conv_3, deconv) |
|
deconv = self.deconv_2(conv_2, deconv) |
|
deconv = self.deconv_1(conv_1, deconv) |
|
return self.out(deconv) |
|
|