mattiagatti's picture
Initial commit
0d7318b
import torch
import torch.nn as nn
from model.mit import mit_b4
class GLPDepth(nn.Module):
def __init__(self, max_depth=10.0):
super().__init__()
self.max_depth = max_depth
self.encoder = mit_b4()
channels_in = [512, 320, 128]
channels_out = 64
self.decoder = Decoder(channels_in, channels_out)
self.last_layer_depth = nn.Sequential(
nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1))
def forward(self, x):
conv1, conv2, conv3, conv4 = self.encoder(x)
out = self.decoder(conv1, conv2, conv3, conv4)
out_depth = self.last_layer_depth(out)
out_depth = torch.sigmoid(out_depth) * self.max_depth
return out_depth
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.bot_conv = nn.Conv2d(
in_channels=in_channels[0], out_channels=out_channels, kernel_size=1)
self.skip_conv1 = nn.Conv2d(
in_channels=in_channels[1], out_channels=out_channels, kernel_size=1)
self.skip_conv2 = nn.Conv2d(
in_channels=in_channels[2], out_channels=out_channels, kernel_size=1)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.fusion1 = SelectiveFeatureFusion(out_channels)
self.fusion2 = SelectiveFeatureFusion(out_channels)
self.fusion3 = SelectiveFeatureFusion(out_channels)
def forward(self, x_1, x_2, x_3, x_4):
x_4_ = self.bot_conv(x_4)
out = self.up(x_4_)
x_3_ = self.skip_conv1(x_3)
out = self.fusion1(x_3_, out)
out = self.up(out)
x_2_ = self.skip_conv2(x_2)
out = self.fusion2(x_2_, out)
out = self.up(out)
out = self.fusion3(x_1, out)
out = self.up(out)
out = self.up(out)
return out
class SelectiveFeatureFusion(nn.Module):
def __init__(self, in_channel=64):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=int(in_channel * 2),
out_channels=in_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_channel),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=in_channel,
out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(int(in_channel / 2)),
nn.ReLU())
self.conv3 = nn.Conv2d(in_channels=int(in_channel / 2),
out_channels=2, kernel_size=3, stride=1, padding=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x_local, x_global):
x = torch.cat((x_local, x_global), dim=1)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
attn = self.sigmoid(x)
out = x_local * attn[:, 0, :, :].unsqueeze(1) + \
x_global * attn[:, 1, :, :].unsqueeze(1)
return out