|
import torch |
|
import torch.nn as nn |
|
|
|
from model.mit import mit_b4 |
|
|
|
|
|
class GLPDepth(nn.Module): |
|
def __init__(self, max_depth=10.0): |
|
super().__init__() |
|
self.max_depth = max_depth |
|
|
|
self.encoder = mit_b4() |
|
|
|
channels_in = [512, 320, 128] |
|
channels_out = 64 |
|
|
|
self.decoder = Decoder(channels_in, channels_out) |
|
|
|
self.last_layer_depth = nn.Sequential( |
|
nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1), |
|
nn.ReLU(inplace=False), |
|
nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1)) |
|
|
|
def forward(self, x): |
|
conv1, conv2, conv3, conv4 = self.encoder(x) |
|
out = self.decoder(conv1, conv2, conv3, conv4) |
|
out_depth = self.last_layer_depth(out) |
|
out_depth = torch.sigmoid(out_depth) * self.max_depth |
|
|
|
return out_depth |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
|
|
self.bot_conv = nn.Conv2d( |
|
in_channels=in_channels[0], out_channels=out_channels, kernel_size=1) |
|
self.skip_conv1 = nn.Conv2d( |
|
in_channels=in_channels[1], out_channels=out_channels, kernel_size=1) |
|
self.skip_conv2 = nn.Conv2d( |
|
in_channels=in_channels[2], out_channels=out_channels, kernel_size=1) |
|
|
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
|
|
|
self.fusion1 = SelectiveFeatureFusion(out_channels) |
|
self.fusion2 = SelectiveFeatureFusion(out_channels) |
|
self.fusion3 = SelectiveFeatureFusion(out_channels) |
|
|
|
def forward(self, x_1, x_2, x_3, x_4): |
|
x_4_ = self.bot_conv(x_4) |
|
out = self.up(x_4_) |
|
|
|
x_3_ = self.skip_conv1(x_3) |
|
out = self.fusion1(x_3_, out) |
|
out = self.up(out) |
|
|
|
x_2_ = self.skip_conv2(x_2) |
|
out = self.fusion2(x_2_, out) |
|
out = self.up(out) |
|
|
|
out = self.fusion3(x_1, out) |
|
out = self.up(out) |
|
out = self.up(out) |
|
|
|
return out |
|
|
|
|
|
class SelectiveFeatureFusion(nn.Module): |
|
def __init__(self, in_channel=64): |
|
super().__init__() |
|
|
|
self.conv1 = nn.Sequential( |
|
nn.Conv2d(in_channels=int(in_channel * 2), |
|
out_channels=in_channel, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(in_channel), |
|
nn.ReLU()) |
|
|
|
self.conv2 = nn.Sequential( |
|
nn.Conv2d(in_channels=in_channel, |
|
out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(int(in_channel / 2)), |
|
nn.ReLU()) |
|
|
|
self.conv3 = nn.Conv2d(in_channels=int(in_channel / 2), |
|
out_channels=2, kernel_size=3, stride=1, padding=1) |
|
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x_local, x_global): |
|
x = torch.cat((x_local, x_global), dim=1) |
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
x = self.conv3(x) |
|
attn = self.sigmoid(x) |
|
|
|
out = x_local * attn[:, 0, :, :].unsqueeze(1) + \ |
|
x_global * attn[:, 1, :, :].unsqueeze(1) |
|
|
|
return out |
|
|