|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import segmentation_models_pytorch as smp |
|
|
|
|
|
class SegformerBranch(nn.Module): |
|
def __init__(self, in_channels=4, classes=4): |
|
super(SegformerBranch, self).__init__() |
|
self.segformer = smp.Segformer( |
|
encoder_name="mobilenet_v2", |
|
encoder_weights=None, |
|
in_channels=in_channels, |
|
classes=classes, |
|
) |
|
|
|
def forward(self, x): |
|
return self.segformer(x) |
|
|
|
class UNetBranch(nn.Module): |
|
def __init__(self, in_channels=4, classes=4, benchmark=False): |
|
super(UNetBranch, self).__init__() |
|
self.unet = smp.Unet( |
|
encoder_name="mobilenet_v2", |
|
encoder_weights=None, |
|
in_channels=in_channels, |
|
classes=classes, |
|
) |
|
self.benchmark = benchmark |
|
|
|
def forward(self, x): |
|
results = self.unet(x) |
|
if self.benchmark: |
|
results = torch.sigmoid(results) |
|
return results |
|
|
|
class UNetPlusPlusBranch(nn.Module): |
|
def __init__(self, in_channels=4, classes=4, benchmark=False): |
|
super(UNetPlusPlusBranch, self).__init__() |
|
self.unet_pp = smp.UnetPlusPlus( |
|
encoder_name="mobilenet_v2", |
|
encoder_weights=None, |
|
in_channels=in_channels, |
|
classes=classes |
|
) |
|
self.benchmark = benchmark |
|
|
|
def forward(self, x): |
|
results = self.unet_pp(x) |
|
if self.benchmark: |
|
results = torch.sigmoid(results) |
|
return results |
|
|
|
class DeepLabV3Branch(nn.Module): |
|
def __init__(self, in_channels=4, classes=4): |
|
super(DeepLabV3Branch, self).__init__() |
|
self.deeplabv3 = smp.DeepLabV3( |
|
encoder_name="mobilenet_v2", |
|
encoder_weights=None, |
|
in_channels=in_channels, |
|
classes=classes, |
|
) |
|
def forward(self, x): |
|
return self.deeplabv3(x) |
|
|
|
class PixelWiseNet(nn.Module): |
|
def __init__(self, in_channels=4, out_channels=4, base_channels=32): |
|
super(PixelWiseNet, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1, bias=False) |
|
self.bn1 = nn.BatchNorm2d(base_channels) |
|
self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=1, bias=False) |
|
self.bn2 = nn.BatchNorm2d(base_channels) |
|
self.conv3 = nn.Conv2d(base_channels, out_channels, kernel_size=1, bias=False) |
|
|
|
def forward(self, x): |
|
x = F.relu(self.bn1(self.conv1(x))) |
|
x = F.relu(self.bn2(self.conv2(x))) |
|
x = self.conv3(x) |
|
return x |
|
|
|
class CombinedNet(nn.Module): |
|
def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): |
|
super(CombinedNet, self).__init__() |
|
self.seg_branch = SegformerBranch(in_channels=in_channels, classes=classes) |
|
self.pixel_branch = PixelWiseNet(in_channels=in_channels, out_channels=classes, base_channels=base_channels) |
|
self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) |
|
self.benchmark = benchmark |
|
|
|
def forward(self, x): |
|
seg_out = self.seg_branch(x) |
|
pixel_out = self.pixel_branch(x) |
|
fused = seg_out + pixel_out |
|
out = self.fusion_conv(fused) |
|
if self.benchmark: |
|
out = torch.sigmoid(out) |
|
return out |
|
|
|
class CombinedNet3(nn.Module): |
|
def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): |
|
super(CombinedNet3, self).__init__() |
|
self.seg_branch = UNetPlusPlusBranch(in_channels=in_channels, classes=classes) |
|
self.pixel_branch = PixelWiseNet( |
|
in_channels=in_channels, |
|
out_channels=classes, |
|
base_channels=base_channels, |
|
) |
|
self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) |
|
self.benchmark = benchmark |
|
|
|
def forward(self, x): |
|
seg_out = self.seg_branch(x) |
|
pixel_out = self.pixel_branch(x) |
|
fused = seg_out + pixel_out |
|
out = self.fusion_conv(fused) |
|
if self.benchmark: |
|
out = torch.sigmoid(out) |
|
return out |
|
|
|
class CombinedNet4(nn.Module): |
|
def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): |
|
super(CombinedNet4, self).__init__() |
|
self.seg_branch = DeepLabV3Branch(in_channels=in_channels, classes=classes) |
|
self.pixel_branch = PixelWiseNet( |
|
in_channels=in_channels, |
|
out_channels=classes, |
|
base_channels=base_channels, |
|
) |
|
self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) |
|
self.benchmark = benchmark |
|
|
|
def forward(self, x): |
|
seg_out = self.seg_branch(x) |
|
pixel_out = self.pixel_branch(x) |
|
fused = seg_out + pixel_out |
|
out = self.fusion_conv(fused) |
|
if self.benchmark: |
|
out= torch.sigmoid(out) |
|
return out |
|
|