csaybar's picture
Upload 9 files
039daa1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
class SegformerBranch(nn.Module):
def __init__(self, in_channels=4, classes=4):
super(SegformerBranch, self).__init__()
self.segformer = smp.Segformer(
encoder_name="mobilenet_v2",
encoder_weights=None,
in_channels=in_channels,
classes=classes,
)
def forward(self, x):
return self.segformer(x)
class UNetBranch(nn.Module):
def __init__(self, in_channels=4, classes=4, benchmark=False):
super(UNetBranch, self).__init__()
self.unet = smp.Unet(
encoder_name="mobilenet_v2",
encoder_weights=None,
in_channels=in_channels,
classes=classes,
)
self.benchmark = benchmark
def forward(self, x):
results = self.unet(x)
if self.benchmark:
results = torch.sigmoid(results)
return results
class UNetPlusPlusBranch(nn.Module):
def __init__(self, in_channels=4, classes=4, benchmark=False):
super(UNetPlusPlusBranch, self).__init__()
self.unet_pp = smp.UnetPlusPlus(
encoder_name="mobilenet_v2",
encoder_weights=None,
in_channels=in_channels,
classes=classes
)
self.benchmark = benchmark
def forward(self, x):
results = self.unet_pp(x)
if self.benchmark:
results = torch.sigmoid(results)
return results
class DeepLabV3Branch(nn.Module):
def __init__(self, in_channels=4, classes=4):
super(DeepLabV3Branch, self).__init__()
self.deeplabv3 = smp.DeepLabV3(
encoder_name="mobilenet_v2",
encoder_weights=None,
in_channels=in_channels,
classes=classes,
)
def forward(self, x):
return self.deeplabv3(x)
class PixelWiseNet(nn.Module):
def __init__(self, in_channels=4, out_channels=4, base_channels=32):
super(PixelWiseNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(base_channels)
self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(base_channels)
self.conv3 = nn.Conv2d(base_channels, out_channels, kernel_size=1, bias=False)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.conv3(x)
return x
class CombinedNet(nn.Module):
def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False):
super(CombinedNet, self).__init__()
self.seg_branch = SegformerBranch(in_channels=in_channels, classes=classes)
self.pixel_branch = PixelWiseNet(in_channels=in_channels, out_channels=classes, base_channels=base_channels)
self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False)
self.benchmark = benchmark
def forward(self, x):
seg_out = self.seg_branch(x)
pixel_out = self.pixel_branch(x)
fused = seg_out + pixel_out
out = self.fusion_conv(fused)
if self.benchmark:
out = torch.sigmoid(out)
return out
class CombinedNet3(nn.Module):
def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False):
super(CombinedNet3, self).__init__()
self.seg_branch = UNetPlusPlusBranch(in_channels=in_channels, classes=classes)
self.pixel_branch = PixelWiseNet(
in_channels=in_channels,
out_channels=classes,
base_channels=base_channels,
)
self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False)
self.benchmark = benchmark
def forward(self, x):
seg_out = self.seg_branch(x)
pixel_out = self.pixel_branch(x)
fused = seg_out + pixel_out
out = self.fusion_conv(fused)
if self.benchmark:
out = torch.sigmoid(out)
return out
class CombinedNet4(nn.Module):
def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False):
super(CombinedNet4, self).__init__()
self.seg_branch = DeepLabV3Branch(in_channels=in_channels, classes=classes)
self.pixel_branch = PixelWiseNet(
in_channels=in_channels,
out_channels=classes,
base_channels=base_channels,
)
self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False)
self.benchmark = benchmark
def forward(self, x):
seg_out = self.seg_branch(x)
pixel_out = self.pixel_branch(x)
fused = seg_out + pixel_out
out = self.fusion_conv(fused)
if self.benchmark:
out= torch.sigmoid(out)
return out